In [1]:
# from https://github.com/dohlee/proteinbert-pytorch/blob/master/proteinbert_pytorch/proteinbert_pytorch.py
# from https://gist.github.com/chirag1992m/4c1f2cb27d7c138a4dc76aeddfe940c2

In [2]:
import numpy as np
import pickle

In [3]:
import torch
import torch.nn as nn
import math

from torch import einsum
from einops import rearrange
from einops.layers.torch import Rearrange

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    
    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
        super().__init__()
        self.conv_narrow = nn.Sequential(
            Rearrange('b l d -> b d l'),
            nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding='same', dilation=1),
            nn.GELU(),
            Rearrange('b d l -> b l d')
        )
        self.conv_wide = nn.Sequential(
            Rearrange('b l d -> b d l'),
            nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding='same', dilation=5),
            nn.GELU(),
            Rearrange('b d l -> b l d')
        )
    
    def forward(self, x):
        return self.conv_narrow(x) + self.conv_wide(x)

class GlobalAttention(nn.Module):
    def __init__(self, d_local, d_global, n_heads, d_key):
        super().__init__()
        d_value = d_global // n_heads

        self.to_q = nn.Sequential(
            nn.Linear(d_global, d_key * n_heads, bias=False),
            nn.Tanh()
        )
        self.to_k = nn.Sequential(
            nn.Linear(d_local, d_key * n_heads, bias=False),
            nn.Tanh()
        )
        self.to_v = nn.Sequential(
            nn.Linear(d_local, d_value * n_heads, bias=False),
            nn.GELU()
        )

        self.n_heads = n_heads
        self.d_key = d_key
    
    def forward(self, x_local, x_global):
        q = self.to_q(x_global)
        k = self.to_k(x_local)
        v = self.to_v(x_local)

        q = rearrange(q, 'b (h d) -> b h d', h=self.n_heads)
        k = rearrange(k, 'b l (h d) -> b l h d', h=self.n_heads)
        v = rearrange(v, 'b l (h d) -> b l h d', h=self.n_heads)

        att = einsum('b h d, b l h d -> b h l', q, k) / math.sqrt(self.d_key)
        att = att.softmax(dim=-1)

        x_global = einsum('b h l, b l h d -> b h d', att, v)
        x_global = rearrange(x_global, 'b h d -> b (h d)')
        return x_global


class TransformerLikeBlock(nn.Module):
    def __init__(self, d_local, d_global):
        super().__init__()

        self.wide_and_narrow_conv1d = ConvBlock(d_local, d_local)
        self.dense_and_broadcast = nn.Sequential(
            nn.Linear(d_global, d_local),
            nn.GELU(),
            Rearrange('b d -> b () d')
        )
        self.local_ln1 = nn.LayerNorm(d_local)
        self.local_dense = nn.Sequential(
            Residual(nn.Sequential(nn.Linear(d_local, d_local), nn.GELU())),
            nn.LayerNorm(d_local),
        )

        self.global_dense1 = nn.Sequential(nn.Linear(d_global, d_global), nn.GELU())
        self.global_attention = GlobalAttention(d_local, d_global, n_heads=4, d_key=64)
        self.global_ln1 = nn.LayerNorm(d_global)
        self.global_dense2 = nn.Sequential(
            Residual(nn.Sequential(nn.Linear(d_global, d_global), nn.GELU())),
            nn.LayerNorm(d_global),
        )
    
    def forward(self, x_local, x_global):
        x_local = self.local_ln1(
            x_local + self.wide_and_narrow_conv1d(x_local) + self.dense_and_broadcast(x_global)
        )
        x_local = self.local_dense(x_local)

        x_global = self.global_ln1(
            x_global + self.global_dense1(x_global) + self.global_attention(x_local, x_global)
        )
        x_global = self.global_dense2(x_global)

        return x_local, x_global

class ProteinBERT(nn.Module):
    def __init__(
            self,
            vocab_size,
            ann_size,
            d_local=128,
            d_global=512,
        ):
        super().__init__()

        self.embed_local = nn.Embedding(vocab_size, d_local)
        self.embed_global = nn.Sequential(nn.Linear(ann_size, d_global), nn.GELU())

        self.blocks = nn.ModuleList([TransformerLikeBlock(d_local, d_global) for _ in range(6)])

        self.local_head = nn.Sequential(nn.Linear(d_local, vocab_size))  # NOTE: logits are returned
        self.global_head = nn.Sequential(nn.Linear(d_global, ann_size), nn.Sigmoid())
    
    def forward(self, x_local, x_global):
        x_local = self.embed_local(x_local)
        x_global = self.embed_global(x_global)

        for block in self.blocks:
            x_local, x_global = block(x_local, x_global)

        return self.local_head(x_local), self.global_head(x_global)

In [7]:
def load_pretrained_weights(model, weights):
    # reorganize weights
    for i, weight in enumerate(pretrained_model_weights):
        if i == 2:
            continue
        if len(weight.shape) == 2:
            pretrained_model_weights[i] = np.transpose(weight, (1, 0))
            continue
        
        if (3 <= i <= 140):
            if i % 23 == 17:
                pretrained_model_weights[i] = np.transpose(weight, (0, 2, 1)).reshape(4 * 64, 512)
                continue
            if i % 23 == 18:
                pretrained_model_weights[i] = np.transpose(weight, (0, 2, 1)).reshape(4 * 64, 128)
                continue
            if i % 23 == 19:
                pretrained_model_weights[i] = np.transpose(weight, (0, 2, 1)).reshape(4 * 128, 128)
                continue
    
        if len(weight.shape) == 3:
            pretrained_model_weights[i] = np.transpose(weight, (2, 1, 0))

    # convert all to tensors
    for i, weight in enumerate(pretrained_model_weights):
        pretrained_model_weights[i] = torch.from_numpy(weight)
        
    # load weights
    state = model.state_dict()
    state["embed_local.weight"] = pretrained_model_weights[2]
    state["embed_global.0.weight"] = pretrained_model_weights[0]
    state["embed_global.0.bias"] = pretrained_model_weights[1]
    
    for block in range(6):
        idx = 3 + block * 23
        state[f"blocks.{block}.wide_and_narrow_conv1d.conv_narrow.1.weight"] = pretrained_model_weights[idx + 2]
        state[f"blocks.{block}.wide_and_narrow_conv1d.conv_narrow.1.bias"] = pretrained_model_weights[idx + 3]
        state[f"blocks.{block}.wide_and_narrow_conv1d.conv_wide.1.weight"] = pretrained_model_weights[idx + 4]
        state[f"blocks.{block}.wide_and_narrow_conv1d.conv_wide.1.bias"] = pretrained_model_weights[idx + 5]
        state[f"blocks.{block}.dense_and_broadcast.0.weight"] = pretrained_model_weights[idx]
        state[f"blocks.{block}.dense_and_broadcast.0.bias"] = pretrained_model_weights[idx + 1]
        state[f"blocks.{block}.local_ln1.weight"] = pretrained_model_weights[idx + 6]
        state[f"blocks.{block}.local_ln1.bias"] = pretrained_model_weights[idx + 7]
        state[f"blocks.{block}.local_dense.0.fn.0.weight"] = pretrained_model_weights[idx + 8]
        state[f"blocks.{block}.local_dense.0.fn.0.bias"] = pretrained_model_weights[idx + 9]
        state[f"blocks.{block}.local_dense.1.weight"] = pretrained_model_weights[idx + 10]
        state[f"blocks.{block}.local_dense.1.bias"] = pretrained_model_weights[idx + 11]
        state[f"blocks.{block}.global_dense1.0.weight"] = pretrained_model_weights[idx + 12]
        state[f"blocks.{block}.global_dense1.0.bias"] = pretrained_model_weights[idx + 13]
        state[f"blocks.{block}.global_attention.to_q.0.weight"] = pretrained_model_weights[idx + 14]
        state[f"blocks.{block}.global_attention.to_k.0.weight"] = pretrained_model_weights[idx + 15]
        state[f"blocks.{block}.global_attention.to_v.0.weight"] = pretrained_model_weights[idx + 16]
        state[f"blocks.{block}.global_ln1.weight"] = pretrained_model_weights[idx + 17]
        state[f"blocks.{block}.global_ln1.bias"] = pretrained_model_weights[idx + 18]
        state[f"blocks.{block}.global_dense2.0.fn.0.weight"] = pretrained_model_weights[idx + 19]
        state[f"blocks.{block}.global_dense2.0.fn.0.bias"] = pretrained_model_weights[idx + 20]
        state[f"blocks.{block}.global_dense2.1.weight"] = pretrained_model_weights[idx + 21]
        state[f"blocks.{block}.global_dense2.1.bias"] = pretrained_model_weights[idx + 22]

    state["local_head.0.weight"] = pretrained_model_weights[141]
    state["local_head.0.bias"] = pretrained_model_weights[142]
    state["global_head.0.weight"] = pretrained_model_weights[143]
    state["global_head.0.bias"] = pretrained_model_weights[144]

    model.load_state_dict(state)

with open("../weights/epoch_92400_sample_23500000.pkl", "rb") as f:
    n_annotations, pretrained_model_weights, pretrained_optimizer_weights = pickle.load(f)


In [9]:
vocab_size = 26
ann_size = 8943
bsz = 1

model = ProteinBERT(vocab_size, ann_size)
load_pretrained_weights(model, pretrained_model_weights)

TypeError: expected np.ndarray (got Tensor)

In [10]:
x_local = torch.randint(0, vocab_size, (bsz, 52))
x_global = torch.rand(bsz, ann_size)

x_local, x_global = model(x_local, x_global)
print(x_local.shape, x_global.shape)

# Print the number of parameters in the model.
# NOTE: Must have ~16M parameters according to the paper.
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

torch.Size([1, 52, 26]) torch.Size([1, 8943])
15981321


In [None]:
ALL_AAS = 'ACDEFGHIKLMNPQRSTUVWXY'
ADDITIONAL_TOKENS = ['<OTHER>', '<START>', '<END>', '<PAD>']

# Each sequence is added <START> and <END> tokens
ADDED_TOKENS_PER_SEQ = 2

n_aas = len(ALL_AAS)
aa_to_token_index = {aa: i for i, aa in enumerate(ALL_AAS)}
additional_token_to_index = {token: i + n_aas for i, token in enumerate(ADDITIONAL_TOKENS)}
token_to_index = {**aa_to_token_index, **additional_token_to_index}
index_to_token = {index: token for token, index in token_to_index.items()}
n_tokens = len(token_to_index)

In [None]:
class InputEncoder:
    def __init__(self, n_annotations):
        self.n_annotations = n_annotations

    def parse_seq(self, seq):
        if isinstance(seq, str):
            return seq
        elif isinstance(seq, bytes):
            return seq.decode('utf8')
        else:
            raise TypeError('Unexpected sequence type: %s' % type(seq))
        
    def tokenize_seq(self, seq):
        other_token_index = additional_token_to_index['<OTHER>']
        return [additional_token_to_index['<START>']] + \
            [aa_to_token_index.get(aa, other_token_index) for aa in self.parse_seq(seq)] + \
            [additional_token_to_index['<END>']]
    
    def tokenize_seqs(self, seqs, seq_len):
        # Note that tokenize_seq already adds <START> and <END> tokens.
        return np.array(
            [
                seq_tokens + (seq_len - len(seq_tokens)) * [additional_token_to_index['<PAD>']]
                for seq_tokens in map(self.tokenize_seq, seqs)
            ],
            dtype = np.int32
            )
    
    def encode_X(self, seqs, seq_len):
        return [
            self.tokenize_seqs(seqs, seq_len),
            np.zeros((len(seqs), self.n_annotations), dtype = np.int8)
        ]

    def decode_output(self, pred):
        pred = torch.argmax(pred, 2)
        return pred

In [None]:
def pred_to_str(pred):
    idx_to_tok = list("ACDEFGHIKLMNPQRSTUVWXY?^$_")
    ret = []
    for i in pred:
        out = "".join([idx_to_tok[j] for j in i])
        ret.append(out)

    return ret

In [None]:
input_enc = InputEncoder(n_annotations)

In [None]:
prots = {
    "test": "XLGMIRNSLFGSVETWPWQVLSTGGKEDVSYEERACEGGKFATVEVTDKPVDEALREAMPKIMKYVGGTN",
    "insulin_correct": "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
    "chained": "YRGDXXXRGDXXXRGDXXXRGDXXXRGD"
}

prot_seqs = []
for name, seq in prots.items():
    seq_tok, global_anno = input_enc.encode_X([seq], 200)
    seq_tok = torch.from_numpy(seq_tok).long()
    global_anno = torch.from_numpy(global_anno).float()
    pred_local, pred_global = model(seq_tok, global_anno)
    print(name)
    print(pred_to_str(input_enc.decode_output(pred_local))[0])
    print("---")