In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from einops import rearrange
from einops import repeat
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from models import MHSA, MSA_Stack, Outer_Product_Mean, Pair_Stack, Triangular_Multiplicative_Model, IPA_Module, Structure_Module



In [3]:
S = 16
B = 64
R = 64
C_m = 128
C_z = 64
H = 12
C = 16
N_qp = 4
N_pv = 8

In [4]:
msa_rep = torch.rand(B, S, R, C_m).cuda()
prw_rep = torch.rand(B, R, R, C_z).cuda()
bbr = torch.rand(B, R, 3, 3).cuda()
bbt = torch.rand(B, R, 3).cuda()

Test MHSA (only works without batch)

In [5]:
mhsa = MHSA(c_m=C_m, c_z=C_z, heads=8, dim_head=C, bias=True).cuda()
print(f'expected: {msa_rep[0].shape}')
print(f'actual:   {mhsa(msa_rep[0], prw_rep[0]).shape}')

expected: torch.Size([16, 64, 128])
actual:   torch.Size([16, 64, 128])


Test MSA Stack

In [6]:
%%timeit
msa_stack = MSA_Stack(c_m=C_m, c_z=C_z, heads=8, dim_head=C).cuda()
print(f'expected: {msa_rep.shape}')
print(f'actual:   {msa_stack(msa_rep, prw_rep).shape}')

expected: torch.Size([64, 16, 64, 128])
actual:   torch.Size([64, 16, 64, 128])
expected: torch.Size([64, 16, 64, 128])
actual:   torch.Size([64, 16, 64, 128])
expected: torch.Size([64, 16, 64, 128])
actual:   torch.Size([64, 16, 64, 128])
expected: torch.Size([64, 16, 64, 128])
actual:   torch.Size([64, 16, 64, 128])
expected: torch.Size([64, 16, 64, 128])
actual:   torch.Size([64, 16, 64, 128])
expected: torch.Size([64, 16, 64, 128])
actual:   torch.Size([64, 16, 64, 128])
expected: torch.Size([64, 16, 64, 128])
actual:   torch.Size([64, 16, 64, 128])
expected: torch.Size([64, 16, 64, 128])
actual:   torch.Size([64, 16, 64, 128])
106 ms ± 229 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


Test Outer Product Mean

In [7]:
%%timeit
opm = Outer_Product_Mean(c_m=C_m, c_z=C_z, c=C).cuda()
print(f'expected: {prw_rep.shape}')
print(f'actual:   {opm(msa_rep).shape}')

expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
636 ms ± 2.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Test Pair Stack

In [8]:
%%timeit
pair_stack = Pair_Stack(c_z=C_z, heads=8, dim_head=C).cuda()
print(f'expected: {prw_rep.shape}')
print(f'actual:   {pair_stack(prw_rep).shape}')

expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
119 ms ± 401 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


Test Triangular Multiplicative Model

In [9]:
%%timeit
tmm = Triangular_Multiplicative_Model('incoming', c_z=C_z, c=C).cuda()
print(f'expected: {prw_rep.shape}')
print(f'actual:   {tmm(prw_rep).shape}')

expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
475 ms ± 9.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
%%timeit
tmm = Triangular_Multiplicative_Model(None, c_z=C_z, c=C).cuda()
print(f'expected: {prw_rep.shape}')
print(f'actual:   {tmm(prw_rep).shape}')

expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
expected: torch.Size([64, 64, 64, 64])
actual:   torch.Size([64, 64, 64, 64])
464 ms ± 2.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Test IPA Module

In [11]:
%%timeit
ipa = IPA_Module(c_m=C_m, c_z=C_z, heads=12, dim_head=None, n_qp=4, n_pv=8).cuda()
print(f'expected: {msa_rep[:, 0].shape}')
print(f'actual:   {ipa(prw_rep, msa_rep[:, 0], bbr, bbt).shape}')

expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64

actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64

expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64

actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64

actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64

actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64

actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64, 128])
actual:   torch.Size([64, 64, 128])
expected: torch.Size([64, 64

Test Structure Model

In [12]:
a_labels = torch.rand(B, R, 4).cuda()
T_labels = (torch.rand(B, R, 3, 3).cuda(), torch.rand(B, R, 3).cuda())
x_labels = torch.rand(B, R, 3).cuda()

In [14]:
%%timeit
structure_module = Structure_Module(R, C_m, C_z, c=C).cuda()
x, L_fape, L_aux = structure_module(prw_rep, msa_rep[:, 0], a_labels, T_labels, x_labels)

59.9 ms ± 2.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
