In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from einops import rearrange
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from models import MHSA, MSA_Stack, Outer_Product_Mean, Pair_Stack, Triangular_Multiplicative_Model, IPA_Module



In [3]:
S = 16
B = 5
R = 64
C_m = 128
C_z = 64
H = 12
C = 16
N_qp = 4
N_pv = 8

In [15]:
msa_rep = torch.rand(B, S, R, C_m).cuda()
prw_rep = torch.rand(B, R, R, C_z).cuda()
bbr = torch.rand(B, R, 3, 3).cuda()
bbt = torch.rand(B, R, 3).cuda()

Test MHSA (only works without batch)

In [16]:
mhsa = MHSA(c_m=C_m, c_z=C_z, heads=8, dim_head=C, bias=True).cuda()
print(f'expected: {msa_rep[0].shape}')
print(f'actual:   {mhsa(msa_rep[0], prw_rep[0]).shape}')

expected: torch.Size([16, 64, 128])
actual:   torch.Size([16, 64, 128])


Test MSA Stack

In [17]:
msa_stack = MSA_Stack(c_m=C_m, c_z=C_z, heads=8, dim_head=C).cuda()
print(f'expected: {msa_rep.shape}')
print(f'actual:   {msa_stack(msa_rep, prw_rep).shape}')

expected: torch.Size([5, 16, 64, 128])
actual:   torch.Size([5, 16, 64, 128])


Test Outer Product Mean

In [18]:
opm = Outer_Product_Mean(c_m=C_m, c_z=C_z, c=C).cuda()
print(f'expected: {prw_rep.shape}')
print(f'actual:   {opm(msa_rep).shape}')

expected: torch.Size([5, 64, 64, 64])
actual:   torch.Size([5, 64, 64, 64])


Test Pair Stack

In [19]:
pair_stack = Pair_Stack(c_z=C_z, heads=8, dim_head=C).cuda()
print(f'expected: {prw_rep.shape}')
print(f'actual:   {pair_stack(prw_rep).shape}')

expected: torch.Size([5, 64, 64, 64])
actual:   torch.Size([5, 64, 64, 64])


Test Triangular Multiplicative Model

In [20]:
tmm = Triangular_Multiplicative_Model('incoming', c_z=C_z, c=C).cuda()
print(f'expected: {prw_rep.shape}')
print(f'actual:   {tmm(prw_rep).shape}')

expected: torch.Size([5, 64, 64, 64])
actual:   torch.Size([5, 64, 64, 64])


In [21]:
tmm = Triangular_Multiplicative_Model(None, c_z=C_z, c=C).cuda()
print(f'expected: {prw_rep.shape}')
print(f'actual:   {tmm(prw_rep).shape}')

expected: torch.Size([5, 64, 64, 64])
actual:   torch.Size([5, 64, 64, 64])


Test IPA Module

In [23]:
ipa = IPA_Module(c_m=C_m, c_z=C_z, heads=12, dim_head=None, n_qp=4, n_pv=8).cuda()
print(f'expected: {msa_rep[:, 0].shape}')
print(f'actual:   {ipa(prw_rep, msa_rep[:, 0], bbr, bbt).shape}')

expected: torch.Size([5, 64, 128])
actual:   torch.Size([5, 64, 128])
