In [1]:
%load_ext autoreload
%autoreload 2

import sys
import tqdm
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import pyarrow.parquet as pq
import seaborn as sns
import matplotlib.pyplot as plt

from transformers import T5Tokenizer, T5EncoderModel
from pytorch_lightning import Trainer

sys.path.append('..')
from pLMtrainer.dataloader import FrustrationDataset, FrustrationDataModule
from pLMtrainer.models import FrustrationFNN

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parquet_path = "../data/frustration/v3_frustration.parquet.gzip"

In [3]:
#df = pq.read_table(parquet_path).to_pandas()
#df.head(2)

In [4]:
data_module = FrustrationDataModule(parquet_path=parquet_path, batch_size=64, num_workers=1, persistent_workers=True)

In [5]:
model = FrustrationFNN(input_dim=1024, 
                       hidden_dim=32, 
                       output_dim=1, 
                       dropout=0.15, 
                       max_seq_length=700, 
                       pLM_model="../data/ProstT5", 
                       pLM_precision="half", 
                       prefix_prostT5="<AA2fold>")

In [6]:
torch.set_float32_matmul_precision('medium')

In [None]:
trainer = Trainer(accelerator='auto',
                    #distributed_backend='ddp',
                    max_epochs=5,
                    #logger=logger,
                    #callbacks=callbacks,
                    gradient_clip_val=1,
                    enable_progress_bar=True,
                    deterministic=True,
                    #num_sanity_val_steps=0,
                    )

trainer.fit(model, datamodule=data_module)

In [6]:
data_module.setup()
for batch in data_module.val_dataloader():
    full_seq, res_mask, frst_vals = batch
    break

Loaded 982852 samples from ../data/frustration/v3_frustration.parquet.gzip
Train/Val/Test split: 896034/29926/56892 samples


In [7]:
full_seq

('MIDQIKRHGLFDIDIHCDGDLEIDDHHTVEDCGITLGQAFAQALGDKKGLRRYGHFYAPLDEALSRVVVDLSGRPGLFMDIPFTRARIGTFDVDLFSEFFQGFVNHALMTLHIDNLKGKNSHHQIESVFKALARALRMACEIDPRAENTIASTKGSL',)

In [20]:
eg_seq = ("SEQVE",)

In [41]:
seq = ["<AA2fold>" + " " + " ".join(seq) for seq in eg_seq]
seq

['<AA2fold> S E Q V E']

In [42]:
tokenizer = T5Tokenizer.from_pretrained("../data/prostT5", do_lower_case=False, max_length=10)
ids = tokenizer.batch_encode_plus(seq, 
                                add_special_tokens=True, 
                                padding="max_length",
                                truncation="longest_first", 
                                max_length=10,
                                return_tensors='pt'
                                )

In [43]:
ids["input_ids"]

tensor([[149,   7,   9,  16,   6,   9,   1,   0,   0,   0]])