In [1]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%config InlineBackend.figure_format = 'retina'

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# mito_df = pd.read_csv('./data/mitochondria_targeting.csv')
# mito_df = mito_df[['Sequence','Mitochondrial Targeting Signal']].drop_duplicates(keep='first')
# cdc28_df = pd.read_csv('./data/cdc28_binding.csv')
# cdc28_df = cdc28_df[['Sequence','Cdc28 Binding']].drop_duplicates(keep='first')
# # seem to be same sequences.
# df = mito_df.merge(cdc28_df,on='Sequence',how='inner')
# # many of these are way longer!
# df = df.loc[df['Sequence'].map(len)<=100] # just b/c this is the longest positional encoding in xformer
# df

In [5]:
with open('./data/train_set.fasta','r') as f:
    fasta = [l.strip() for l in f]
    df = pd.Series(fasta[::3]).str.extract('>(?P<uniprot>[^\|]+)\|(?P<kingdom>[^|]+)\|(?P<type>[^|]+)\|(?P<partition>[^|]+)')
    df['sequence'] = fasta[1::3]
    df['annotation'] = fasta[2::3]
df = df.sample(frac=1.,random_state=0)
df

Unnamed: 0,uniprot,kingdom,type,partition,sequence,annotation
18174,P62320,EUKARYA,NO_SP,0,MSIGVPIKVLHEAEGHIVTCETNTGEVYRGKLIEAEDNMNCQMSNI...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
13361,O94622,EUKARYA,NO_SP,2,MRWYSYVIPAVILSIIAISGVWWNATLGTRLDQKVQLFLNEHSSIL...,IIMMMMMMMMMMMMMMMMMMMMMOOOOOOOOOOOOOOOOOOOOOOO...
13635,Q9V2T0,ARCHAEA,SP,1,MSKKKFVIVSILTILLVQAIYFVEKYHTSEDKSTSNTSSTPPQTTL...,SSSSSSSSSSSSSSSSSSSOOOOOOOOOOOOOOOOOOOOOOOOOOO...
771,Q6ZPV2,EUKARYA,NO_SP,2,MASELGAGDDGSSTELAKPLYLQYLERALRLDHFLRQTSAIFNRNI...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
17289,P40066,EUKARYA,NO_SP,0,MSFFNRSNTTSALGTSTAMANEKDLANDIVINSPAEDSISDIAFSP...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
...,...,...,...,...,...,...
13123,Q6UVW9,EUKARYA,NO_SP,2,MINPELRDGRADGFIHRIVPKLIQNWKIGLMCFLSIIITTVCIIMI...,IIIIIIIIIIIIIIIIIIIIIIIIIIIMMMMMMMMMMMMMMMMMMM...
19648,P31448,NEGATIVE,NO_SP,2,MNSLQILSFVGFTLLVAVITWWKVRKTDTGSQQGYFLAGRSLKAPV...,OOOOOOOOOMMMMMMMMMMMMMMMMMMMMMMMMMMMMIIIIIIIII...
9845,Q9P3W1,EUKARYA,NO_SP,1,MSNAPEIVQRLIKMIMRAFYETRHIIFMDAILRHSALTDEQTALLM...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
10799,Q80VJ6,EUKARYA,NO_SP,2,MASQQAPAKDLQTNNLEFTPTDSSGVQWAEDISNSPSAQLNFSPSN...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...


In [6]:
df['type'].value_counts()

NO_SP      15625
SP          2582
LIPO        1615
TAT          365
PILIN         70
TATLIPO       33
Name: type, dtype: int64

In [7]:
from src.torch_helpers import NamedTensorDataset
from src.datamodule import PeptideDataModule
from src.constants import MSConstants

C = MSConstants()

# dataset = NamedTensorDataset(
#     sequence=df['Sequence'],
#     x=df['Sequence'].map(lambda s: [C.alphabet.index(c) for c in s]),
#     x_mask=df['Sequence'].map(lambda s: [1]*len(s)),
#     y=df[['Mitochondrial Targeting Signal','Cdc28 Binding']].astype(np.int32).values
# )

dataset = NamedTensorDataset(
    sequence=df['sequence'],
    x=df['sequence'].map(lambda s: [C.alphabet.index(c) for c in s]),
    x_mask=df['sequence'].map(lambda s: [1]*len(s)),
    y=(df['type']!='NO_SP').values[:,None].astype(int)
)

dm = PeptideDataModule(
    dataset,
    batch_size=256,
    #val_batch_size=-1,
    train_val_split=0.8,
    cdhit_threshold=0.5,
    cdhit_word_length=3,
    num_workers=4
)
dm.setup()

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
import torch
from src.baselines import CNNModel, CARPModel, MSModel

torch.manual_seed(0)

# model = CNNModel(
#     output_dim = len(dm.dataset[0]['y']),
#     model_dim = 128,
#     model_depth = 3,
#     kernel_size = 3,
#     num_residues = len(C.alphabet),
#     dropout = 0.1,
#     lr = 5e-4,
#     output_weights = [1,]
# )

# model = CNNModel(
#     output_dim = len(dm.dataset[0]['y']),
#     model_dim = 64,
#     model_depth = 3,
#     kernel_size = 3,
#     num_residues = len(C.alphabet),
#     dropout = 0.1,
#     lr = 5e-4,
#     output_weights = [0.5,0.5]
# )

# gotta go back to the old transformer, again
[last_ckpt] = !ls -t1 ./old_logs/large/checkpoints/*.ckpt | head -n1
model = MSModel(
    checkpoint = last_ckpt,
    model_dim = 512,
    output_dim = len(dm.dataset[0]['y']),,
    fixed_weights = False,
    lr = 5e-4,
    output_weights = [1,],
)

# model = CARPModel(
#     output_dim = len(dm.dataset[0]['y']),
#     fixed_weights = True,
#     max_length = 100,
#     lr = 5e-4,
#     output_weights = [1],
# )

In [14]:
!rm -rf ./lightning_logs/version_$SLURM_JOBID

In [15]:
from pytorch_lightning import Trainer
from src.torch_helpers import NoValProgressBar

trainer = Trainer(
    gpus=1,
    precision=32,
    max_epochs=100,
    callbacks=[NoValProgressBar()]
)

trainer.fit(model, dm)

Multiprocessing is handled by SLURM.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-34c7fac6-49db-54e7-42d2-e2206a453ab6]

  | Name       | Type       | Params
------------------------------------------
0 | embedding  | Embedding  | 3.1 K 
1 | encoder    | Sequential | 31.0 K
2 | classifier | Linear     | 33    
------------------------------------------
34.1 K    Trainable params
0         Non-trainable params
34.1 K    Total params
0.136     Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 99: 100%|██████████| 80/80 [02:25<00:00,  1.82s/it, loss=0.16, v_num=1.67e+7]   


In [None]:
from sklearn.metrics import confusion_matrix

model = model.cpu()
model.eval()

for batch in dm.val_dataloader():
    y_pred = model.predict_step(batch, 0).detach().cpu().numpy()
    y = batch['y'].cpu().numpy()

for k in range(y.shape[1]):
    plt.figure(figsize=(4,4))
    sns.heatmap(
        confusion_matrix(y[:,k], y_pred[:,k]>0.5),
        annot=True, fmt='d', cmap='Blues'
    )