In [6]:
import pickle

In [36]:
with open("../weights/epoch_92400_sample_23500000.pkl", "rb") as f:
    n_annotations, pretrained_model_weights, pretrained_optimizer_weights = pickle.load(f)

In [190]:
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, repeat

# helpers

def exists(val):
    return val is not None

def max_neg_value(t):
    return -torch.finfo(t.dtype).max

# helper classes

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

class GlobalLinearSelfAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head,
        heads
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, feats, mask = None):
        h = self.heads
        q, k, v = self.to_qkv(feats).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        if exists(mask):
            mask = rearrange(mask, 'b n -> b () n ()')
            k = k.masked_fill(~mask, -torch.finfo(k.dtype).max)

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        if exists(mask):
            v = v.masked_fill(~mask, 0.)

        context = einsum('b h n d, b h n e -> b h d e', k, v)
        out = einsum('b h d e, b h n d -> b h n e', context, q)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class CrossAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_keys,
        dim_out,
        heads,
        dim_head = 64,
        qk_activation = nn.Tanh()
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.qk_activation = qk_activation

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_keys, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim_out)

        self.null_key = nn.Parameter(torch.randn(dim_head))
        self.null_value = nn.Parameter(torch.randn(dim_head))

    def forward(self, x, context, mask = None, context_mask = None):
        b, h, device = x.shape[0], self.heads, x.device

        q = self.to_q(x)
        k, v = self.to_kv(context).chunk(2, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        null_k, null_v = map(lambda t: repeat(t, 'd -> b h () d', b = b, h = h), (self.null_key, self.null_value))
        k = torch.cat((null_k, k), dim = -2)
        v = torch.cat((null_v, v), dim = -2)

        q, k = map(lambda t: self.qk_activation(t), (q, k))

        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        if exists(mask) or exists(context_mask):
            i, j = sim.shape[-2:]

            if not exists(mask):
                mask = torch.ones(b, i, dtype = torch.bool, device = device)

            if exists(context_mask):
                context_mask = F.pad(context_mask, (1, 0), value = True)
            else:
                context_mask = torch.ones(b, j, dtype = torch.bool, device = device)

            mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j')
            sim.masked_fill_(~mask, max_neg_value(sim))

        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Layer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_global,
        narrow_conv_kernel = 9,
        wide_conv_kernel = 9,
        wide_conv_dilation = 5,
        attn_heads = 8,
        attn_dim_head = 64,
        attn_qk_activation = nn.Tanh(),
        local_to_global_attn = False,
        local_self_attn = False,
        glu_conv = False
    ):
        super().__init__()

        self.seq_self_attn = GlobalLinearSelfAttention(dim = dim, dim_head = attn_dim_head, heads = attn_heads) if local_self_attn else None

        conv_mult = 2 if glu_conv else 1

        self.narrow_conv = nn.Sequential(
            nn.Conv1d(dim, dim * conv_mult, narrow_conv_kernel, padding = narrow_conv_kernel // 2),
            nn.GELU() if not glu_conv else nn.GLU(dim = 1)
        )

        wide_conv_padding = (wide_conv_kernel + (wide_conv_kernel - 1) * (wide_conv_dilation - 1)) // 2

        self.wide_conv = nn.Sequential(
            nn.Conv1d(dim, dim * conv_mult, wide_conv_kernel, dilation = wide_conv_dilation, padding = wide_conv_padding),
            nn.GELU() if not glu_conv else nn.GLU(dim = 1)
        )

        self.local_to_global_attn = local_to_global_attn

        if local_to_global_attn:
            self.extract_global_info = CrossAttention(
                dim = dim,
                dim_keys = dim_global,
                dim_out = dim,
                heads = attn_heads,
                dim_head = attn_dim_head
            )
        else:
            self.extract_global_info = nn.Sequential(
                Reduce('b n d -> b d', 'mean'),
                nn.Linear(dim_global, dim),
                nn.GELU(),
                Rearrange('b d -> b () d')
            )

        self.local_norm = nn.LayerNorm(dim)

        self.local_feedforward = nn.Sequential(
            Residual(nn.Sequential(
                nn.Linear(dim, dim),
                nn.GELU(),
            )),
            nn.LayerNorm(dim)
        )

        self.global_attend_local = CrossAttention(dim = dim_global, dim_out = dim_global, dim_keys = dim, heads = attn_heads, dim_head = attn_dim_head, qk_activation = attn_qk_activation)

        self.global_dense = nn.Sequential(
            nn.Linear(dim_global, dim_global),
            nn.GELU()
        )

        self.global_norm = nn.LayerNorm(dim_global)

        self.global_feedforward = nn.Sequential(
            Residual(nn.Sequential(
                nn.Linear(dim_global, dim_global),
                nn.GELU()
            )),
            nn.LayerNorm(dim_global),
        )

    def forward(self, tokens, annotation, mask = None):
        if self.local_to_global_attn:
            global_info = self.extract_global_info(tokens, annotation, mask = mask)
        else:
            global_info = self.extract_global_info(annotation)

        # process local (protein sequence)

        global_linear_attn = self.seq_self_attn(tokens) if exists(self.seq_self_attn) else 0

        conv_input = rearrange(tokens, 'b n d -> b d n')

        if exists(mask):
            conv_input_mask = rearrange(mask, 'b n -> b () n')
            conv_input = conv_input.masked_fill(~conv_input_mask, 0.)

        narrow_out = self.narrow_conv(conv_input)
        narrow_out = rearrange(narrow_out, 'b d n -> b n d')
        wide_out = self.wide_conv(conv_input)
        wide_out = rearrange(wide_out, 'b d n -> b n d')

        tokens = tokens + narrow_out + wide_out + global_info + global_linear_attn
        tokens = self.local_norm(tokens)

        tokens = self.local_feedforward(tokens)

        # process global (annotations)

        annotation = self.global_attend_local(annotation, tokens, context_mask = mask)
        annotation = self.global_dense(annotation)
        annotation = self.global_norm(annotation)
        annotation = self.global_feedforward(annotation)

        return tokens, annotation

# main model

class ProteinBERT(nn.Module):
    def __init__(
        self,
        *,
        num_tokens = 26,
        num_annotation = 8943,
        dim = 512,
        dim_global = 256,
        depth = 6,
        narrow_conv_kernel = 9,
        wide_conv_kernel = 9,
        wide_conv_dilation = 5,
        attn_heads = 8,
        attn_dim_head = 64,
        attn_qk_activation = nn.Tanh(),
        local_to_global_attn = False,
        local_self_attn = False,
        num_global_tokens = 1,
        glu_conv = False
    ):
        super().__init__()
        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, dim)

        self.num_global_tokens = num_global_tokens
        self.to_global_emb = nn.Linear(num_annotation, num_global_tokens * dim_global)

        self.layers = nn.ModuleList([
            Layer(
                dim = dim,
                dim_global = dim_global,
                narrow_conv_kernel = narrow_conv_kernel,
                wide_conv_dilation = wide_conv_dilation,
                wide_conv_kernel = wide_conv_kernel,
                attn_qk_activation = attn_qk_activation,
                local_to_global_attn = local_to_global_attn,
                local_self_attn = local_self_attn,
                glu_conv = glu_conv
            )
            for layer in range(depth)
        ])
        # print(self.layers)

        self.to_token_logits = nn.Linear(dim, num_tokens)

        self.to_annotation_logits = nn.Sequential(
            Reduce('b n d -> b d', 'mean'),
            nn.Linear(dim_global, num_annotation)
        )

    def forward(self, seq, annotation, mask = None):
        tokens = self.token_emb(seq)

        annotation = self.to_global_emb(annotation)
        annotation = rearrange(annotation, 'b (n d) -> b n d', n = self.num_global_tokens)

        for layer in self.layers:
            tokens, annotation = layer(tokens, annotation, mask = mask)

        tokens = self.to_token_logits(tokens)
        annotation = self.to_annotation_logits(annotation)
        return tokens, annotation

print("Done")

Done


In [191]:
import torch

model = ProteinBERT(
    num_tokens = 21,
    num_annotation = 8943,
    dim = 512,
    dim_global = 256,
    depth = 6,
    narrow_conv_kernel = 9,
    wide_conv_kernel = 9,
    wide_conv_dilation = 5,
    attn_heads = 8,
    attn_dim_head = 64
)

seq = torch.randint(0, 21, (2, 2048))
mask = torch.ones(2, 2048).bool()
annotation = torch.randint(0, 1, (2, 8943)).float()

seq_logits, annotation_logits = model(seq, annotation, mask = mask) # (2, 2048, 21), (2, 8943)
print(seq_logits, annotation_logits)

tensor([[[ 0.5929, -0.9098,  0.9188,  ..., -0.1676, -0.7180, -0.4313],
         [ 1.1562,  0.2398, -0.0485,  ..., -0.6763, -0.7170, -0.8447],
         [ 0.2704,  0.0859, -0.0987,  ..., -0.8367, -0.3104, -0.2539],
         ...,
         [ 0.5719, -0.4015, -0.4394,  ...,  0.8981, -0.8608, -0.4114],
         [ 0.1337, -0.0078, -0.0260,  ..., -0.2980,  0.3003, -0.5458],
         [ 0.8318, -0.6286,  0.4875,  ..., -0.1456, -0.1285, -0.3256]],

        [[ 1.0427,  0.2485, -0.6645,  ..., -0.4918,  0.3169,  0.1762],
         [ 0.5184,  0.0329,  0.8701,  ..., -0.2437, -0.1188, -0.2436],
         [ 0.2250, -0.6192,  1.0777,  ..., -0.6017,  0.6816, -0.3975],
         ...,
         [ 0.5297,  0.4673, -0.2527,  ..., -0.3983, -0.1364, -0.4662],
         [ 0.8680,  0.3956, -0.3994,  ..., -0.5084, -0.1851, -0.5838],
         [-0.3677,  0.2939, -0.2362,  ..., -0.5652, -0.1340, -1.0860]]],
       grad_fn=<ViewBackward0>) tensor([[ 0.4354, -0.5147, -0.9464,  ...,  0.1424, -0.7635,  0.4639],
        [ 0.45

In [209]:
pretrained_model = ProteinBERT(
    num_tokens = 26,
    num_annotation = 8943,
    dim = 128,
    dim_global = 512,
    depth = 6,
    narrow_conv_kernel = 9,
    wide_conv_kernel = 9,
    wide_conv_dilation = 5,
    attn_heads = 4,
    attn_dim_head = 64
)
print(pretrained_model)

ProteinBERT(
  (token_emb): Embedding(26, 128)
  (to_global_emb): Linear(in_features=8943, out_features=512, bias=True)
  (layers): ModuleList(
    (0-5): 6 x Layer(
      (narrow_conv): Sequential(
        (0): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(4,))
        (1): GELU(approximate='none')
      )
      (wide_conv): Sequential(
        (0): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(20,), dilation=(5,))
        (1): GELU(approximate='none')
      )
      (extract_global_info): Sequential(
        (0): Reduce('b n d -> b d', 'mean')
        (1): Linear(in_features=512, out_features=128, bias=True)
        (2): GELU(approximate='none')
        (3): Rearrange('b d -> b () d')
      )
      (local_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (local_feedforward): Sequential(
        (0): Residual(
          (fn): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): GELU(approximate='none')

In [210]:
count = 0

for name, param in pretrained_model.to_global_emb.named_parameters():
    print(f"{str(tuple(param.shape)):<20} {name}")
    count += 1
    
for name, param in pretrained_model.token_emb.named_parameters():
    print(f"{str(tuple(param.shape)):<20} {name}")
    count += 1

print("---")
for layer in pretrained_model.layers:
    for idx, (name, param) in enumerate(layer.named_parameters()):
        print(f"{idx:<4} {str(tuple(param.shape)):<20} {name}")
        count += 1
    print("---")

for name, param in pretrained_model.to_token_logits.named_parameters():
    print(f"{str(tuple(param.shape)):<20} {name}")
    count += 1

for name, param in pretrained_model.to_annotation_logits.named_parameters():
    print(f"{str(tuple(param.shape)):<20} {name}")
    count += 1

print(count)

(512, 8943)          weight
(512,)               bias
(26, 128)            weight
---
0    (128, 128, 9)        narrow_conv.0.weight
1    (128,)               narrow_conv.0.bias
2    (128, 128, 9)        wide_conv.0.weight
3    (128,)               wide_conv.0.bias
4    (128, 512)           extract_global_info.1.weight
5    (128,)               extract_global_info.1.bias
6    (128,)               local_norm.weight
7    (128,)               local_norm.bias
8    (128, 128)           local_feedforward.0.fn.0.weight
9    (128,)               local_feedforward.0.fn.0.bias
10   (128,)               local_feedforward.1.weight
11   (128,)               local_feedforward.1.bias
12   (64,)                global_attend_local.null_key
13   (64,)                global_attend_local.null_value
14   (512, 512)           global_attend_local.to_q.weight
15   (1024, 128)          global_attend_local.to_kv.weight
16   (512, 512)           global_attend_local.to_out.weight
17   (512,)               global_

In [211]:
for i, w in enumerate(pretrained_model_weights):

    print(i, w.shape) # 23 weight objects per transformer layer in keras model?

    if i == 2:
        print("---")
    if i > 3 and (i - 2) % 23 == 0:
        print("---")

0 (8943, 512)
1 (512,)
2 (26, 128)
---
3 (512, 128)
4 (128,)
5 (9, 128, 128)
6 (128,)
7 (9, 128, 128)
8 (128,)
9 (128,)
10 (128,)
11 (128, 128)
12 (128,)
13 (128,)
14 (128,)
15 (512, 512)
16 (512,)
17 (4, 512, 64)
18 (4, 128, 64)
19 (4, 128, 128)
20 (512,)
21 (512,)
22 (512, 512)
23 (512,)
24 (512,)
25 (512,)
---
26 (512, 128)
27 (128,)
28 (9, 128, 128)
29 (128,)
30 (9, 128, 128)
31 (128,)
32 (128,)
33 (128,)
34 (128, 128)
35 (128,)
36 (128,)
37 (128,)
38 (512, 512)
39 (512,)
40 (4, 512, 64)
41 (4, 128, 64)
42 (4, 128, 128)
43 (512,)
44 (512,)
45 (512, 512)
46 (512,)
47 (512,)
48 (512,)
---
49 (512, 128)
50 (128,)
51 (9, 128, 128)
52 (128,)
53 (9, 128, 128)
54 (128,)
55 (128,)
56 (128,)
57 (128, 128)
58 (128,)
59 (128,)
60 (128,)
61 (512, 512)
62 (512,)
63 (4, 512, 64)
64 (4, 128, 64)
65 (4, 128, 128)
66 (512,)
67 (512,)
68 (512, 512)
69 (512,)
70 (512,)
71 (512,)
---
72 (512, 128)
73 (128,)
74 (9, 128, 128)
75 (128,)
76 (9, 128, 128)
77 (128,)
78 (128,)
79 (128,)
80 (128, 128)
81 (128