In [1]:
%load_ext autoreload
%autoreload 2

import sys
import tqdm
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import pyarrow.parquet as pq
import seaborn as sns
import matplotlib.pyplot as plt

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from transformers import T5Tokenizer, T5EncoderModel
from pytorch_lightning import Trainer

sys.path.append('..')
from pLMtrainer.dataloader import FrustrationDataset, FrustrationDataModule
from pLMtrainer.models import FrustrationFNN

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parquet_path = "../data/frustration/v3_frustration.parquet.gzip"

In [3]:
df = pq.read_table(parquet_path).to_pandas()
df.head(2)

Unnamed: 0,proteinID,full_seq,res_seq,res_idx,frst_idx,frst_class,set,cath_T_id
0,AF-A0A009EQP3-F1-model_v4_TED02,MKESLRLRLDQLSDRHEELTALLADVEVISDNKRFRQLSREHNDLT...,"[Y, L, E, I, R, A, G, T, G, G, D, E, A, A, I, ...","[111, 112, 113, 114, 115, 116, 117, 118, 119, ...","[0.235, 1.632, -0.844, 1.365, 0.282, -0.384, 0...","[9, 6, 11, 6, 9, 10, 9, 9, 9, 11, 12, 13, 9, 8...",train,3.30.70
1,AF-A0A009F754-F1-model_v4_TED02,MANPAQLVRHKLLNTFFSRHSVWFACITIAVIFTIFHIGYEPRYIY...,"[R, I, L, I, G, N, E, Q, C, T, Q, P, Y, S, A, ...","[164, 165, 166, 167, 168, 169, 170, 171, 172, ...","[-1.068, 1.469, -0.317, 1.098, -0.663, -0.771,...","[12, 6, 10, 7, 11, 11, 7, 10, 3, 9, 13, 11, 9,...",train,3.20.20


In [3]:
data_module = FrustrationDataModule(parquet_path=parquet_path, regression=False, batch_size=10, max_seq_length=100, num_workers=1, persistent_workers=True)

In [4]:
model = FrustrationFNN(input_dim=1024, 
                       hidden_dim=32, 
                       output_dim=20, 
                       dropout=0.15, 
                       max_seq_length=100,
                       regression=False, 
                       pLM_model="../data/ProstT5", 
                       pLM_precision="half", 
                       prefix_prostT5="<AA2fold>")

Using half precision for the pLM encoder


In [5]:
torch.set_float32_matmul_precision('medium')

In [None]:
early_stop = EarlyStopping(monitor="val_loss",
                           patience=5,
                           mode='min',
                           verbose=True)
checkpoint = ModelCheckpoint(monitor="val_loss",
                             dirpath="./checkpoints",
                             filename=f"debug",
                             save_top_k=1,
                             mode='min',
                             save_weights_only=True)
logger = CSVLogger("./checkpoints", name="debug_logs")

In [7]:
trainer = Trainer(accelerator='auto', # gpu
                  devices=-1, # 4 for one node on haicore
                  #strategy='ddp',
                  max_epochs=5,
                  logger=logger,
                  log_every_n_steps=10, # 50 for haicore default
                  callbacks=[early_stop, checkpoint],
                  precision="16-mixed",
                  gradient_clip_val=1,
                  enable_progress_bar=True,
                  deterministic=False, # for reproducibility disable on cluster 
                  #num_sanity_val_steps=0,
                  #accumulate_grad_batches=2, # if batch size gets too small --> test on H100/A100
                  )

trainer.fit(model, datamodule=data_module)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Loaded 1000 samples from ../data/frustration/v3_frustration.parquet.gzip
Created train/val/test masks
Initialized res_idx_mask and frst_vals tensors
Populated res_idx_mask and frst_vals tensors
Created train dataset
Created val dataset
Created test dataset
Train/Val/Test split: 899/35/66 samples


/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /Users/janleusch/Documents/phd/pLMtrainer/pLMtrainer/notebooks/checkpoints exists and is not empty.
/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | encoder | T5EncoderModel   | 1.2 B  | eval 
1 | loss_fn | CrossEntropyLoss | 0      | train
2 | FNN     | Sequential       | 33.5 K | train
-----------------------------------------------------
1.2 B     Trainable params
0         Non-trainable params
1.2 B     Total params
4,832.791 Total estimated model params size (MB)
6         Modules in train mode
439       Modules in eva

Configuring optimizers
Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]



Validation loss: 2.9851627349853516
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:01<00:01,  0.83it/s]Validation loss: 2.9802629947662354
Validation loss: 2.9802629947662354                                        
                                                                           

/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 90/90 [01:28<00:00,  1.02it/s, v_num=6, train_loss_step=1.890]Validation loss: 1.8892680406570435
Validation loss: 1.8892680406570435
Validation loss: 1.9481663703918457
Validation loss: 1.9481663703918457
Validation loss: 1.902886152267456
Validation batch with no valid residues - skipping
Validation loss: 1.902886152267456
Validation batch with no valid residues - skipping
Epoch 0: 100%|██████████| 90/90 [01:31<00:00,  0.98it/s, v_num=6, train_loss_step=1.890, val_loss=1.910, train_loss_epoch=2.310]

Metric val_loss improved. New best score: 1.913


Epoch 1: 100%|██████████| 90/90 [01:35<00:00,  0.94it/s, v_num=6, train_loss_step=1.810, val_loss=1.910, train_loss_epoch=2.310]Validation loss: 1.753843069076538
Validation loss: 1.753843069076538
Validation loss: 1.8118500709533691
Validation loss: 1.8118500709533691
Validation loss: 1.7686442136764526
Validation batch with no valid residues - skipping
Epoch 1: 100%|██████████| 90/90 [01:38<00:00,  0.91it/s, v_num=6, train_loss_step=1.810, val_loss=1.780, train_loss_epoch=1.850]Validation loss: 1.7686442136764526
Validation batch with no valid residues - skipping
Epoch 1: 100%|██████████| 90/90 [01:38<00:00,  0.91it/s, v_num=6, train_loss_step=1.810, val_loss=1.780, train_loss_epoch=1.850]

Metric val_loss improved by 0.135 >= min_delta = 0.0. New best score: 1.778


Epoch 2: 100%|██████████| 90/90 [01:46<00:00,  0.85it/s, v_num=6, train_loss_step=1.740, val_loss=1.780, train_loss_epoch=1.850]Validation loss: 1.7067312002182007
Validation loss: 1.7067312002182007
Validation loss: 1.7602545022964478
Validation loss: 1.7602545022964478
Validation loss: 1.7078973054885864
Validation batch with no valid residues - skipping
Validation loss: 1.7078973054885864
Validation batch with no valid residues - skipping
Epoch 2: 100%|██████████| 90/90 [01:50<00:00,  0.82it/s, v_num=6, train_loss_step=1.740, val_loss=1.720, train_loss_epoch=1.760]

Metric val_loss improved by 0.053 >= min_delta = 0.0. New best score: 1.725


Epoch 3: 100%|██████████| 90/90 [01:48<00:00,  0.83it/s, v_num=6, train_loss_step=1.730, val_loss=1.720, train_loss_epoch=1.760]Validation loss: 1.6805704832077026
Validation loss: 1.6805704832077026
Validation loss: 1.7211381196975708
Validation loss: 1.7211381196975708
Validation loss: 1.675698161125183
Validation batch with no valid residues - skipping

Validation loss: 1.675698161125183
Epoch 3: 100%|██████████| 90/90 [01:52<00:00,  0.80it/s, v_num=6, train_loss_step=1.730, val_loss=1.690, train_loss_epoch=1.710]

Metric val_loss improved by 0.032 >= min_delta = 0.0. New best score: 1.692


Epoch 4: 100%|██████████| 90/90 [01:50<00:00,  0.81it/s, v_num=6, train_loss_step=1.790, val_loss=1.690, train_loss_epoch=1.710]Validation loss: 1.6630829572677612
Validation loss: 1.6630829572677612
Validation loss: 1.6932507753372192
Validation loss: 1.6932507753372192
Validation loss: 1.659364938735962
Validation batch with no valid residues - skipping
Validation loss: 1.659364938735962
Validation batch with no valid residues - skipping
Epoch 4: 100%|██████████| 90/90 [01:54<00:00,  0.78it/s, v_num=6, train_loss_step=1.790, val_loss=1.670, train_loss_epoch=1.670]

Metric val_loss improved by 0.021 >= min_delta = 0.0. New best score: 1.672
`Trainer.fit` stopped: `max_epochs=5` reached.
`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 90/90 [01:56<00:00,  0.77it/s, v_num=6, train_loss_step=1.790, val_loss=1.670, train_loss_epoch=1.670]



In [6]:
data_module.setup()
for batch in data_module.val_dataloader():
    full_seq, res_mask, frst_vals = batch
    break

Loaded 982852 samples from ../data/frustration/v3_frustration.parquet.gzip
Train/Val/Test split: 896034/29926/56892 samples


In [7]:
full_seq

('MIDQIKRHGLFDIDIHCDGDLEIDDHHTVEDCGITLGQAFAQALGDKKGLRRYGHFYAPLDEALSRVVVDLSGRPGLFMDIPFTRARIGTFDVDLFSEFFQGFVNHALMTLHIDNLKGKNSHHQIESVFKALARALRMACEIDPRAENTIASTKGSL',)

In [20]:
eg_seq = ("SEQVE",)

In [41]:
seq = ["<AA2fold>" + " " + " ".join(seq) for seq in eg_seq]
seq

['<AA2fold> S E Q V E']

In [42]:
tokenizer = T5Tokenizer.from_pretrained("../data/prostT5", do_lower_case=False, max_length=10)
ids = tokenizer.batch_encode_plus(seq, 
                                add_special_tokens=True, 
                                padding="max_length",
                                truncation="longest_first", 
                                max_length=10,
                                return_tensors='pt'
                                )

In [43]:
ids["input_ids"]

tensor([[149,   7,   9,  16,   6,   9,   1,   0,   0,   0]])