In [1]:
import os
from pathlib import Path
import sys
curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))
from collections import Counter
import pandas as pd
import matplotlib.pyplot as plt
from src.utils.data import read_fasta
from src.data.datasets import ProteinDataset

In [2]:
train = read_fasta('../data/swissprot/proteinfer_splits/random/train_GO.fasta')
val = read_fasta('../data/swissprot/proteinfer_splits/random/dev_GO.fasta')
test = read_fasta('../data/swissprot/proteinfer_splits/random/test_GO.fasta')

In [3]:
train = [(k[0],j," ".join(k[1:])) for j,k in train]
test = [(k[0],j," ".join(k[1:])) for j,k in test]
val = [(k[0],j," ".join(k[1:])) for j,k in val]

df = train + val + test

In [4]:
df = pd.DataFrame(df,columns=['id','sequence','labels'])

In [5]:
num_sequences = len(df)
print('number of sequences:',num_sequences)

number of sequences: 522607


In [6]:
labels = Counter()

vocab = set()
amino_freq = Counter()
for idx,row in df.iterrows():
    sequence = row['sequence']
    row_labels = row['labels']
    aa_list = list(sequence)
    if row_labels =='':
        print(row['id'],row['labels'])
    vocab.update(aa_list)
    amino_freq.update(aa_list)
    labels.update(row_labels.split(" "))

In [7]:
print('# GO Terms:',len(labels.keys()))

# GO Terms: 32102


In [8]:
print('GO Terms distribution')
pd.Series(labels.values()).describe()

GO Terms distribution


count     32102.000000
mean        777.250545
std        9114.786603
min           1.000000
25%           4.000000
50%          17.000000
75%          84.000000
max      462356.000000
dtype: float64

In [9]:
print('Sequence length distribution')

df['sequence'].apply(len).describe()

Sequence length distribution


count    522607.000000
mean        368.042215
std         334.721845
min           2.000000
25%         179.000000
50%         303.000000
75%         456.000000
max       35213.000000
Name: sequence, dtype: float64

In [3]:
PD = ProteinDataset(data_path='../data/swissprot/proteinfer_splits/random/train_GO.fasta',
                    sequence_vocabulary_path='../data/vocabularies/amino_acid_vocab.json',
                    label_vocabulary_path='../data/vocabularies/GO_label_vocab.json')

In [4]:
len(PD),PD.get_max_seq_len()

(418015, 35213)

In [112]:
import torch
from src.models.protein_encoders import Residual,ProteInfer
i=torch.ones((8,20,100))
#r=Residual(input_channels=20,kernel_size=9,dilation=9,bottleneck_factor=0.5,activation = torch.nn.ReLU)
r = ProteInfer(num_labels=32102,input_channels=20,output_channels=1100,kernel_size=9,activation=torch.nn.ReLU,dilation_base=3,num_resnet_blocks=5,bottleneck_factor=0.5)
seqs_lengths = torch.tensor([80,20,5,100,95,80,20,5])
o=r(i,seqs_lengths)
o.shape

torch.Size([8, 32102])

In [119]:
from src.utils.data import read_pickle
import numpy as np
tf_weights = read_pickle('../models/proteinfer/GO_model_weights.pkl')

# set new weights from loaded tf values
matched_shapes = []
with torch.no_grad():
    for (name, param), (tf_name, tf_param) in zip(r.named_parameters(), tf_weights.items()):
        
        if (tf_param.ndim>=2):
            tf_param =np.transpose(tf_param,
                                   tuple(sorted(range(tf_param.ndim),reverse=True))
                                   ) 
        
        
        print(f'{name}:{param.shape}','<-->',f'{tf_name}:{tf_param.shape}')
        matched_shapes.append((param.detach().numpy().shape==tf_param.shape))
        # convert NHWC to NCHW format and copy to change memory layout
        #tf_param = np.transpose(tf_param, (3, 2, 0, 1)).copy() if len(tf_param.shape) == 4 else tf_param
        #assert tf_param.shape == param.detach().numpy().shape, name

        
        #param.copy_(torch.tensor(tf_param, requires_grad=True, dtype=param.dtype))
print('matched all shapes =',all(matched_shapes))

conv1.weight:torch.Size([1100, 20, 9]) <--> inferrer/conv1d/kernel:0:(1100, 20, 9)
conv1.bias:torch.Size([1100]) <--> inferrer/conv1d/bias:0:(1100,)
resnet_blocks.0.bn_activation_1.0.weight:torch.Size([1100]) <--> inferrer/residual_block_0/batch_normalization/gamma:0:(1100,)
resnet_blocks.0.bn_activation_1.0.bias:torch.Size([1100]) <--> inferrer/residual_block_0/batch_normalization/beta:0:(1100,)
resnet_blocks.0.masked_conv1.weight:torch.Size([550, 1100, 9]) <--> inferrer/residual_block_0/conv1d/kernel:0:(550, 1100, 9)
resnet_blocks.0.masked_conv1.bias:torch.Size([550]) <--> inferrer/residual_block_0/conv1d/bias:0:(550,)
resnet_blocks.0.bn_activation_2.0.weight:torch.Size([550]) <--> inferrer/residual_block_0/batch_normalization_1/gamma:0:(550,)
resnet_blocks.0.bn_activation_2.0.bias:torch.Size([550]) <--> inferrer/residual_block_0/batch_normalization_1/beta:0:(550,)
resnet_blocks.0.masked_conv2.weight:torch.Size([1100, 550, 1]) <--> inferrer/residual_block_0/conv1d_1/kernel:0:(1100, 5

In [117]:
all(a)

True

In [111]:
float(a[0][0][0]),float(a[0][1][0])

(1110.0, 1100.0)

In [14]:
tf_weights['inferrer/residual_block_2/batch_normalization_1/gamma:0'].shape

(550,)

In [24]:
tf_weights['inferrer/conv1d/kernel:0'].ndim

3

In [5]:
len([i[0] for i in list(r.named_parameters())])

44

In [16]:
r.get_parameter('resnet_blocks.2.bn_activation_2.0.weight').shape

torch.Size([250])

In [21]:
import 
np.permue(tf_param)

array([-0.49530053, -0.55358905, -0.41502154, ..., -0.07069803,
       -0.12090866, -0.12089201], dtype=float32)