In [1]:
import sys
from pathlib import Path

package_path = str(Path.cwd().parent)
if package_path not in sys.path:
    sys.path.append(package_path)


import torch
import numpy as np
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from collections import OrderedDict
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score

from tbprop.deep_learning.modules import construct_fc_layers, MultiLayerPerceptron

In [2]:
def max_f1_score(y_true, y_score):
    """ 
    Calculates the maximum possible F1 score given true labels and prob scores. 
    
    Parameters
    ----------
    y_true: List[int] or np.array
        True (binary) labels.
    y_score: List[float] or np.array
        Prob scores from the model.
    """
    return max([f1_score(y_true, (np.array(y_score) > t).astype(int)) for t in np.arange(0., 1.01, 0.02)])

In [3]:
class LightningBinaryClassifier(L.LightningModule):
    def __init__(self, config={}):
        """
        Constructor. Config should at least have the following params: 
        'learning_rate': float

        To construct a lightning classifier, inherit this class and override the 
        following functions:
            __init__()
            forward()
            predict_step() (Optional): use to create embeddings.
            configure_optimizers (Optional): to use an alternate optimizer.
        """
        super().__init__()
        self.pos_weight = config.get('pos_weight', 1.)
        self.learning_rate = config.get('learning_rate', 1e-1)
        self.scale_loss = config.get('scale_loss', False)
        self.loss_fn = torch.nn.BCEWithLogitsLoss(
            pos_weight=torch.FloatTensor([self.pos_weight])
        )

    def forward(self, x):
        """ Forward function connecting torch.nn.Modules. """
        return x

    def training_step(self, batch, batch_idx):
        """ Batch-wise training function. """
        inputs, targets = batch
        outputs = self(inputs).reshape(-1)

        multiplier = 1/inputs.shape[0] if self.scale_loss else 1
        train_loss = multiplier*self.loss_fn(outputs, targets.to(torch.float32))

        self.log("ptl/train_loss", train_loss, on_step=True, 
                 on_epoch=True, prog_bar=True, logger=True)
        return train_loss
        
    def validation_step(self, batch, batch_idx):
        """ Calculate batch-wise validation metrics at the end of each epoch. """
        inputs, targets = batch
        outputs = self(inputs).reshape(-1)

        multiplier = 1/inputs.shape[0] if self.scale_loss else 1
        loss = multiplier*self.loss_fn(outputs, targets.to(torch.float32))

        targets = targets.cpu().detach().numpy()
        outputs = outputs.cpu().detach().numpy()
        auroc = roc_auc_score(targets, outputs)
        ap = average_precision_score(targets, outputs)
        max_f1 = max_f1_score(targets, outputs)

        self.log("ptl/val_loss", loss)
        self.log("ptl/val_auroc",auroc)
        self.log("ptl/val_ap", ap)
        self.log("ptl/val_f1_score", max_f1)
        
    def test_step(self, batch, batch_idx):
        """ Calculate batch-wise validation metrics at the end of each epoch. """
        inputs, targets = batch
        outputs = self(inputs).reshape(-1)

        multiplier = 1/inputs.shape[0] if self.scale_loss else 1
        loss = multiplier*self.loss_fn(outputs, targets.to(torch.float32))

        targets = targets.cpu().detach().numpy()
        outputs = outputs.cpu().detach().numpy()
        auroc = roc_auc_score(targets, outputs)
        ap = average_precision_score(targets, outputs)
        max_f1 = max_f1_score(targets, outputs)

        self.log("ptl/test_loss", loss)
        self.log("ptl/test_auroc", auroc)
        self.log("ptl/test_ap", ap)
        self.log("ptl/test_f1_score", max_f1)

    def configure_optimizers(self):
        """ Configure optimizer. """
        optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
        return optimizer


class FullyConnectedBinaryClassifier(LightningBinaryClassifier):
    def __init__(self, config={}):
        """
        Structure: Embedding -> Flatten -> MLP (x m)
        Hyperparameters:
            seq_len: int
            embed_size: int
            vocab_size: int
            fc_num_layers: int
            dropout_rate: int
            learning_rate: float
        """
        super().__init__(config=config)

        self.seq_len = config.get('seq_len', 512) # Based on ChemBERTa
        self.embed_size = config.get('embed_size', 64)
        self.vocab_size = config.get('vocab_size', 591) # Based on ChemBERTa
        self.dropout_rate = config.get('dropout_rate', 0.25)
        self.fc_num_layers = config.get('fc_num_layers', 5)
        self.fc_end_dim = 1
        self.pad_index = 0

        self.example_input_array = torch.randint(low=0, 
                                                 high=self.vocab_size, 
                                                 size=(1, self.seq_len))

        self.embedding = torch.nn.Embedding(self.vocab_size, 
                                            self.embed_size, 
                                            padding_idx=self.pad_index)
        self.feature_extractor = torch.nn.Sequential(OrderedDict([('flatten', torch.nn.Flatten())]))

        self.fc_start_dim = self.feature_extractor(torch.zeros(1, self.seq_len, self.embed_size)).shape[1]

        self.fc_layer_sizes = construct_fc_layers(self.fc_start_dim, 
                                                  self.fc_num_layers, 
                                                  self.fc_end_dim)

        self.mlp = MultiLayerPerceptron(self.fc_layer_sizes, 
                                        dropout_rate=self.dropout_rate)
        
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.embedding(x)
        x = self.feature_extractor(x)
        x = self.mlp(x)
        x = x.squeeze()
        return x
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        """ Use this as an embedding function. """
        x = self.embedding(batch)
        x = self.feature_extractor(x)
        x = self.mlp(x)
        x = self.sigmoid(x)
        return x.squeeze()

In [6]:
model = FullyConnectedBinaryClassifier({'seq_len': 128, 'fc_num_layers': 1})
batch = torch.randint(0, 591, (3, 10, 128))
trainer = L.Trainer(callbacks=[
    ModelCheckpoint(dirpath="../data/checkpoints/pytorch_lightning",
                    filename="fcnn_best_val_loss_{epoch}",
                    monitor="ptl/val_loss",
                    mode="min")
])

test_predictions = torch.cat(trainer.predict(model, batch))

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 80.27it/s]


In [8]:
test_predictions.detach().cpu().numpy()

array([0.42967474, 0.45906004, 0.43000174, 0.4293967 , 0.65963835,
       0.36255294, 0.43719235, 0.48620018, 0.32182297, 0.50228196,
       0.36033338, 0.6073804 , 0.4058569 , 0.5296254 , 0.46441895,
       0.43309945, 0.60983056, 0.57576287, 0.40438625, 0.5171854 ,
       0.6859662 , 0.6818627 , 0.60502887, 0.57953286, 0.4866189 ,
       0.5859223 , 0.62696993, 0.5145324 , 0.6307675 , 0.39782944],
      dtype=float32)