In [1]:
import os
from datetime import date
import timm
import wandb
from collections import Counter
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import HTML, display
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, DeviceStatsMonitor, ModelSummary
from PIL import Image
from torchvision import transforms

from torchvision.transforms import ToTensor

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.cuda.amp import GradScaler

import lightning.pytorch as pl
from lightning.fabric import Fabric

import torchmetrics

In [2]:

# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "./saved_models")

In [3]:
CHECKPOINT_PATH

'./saved_models'

In [4]:
# Function for setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
device

Global seed set to 42


device(type='cuda', index=0)

In [5]:
class CFG:
  MODEL_NAME = 'maxvit_tiny_tf_224'
  START_EPOCH = 0
  EPOCHS = 30

  TRAIN_BATCH = 32
  VAL_BATCH = 32
  
  PRINT_FREQ = 200
  
  WORKERS = 4

  DATADIR = "/data/home/ec2-user/broad/training_images/BBBC037/"
  TRAINDIR = DATADIR+"train_27k"
  VALDIR = DATADIR+"val_27k"
  TESTDIR = DATADIR+"test_27k"

  PRETRAINED = False
  IMAGE_SIZE = 224
  IN_CHANS = 5
  NUM_CLASSES = 13 #13 for 27k subset# 45

  ### optimizer
  LR = 0.01
  MOMENTUM = 0.9
  ADAM_EPSILON = 1e-6
  WEIGHT_DECAY = 1e-8 # for AdamW

  RANDOM_SEED = 42

  OUTPUT_DIR = '/home/ubuntu' + '/saved_models/' + str(date.today())
  CHECKPOINT_LAST = OUTPUT_DIR + '/' + MODEL_NAME + '/checkpoint-last'
  CHECKPOINT_BEST = OUTPUT_DIR + '/' + MODEL_NAME + '/checkpoint-best'

  WANDB_NOTEBOOK_NAME = str(date.today()) + '_' + MODEL_NAME + '_cjdonahoe'


In [6]:
# class_labels = [
#     'AKT1_E17K',
#     'AKT1_WT',
#     'BRAF_V600E',
#     'BRAF_WT',
#     'CDC42_Q61L',
#     'CDC42_T17N',
#     'CDC42_WT',
#     'KRAS_G12V',
#     'KRAS_WT',
#     'RAF1_L613V',
#     'RAF1_WT',
#     'RHOA_Q63L',
#     'RHOA_WT'
# ]

# for class_label in class_labels:
#     for i in ["train", "val", "test"]:
#         print(f"mkdir -p {i}_27k/{class_label}")

# # print one row for every class label and format the output like cp -R test/{class_label}/* ./test_27k/{class_label}/
# for class_label in class_labels:
#     for i in ["train", "val", "test"]:
#         print(f"cp -R {i}/{class_label}/* {i}_27k/{class_label}/")


## Weights & Biases

In [7]:
logger = pl.loggers.WandbLogger(project='cjdonahoe--cellvit/')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcjdonahoe[0m ([33mcellvit[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Define Functions

In [8]:
def get_class_weights(dataset):
    ''' Get class weights for a dataset
    Args:
        dataset: torch.utils.data.Dataset
    Returns:
        class_weights: torch.FloatTensor
    '''
    
    class_counts = Counter(dataset.targets)
    n_classes = len(class_counts.keys())
    total_count = len(dataset.targets)
    class_weights = list({class_id: total_count/(n_classes * class_counts) for class_id, class_counts in class_counts.items()}.values())
    class_weights = torch.FloatTensor(class_weights).cuda()
    return class_weights

In [9]:
class SplitTensorToFiveChannels(object):
    """Convert images in Pytorch Dataset to Tensors with one channel
    for each discrete fluerecent image in a Cell Painting sample."""
    def __call__(self, img):
        # select the first channel since the image is grayscale
        img = img[0,:,:]
        # split the image into the 6 channels and remove the last channel
        img = torch.tensor_split(img,6,dim=1)[:-1]
        # concatenate the 5 channels into a single tensor
        img = torch.stack(img, dim=0)
        return img

# Load Data

In [10]:
transform_train = transforms.Compose([
    transforms.RandomVerticalFlip(),
    transforms.Resize((CFG.IMAGE_SIZE, CFG.IMAGE_SIZE*6)),
    transforms.ToTensor(),
    SplitTensorToFiveChannels(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225, 0.225, 0.225]),
])

transform_val = transforms.Compose([
    transforms.Resize((CFG.IMAGE_SIZE, CFG.IMAGE_SIZE*6)),
    transforms.ToTensor(),
    SplitTensorToFiveChannels(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225, 0.225, 0.225]),
])

In [11]:
train_dataset = datasets.ImageFolder(
    CFG.TRAINDIR, transform=transform_train)

val_dataset = datasets.ImageFolder(
    CFG.VALDIR, transform=transform_val)

In [12]:
train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=CFG.TRAIN_BATCH, shuffle=True,
        num_workers=CFG.WORKERS, pin_memory=True, sampler=None)

val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=CFG.VAL_BATCH, shuffle=False,
        num_workers=CFG.WORKERS, pin_memory=True, sampler=None)



## Helpful Dataset Debug Functions

In [13]:
## Visualize an image from a batch
# trainBatch = next(iter(train_loader))

# fig = plt.figure(figsize= (15, 15))

# for label, minibatch in enumerate(trainBatch):
#     print(f'Label: {list(train_dataset.class_to_idx)[label]}') 
#     print(f'Label index: {label}')
#     for i in range(5):
#         img = minibatch[0][i].unsqueeze(0)
#         img = np.array(img.mul(255), dtype=np.uint8)
#         ax = fig.add_subplot(1, 5, i+1)
#         ax.imshow(img[0], cmap='gray')
#     break

In [14]:
# train_loader.dataset.class_to_idx

In [15]:
# print(dict(Counter(train_dataset.targets)))

In [16]:
# weights = get_class_weights(train_dataset)
# weights

## nn - MaxViT

In [17]:
class MaxViTModule(nn.Module):
    def __init__(self, checkpoint=None):
        super().__init__()
        self.model_name = CFG.MODEL_NAME
        self.model = timm.create_model(
            CFG.MODEL_NAME,
            in_chans=CFG.IN_CHANS,
            pretrained=CFG.PRETRAINED, 
            num_classes=CFG.NUM_CLASSES)

    def forward(self, x):
        x = self.model(x)
        return x
    
    def freeze(self):
        # To freeze the residual layers
        for param in self.model.parameters():
            param.requires_grad = False

        for param in self.model.head.parameters():
            param.requires_grad = True
    
    def unfreeze(self):
        # Unfreeze all layers
        for param in self.model.parameters():
            param.requires_grad = True

# Pytorch Lightning

## Define a Lightening Module

In [23]:
# define the LightningModule
class LitMaxViT(pl.LightningModule):
    def __init__(self, model_name, optimizer_name, optimizer_hparams):
        """
        Inputs:
            model_name - Name of the model/CNN to run. Used for creating the model (see function below)
            model_hparams - Hyperparameters for the model, as dictionary.
            optimizer_name - Name of the optimizer to use. Currently supported: Adam, SGD
            optimizer_hparams - Hyperparameters for the optimizer, as dictionary. This includes learning rate, weight decay, etc.
        """
        super().__init__()
        # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
        self.save_hyperparameters()
        # Create model
        self.model = timm.create_model(model_name=CFG.MODEL_NAME, in_chans=CFG.IN_CHANS, pretrained=CFG.PRETRAINED, num_classes=CFG.NUM_CLASSES)
        # Create loss module
        self.loss_module = nn.CrossEntropyLoss()
        # Example input for visualizing the graph in Tensorboard
        self.example_input_array = torch.zeros((1, 5, 224, 224), dtype=torch.float32)
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=CFG.NUM_CLASSES)
        self.f1 = torchmetrics.classification.MulticlassF1Score(num_classes=CFG.NUM_CLASSES)
        # self.multiclassconfusionmatrix = torchmetrics.classification.MulticlassConfusionMatrix(num_classes=CFG.NUM_CLASSES, normalize='true')

    def forward(self, imgs):
        # Forward function that is run when visualizing the graph
        return self.model(imgs)

    def configure_optimizers(self):
        # We will support Adam or SGD as optimizers.
        if self.hparams.optimizer_name == "Adam":
            # AdamW is Adam with a correct implementation of weight decay (see here
            # for details: https://arxiv.org/pdf/1711.05101.pdf)
            optimizer = optim.AdamW(self.parameters(), **self.hparams.optimizer_hparams)
        elif self.hparams.optimizer_name == "SGD":
            optimizer = optim.SGD(self.parameters(), **self.hparams.optimizer_hparams)
        else:
            assert False, f'Unknown optimizer: "{self.hparams.optimizer_name}"'

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # "batch" is the output of the training data loader.
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss_module(preds, labels)
        acc = self.accuracy(preds, labels)
        f1 = self.f1(preds, labels)
        # mc_conf_matrix = self.multiclassconfusionmatrix(preds, labels)

        # Logs the accuracy per epoch to tensorboard (weighted average over batches)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_f1", f1, on_epoch=True)
        # self.log("train_mc_conf_matrix", mc_conf_matrix, on_epoch=True)
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return loss  # Return tensor to call ".backward" on

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs).argmax(dim=-1)
        acc = self.accuracy(preds, labels)
        f1 = self.f1(preds, labels)
        # mc_conf_matrix = self.multiclassconfusionmatrix(preds, labels)
        # By default logs it per epoch (weighted average over batches)
        self.log("val_acc", acc, on_epoch=True)
        self.log("val_f1", f1, on_epoch=True)
        # self.log("val_mc_conf_matrix", mc_conf_matrix, on_epoch=True)


In [19]:
# model = LitMaxViT(MaxViTModule(CFG))
# compiled_model = torch.compile(model)

In [20]:
model_dict = {}


def create_model(model_name, model_hparams):
    if model_name in model_dict:
        return model_dict[model_name](**model_hparams)
    else:
        assert False, f'Unknown model name "{model_name}". Available models are: {str(model_dict.keys())}'

In [21]:
def train_model(model_name, save_name=None, **kwargs):
    """
    Inputs:
        model_name - Name of the model you want to run. Is used to look up the class in "model_dict"
        save_name (optional) - If specified, this name will be used for creating the checkpoint and logging directory.
    """
    if save_name is None:
        save_name = model_name

    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(
        logger=logger,  # We log to wandb
        profiler="simple", 
        default_root_dir=os.path.join(CHECKPOINT_PATH, save_name),  # Where to save models
        # We run on a single GPU (if possible)
        # accelerator="auto",
        # devices=-1,
        # How many epochs to train for if no patience is set
        max_epochs=CFG.EPOCHS,
        callbacks=[
            ModelCheckpoint(
                save_weights_only=True, mode="max", monitor="val_acc"
            ),  # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer
            LearningRateMonitor("epoch"),
            ModelSummary(max_depth=1),
        ],  # Log learning rate every epoch
    )  # In case your notebook crashes due to the progress bar, consider increasing the refresh rate
    trainer.logger._log_graph = True  # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, save_name + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        # Automatically loads the model with the saved hyperparameters
        model = LitMaxViT.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)  # To be reproducable
        model = LitMaxViT(model_name=model_name, **kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = LitMaxViT.load_from_checkpoint(
            trainer.checkpoint_callback.best_model_path
        )  # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
    # test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    result = {"val": val_result[0]["test_acc"]}

    return model, result

In [22]:
torch.set_float32_matmul_precision('high')
model_dict["MaxViT"] = MaxViTModule

maxvit_model, maxvit_results = train_model(
    model_name=CFG.MODEL_NAME,
    # model_hparams={"num_classes": CFG.NUM_CLASSES},
    optimizer_name="Adam",
    optimizer_hparams={"lr": CFG.LR, "weight_decay": CFG.WEIGHT_DECAY},
)

Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 42
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type               | Params | In sizes         | Out sizes
----------------------------------------------------------------------------------
0 | model       | MaxxVit            | 30.4 M | [1, 5, 224, 224] | [1, 13]  
1 | loss_module | CrossEntropyLoss   | 0      | ?                | ?        
2 | accuracy    | MulticlassAccuracy | 0      | ?                | ?        
3 | f1          | MulticlassF1Score  | 0      | ?                | ?        
----------------------------------------------------------------------------------
30.4 M    Trainable params
0         Non-trainable para

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


MisconfigurationException: No `test_step()` method defined to run `Trainer.test`.