In [125]:
import torch
from torch import nn
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.data import Dataset, DataLoader, RandomSampler, TensorDataset
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torchmetrics
#model_def.py
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup

## Data Creation

In [9]:
blind_path = "../data/blind.fasta.txt"
cyto_path = "../data/cyto.fasta.txt"
mito_path = "../data/mito.fasta.txt"
nucleus_path = "../data/nucleus.fasta.txt"
other_path = "../data/other.fasta.txt"
secreted_path = "../data/secreted.fasta.txt"

In [10]:
def read_fasta(file):
    """
    This function takes an unstructured fasta file and outputs a dictionary of the sequences
    Input: - fasta file
    Output: - dict with keys (sequence header) and values (sequence)
    """
    sequences = {}
    with open(file, 'r') as f:
        header = ""
        sequence = ""
        for line in f:
            #in a fasta file the first character is a > sign
            if line[0] == ">":
                if header:
                    sequences[header] = sequence
                header = line[1:].strip()
                sequence = ""
            else:
                sequence += line.strip()
        sequences[header] = sequence
    return sequences

In [11]:
# This creates the dictionary of sequences for each location
blind_sequences = read_fasta(blind_path)
cyto_sequences = read_fasta(cyto_path)
mito_sequences = read_fasta(mito_path)
nucleus_sequences = read_fasta(nucleus_path)
other_sequences = read_fasta(other_path)
secreted_sequences = read_fasta(secreted_path)

df_cyto = pd.DataFrame.from_dict(cyto_sequences, orient='index', columns=['Sequences'])
df_cyto = df_cyto.reset_index().rename(columns={'index':'Label'})
df_cyto['Label'] = 'cyto'

df_mito = pd.DataFrame.from_dict(mito_sequences, orient='index', columns=['Sequences'])
df_mito = df_mito.reset_index().rename(columns={'index':'Label'})
df_mito['Label'] = 'mito'

df_nucleus = pd.DataFrame.from_dict(nucleus_sequences, orient='index', columns=['Sequences'])
df_nucleus = df_nucleus.reset_index().rename(columns={'index':'Label'})
df_nucleus['Label'] = 'nucleus'

df_other = pd.DataFrame.from_dict(other_sequences, orient='index', columns=['Sequences'])
df_other = df_other.reset_index().rename(columns={'index':'Label'})
df_other['Label'] = 'other'

df_secreted = pd.DataFrame.from_dict(secreted_sequences, orient='index', columns=['Sequences'])
df_secreted = df_secreted.reset_index().rename(columns={'index':'Label'})
df_secreted['Label'] = 'secreted'

df = pd.concat([df_cyto, df_mito, df_nucleus, df_other, df_secreted], axis=0).reset_index()
# Display the DataFrame
#df['encoded_cat'] = df['Label'].astype('category').cat.codes
#df.drop(columns={'index', 'Label'}, inplace=True)

#result = df.to_dict('records')

one_hot = pd.get_dummies(df['Label'])
df = pd.concat([df, one_hot], axis=1)
df.drop(columns={'index', 'Label'}, inplace=True)
df

Unnamed: 0,Sequences,cyto,mito,nucleus,other,secreted
0,MGQQVGRVGEAPGLQQPQPRGIRGSSAARPSGRRRDPAGRTADAGF...,1,0,0,0,0
1,MALEPIDYTTHSREIDAEYLKIVRGSDPDTTWLIISPNAKKEYEPE...,1,0,0,0,0
2,MNQIEPGVQYNYVYDEDEYMIQEEEWDRDLLLDPAWEKQQRKTFTA...,1,0,0,0,0
3,MSEEPTPVSGNDKQLLNKAWEITQKKTFTAWCNSHLRKLGSSIEQI...,1,0,0,0,0
4,MGDWMTVTDPGLSSESKTISQYTSETKMSPSSLYSQQVLCSSIPLS...,1,0,0,0,0
...,...,...,...,...,...,...
11219,MIPNITQLKTAALVMLFAGQALSGPVESRQASESIDAKFKAHGKKY...,0,0,0,0,1
11220,MLRKLVTGALAAALLLSGQSNAQNACQQTQQLSGGRTINNKNETGN...,0,0,0,0,1
11221,MIFHQFYSILILCLIFPNQVVQSDKERQDWIPSDYGGYMNPAGRSD...,0,0,0,0,1
11222,MKFQVVLSALLACSSAVVASPIENLFKYRAVKASHSKNINSTLPAW...,0,0,0,0,1


In [12]:
target_list = ['cyto', 'mito', 'nucleus','other', 'secreted']

In [48]:
def train_test_split(df, validation_set = False):  
    
    #train test split of dataset 
    train_size = 0.8
    train_df=df.sample(frac=train_size,random_state=200)
    test_df=df.drop(train_df.index).reset_index(drop=True)
    train_df = train_df.reset_index(drop=True)
    
    if validation_set:
        #split the train further with a validation dataset
        train_size = 0.8
        train2_df = train_df.sample(frac=train_size, random_state=200)
        val_df = train_df.drop(train2_df.index).reset_index(drop=True)
        train2_df = train2_df.reset_index(drop=True)
    
        return train2_df, test_df, val_df
   
    return train_df, test_df

train2_df, test_df, val_df = train_test_split(df, validation_set = True)

## Model

In [93]:
PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert_bfd_localization'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME, do_lower_case=False)

In [136]:
class ProteinSequenceDataset(Dataset):
    def __init__(self, data, tokenizer, max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        
        single_row = self.data.iloc[item]
        sequence = single_row['Sequences']
        target = single_row[target_list]
        #target[['cyto', 'mito', 'nucleus','other', 'secreted']] = target[['cyto', 'mito', 'nucleus','other', 'secreted']].astype(int)

        
        encoding = self.tokenizer.encode_plus(
            sequence,
            truncation=True,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
          'protein_sequence': sequence,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'targets': torch.tensor(target, dtype=torch.long)
        }

In [137]:
train_dataset = ProteinSequenceDataset(train_df, tokenizer = tokenizer, max_len = 1500)

In [138]:
train_dataset[100]

{'protein_sequence': 'MLPGLAAAAAHRCSWSSLCRLRLRCRAAACNPSDRQEWQNLVTFGSFSNMVPCSHPYIGTLSQVKLYSTNVQKEGQGSQTLRVEKVPSFETAEGIGTELKAPLKQEPLQVRVKAVLKKREYGSKYTQNNFITGVRAINEFCLKSSDLEQLRKIRRRSPHEDTESFTVYLRSDVEAKSLEVWGSPEALAREKKLRKEAEIEYRERLFRNQKILREYRDFLGNTKPRSRTASVFFKGPGKVVMVAICINGLNCFFKFLAWIYTGSASMFSEAIHSLSDTCNQGLLALGISKSVQTPDPSHPYGFSNMRYISSLISGVGIFMMGAGLSWYHGVMGLLHPQPIESLLWAYCILAGSLVSEGATLLVAVNELRRNARAKGMSFYKYVMESRDPSTNVILLEDTAAVLGVIIAATCMGLTSITGNPLYDSLGSLGVGTLLGMVSAFLIYTNTEALLGRSIQPEQVQRLTELLENDPSVRAIHDVKATDLGLGKVRFKAEVDFDGRVVTRSYLEKQDFDQMLQEIQEVKTPEELETFMLKHGENIIDTLGAEVDRLEKELKKRNPEVRHVDLEIL',
 'input_ids': tensor([2, 1, 3,  ..., 0, 0, 0]),
 'attention_mask': tensor([1, 1, 1,  ..., 0, 0, 0]),
 'targets': tensor([0, 0, 1, 0, 0])}

In [166]:
class ProteinDataModule(pl.LightningDataModule):
    def __init__(self, train_df, test_df,val_df,  tokenizer, batch_size=32, max_len=1500):
        super().__init__()
        
        self.train_df = train_df
        self.test_df = test_df
        self.val_df = val_df

        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_len = max_len
    
    def setup(self, stage=None):
        self.train_dataset = ProteinSequenceDataset(self.train_df, self.tokenizer, self.max_len)
        self.test_dataset = ProteinSequenceDataset(self.test_df, self.tokenizer, self.max_len)
        self.val_dataset = ProteinSequenceDataset(self.val_df, self.tokenizer, self.max_len)

    def train_dataloader(self):
        return DataLoader(self.train_dataset,batch_size=self.batch_size,shuffle=True,num_workers=4)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset,batch_size=1,num_workers=4)    
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset,batch_size=1,num_workers=4)

In [177]:
EPOCHS = 2
BATCH_SIZE = 32
MAX_LENGTH = 1500

data_module = ProteinDataModule(
    train_df, 
    test_df,
    val_df,
    tokenizer, 
    batch_size=BATCH_SIZE,
    max_len = MAX_LENGTH
)

data_module.setup() 

In [159]:
PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert_bfd_localization'

class ProteinClassifier(pl.LightningModule):
    def __init__(self, n_classes: int, steps_per_epoch=None, n_epochs=None):
        super().__init__()

        self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
        self.classifier = nn.Sequential(nn.Dropout(p=0.2),
                                        nn.Linear(self.bert.config.hidden_size, n_classes),
                                        nn.Tanh())
        self.steps_per_epoch = steps_per_epoch
        self.n_epochs = n_epochs
        self.criterion = nn.BCELoss()

    def forward(self, input_ids, attention_mask, targets=None):
        output = self.bert(input_ids, attention_mask=attention_mask)
        output = self.classifier(output.pooler_output)
        output = torch.sigmoid(output)
        loss = 0
        if targets is not None:
            loss = self.criterion(output, targets)
        return loss, output
    
    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        targets = batch["targets"]
        loss, outputs = self(input_ids, attention_mask, targets)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return {
            "loss": loss,
            "predictions": outputs,
            "targets": targets
        }

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        targets = batch["targets"]
        loss, outputs = self(input_ids, attention_mask, targets)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        targets = batch["targets"]
        loss, outputs = self(input_ids, attention_mask, targets)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        return loss
    
    def training_epoch_end(self, outputs):
        targets = []
        predictions = []
        
        for output in outputs:
            for out_labels in output["targets"].detach().cpu():
                targets.append(out_labels)

            for out_predictions in output["predictions"].detach().cpu():
                predictions.append(out_predictions)
        print("#####")
        targets = torch.stack(targets).int()
        predictions = torch.stack(predictions)

        for i, name in enumerate(Classes):
            roc_score = torchmetrics.AUROC(predictions[:, i], labels[:, i])
            self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", roc_score, self.current_epoch)
    
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=2e-5)
        warmup_steps = self.steps_per_epoch // 3
        total_steps = self.steps_per_epoch * self.n_epochs - warmup_steps
        scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)
        return [optimizer], [scheduler]

In [178]:
dm = ProteinDataModule(
    train_df, 
    test_df,
    val_df,
    tokenizer, 
    batch_size=BATCH_SIZE,
    max_len = MAX_LENGTH
)

model = ProteinClassifier(
    n_classes=len(target_list), 
    steps_per_epoch=len(train_df)//BATCH_SIZE, 
    n_epochs=EPOCHS
)

Some weights of the model checkpoint at Rostlab/prot_bert_bfd_localization were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [179]:
trainer = pl.Trainer(max_epochs=EPOCHS)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [180]:
trainer.fit(model, dm)


  | Name       | Type       | Params
------------------------------------------
0 | bert       | BertModel  | 419 M 
1 | classifier | Sequential | 5.1 K 
2 | criterion  | BCELoss    | 0     
------------------------------------------
419 M     Trainable params
0         Non-trainable params
419 M     Total params
1,679.745 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/pierredemetz/.pyenv/versions/3.8.12/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/pierredemetz/.pyenv/versions/3.8.12/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'ProteinSequenceDataset' on <module '__main__' (built-in)>


In [None]:
PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert_bfd_localization'
class ProteinClassifier(nn.Module):
    def __init__(self, n_classes):
        super(ProteinClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
        self.classifier = nn.Sequential(nn.Dropout(p=0.2),
                                        nn.Linear(self.bert.config.hidden_size, n_classes),
                                        nn.Tanh())
        
    def forward(self, input_ids, attention_mask):
        output = self.bert(
          input_ids=input_ids,
          attention_mask=attention_mask
        )
        return self.classifier(output.pooler_output)