In [1]:
# !pip install torch torchvision
!pip install datasets
!pip install transformers
!pip install --pre timm
!pip install wandb
!pip install opencv-python

Collecting transformers
  Downloading transformers-4.27.3-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m65.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m116.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting regex!=2019.12.17
  Downloading regex-2023.3.23-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (768 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m769.0/769.0 kB[0m [31m84.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, regex, transformers
Successfully installed regex-2023.3.23 tokenizers-0.13.2 transformers-4.27.3
Collecting timm
  Downloading timm-0.8.17.dev0-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━

In [1]:
from __future__ import print_function, division
import os
import torch
import timm
import pandas as pd
# from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.optim import Adam, AdamW

from datasets import Dataset
# from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
import time
import copy
import cv2
import wandb
import uuid
import tempfile
from datetime import datetime, date

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

<matplotlib.pyplot._IonContext at 0x7f974842a2e0>

In [2]:
CELL_PAINTING_DIR = '/home/ubuntu/src'
MODEL_OUTPUT_DIR = '/dev/shm'
CACHE_DIR = CELL_PAINTING_DIR + "/cache"
BROAD_DIR = CELL_PAINTING_DIR + '/data/cpg0019-moshkov-deepprofiler/broad'

In [3]:
class_labels_to_int = {
    'AKT1_E17K': 0,
    'AKT1_WT': 1,
    'BRAF_V600E': 2,
    'BRAF_WT': 3,
    'CDC42_Q61L': 4,
    'CDC42_T17N': 5,
    'CDC42_WT': 6,
    'EMPTY': 7,
    'KRAS_G12V': 8,
    'KRAS_WT': 9,
    'RAF1_L613V': 10,
    'RAF1_WT': 11,
    'RHOA_Q63L': 12,
    'RHOA_WT': 13
}
print(len(class_labels_to_int))

14


### Prepare train/dev dataset (Optional)

In [5]:
def add_image_paths(row, delim='../../', prefix=BROAD_DIR):
    _, path_suffix = row.Image_Name.split(delim)
    path = prefix + "/" + path_suffix
    return path

def do_the_thing(dataset_name: str):
    from data_fetcher import DataFetcher
    dataset = pd.read_parquet(CACHE_DIR + f"/{dataset_name}_metadata.parquet", engine="pyarrow")

    c037_27k_dataset['Path'] = c037_27k_dataset.apply(lambda x: add_image_paths(x), axis=1)
    c037_27k_dataset['Labels'] = c037_27k_dataset["Treatment"].replace(
          to_replace=class_labels_to_int)
    
    c037_27k_dataset_clean = DataFetcher.clean_data(c037_27k_dataset)
    c037_27k_dataset_clean.shape

    DataFetcher.create_train_dev_split(c037_27k_dataset_clean, CACHE_DIR)

    # Add channel information to each train and dev dataset
    for dataset_t in ["train", "dev"]:
    dataset_file = CACHE_DIR + f"/{dataset_t}/data.parquet"
    dataset_df = pd.read_parquet(dataset_file, engine="pyarrow")
    dataset_df = DataFetcher.repeat_channels(dataset_df)
    dataset_df = dataset_df.drop(columns=['__index_level_0__'])
    dataset_df.to_parquet(dataset_file, engine="pyarrow")
    print(f"Added channels to {dataset_file}, shape is: {dataset_df.shape}")


In [8]:
c037_27k_dataset['Path'] = c037_27k_dataset.apply(lambda x: add_image_paths(x), axis=1)
c037_27k_dataset['Labels'] = c037_27k_dataset["Treatment"].replace(
          to_replace=class_labels_to_int)
print(c037_27k_dataset.shape)
c037_27k_dataset.head()

(27448, 16)


Unnamed: 0,Collection,Metadata_Plate,Metadata_Well,Metadata_Site,Nuclei_Location_Center_X,Nuclei_Location_Center_Y,Image_Name,Treatment,Treatment_Type,Control,Cell_line,LeaveReplicatesOut,LeaveCellsOut,PathId,Path,Labels
8255825,BBBC037,41754,l09,2,776,538,../../training_images/BBBC037/41754/l09/2/105@...,EMPTY,ORF,Control,U2OS,Training,NotUsed,105@776x538.png,/home/ubuntu/src/data/cpg0019-moshkov-deepprof...,7
8314697,BBBC037,41755,m16,2,522,724,../../training_images/BBBC037/41755/m16/2/105@...,EMPTY,ORF,Control,U2OS,Training,NotUsed,105@522x724.png,/home/ubuntu/src/data/cpg0019-moshkov-deepprof...,7
8321958,BBBC037,41755,p23,5,953,186,../../training_images/BBBC037/41755/p23/5/105@...,EMPTY,ORF,Control,U2OS,Training,NotUsed,105@953x186.png,/home/ubuntu/src/data/cpg0019-moshkov-deepprof...,7
8264069,BBBC037,41754,o21,3,525,56,../../training_images/BBBC037/41754/o21/3/105@...,EMPTY,ORF,Control,U2OS,Training,NotUsed,105@525x56.png,/home/ubuntu/src/data/cpg0019-moshkov-deepprof...,7
8360445,BBBC037,41756,k08,8,995,81,../../training_images/BBBC037/41756/k08/8/105@...,EMPTY,ORF,Control,U2OS,NotUsed,NotUsed,105@995x81.png,/home/ubuntu/src/data/cpg0019-moshkov-deepprof...,7


In [9]:
c037_27k_dataset_clean = DataFetcher.clean_data(c037_27k_dataset)
c037_27k_dataset_clean.shape

(27448, 16)

In [10]:
DataFetcher.create_train_dev_split(c037_27k_dataset_clean, CACHE_DIR)

Saved dev data with shape (5490, 17) to /home/ubuntu/src/cache/dev
Saved train data with shape (21958, 17) to /home/ubuntu/src/cache/train


In [11]:
# Add channel information to each train and dev dataset

for dataset_t in ["train", "dev"]:
  dataset_file = CACHE_DIR + f"/{dataset_t}/data.parquet"
  dataset_df = pd.read_parquet(dataset_file, engine="pyarrow")
  dataset_df = DataFetcher.repeat_channels(dataset_df)
  dataset_df = dataset_df.drop(columns=['__index_level_0__'])
  dataset_df.to_parquet(dataset_file, engine="pyarrow")
  print(f"Added channels to {dataset_file}, shape is: {dataset_df.shape}")

Added channels to /home/ubuntu/src/cache/train/data.parquet, shape is: (109790, 17)
Added channels to /home/ubuntu/src/cache/dev/data.parquet, shape is: (27450, 17)


### Create a custom pytorch Dataset

### Check to see if custom Dataset is working

In [None]:
# training_data = MaxVitDataset(CFG, "train")
# dev_data = MaxVitDataset(CFG, "dev")

In [None]:
# print("Training and dev data sizes")
# print(len(training_data))
# print(len(dev_data))
# print("Training and dev data at idx")
# print(training_data[0][0].shape)
# print(training_data[0][0].dtype)
# print(training_data[0][0])
# print(training_data[0][1])

### Inspect a subset of images

In [None]:
import random
from typing import List

# 1. Take in a Dataset as well as a list of class names
def display_random_images(dataset: torch.utils.data.dataset.Dataset,
                          classes: List[str] = None,
                          n: int = 10,
                          display_shape: bool = True,
                          seed: int = None):
    
    # 2. Adjust display if n too high
    if n > 10:
        n = 10
        display_shape = False
        print(f"For display purposes, n shouldn't be larger than 10, setting to 10 and removing shape display.")
    
    # 3. Set random seed
    if seed:
        random.seed(seed)

    # 4. Get random sample indexes
    random_samples_idx = random.sample(range(len(dataset)), k=n)

    # 5. Setup plot
    plt.figure(figsize=(16, 8))

    # 6. Loop through samples and display random samples 
    for i, targ_sample in enumerate(random_samples_idx):
        targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]

        # 7. Adjust image tensor shape for plotting: [color_channels, height, width] -> [color_channels, height, width]
        targ_image_adjust = targ_image.permute(1, 2, 0)

        # Plot adjusted samples
        plt.subplot(1, n, i+1)
        plt.imshow(targ_image_adjust)
        plt.axis("off")
        title = f"class: {targ_label}"
        if display_shape:
            title = title + f"\nshape: {targ_image_adjust.shape}"
        plt.title(title)

In [None]:
# display_random_images(training_data, n=5)

### Fine Tuning With Timm

In [12]:
timm.list_models("**maxvit**", pretrained=True)

['maxvit_base_tf_224.in1k',
 'maxvit_base_tf_384.in1k',
 'maxvit_base_tf_384.in21k_ft_in1k',
 'maxvit_base_tf_512.in1k',
 'maxvit_base_tf_512.in21k_ft_in1k',
 'maxvit_large_tf_224.in1k',
 'maxvit_large_tf_384.in1k',
 'maxvit_large_tf_384.in21k_ft_in1k',
 'maxvit_large_tf_512.in1k',
 'maxvit_large_tf_512.in21k_ft_in1k',
 'maxvit_nano_rw_256.sw_in1k',
 'maxvit_rmlp_base_rw_224.sw_in12k',
 'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k',
 'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k',
 'maxvit_rmlp_nano_rw_256.sw_in1k',
 'maxvit_rmlp_pico_rw_256.sw_in1k',
 'maxvit_rmlp_small_rw_224.sw_in1k',
 'maxvit_rmlp_tiny_rw_256.sw_in1k',
 'maxvit_small_tf_224.in1k',
 'maxvit_small_tf_384.in1k',
 'maxvit_small_tf_512.in1k',
 'maxvit_tiny_rw_224.sw_in1k',
 'maxvit_tiny_tf_224.in1k',
 'maxvit_tiny_tf_384.in1k',
 'maxvit_tiny_tf_512.in1k',
 'maxvit_xlarge_tf_384.in21k_ft_in1k',
 'maxvit_xlarge_tf_512.in21k_ft_in1k']

In [4]:
class CFG:
  data_dir = CACHE_DIR
  debug = False
  n_gpu = 1
  # device = "cpu" # ['cpu', 'mps']
  img_size = 224
  ### total # of classes in this dataset
  num_classes = len(class_labels_to_int)
  ### model
  model_name = 'maxvit_large_tf_224'
  checkpoint = 'maxvit_large_tf_224'
  pretrained=True
  batch_size = 20
  num_epochs = 30
  num_workers = 16

  ### set only one to True
  save_best_loss = False
  save_best_accuracy = True

  optimizer = 'adamw' # ["rmsprop", "adam"]
  learning_rate = 5e-5
  adam_epsilon = 1e-6
  weight_decay = 1e-8 # for adamw
  l2_penalty = 0.01 # for RMSprop
  rms_momentum = 0 # for RMSprop

  ### learning rate scheduler (LRS)
  scheduler = 'ReduceLROnPlateau' # []
  # scheduler = 'CosineAnnealingLR'
  plateau_factor = 0.5
  plateau_patience = 3
  cosine_T_max = 4
  cosine_eta_min = 1e-8
  verbose = True

  ### train and validation DataLoaders
  shuffle = True

  random_seed = 42

  output_dir = MODEL_OUTPUT_DIR + '/' + str(date.today())
  checkpoint_last = output_dir + '/' + model_name + '/checkpoint-last'
  checkpoint_best = output_dir + '/' + model_name + '/checkpoint-best'

In [5]:
os.environ['WANDB_API_KEY']='808606b1ec54e09c37c9c19ea6bb8d5b8a679987'

class WandBLogger(object):
    def __init__(self, variant, project, prefix=''):
      """
      Args:
        variant: dictionary of hyperparameters
        project: name of project
      """
      log_dir = tempfile.mkdtemp()
      if prefix != '':
          project = '{}--{}'.format(prefix, project)

      wandb.init(
          config=variant,
          project=project,
          dir=log_dir,
          id=uuid.uuid4().hex,
      )

    def log(self, *args, **kwargs):
      wandb.log(*args, **kwargs)

wblogger = WandBLogger(
    variant={
      'initial_learning_rate': CFG.learning_rate,
      'adam_epsilon': CFG.adam_epsilon,
      'num_epochs': CFG.num_epochs,
      'batch_size': CFG.batch_size
    },
    project=f'cellvit-{CFG.model_name}',
    prefix='Can-c037_27k'
)

[34m[1mwandb[0m: Currently logged in as: [33mcankoc[0m ([33mcellvit[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
from maxvit_dataset import MaxVitDataset

In [7]:
class MaxVitClassifier(nn.Module):
    def __init__(self, cfg, checkpoint=None):
        super().__init__()
        self.model_name = cfg.model_name
        self.model = timm.create_model(cfg.model_name, 
                                       pretrained=cfg.pretrained, 
                                       num_classes=cfg.num_classes)
        # n_features = self.model.head.in_features
        # self.model.head = nn.Linear(n_features, num_classes)
        # self.model.fc = nn.Linear(n_features, num_classes)
        if checkpoint:
          self.model.load_state_dict(torch.load(checkpoint), strict=False)

    def forward(self, x):
        x = self.model(x)
        return x
    
    def freeze(self):
        # To freeze the residual layers
        for param in self.model.parameters():
            param.requires_grad = False

        for param in self.model.head.parameters():
            param.requires_grad = True
    
    def unfreeze(self):
        # Unfreeze all layers
        for param in self.model.parameters():
            param.requires_grad = True

In [8]:
# Data augmentation and normalization for training
# Just normalization for validation

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'dev': transforms.Compose([
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

image_datasets = {x: MaxVitDataset(CFG, split=x,
                                   transform=data_transforms[x])
                  for x in ['train', 'dev']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                              batch_size=CFG.batch_size,
                                              num_workers=CFG.num_workers,
                                              shuffle=True)
              for x in ['train', 'dev']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'dev']}
# class_names = image_datasets['train'].classes
print(f"Dataset sizes: {dataset_sizes}")

device = torch.device("cuda:0" if torch.cuda.is_available() else CFG.device)
print(device)

Dataset sizes: {'train': 109790, 'dev': 27450}
cuda:0


In [9]:
import random
def set_seed(cfg):
    random.seed(cfg.random_seed)
    np.random.seed(cfg.random_seed)
    torch.manual_seed(cfg.random_seed)
    if cfg.n_gpu > 0:
        torch.cuda.manual_seed_all(cfg.random_seed)

def train_model(cfg, model, dataloaders, criterion, optimizer):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0


    last_checkpoint_path = CFG.checkpoint_last
    last_scheduler_path = os.path.join(last_checkpoint_path, 'scheduler.pt')
    last_optimizer_path = os.path.join(last_checkpoint_path, 'optimizer.pt')
    best_checkpoint_path = CFG.checkpoint_best
    best_scheduler_path = os.path.join(best_checkpoint_path, 'scheduler.pt')
    best_optimizer_path = os.path.join(best_checkpoint_path, 'optimizer.pt')

    for epoch in range(cfg.num_epochs):
        print('Epoch {}/{}'.format(epoch, cfg.num_epochs - 1))
        print('-' * 10)

        wblogdict = {}

        # Each epoch has a training and validation phase
        for phase in ['train', 'dev']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            wblogdict[f'{phase}/loss'] = np.round(epoch_loss, 4)
            wblogdict[f'{phase}/acc'] = np.round(epoch_acc.cpu(), 4)

            if phase == "train":
              wblogdict['train/learning_rate'] = CFG.learning_rate

            if not os.path.exists(last_checkpoint_path):
                os.makedirs(last_checkpoint_path)
            
            # torch.save(model.state_dict(), last_checkpoint_path + f"/MaxVitModel_ep{epoch_acc}.pth")
            # torch.save(optimizer.state_dict(), last_optimizer_path)

            # deep copy the model
            if phase == 'dev' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            
                if not os.path.exists(best_checkpoint_path):
                    os.makedirs(best_checkpoint_path)

                torch.save(model.state_dict(), best_checkpoint_path + f"/MaxVitModel_ep{best_acc}.pth")
                torch.save(optimizer.state_dict(), best_optimizer_path)
  
            if phase == 'dev':
                val_acc_history.append(epoch_acc)
                # scheduler.step(epoch_loss)

        wblogger.log(wblogdict)
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

In [10]:
# model_ft = models.resnet18(pretrained=True)
# num_ftrs = model_ft.fc.in_features
# # Here the size of each output sample is set to 2.
# # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
# model_ft.fc = nn.Linear(num_ftrs, 2)

set_seed(CFG)

checkpoint = CELL_PAINTING_DIR + '/MaxVitModel_ep0.669.pth'
model_ft = MaxVitClassifier(CFG, checkpoint=checkpoint)
model_ft = model_ft.to(device)

params_to_update = model_ft.parameters()
# print("Params to learn:")

# for name,param in model_ft.named_parameters():
#     if param.requires_grad == True:
#             print("\t",name)

# Observe that all parameters are being optimized
# optimizer_ft = optim.SGD(model_ft.parameters(), lr=5e-5, momentum=0.9)

optimizer_ft = AdamW(model_ft.parameters(), lr=CFG.learning_rate, eps=CFG.adam_epsilon, weight_decay=CFG.weight_decay)
# scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, factor=CFG.plateau_factor, patience=CFG.plateau_patience)

criterion = nn.CrossEntropyLoss()

In [11]:
model_ft, hist = train_model(CFG, model_ft, dataloaders, criterion, optimizer_ft)

Epoch 0/29
----------


  2%|█▌                                                                                                        | 104/6862 [01:09<1:14:59,  1.50it/s]


KeyboardInterrupt: 

wandb: Waiting for W&B process to finish... (success).
wandb: - 0.003 MB of 0.003 MB uploaded (0.000 MB deduped)