# init

## Import Packages

In [1]:
import os
import time
# import warnings
import pandas as pd
import numpy as np
import cv2
import copy 
import matplotlib.pyplot as plt
import wandb
import uuid
import tempfile
from datetime import date

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader

# import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import timm
from tqdm import tqdm

# Config

In [2]:
cudnn.deterministic = True
cudnn.benchmark = True


In [3]:
# Print num GPUs available
print(f"GPU(s) available: {torch.cuda.device_count()}") 
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


GPU(s) available: 1
Device: cuda:0


In [4]:
print(
    len(next(os.walk("/data/home/ec2-user/broad/training_images/BBBC037"+"/train"))[1]),
    len(next(os.walk("/data/home/ec2-user/broad/training_images/BBBC037"+"/val"))[1]),
    len(next(os.walk("/data/home/ec2-user/broad/training_images/BBBC037/"+"/test"))[1]))

47 47 47


In [5]:
class CFG:
  # Set up data directories
  data_dir="/data/home/ec2-user/broad/training_images/BBBC037/"
  train_dir=data_dir+"/train"
  val_dir=data_dir+"/val"
  test_dir=data_dir+"/test"
  debug = False
  n_gpu = 1
  # device = "cpu" # ['cpu', 'mps']
  img_size = 224
  ### total # of classes in this dataset
  num_classes = 47
  ### model
  model_name = 'maxvit_small_224'
  checkpoint = 'maxvit_small_224'
  pretrained = False
  batch_size = 64
  num_epochs = 20
  in_chans = 5

  ### set only one to True
  save_best_loss = False
  save_best_accuracy = True
  adam_epsilon = 1e-6
  initial_lr = 0.1

  verbose = True

  ### train and validation DataLoaders
  num_workers = 8

  random_seed = 42

  output_dir = '/home/ubuntu' + '/saved_models_cj/' + str(date.today())
  checkpoint_last = output_dir + '/' + model_name + '/checkpoint-last'
  checkpoint_best = output_dir + '/' + model_name + '/checkpoint-best'

## Weights & Biases

In [6]:
os.environ['WANDB_API_KEY']='e2b77d7240d4c1ceee8264dbfbea27d2f30d5331'

class WandBLogger(object):
    def __init__(self, variant, project, prefix=''):
      """
      Args:
        variant: dictionary of hyperparameters
        project: name of project
      """
      log_dir = tempfile.mkdtemp()
      if prefix != '':
          project = '{}--{}'.format(prefix, project)

      wandb.init(
          config=variant,
          project=project,
          dir=log_dir,
          id=uuid.uuid4().hex,
      )

    def log(self, *args, **kwargs):
      wandb.log(*args, **kwargs)

wblogger = WandBLogger(
    variant={
      'initial_learning_rate': CFG.initial_lr,
      'adam_epsilon': CFG.adam_epsilon,
      'num_epochs': CFG.num_epochs,
      'batch_size': CFG.batch_size
    },
    project=f'cellvit',
    prefix='cjdonahoe'
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcjdonahoe[0m ([33mcellvit[0m). Use [1m`wandb login --relogin`[0m to force relogin


# MaxVitClassifier

In [7]:
class MaxVitClassifier(nn.Module):
    def __init__(self, cfg, checkpoint=None):
        super().__init__()
        self.model_name = cfg.model_name
        self.model = timm.create_model(
            cfg.model_name,
            in_chans=cfg.in_chans,
            pretrained=cfg.pretrained, 
            num_classes=cfg.num_classes)
        # n_features = self.model.head.in_features
        # self.model.head = nn.Linear(n_features, num_classes)
        # self.model.fc = nn.Linear(n_features, num_classes)
        if checkpoint:
          self.model.load_state_dict(torch.load(checkpoint), strict=False)

    def forward(self, x):
        x = self.model(x)
        return x
    
    def freeze(self):
        # To freeze the residual layers
        for param in self.model.parameters():
            param.requires_grad = False

        for param in self.model.head.parameters():
            param.requires_grad = True
    
    def unfreeze(self):
        # Unfreeze all layers
        for param in self.model.parameters():
            param.requires_grad = True

# Load Data

In [8]:
class SplitTensorToFiveChannels(object):
    """Convert images in Pytorch Dataset to Tensors with one channel
    for each discrete fluerecent image in a Cell Painting sample."""

    def __call__(self, img):
        # select the first channel since the image is grayscale
        img = img[0,:,:]
        # split the image into the 6 channels and remove the last channel
        img = torch.tensor_split(img,6,dim=1)[:-1]
        # concatenate the 5 channels into a single tensor
        img = torch.stack(img, dim=0)
        return img


In [9]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize(CFG.img_size),
        transforms.ToTensor(),
        SplitTensorToFiveChannels()#,
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(CFG.img_size),
        transforms.ToTensor(),
        SplitTensorToFiveChannels()
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(CFG.img_size),
        transforms.ToTensor(),
        SplitTensorToFiveChannels()
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [10]:
train_dataset = datasets.ImageFolder(CFG.train_dir, data_transforms['train'])
val_dataset = datasets.ImageFolder(CFG.val_dir, data_transforms['val'])
test_dataset = datasets.ImageFolder(CFG.test_dir, data_transforms['test'])

dataloaders = {}
dataloaders['train'] = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, pin_memory=True)
dataloaders['val'] = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers, pin_memory=True)
dataloaders['test'] = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers, pin_memory=True)

In [11]:
print(
    f"Train dataset size: {len(train_dataset)}",
    f"Val dataset size: {len(val_dataset)}",
    f"Test dataset size: {len(test_dataset)}",
    sep = "\n"
)

Train dataset size: 179758
Val dataset size: 51348
Test dataset size: 25718


In [12]:
class_names = train_dataset.classes
print(class_names)


['AKT1_E17K', 'AKT1_WT', 'ARAF_WT', 'ATF2_WT', 'ATF6_1-373', 'BCL2L11_WT', 'BRAF_V600E', 'BRAF_WT', 'CASP8_WT', 'CCND1_WT', 'CDC42_Q61L', 'CDC42_T17N', 'CDC42_WT', 'CDKN1A_WT', 'CEBPA_WT', 'CSNK1E_WT', 'CTNNB1_WT', 'CXXC4_WT', 'E2F1_WT', 'ELK1_WT', 'EMPTY', 'ERBB2_WT', 'GSK3B_WT', 'HRAS_G12V', 'JUN_WT', 'KRAS_G12V', 'KRAS_WT', 'MAP2K1_WT', 'MAP3K2_WT', 'MAP3K9_WT', 'MAPK1_WT', 'MYD88_WT', 'NOTCH1_ICN1', 'PIK3CA_WT', 'PPARGC1A_WT', 'PRKACA_WT', 'PRKCE_WT', 'PTEN_WT', 'RAC1_Q61L', 'RAF1_L613V', 'RAF1_WT', 'RB1_WT', 'RHOA_Q63L', 'RHOA_WT', 'SMAD4_WT', 'STK11_WT', 'XBP1_WT']


# Train Model

In [13]:
import random
def set_seed(cfg):
    random.seed(cfg.random_seed)
    np.random.seed(cfg.random_seed)
    torch.manual_seed(cfg.random_seed)
    if cfg.n_gpu > 0:
        torch.cuda.manual_seed_all(cfg.random_seed)

def train_model(cfg, model, dataloaders, criterion, optimizer, scheduler):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0


    last_checkpoint_path = CFG.checkpoint_last
    last_scheduler_path = os.path.join(last_checkpoint_path, 'scheduler.pt')
    last_optimizer_path = os.path.join(last_checkpoint_path, 'optimizer.pt')
    best_checkpoint_path = CFG.checkpoint_best
    best_scheduler_path = os.path.join(best_checkpoint_path, 'scheduler.pt')
    best_optimizer_path = os.path.join(best_checkpoint_path, 'optimizer.pt')

    for epoch in range(cfg.num_epochs):
        print('Epoch {}/{}'.format(epoch, cfg.num_epochs - 1))
        print('-' * 10)

        wblogdict = {}

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.autocast(device_type='cuda'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)


            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            wblogdict[f'{phase}/loss'] = np.round(epoch_loss, 4)
            wblogdict[f'{phase}/acc'] = np.round(epoch_acc.cpu(), 3)

            # if phase == "train":
            #   wblogdict['train/learning_rate'] = CFG.learning_rate

            if not os.path.exists(last_checkpoint_path):
                os.makedirs(last_checkpoint_path)
            
            torch.save(model.state_dict(), last_checkpoint_path + f"/MaxVitModel_ep{epoch_acc}.pth")
            torch.save(optimizer.state_dict(), last_optimizer_path)

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            
                if not os.path.exists(best_checkpoint_path):
                    os.makedirs(best_checkpoint_path)

                torch.save(model.state_dict(), best_checkpoint_path + f"/MaxVitModel_ep{best_acc}.pth")
                torch.save(optimizer.state_dict(), best_optimizer_path)
  
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        wblogger.log(wblogdict)
        print()

        scheduler.step()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

In [14]:
GPU = 0

# model.cuda(GPU)

criterion = nn.CrossEntropyLoss().cuda(GPU)

model_ft = MaxVitClassifier(CFG)
model_ft = model_ft.to(device)

params_to_update = model_ft.parameters()

optimizer = torch.optim.AdamW(
    model_ft.parameters(), 
    lr=CFG.initial_lr, 
    eps=CFG.adam_epsilon)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.num_epochs, eta_min=1e-8)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [15]:
model_ft, hist = train_model(CFG, model_ft, dataloaders, criterion, optimizer, scheduler)

Epoch 0/19
----------


100%|██████████| 2809/2809 [22:53<00:00,  2.04it/s]


train Loss: nan Acc: 0.0146


  0%|          | 1/803 [00:03<47:21,  3.54s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 148.00 MiB (GPU 0; 22.06 GiB total capacity; 20.28 GiB already allocated; 40.38 MiB free; 20.66 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF