<a href="https://colab.research.google.com/github/katyachemistry/Repository/blob/main/reproduce_the_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import dependencies and load data

In [7]:
import pandas as pd
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output
import os
import h5py

In [None]:
!gdown 1-3R1BTp5QmL4O87QuxgQG5bagAdClpGl
df = pd.read_pickle('dataframe.pkl')

Please, use your account or create one to be able to proceed with W&B for training.

In [55]:
!pip install -q --upgrade tbb
!pip install wandb -qU
import wandb
wandb.login()

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m33.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m16.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m289.6/289.6 kB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25h

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

# Create double-blind dataset

In [56]:
np.random.seed(100)

prots_val_test = np.random.choice(df.STITCH_protein_ID.unique(), 1600, replace=False) # 1600 is an empirical number of
                                                                                      # protein IDs to follow 1:4 ratio of val+test:train

val_test_df = df.loc[df['STITCH_protein_ID'].isin(prots_val_test)]

df_dupl = pd.concat([df, val_test_df])
df_dupl['Duplicated'] = df_dupl.duplicated(['STITCH_protein_ID', 'SMILES'], keep=False)
train = df_dupl[~df_dupl['Duplicated']]

val = val_test_df.sample(frac=0.5)
test = val_test_df.drop(val.index)

In [57]:
if sum([val.SMILES.isin(train.STITCH_protein_ID).any(), test.SMILES.isin(train.STITCH_protein_ID).any(),
      val.STITCH_protein_ID.isin(train.STITCH_protein_ID).any(), val.STITCH_protein_ID.isin(train.STITCH_protein_ID).any()]) == 0:
      print('No proteins or molecules from train are in test & val')

No proteins or molecules from train are in test & val


In [58]:
if sum([val.STITCH_protein_ID.isin(test.STITCH_protein_ID).any(), val.STITCH_protein_ID.isin(test.STITCH_protein_ID).any()]) >0:
  print('Some proteins and molecules appear in both val and test')

Some proteins and molecules appear in both val and test


# Train with ProtT5 embeddings

In [61]:
#---------------------------------------------

def get_ProtT5_data(dataset, morgan_fp=True):
    """
    Extracts molecular and protein features, and labels from the dataset.

    Args:
        dataset (pd.DataFrame): The input dataset containing the features and labels.
        use_moltr (bool, optional): If True, use MolTr embeddings for molecular representations. If False, use Morgan fingerprints. Default is False.

    Returns:
        tuple: Containing arrays of molecular features, protein features, and labels.
    """
    if morgan_fp:
        mol = dataset.MorganFP.values
    else:
        mol = dataset.MolTr.values

    prots = dataset.ProtT5.values
    labels = dataset.label.values

    return mol, prots, labels

#@markdown - Name of **your W&B project**. Project with this name should be created beforehand in W&B
project = "MolTransf-and-ProtTrans_per_prot_embs" #@param {type:"string"}

#@markdown - Untick **morgan_fp** to use MolecularTransformer embeddings instead of Morgan fingerprints for ligands
morgan_fp = True #@param {type:"boolean"}
learning_rate = 1e-5 #@param {type:"number"}
epochs = 1 #@param {type:"number"}
weight_decay = 1e-5 #@param {type:"number"}
batch_size = 32 #@param {type:"number"}
input_size_protein = 1024
if morgan_fp:
  input_size_molecule = 1024
else:
  input_size_molecule = 512
fc1_layer_size_factor = 2
fc2_layer_size_factor = 2
dropout_rate = 0.0 #@param {type:"number"}

# Extract features for test set
test_mols, test_prots, test_labels = get_ProtT5_data(test, morgan_fp=morgan_fp)

# Extract features for validation set
val_mols, val_prots, val_labels = get_ProtT5_data(val, morgan_fp=morgan_fp)

# Extract features for training set
train_mols, train_prots, train_labels = get_ProtT5_data(train, morgan_fp=morgan_fp)


#---------------------------------------------
# Dataset class for our data

class ProteinMoleculeDataset(Dataset):
    def __init__(self, proteins, molecules, labels):
        self.proteins = proteins
        self.molecules = molecules
        self.labels = torch.tensor(np.vstack(labels), dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        protein = self.proteins[idx]
        molecule = self.molecules[idx]
        label = self.labels[idx]
        return protein, molecule, label


#---------------------------------------------
# Define class for our model

class InteractionClassifier_ProtT5_based(nn.Module):
    '''
    Interaction/non-interaction classification model for using ProtT5 protein embeddings

    Args:
        input_size_protein (int): Size of the input feature vector for proteins.
        input_size_molecule (int): Size of the input feature vector for molecules.
        fc1_layer_size_factor (int): Factor to reduce the size of the first fully connected layer.
        fc2_layer_size_factor (int): Factor to reduce the size of the second fully connected layer.
        dropout_rate (float): Dropout rate to apply after each layer. Default is 0.

    Attributes:
        protein_fc1 (nn.Linear): First fully connected layer for protein features.
        protein_fc2 (nn.Linear): Second fully connected layer for protein features.
        molecule_fc1 (nn.Linear): First fully connected layer for molecule features.
        molecule_fc2 (nn.Linear): Second fully connected layer for molecule features.
        dropout (nn.Dropout): Dropout layer.
        fc1 (nn.Linear): Fully connected layer combining protein and molecule features.
        fc2 (nn.Linear): Output layer.
        norm_prot1 (nn.BatchNorm1d): Batch normalization for the first protein layer.
        norm_prot2 (nn.BatchNorm1d): Batch normalization for the second protein layer.
        norm_mol1 (nn.BatchNorm1d): Batch normalization for the first molecule layer.
        norm_mol2 (nn.BatchNorm1d): Batch normalization for the second molecule layer.
        norm_all (nn.BatchNorm1d): Batch normalization for the combined features layer.
    '''

    def __init__(self, input_size_protein, input_size_molecule, fc1_layer_size_factor, fc2_layer_size_factor, dropout_rate=0):
        super().__init__()

        output_size_protein_1 = int(input_size_protein / fc1_layer_size_factor)
        self.protein_fc1 = nn.Linear(input_size_protein, output_size_protein_1)

        output_size_protein_2 = int(output_size_protein_1 / fc2_layer_size_factor)
        self.protein_fc2 = nn.Linear(output_size_protein_1, output_size_protein_2)

        output_size_molecule_1 = int(input_size_molecule / fc1_layer_size_factor)
        self.molecule_fc1 = nn.Linear(input_size_molecule, output_size_molecule_1)

        output_size_molecule_2 = int(output_size_molecule_1 / fc2_layer_size_factor)
        self.molecule_fc2 = nn.Linear(output_size_molecule_1, output_size_molecule_2)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout_rate)

        self.fc1 = nn.Linear(output_size_protein_2 + output_size_molecule_2, 64)
        self.fc2 = nn.Linear(64, 1)

        self.norm_prot1 = nn.BatchNorm1d(output_size_protein_1)
        self.norm_prot2 = nn.BatchNorm1d(output_size_protein_2)
        self.norm_mol1 = nn.BatchNorm1d(output_size_molecule_1)
        self.norm_mol2 = nn.BatchNorm1d(output_size_molecule_2)
        self.norm_all = nn.BatchNorm1d(64)

    def forward(self, protein, molecule):
        molecule = molecule.view(molecule.size(0), -1).to(torch.float32)

        protein = self.relu(self.norm_prot1(self.protein_fc1(protein)))
        protein = self.dropout(protein)
        protein = self.relu(self.norm_prot2(self.protein_fc2(protein)))
        protein = self.dropout(protein)

        molecule = self.relu(self.norm_mol1(self.molecule_fc1(molecule)))
        molecule = self.dropout(molecule)
        molecule = self.relu(self.norm_mol2(self.molecule_fc2(molecule)))
        molecule = self.dropout(molecule)

        combined = torch.cat((protein, molecule), dim=1)

        x = self.relu(self.norm_all(self.fc1(combined)))
        x = self.dropout(x)

        x = self.fc2(x)

        return x


#---------------------------------------------
# Install some more dependencies

!pip install -q lightning
import pytorch_lightning as L
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
import torchmetrics

#---------------------------------------------
# Lightning will take care of the training process. Define Lit class

class Lit(L.LightningModule):
    def __init__(
        self,
        model,
        optimizer_kwargs,
        exp_name="MyClassifier",
        criterion=nn.BCEWithLogitsLoss(pos_weight = torch.tensor([1.25])),
        optimizer_class=torch.optim.AdamW,

    ) -> None:
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.optimizer_class = optimizer_class
        self.optimizer_kwargs = optimizer_kwargs

        self.train_auroc = torchmetrics.AUROC(task="binary")
        self.valid_auroc = torchmetrics.AUROC(task="binary")
        self.test_auroc = torchmetrics.AUROC(task="binary")

        self.train_accuracy = torchmetrics.Accuracy(task="binary")
        self.valid_accuracy = torchmetrics.Accuracy(task="binary")
        self.test_accuracy = torchmetrics.Accuracy(task="binary")

        self.train_recall = torchmetrics.Recall(task="binary")
        self.valid_recall = torchmetrics.Recall(task="binary")
        self.test_recall = torchmetrics.Recall(task="binary")

        self.train_precision = torchmetrics.Precision(task="binary")
        self.valid_precision = torchmetrics.Precision(task="binary")
        self.test_precision = torchmetrics.Precision(task="binary")

        self.train_f1 = torchmetrics.F1Score(task="binary")
        self.valid_f1 = torchmetrics.F1Score(task="binary")
        self.test_f1 = torchmetrics.F1Score(task="binary")

    def configure_optimizers(self):
        optimizer = self.optimizer_class(
            self.model.parameters(), **self.optimizer_kwargs
        )

        return optimizer

    def training_step(self, batch, batch_idx):
        prots, mols, labels = batch
        out = self.model(prots, mols)
        loss = self.criterion(out, labels)
        self.log("loss_on_train", loss, prog_bar=True)
        self.train_auroc.update(out, labels)
        self.train_accuracy.update(out, labels)
        self.train_recall.update(out, labels)
        self.train_precision.update(out, labels)
        self.train_f1.update(out, labels)

        return loss

    def validation_step(self, batch, batch_idx):
        prots, mols, labels = batch
        out = self.model(prots, mols)
        loss = self.criterion(out, labels)
        self.log("loss_on_val", loss, prog_bar=True)
        self.valid_auroc.update(out, labels)
        self.valid_accuracy.update(out, labels)
        self.valid_recall.update(out, labels)
        self.valid_precision.update(out, labels)
        self.valid_f1.update(out, labels)

    def on_train_epoch_end(self):
        self.log("AUROC/train", self.train_auroc.compute(), prog_bar = True)
        self.log("Accuracy/train", self.train_accuracy.compute(), prog_bar = True)
        self.log("Recall/train", self.train_recall.compute(), prog_bar = True)
        self.log("Precision/train", self.train_precision.compute(), prog_bar = True)
        self.log("F1/train", self.train_f1.compute(), prog_bar = True)

        self.train_auroc.reset()
        self.train_accuracy.reset()
        self.train_recall.reset()
        self.train_precision.reset()
        self.train_f1.reset()

    def on_validation_epoch_end(self):
        self.log("AUROC/valid", self.valid_auroc.compute(), prog_bar = True)
        self.log("Accuracy/valid", self.valid_accuracy.compute(), prog_bar = True)
        self.log("Recall/valid", self.valid_recall.compute(), prog_bar = True)
        self.log("Precision/valid", self.valid_precision.compute(), prog_bar = True)
        self.log("F1/valid", self.valid_f1.compute(), prog_bar = True)

        self.valid_auroc.reset()
        self.valid_accuracy.reset()
        self.valid_recall.reset()
        self.valid_precision.reset()
        self.valid_f1.reset()

    def test_step(self, batch, batch_idx):
        prots, mols, labels = batch
        out = self.model(prots, mols)
        loss = self.criterion(out, labels)
        self.log("loss_on_test", loss, prog_bar=True)
        self.test_auroc.update(out, labels)
        self.test_accuracy.update(out, labels)
        self.test_recall.update(out, labels)
        self.test_precision.update(out, labels)
        self.test_f1.update(out, labels)

    def on_test_epoch_end(self):
        self.log("AUROC/test", self.test_auroc.compute(), prog_bar = True)
        self.log("Accuracy/test", self.test_accuracy.compute(), prog_bar = True)
        self.log("Recall/test", self.test_recall.compute(), prog_bar = True)
        self.log("Precision/test", self.test_precision.compute(), prog_bar = True)
        self.log("F1/test", self.test_f1.compute(), prog_bar = True)

        self.test_auroc.reset()
        self.test_accuracy.reset()
        self.test_recall.reset()
        self.test_precision.reset()
        self.test_f1.reset()

#---------------------------------------------
# And some more dependencies

from pytorch_lightning.loggers import WandbLogger

#---------------------------------------------
# Finally. Let's train and track experiment with W&B.

config={
      "learning_rate": learning_rate,
      "epochs": epochs,
      "project": project,
      "weight_decay": weight_decay,
      "batch_size": batch_size,
      "input_size_protein": input_size_protein,
      "input_size_molecule": input_size_molecule,
      "fc1_layer_size_factor": fc1_layer_size_factor,
      "fc2_layer_size_factor": fc2_layer_size_factor,
      "dropout_rate": dropout_rate
      }

batch_size = config['batch_size']

train_dataset = ProteinMoleculeDataset(train_prots, train_mols, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataset = ProteinMoleculeDataset(val_prots, val_mols, val_labels)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataset = ProteinMoleculeDataset(test_prots, test_mols, test_labels)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

wandb.init(
      project=project,
      config=config)

model_name = 'MyClassifier'
model = InteractionClassifier_ProtT5_based(config['input_size_protein'], config['input_size_molecule'],
                                    config["fc1_layer_size_factor"], config["fc2_layer_size_factor"],
                                    dropout_rate=config["dropout_rate"])


lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = L.Trainer(
    max_epochs=config["epochs"],
    logger=WandbLogger(log_model="all", project=config["project"]),
    num_sanity_val_steps=0,
    callbacks=[lr_monitor]
)

pipeline = Lit(model=model, exp_name=model_name, optimizer_kwargs={'lr':config['learning_rate'],
                                                                   'weight_decay':config['weight_decay']})

trainer.fit(
      model=pipeline,
      train_dataloaders=train_dataloader,
      val_dataloaders=val_dataloader
  )
trainer.test(dataloaders=test_dataloader)

wandb.finish()


INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
   | Name            | Type                               | Params | Mode 
--------------------------------------------------------------------------------
0  | model           | InteractionClassifier_ProtT5_based | 1.3 M  | train
1  | criterion       | BCEWithLogitsLoss                  | 0      | train
2  | train_auroc     | BinaryAUROC                        | 0      | train
3  | valid_auroc     | BinaryAUROC                        | 0      | train
4  | test_auroc      | BinaryAUROC                        | 0      | train
5  | train_accuracy  | BinaryAccuracy                     | 0      | train
6  | valid_accuracy  | BinaryAccuracy                     | 0      | train
7  | test_accuracy   | BinaryAc

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at ./MolTransf-and-ProtTrans_per_prot_embs/5ql2uekl/checkpoints/epoch=0-step=11584.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at ./MolTransf-and-ProtTrans_per_prot_embs/5ql2uekl/checkpoints/epoch=0-step=11584.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

VBox(children=(Label(value='15.475 MB of 15.475 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
AUROC/test,▁
AUROC/train,▁
AUROC/valid,▁
Accuracy/test,▁
Accuracy/train,▁
Accuracy/valid,▁
F1/test,▁
F1/train,▁
F1/valid,▁
Precision/test,▁

0,1
AUROC/test,0.92474
AUROC/train,0.94557
AUROC/valid,0.92472
Accuracy/test,0.84828
Accuracy/train,0.87818
Accuracy/valid,0.84832
F1/test,0.83218
F1/train,0.86573
F1/valid,0.8333
Precision/test,0.86616


# Train with AlphaFold2 embeddings

In [60]:
#---------------------------------------------

def get_AF_data(dataset, morgan_fp=True):
    """
    Converts dataset values to tensors with optional padding and choice of molecular representations.

    Args:
        dataset: The input dataset containing values to be converted.
        morgan_fp_embeddings (bool, optional): If True, use Morgan fingerprint embeddings for molecular representations. If False, use MolTr embeddings. Default is True.

    Returns:
        tuple: Containing lists of tensors for single, pair, msa, mol, and label values.
    """

    # Choose the appropriate molecular representation
    if morgan_fp:
        mol = [torch.tensor(x, dtype=torch.float32) for x in dataset.MorganFP.values]
    else:
        mol = [torch.tensor(x, dtype=torch.float32) for x in dataset.MolTr.values]

    # Convert label values to tensors
    label = [torch.tensor(x, dtype=torch.float32) for x in dataset.label.values]

    pad_fn = lambda x: torch.tensor(x, dtype=torch.float32)

    # Apply the padding function to single, pair, and msa values
    single = [pad_fn(x) for x in dataset.AF_single.values]
    pair = [pad_fn(x) for x in dataset.AF_pair.values]
    msa = [pad_fn(x) for x in dataset.AF_MSA.values]

    return single, pair, msa, mol, label


#@markdown - Name of **your W&B project**. Project with this name should be created beforehand in W&B
project = "AF_trial" #@param {type:"string"}
#@markdown - Untick **morgan_fp** to use MolecularTransformer embeddings instead of Morgan fingerprints for ligands
morgan_fp = True #@param {type:"boolean"}
learning_rate = 1e-5 #@param {type:"number"}
epochs = 1 #@param {type:"number"}
weight_decay = 1e-5 #@param {type:"number"}
batch_size = 32 #@param {type:"number"}
single_size = 256
pair_size = 128
msa_size = 23
if morgan_fp:
  molecule_size = 1024
else:
  molecule_size = 512
fc1_layer_size_factor = 2
fc2_layer_size_factor = 2
dropout_rate = 0.0 #@param {type:"number"}

test_single, test_pair, test_msa, test_mol, test_label = get_AF_data(test, morgan_fp=morgan_fp)
val_single, val_pair, val_msa, val_mol, val_label = get_AF_data(val, morgan_fp=morgan_fp)
train_single, train_pair, train_msa, train_mol, train_label = get_AF_data(train, morgan_fp=morgan_fp)

#---------------------------------------------
# Dataset class for our data

class ProteinMoleculeDataset_AF(Dataset):
    def __init__(self, single, pair, msa, molecules, labels):
        self.single = single
        self.pair = pair
        self.msa = msa
        self.molecules = molecules
        self.labels = torch.tensor(np.vstack(labels), dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        single = self.single[idx]
        pair = self.pair[idx]
        msa = self.msa[idx]
        molecule = self.molecules[idx]
        label = self.labels[idx]
        return single, pair, msa, molecule, label


#---------------------------------------------
# Define class for our model

class InteractionClassifier_AF2_based(nn.Module):

    def __init__(self, single_size, pair_size, msa_size, molecule_size, fc1_layer_size_factor, fc2_layer_size_factor, dropout_rate=0):
        super().__init__()

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout_rate)

        output_size_single_1 = int(single_size / fc1_layer_size_factor)
        self.single_fc1 = nn.Linear(single_size, output_size_single_1)
        self.norm_single_1 = nn.BatchNorm1d(output_size_single_1)

        output_size_pair_1 = int(pair_size / fc1_layer_size_factor)
        self.pair_fc1 = nn.Linear(pair_size, output_size_pair_1)
        self.norm_pair_1 = nn.BatchNorm1d(output_size_pair_1)

        output_size_msa_1 = int(msa_size / fc1_layer_size_factor)
        self.msa_fc1 = nn.Linear(msa_size, output_size_msa_1)
        self.norm_msa_1 = nn.BatchNorm1d(output_size_msa_1)

        output_size_single_2 = int(output_size_single_1 / fc2_layer_size_factor)
        self.single_fc2 = nn.Linear(output_size_single_1, output_size_single_2)
        self.norm_single_2 = nn.BatchNorm1d(output_size_single_2)

        output_size_pair_2 = int(output_size_pair_1 / fc2_layer_size_factor)
        self.pair_fc2 = nn.Linear(output_size_pair_1, output_size_pair_2)
        self.norm_pair_2 = nn.BatchNorm1d(output_size_pair_2)

        output_size_msa_2 = int(output_size_msa_1 / fc2_layer_size_factor)
        self.msa_fc2 = nn.Linear(output_size_msa_1, output_size_msa_2)
        self.norm_msa_2 = nn.BatchNorm1d(output_size_msa_2)

        output_size_molecule_1 = int(molecule_size / fc1_layer_size_factor)
        self.molecule_fc1 = nn.Linear(molecule_size, output_size_molecule_1)
        self.norm_mol1 = nn.BatchNorm1d(output_size_molecule_1)

        output_size_molecule_2 = int(output_size_molecule_1 / fc2_layer_size_factor)
        self.molecule_fc2 = nn.Linear(output_size_molecule_1, output_size_molecule_2)
        self.norm_mol2 = nn.BatchNorm1d(output_size_molecule_2)

        self.fc1 = nn.Linear(output_size_single_2 + output_size_pair_2 + output_size_msa_2 + output_size_molecule_2, 64)
        self.norm_all = nn.BatchNorm1d(64)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, single, pair, msa, molecule):
        molecule = molecule.view(molecule.size(0), -1).to(torch.float32)

        single = self.relu(self.norm_single_1(self.single_fc1(single)))
        single = self.dropout(single)
        single = self.relu(self.norm_single_2(self.single_fc2(single)))
        single = self.dropout(single)

        pair = self.relu(self.norm_pair_1(self.pair_fc1(pair)))
        pair = self.dropout(pair)
        pair = self.relu(self.norm_pair_2(self.pair_fc2(pair)))
        pair = self.dropout(pair)

        msa = self.relu(self.norm_msa_1(self.msa_fc1(msa)))
        msa = self.dropout(msa)
        msa = self.relu(self.norm_msa_2(self.msa_fc2(msa)))
        msa = self.dropout(msa)

        molecule = self.relu(self.norm_mol1(self.molecule_fc1(molecule)))
        molecule = self.dropout(molecule)
        molecule = self.relu(self.norm_mol2(self.molecule_fc2(molecule)))
        molecule = self.dropout(molecule)

        combined = torch.cat((single, pair, msa, molecule), dim=1)

        x = self.relu(self.norm_all(self.fc1(combined)))
        x = self.dropout(x)

        x = self.fc2(x)

        return x


#---------------------------------------------
# Install some more dependencies

!pip install -q lightning
import pytorch_lightning as L
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
import torchmetrics

#---------------------------------------------
# Lightning will take care of the training process. Define Lit class

class Lit_AF(L.LightningModule):
    def __init__(
        self,
        model,
        optimizer_kwargs,
        exp_name="MyClassifier",
        criterion=nn.BCEWithLogitsLoss(pos_weight = torch.tensor([1.25])),
        optimizer_class=torch.optim.AdamW,

    ) -> None:
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.optimizer_class = optimizer_class
        self.optimizer_kwargs = optimizer_kwargs

        self.train_auroc = torchmetrics.AUROC(task="binary")
        self.valid_auroc = torchmetrics.AUROC(task="binary")
        self.test_auroc = torchmetrics.AUROC(task="binary")

        self.train_accuracy = torchmetrics.Accuracy(task="binary")
        self.valid_accuracy = torchmetrics.Accuracy(task="binary")
        self.test_accuracy = torchmetrics.Accuracy(task="binary")

        self.train_recall = torchmetrics.Recall(task="binary")
        self.valid_recall = torchmetrics.Recall(task="binary")
        self.test_recall = torchmetrics.Recall(task="binary")

        self.train_precision = torchmetrics.Precision(task="binary")
        self.valid_precision = torchmetrics.Precision(task="binary")
        self.test_precision = torchmetrics.Precision(task="binary")

        self.train_f1 = torchmetrics.F1Score(task="binary")
        self.valid_f1 = torchmetrics.F1Score(task="binary")
        self.test_f1 = torchmetrics.F1Score(task="binary")

    def configure_optimizers(self):
        optimizer = self.optimizer_class(
            self.model.parameters(), **self.optimizer_kwargs
        )

        return optimizer


    def training_step(self, batch, batch_idx):
        single, pair, msa, mol, labels = batch
        out = self.model(single, pair, msa, mol)
        loss = self.criterion(out, labels)
        self.log("loss_on_train", loss, prog_bar=True)
        self.train_auroc.update(out, labels)
        self.train_accuracy.update(out, labels)
        self.train_recall.update(out, labels)
        self.train_precision.update(out, labels)
        self.train_f1.update(out, labels)

        return loss

    def validation_step(self, batch, batch_idx):
        single, pair, msa, mol, labels = batch
        out = self.model(single, pair, msa, mol)
        loss = self.criterion(out, labels)
        self.log("loss_on_val", loss, prog_bar=True)
        self.valid_auroc.update(out, labels)
        self.valid_accuracy.update(out, labels)
        self.valid_recall.update(out, labels)
        self.valid_precision.update(out, labels)
        self.valid_f1.update(out, labels)

    def on_train_epoch_end(self):
        self.log("AUROC/train", self.train_auroc.compute(), prog_bar = True)
        self.log("Accuracy/train", self.train_accuracy.compute(), prog_bar = True)
        self.log("Recall/train", self.train_recall.compute(), prog_bar = True)
        self.log("Precision/train", self.train_precision.compute(), prog_bar = True)
        self.log("F1/train", self.train_f1.compute(), prog_bar = True)

        self.train_auroc.reset()
        self.train_accuracy.reset()
        self.train_recall.reset()
        self.train_precision.reset()
        self.train_f1.reset()

    def on_validation_epoch_end(self):
        self.log("AUROC/valid", self.valid_auroc.compute(), prog_bar = True)
        self.log("Accuracy/valid", self.valid_accuracy.compute(), prog_bar = True)
        self.log("Recall/valid", self.valid_recall.compute(), prog_bar = True)
        self.log("Precision/valid", self.valid_precision.compute(), prog_bar = True)
        self.log("F1/valid", self.valid_f1.compute(), prog_bar = True)

        self.valid_auroc.reset()
        self.valid_accuracy.reset()
        self.valid_recall.reset()
        self.valid_precision.reset()
        self.valid_f1.reset()

    def test_step(self, batch, batch_idx):
        single, pair, msa, mol, labels = batch
        out = self.model(single, pair, msa, mol)
        loss = self.criterion(out, labels)
        self.log("loss_on_test", loss, prog_bar=True)
        self.test_auroc.update(out, labels)
        self.test_accuracy.update(out, labels)
        self.test_recall.update(out, labels)
        self.test_precision.update(out, labels)
        self.test_f1.update(out, labels)

    def on_test_epoch_end(self):
        self.log("AUROC/test", self.test_auroc.compute(), prog_bar = True)
        self.log("Accuracy/test", self.test_accuracy.compute(), prog_bar = True)
        self.log("Recall/test", self.test_recall.compute(), prog_bar = True)
        self.log("Precision/test", self.test_precision.compute(), prog_bar = True)
        self.log("F1/test", self.test_f1.compute(), prog_bar = True)

        self.test_auroc.reset()
        self.test_accuracy.reset()
        self.test_recall.reset()
        self.test_precision.reset()
        self.test_f1.reset()

#---------------------------------------------
# And some more dependencies

from pytorch_lightning.loggers import WandbLogger


config={
      "learning_rate": learning_rate,
      "epochs": epochs,
      "project": project,
      "weight_decay": weight_decay,
      "batch_size": batch_size,
      "single size": single_size,
      "pair size": pair_size,
      "masked msa size": msa_size,
      "molecule size": molecule_size,
      "fc1_layer_size_factor": fc1_layer_size_factor,
      "fc2_layer_size_factor": fc2_layer_size_factor,
      "dropout_rate": dropout_rate
      }

batch_size = config['batch_size']
train_dataset_AF = ProteinMoleculeDataset_AF(train_single, train_pair, train_msa, train_mol, train_label)
train_dataloader_AF = DataLoader(train_dataset_AF, batch_size=batch_size, shuffle=True)
val_dataset_AF = ProteinMoleculeDataset_AF(val_single, val_pair, val_msa, val_mol, val_label)
val_dataloader_AF = DataLoader(val_dataset_AF, batch_size=batch_size, shuffle=False)
test_dataset_AF = ProteinMoleculeDataset_AF(test_single, test_pair, test_msa, test_mol, test_label)
test_dataloader_AF = DataLoader(test_dataset_AF, batch_size=batch_size, shuffle=False)

wandb.init(
    project=config["project"],
    config=config)

model_name = 'MyClassifier'
model = InteractionClassifier_AF2_based(single_size=config['single size'], pair_size=config['pair size'],
                                       msa_size=config['masked msa size'], molecule_size=config['molecule size'],
                                            fc1_layer_size_factor=config['fc1_layer_size_factor'],
                                         fc2_layer_size_factor=config['fc2_layer_size_factor'], dropout_rate=config['dropout_rate'])

lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = L.Trainer(
    max_epochs=config["epochs"],
    logger=WandbLogger(log_model="all", project=config["project"]),
    num_sanity_val_steps=0,
    callbacks=[lr_monitor]
)

pipeline = Lit_AF(model=model, exp_name=model_name, optimizer_kwargs={'lr':config['learning_rate'],
                                                                   'weight_decay':config['weight_decay']
                                                                   })

trainer.fit(
      model=pipeline,
      train_dataloaders=train_dataloader_AF,
      val_dataloaders=val_dataloader_AF
  )
trainer.test(dataloaders=test_dataloader_AF)

wandb.finish()




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▁▁▁
loss_on_train,█▇▆▁▇▄█▃
lr-AdamW,▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▂▂▃▃▄▄▅▅▆▆▇▇██

0,1
epoch,0.0
loss_on_train,0.57497
lr-AdamW,1e-05
trainer/global_step,399.0


INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
INFO:pytorch_lightning.callbacks.model_summary:
   | Name            | Type                            | Params | Mode 
-----------------------------------------------------------------------------
0  | model           | InteractionClassifier_AF2_based | 733 K  | train
1  | criterion       | BCEWithLogitsLoss               | 0      | train
2  | train_auroc     | BinaryAUROC                     | 0      | train
3  | valid_auroc     | BinaryAUROC                     | 0      | train
4  | 

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at ./AF_trial/pvgpz56e/checkpoints/epoch=0-step=11584.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at ./AF_trial/pvgpz56e/checkpoints/epoch=0-step=11584.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

VBox(children=(Label(value='8.451 MB of 8.465 MB uploaded\r'), FloatProgress(value=0.9983768677231128, max=1.0…

0,1
AUROC/test,▁
AUROC/train,▁
AUROC/valid,▁
Accuracy/test,▁
Accuracy/train,▁
Accuracy/valid,▁
F1/test,▁
F1/train,▁
F1/valid,▁
Precision/test,▁

0,1
AUROC/test,0.89199
AUROC/train,0.91021
AUROC/valid,0.89063
Accuracy/test,0.80818
Accuracy/train,0.83196
Accuracy/valid,0.80691
F1/test,0.78658
F1/train,0.80799
F1/valid,0.78645
Precision/test,0.82397
