In [None]:
# Install packages and clone repo
!pip install lightning
!git clone https://github.com/mhbakalar/phla-prediction.git

# Restart runtime after package installation
import os
os.kill(os.getpid(), 9)

In [None]:
cd phla-prediction

In [None]:
from google.colab import drive
drive.mount('/content/drive')
data_root = '/content/phla-prediction/'

In [None]:
import os
from sklearn.model_selection import ParameterGrid

import lightning as L
import torch

from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint, Callback

import models.datasets.phla_data
import models.modules.transformer
import models.modules.split_transformer

class PeptidePrediction():
    def __init__(self):
        super().__init__()

    def run(self):
        save_dir = data_root+"logs"

        # Define parameters for sweep
        parameter_dict = {'embedding_dim': [256],
                          'heads': [16],
                          'transformer_layers': [3]                          }

        # Parameter sweep
        for params in ParameterGrid(parameter_dict):
            # Extract parameters
            embedding_dim = params['embedding_dim']
            heads = params['heads']
            layers = params['transformer_layers']

            # Configure data
            hits_file = data_root+'data/hits_95.txt'
            decoys_file = data_root+'data/decoys.txt'
            aa_order_file = data_root+'data/amino_acid_ordering.txt'
            allele_sequence_file = data_root+'data/alleles_95_variable.txt'

            data = models.datasets.phla_data.PeptideHLADataModule(
                hits_file=hits_file,
                decoys_file=decoys_file,
                aa_order_file=aa_order_file,
                allele_sequence_file=allele_sequence_file,
                decoy_mul=1,
                decoy_pool_mul=10,
                train_test_split=0.2,
                batch_size=32,
                predict_mode=False
            )
            data.prepare_data()

            # Configure the model
            model = models.modules.split_transformer.PeptideHLATransformer(
                peptide_length=12,
                allele_length=60,
                dropout_rate=0.3,
                embedding_dim=embedding_dim,
                transformer_heads=heads,
                transformer_layers=layers,
                learning_rate=1e-4
            )

            # Create a logger
            logger = TensorBoardLogger(
                save_dir=save_dir,
            )

            checkpoint_callback = ModelCheckpoint(dirpath=logger.log_dir, save_top_k=2, monitor="val_loss")
            trainer = L.Trainer(
                max_epochs=10,
                logger=logger,
                callbacks=[checkpoint_callback],
                accelerator="gpu",
                reload_dataloaders_every_n_epochs=1
            )
            trainer.tune(model, datamodule=data)
            trainer.fit(model, datamodule=data)


In [None]:
app = PeptidePrediction()
app.run()