In [13]:
%load_ext autoreload
%autoreload 2

import sys
import tqdm
import torch
import numpy as np
import pandas as pd
import scanpy as sc
import torch.nn as nn
import pyarrow.parquet as pq
import seaborn as sns
import matplotlib.pyplot as plt

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from transformers import T5Tokenizer, T5EncoderModel
from pytorch_lightning import Trainer
from lightning.pytorch.tuner import Tuner

sys.path.append('..')
from pLMtrainer.dataloader import FrustrationDataset, FrustrationDataModule
from pLMtrainer.models import FrustrationCNN

torch.set_float32_matmul_precision('medium')
parquet_path = "../data/frustration/v3_frustration.parquet.gzip"
cath_sampling_n = 100  # number of samples per CATH superfamily

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
df = pq.read_table(parquet_path).to_pandas()
df.head(2)

Unnamed: 0,proteinID,full_seq,res_seq,res_idx,frst_idx,frst_class,set,cath_T_id
0,AF-A0A009EQP3-F1-model_v4_TED02,MKESLRLRLDQLSDRHEELTALLADVEVISDNKRFRQLSREHNDLT...,"[Y, L, E, I, R, A, G, T, G, G, D, E, A, A, I, ...","[111, 112, 113, 114, 115, 116, 117, 118, 119, ...","[0.235, 1.632, -0.844, 1.365, 0.282, -0.384, 0...","[9, 6, 11, 6, 9, 10, 9, 9, 9, 11, 12, 13, 9, 8...",train,3.30.70
1,AF-A0A009F754-F1-model_v4_TED02,MANPAQLVRHKLLNTFFSRHSVWFACITIAVIFTIFHIGYEPRYIY...,"[R, I, L, I, G, N, E, Q, C, T, Q, P, Y, S, A, ...","[164, 165, 166, 167, 168, 169, 170, 171, 172, ...","[-1.068, 1.469, -0.317, 1.098, -0.663, -0.771,...","[12, 6, 10, 7, 11, 11, 7, 10, 3, 9, 13, 11, 9,...",train,3.20.20


In [14]:
grouped = df.groupby("cath_T_id")
df_sub = grouped.sample(n=cath_sampling_n, replace=True, random_state=42).reset_index(drop=True)

In [15]:
df_sub

Unnamed: 0,proteinID,full_seq,res_seq,res_idx,frst_idx,frst_class,set,cath_T_id
0,AF-A0A7R9RFN3-F1-model_v4_TED01,LALLKMIDKTDGSCEESPSVRVSRYPVQRKDTSSSISSVGTDDSYE...,"[E, L, V, Q, E, I, N, E, Q, V, E, S, Y, F, A, ...","[111, 112, 113, 114, 115, 116, 117, 118, 119, ...","[-0.962, 1.174, 0.953, -1.313, -1.273, 1.204, ...","[12, 7, 7, 12, 12, 7, 11, 13, 11, 7, 12, 11, 1...",train,1.10.10
1,AF-A0A0R3X2J2-F1-model_v4_TED01,MDTANSESICYDSSALSMDLFRRLTSYIQSQQHQQQQQQQEQLNPS...,"[T, L, H, K, P, P, Y, S, Y, I, A, L, I, A, M, ...","[104, 105, 106, 107, 108, 109, 110, 111, 112, ...","[0.218, 0.427, 1.397, -1.159, -0.17, 1.177, -0...","[9, 9, 6, 12, 10, 7, 10, 10, 10, 6, 8, 6, 6, 7...",train,1.10.10
2,AF-A0A536Z0C6-F1-model_v4_TED01,MERKLNSAKSRTRLSSVANSLRLIKAFSEDEYEIGISDLAKRLGLA...,"[S, V, A, N, S, L, R, L, I, K, A, F, S, E, D, ...","[15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 2...","[0.005, 0.643, 0.494, -1.044, -0.8, 1.36, -1.1...","[9, 8, 8, 12, 11, 6, 12, 7, 6, 13, 8, 9, 7, 12...",train,1.10.10
3,AF-A0A4Y8IBB4-F1-model_v4_TED02,MEKTWYLKRINIFQGMSEAEMKQLDGITHMKHYDKKSLIYLPGDVS...,"[V, H, E, R, I, A, L, L, L, L, R, L, A, D, R, ...","[148, 149, 150, 151, 152, 153, 154, 155, 156, ...","[1.205, -0.233, -0.518, -0.452, 1.321, 0.54, 1...","[7, 10, 11, 11, 7, 8, 7, 7, 6, 7, 11, 7, 8, 7,...",train,1.10.10
4,AF-A0A5N0UUB5-F1-model_v4_TED01,MGPAPKLDDQVCFALYAASRAVTALYRPLLDEMGLTYPQYLVMLVL...,"[Q, V, C, F, A, L, Y, A, A, S, R, A, V, T, A, ...","[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20...","[-1.511, 1.074, 1.257, 0.662, 0.945, 1.322, 0....","[13, 7, 7, 8, 7, 7, 9, 8, 7, 11, 11, 8, 6, 11,...",train,1.10.10
...,...,...,...,...,...,...,...,...
56195,AF-A0A134CBI3-F1-model_v4_TED02,MSKFFDDDIIEQVKDANDIVSVISEHIPLKKKGKNYWGCCPFHNEK...,"[A, R, M, E, S, M, Y, E, V, N, E, L, A, T, N, ...","[110, 111, 112, 113, 114, 115, 116, 117, 118, ...","[-0.108, 1.898, 0.234, -0.471, -0.629, 1.057, ...","[10, 5, 9, 11, 11, 7, 9, 13, 7, 11, 11, 7, 9, ...",train,3.90.980
56196,AF-A0A2M8S4A9-F1-model_v4_TED02,MKGQIPRTFIDELLSKTDIVDVVNTRVKLKKAGRDYQACCPFHHEK...,"[Q, T, K, R, N, L, Y, E, L, M, Q, D, I, A, Q, ...","[113, 114, 115, 116, 117, 118, 119, 120, 121, ...","[0.682, -0.022, -0.939, 1.035, -0.63, 1.293, -...","[8, 10, 12, 7, 11, 7, 14, 8, 7, 9, 11, 13, 5, ...",train,3.90.980
56197,AF-A0A661D206-F1-model_v4_TED01,ELLDQVSNYYQYQINHHSQSHQVQQYVKQRGLSPEIIKAFGLGFSP...,"[L, L, D, Q, V, S, N, Y, Y, Q, Y, Q, I, N, H, ...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...","[1.261, 1.164, 2.057, -0.779, 1.129, -0.32, -0...","[7, 7, 5, 11, 7, 10, 10, 11, 12, 12, 11, 12, 5...",train,3.90.980
56198,AF-A0A7X9S5I5-F1-model_v4_TED02,MRIQEEVIEKIKEQNDIVDIISENVRLKKAGRNFTGLCPFHNDKSP...,"[K, K, R, D, L, L, F, K, V, N, V, E, A, A, R, ...","[106, 107, 108, 109, 110, 111, 112, 113, 114, ...","[-0.073, -0.28, -0.707, 0.286, 0.474, 1.245, 0...","[10, 10, 11, 9, 8, 7, 9, 11, 7, 12, 6, 11, 8, ...",train,3.90.980


In [9]:
df["cath_T_id"].value_counts()

cath_T_id
3.40.50      225209
3.20.20       52550
2.60.40       43757
2.60.120      25416
3.30.70       24062
              ...  
1.20.1370        30
3.40.1790        28
3.10.520         20
3.30.1010        20
3.40.420         17
Name: count, Length: 562, dtype: int64

In [5]:
grouped = df.groupby("cath_T_id")
grouped

<pandas.core.groupby.generic.DataFrameGroupBy object at 0x174b225a0>

In [None]:
grouped.sample(n=100, replace=True).reset_index(drop=True)

Unnamed: 0,proteinID,full_seq,res_seq,res_idx,frst_idx,frst_class,set,cath_T_id
0,AF-A0A6A5VA62-F1-model_v4_TED01,MASDDASDPVPRYSGFSRFELELEFVQCLANPAYLNYLAQQKILDK...,"[F, S, R, F, E, L, E, L, E, F, V, Q, C, L, A, ...","[15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 2...","[0.434, -0.179, 0.764, 0.807, -0.775, 0.966, -...","[9, 10, 8, 8, 11, 7, 12, 7, 13, 8, 7, 11, 6, 6...",train,1.10.10
1,AF-A0A1C5A311-F1-model_v4_TED01,MTLGAAPSSSVARARAVIGLNTAYFRTKALQSAVELGVFDLLADGP...,"[V, A, R, A, R, A, V, I, G, L, N, T, A, Y, F, ...","[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 2...","[0.93, -0.429, -0.824, 0.775, -0.778, -0.196, ...","[7, 10, 11, 8, 11, 10, 7, 8, 11, 7, 13, 11, 8,...",train,1.10.10
2,AF-A0A4S8VZ13-F1-model_v4_TED03,DWERFNKIQTQTFNALYASDDNVFVGSSTGSGKTVCAEFALLRHWS...,"[Q, A, F, L, H, D, V, F, V, T, E, I, S, T, K, ...","[405, 406, 407, 408, 409, 410, 411, 412, 413, ...","[-1.017, -1.549, -0.089, 1.295, -0.215, -1.343...","[12, 13, 10, 7, 10, 13, 6, 7, 7, 9, 11, 6, 11,...",train,1.10.10
3,AF-A0A5E7FPM6-F1-model_v4_TED01,MLDAIRWDTDLIRRYDLVGPRYTSYPTAVQFNSQVGTFDLLHALRD...,"[T, D, L, I, R, R, Y, D, L, V, G, P, R, Y, D, ...","[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,...","[-0.423, 2.217, 1.205, 1.061, 1.631, 0.804, -0...","[10, 5, 7, 7, 6, 8, 11, 9, 10, 9, 11, 9, 7, 7,...",train,1.10.10
4,AF-A0A3Q2HIX3-F1-model_v4_TED01,MPRGHKSKLRACEKRQQARGETQGLQDAQATTAAEEESTPSPSPQP...,"[D, S, L, S, K, K, A, V, L, L, E, Q, Y, L, L, ...","[110, 111, 112, 113, 114, 115, 116, 117, 118, ...","[2.344, 0.111, 0.821, -0.272, -1.053, -0.963, ...","[4, 9, 8, 10, 12, 12, 8, 10, 6, 7, 11, 11, 9, ...",train,1.10.10
...,...,...,...,...,...,...,...,...
56195,AF-A0A3E0H9E9-F1-model_v4_TED02,MAAGRIPQSFIDEVLLRTDVVDLIDARVKLKRAGKNYSACCPFHQE...,"[Q, P, L, F, D, A, L, E, R, A, A, G, F, F, E, ...","[112, 113, 114, 115, 116, 117, 118, 119, 120, ...","[-1.087, -1.365, 1.397, 0.323, 0.112, 0.802, 1...","[12, 13, 6, 9, 9, 8, 6, 13, 11, 8, 9, 12, 7, 7...",train,3.90.980
56196,AF-A0A1Z3HPX8-F1-model_v4_TED02,MTERAIVRHLRLAMTTPRLHPDTIEAVRERADIVDIVSQHVVLKKQ...,"[A, Q, R, Q, Q, L, Q, R, Q, L, S, L, R, Q, Q, ...","[116, 117, 118, 119, 120, 121, 122, 123, 124, ...","[-0.662, -0.492, -0.106, -0.498, 0.435, 0.425,...","[11, 11, 10, 11, 9, 9, 10, 11, 10, 8, 9, 7, 10...",train,3.90.980
56197,AF-A0A661KJF8-F1-model_v4_TED02,MGGLIPEAVLEEIRQRADIVEVISDYVTLKKAGRNYKGLCPFHQEK...,"[K, D, N, R, K, E, N, L, F, K, I, N, T, L, A, ...","[104, 105, 106, 107, 108, 109, 110, 111, 112, ...","[-0.429, 1.691, 0.778, 1.75, -0.006, -0.692, 0...","[10, 6, 8, 6, 10, 11, 8, 7, 11, 10, 6, 11, 10,...",train,3.90.980
56198,AF-A0A3S0LMH6-F1-model_v4_TED02,MAIPRDFINELVARIDIVDLIDAKVPLKKAGKNHSACCPFHSEKSP...,"[G, L, S, R, D, L, Y, Q, L, M, E, E, A, S, L, ...","[107, 108, 109, 110, 111, 112, 113, 114, 115, ...","[0.117, 1.027, -0.517, -0.446, -1.199, 1.476, ...","[9, 7, 11, 11, 12, 6, 10, 10, 6, 7, 11, 12, 8,...",train,3.90.980


In [None]:
data_module = FrustrationDataModule(df=None,
                                    parquet_path=parquet_path, 
                                    batch_size=10, 
                                    max_seq_length=512, 
                                    num_workers=1, 
                                    persistent_workers=True, 
                                    sample_size=None, 
                                    cath_sampling_n=100)

In [3]:
early_stop = EarlyStopping(monitor="val_loss",
                           patience=5,
                           mode='min',
                           verbose=True)
checkpoint = ModelCheckpoint(monitor="val_loss",
                             dirpath="./checkpoints",
                             filename=f"debug",
                             save_top_k=1,
                             mode='min',
                             save_weights_only=True)
logger = CSVLogger("./checkpoints", name="debug_logs")

In [None]:
model = FrustrationBaseline(input_dim=1024, 
                            output_dim=4,
                            max_seq_length=512,
                            precision="half",
                            pLM_model="../data/ProtT5", 
                            prefix_prostT5="<AA2fold>",
                            no_label_token=-100)

Using half precision


In [None]:
trainer = Trainer(accelerator='mps', # gpu
                  devices=-1, # 4 for one node on haicore
                  #strategy='ddp',
                  max_epochs=2,
                  logger=logger,
                  log_every_n_steps=10, # 50 for haicore default
                  callbacks=[early_stop, checkpoint],
                  precision="16-mixed",
                  gradient_clip_val=1,
                  enable_progress_bar=True,
                  deterministic=False, # for reproducibility disable on cluster 
                  #num_sanity_val_steps=0,
                  #accumulate_grad_batches=2, # if batch size gets too small --> test on H100/A100
                  )

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
tuner = Tuner(trainer)

In [None]:
trainer.fit(model, datamodule=data_module)