# init

In [1]:
import os
import random
import shutil
import time
import warnings
from datetime import date
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
import gc
import wandb
import uuid
import tempfile

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim

import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.cuda.amp import GradScaler
# from torchmetrics.classification import MulticlassAccuracy, F1

import timm
from tqdm import tqdm, trange

from IPython.display import Image 

In [2]:
GPU=0
SEED=1

random.seed(SEED)
torch.manual_seed(SEED)
cudnn.deterministic = True

In [3]:
# Print num GPUs available
print(f"GPU(s) available: {torch.cuda.device_count()}") 
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

GPU(s) available: 1
Device: cuda:0


## Config

In [4]:
class CFG:
  ARCH = 'maxvit_large_tf_224'
  START_EPOCH = 0
  EPOCHS = 20
  LR = 0.1
  MOMENTUM = 0.9
  WEIGHT_DECAY = 1e-4
  ADAM_EPSILON = 1e-6
  PRINT_FREQ = 200
  TRAIN_BATCH = 22
  VAL_BATCH = 22
  WORKERS = 4
  DATADIR = "/data/home/ec2-user/broad/training_images/BBBC037/"
  TRAINDIR = DATADIR+"train"
  VALDIR = DATADIR+"val"
  TESTDIR = DATADIR+"test"

  PRETRAINED = False
  IMAGE_SIZE = 224
  IN_CHANS = 5
  NUM_CLASSES = 46

  LEARNING_RATE = 0.01
  ADAM_EPSILON = 1e-6
  # WEIGHT_DECAY = 0.01 # for adamw
  L2_PENALTY = 0.01 # for RMSprop
  RMS_MOMENTEM = 0 # for RMSprop

  ### learning rate scheduler (LRS)
  # scheduler = 'ReduceLROnPlateau' # []
  # scheduler = 'CosineAnnealingLR'
  PLATEAU_FACTOR = 0.5
  PLATEAU_PATIENCE = 3

  RANDOM_SEED = 42

  OUTPUT_DIR = '/home/ubuntu' + '/saved_models/' + str(date.today())
  CHECKPOINT_LAST = OUTPUT_DIR + '/' + ARCH + '/checkpoint-last'
  CHECKPOINT_BEST = OUTPUT_DIR + '/' + ARCH + '/checkpoint-best'

  WANDB_NOTEBOOK_NAME = str(date.today()) + '_' + ARCH + '_cjdonahoe'


## W&B

In [5]:
os.environ['WANDB_API_KEY']='0f110ed76bb0d63c6552597f89bfc99bf2469e43'

class WandBLogger(object):
    def __init__(self, variant, project, prefix=''):
      """
      Args:
        variant: dictionary of hyperparameters
        project: name of project
      """
      log_dir = tempfile.mkdtemp()
      if prefix != '':
          project = '{}--{}'.format(prefix, project)

      wandb.init(
          config=variant,
          project=project,
          dir=log_dir,
          id=uuid.uuid4().hex,
      )

    def log(self, *args, **kwargs):
      wandb.log(*args, **kwargs)

wblogger = WandBLogger(
    variant={
      'initial_learning_rate': CFG.LR,
      'adam_epsilon': CFG.ADAM_EPSILON,
      'num_epochs': CFG.EPOCHS,
      'batch_size': CFG.TRAIN_BATCH,
      'weight_decay': CFG.WEIGHT_DECAY,
      'architecture': CFG.ARCH,
      'pretrained': CFG.PRETRAINED,
    },
    project=f'cellvit',
    prefix='cjdonahoe'
)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcjdonahoe[0m ([33mcellvit[0m). Use [1m`wandb login --relogin`[0m to force relogin


# Functions

In [6]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [7]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

In [8]:
class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [9]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [10]:
def get_class_weights(dataset):
    ''' Get class weights for a dataset
    Args:
        dataset: torch.utils.data.Dataset
    Returns:
        class_weights: torch.FloatTensor
    '''
    
    class_counts = Counter(dataset.targets)
    n_classes = len(class_counts.keys())
    total_count = len(dataset.targets)
    class_weights = list({class_id: total_count/(n_classes * class_counts) for class_id, class_counts in class_counts.items()}.values())
    class_weights = torch.FloatTensor(class_weights).cuda()
    return class_weights

In [11]:
class SplitTensorToFiveChannels(object):
    """Convert images in Pytorch Dataset to Tensors with one channel
    for each discrete fluerecent image in a Cell Painting sample."""

    def __call__(self, img):
        # select the first channel since the image is grayscale
        img = img[0,:,:]
        # split the image into the 6 channels and remove the last channel
        img = torch.tensor_split(img,6,dim=1)[:-1]
        # concatenate the 5 channels into a single tensor
        img = torch.stack(img, dim=0)
        return img


## MaxVitClassifier

In [12]:
class MaxVitClassifier(nn.Module):
    def __init__(self, checkpoint=None):
        super().__init__()
        self.model_name = CFG.ARCH
        self.model = timm.create_model(
            CFG.ARCH,
            in_chans=CFG.IN_CHANS,
            pretrained=CFG.PRETRAINED, 
            num_classes=CFG.NUM_CLASSES)
        # n_features = self.model.head.in_features
        # self.model.head = nn.Linear(n_features, num_classes)
        # self.model.fc = nn.Linear(n_features, num_classes)
        # if checkpoint:
        #   self.model.load_state_dict(torch.load(checkpoint), strict=False)

    def forward(self, x):
        x = self.model(x)
        return x
    
    def freeze(self):
        # To freeze the residual layers
        for param in self.model.parameters():
            param.requires_grad = False

        for param in self.model.head.parameters():
            param.requires_grad = True
    
    def unfreeze(self):
        # Unfreeze all layers
        for param in self.model.parameters():
            param.requires_grad = True

# Train & Validation Functions

In [26]:
device = torch.device("cuda:0" if torch.cuda.is_available() else CFG.device)

In [13]:
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # Grad Scaler
    scaler = GradScaler()
    # switch to train mode
    model.train()

    running_loss = 0.0
    running_corrects = 0

    end = time.time()
    for images, target in enumerate(train_loader):
        optimizer.zero_grad()
        # measure data loading time
        data_time.update(time.time() - end)

        # move data to GPU
        images = images.to(device)
        target = target.to(device)

        # compute output
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            output = model(images)
            loss = criterion(output, target)

            _, preds = torch.max(output, 1)
        
        running_loss += loss.item() * images.size(0)
        running_corrects += torch.sum(preds == target.data)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))
        
        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # if i % CFG.PRINT_FREQ == 0:
        #     progress.display(i)

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = running_corrects.double() / len(train_loader.dataset)
    
    # print('{} Loss: {:.4f} Acc: {:.4f}'.format("train", epoch_loss, epoch_acc))
    wandb.log({"train/loss": losses.avg, 'train/acc1': top1.avg, 'train/acc5': top5.avg})

    wblogdict[f'{"train"}/loss'] = np.round(epoch_loss, 4)
    wblogdict[f'{"train"}/acc'] = np.round(epoch_acc.cpu(), 3)
    wblogdict['train/learning_rate'] = CFG.LEARNING_RATE

In [14]:
def validate(val_loader, model, criterion):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Validation: ')

    # switch to evaluate mode
    model.eval()

    running_loss = 0.0
    running_corrects = 0

    with torch.no_grad():
        end = time.time()
        for images, target in enumerate(val_loader):
            # move data to GPU
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            _, preds = torch.max(output, 1)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # if i % CFG.PRINT_FREQ == 0:
            #     progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))
        
        running_loss += loss.item() * images.size(0)
        running_corrects += torch.sum(preds == target.data)

    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_acc = running_corrects.double() / len(val_loader.dataset)
    
    print('{} Loss: {:.4f} Acc: {:.4f}'.format("Validation", epoch_loss, epoch_acc))

    wblogdict[f'{"val"}/loss'] = np.round(epoch_loss, 4)
    wblogdict[f'{"val"}/acc'] = np.round(epoch_acc.cpu(), 3)

    scheduler.step(epoch_loss)

    wandb.log({"val/loss": losses.avg, 'val/acc1': top1.avg, 'val/acc5': top5.avg})
    # wandb.log({'lr': scheduler.get_last_lr()[0]})
    return top1.avg, top5.avg

# Transformations, Datasets, & Dataloaders

In [15]:
transform_train = transforms.Compose([
    transforms.RandomVerticalFlip(),
    transforms.Resize((CFG.IMAGE_SIZE, CFG.IMAGE_SIZE*6)),
    transforms.ToTensor(),
    SplitTensorToFiveChannels(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225, 0.225, 0.225]),
])

transform_val = transforms.Compose([
    transforms.Resize((CFG.IMAGE_SIZE, CFG.IMAGE_SIZE*6)),
    transforms.ToTensor(),
    SplitTensorToFiveChannels(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225, 0.225, 0.225]),
])


In [16]:
train_dataset = datasets.ImageFolder(
    CFG.TRAINDIR, transform=transform_train)

val_dataset = datasets.ImageFolder(
    CFG.VALDIR, transform=transform_val)

In [17]:
# sample_weight = get_class_weights(train_dataset)
# # sampler = data.WeightedRandomSampler(sample_weight, len(train_dataset))
# sampler = data.WeightedRandomSampler(sample_weight, 1000*CFG.BATCH_SIZE)

In [18]:
train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=CFG.TRAIN_BATCH, shuffle=True,
        num_workers=CFG.WORKERS, pin_memory=True, sampler=None)

In [19]:
val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=CFG.VAL_BATCH, shuffle=True,
        num_workers=CFG.WORKERS, pin_memory=True, sampler=None)

# Training Loop

In [20]:
model = torch.compile(MaxVitClassifier(CFG))
model = model.cuda(GPU)


In [21]:
# define loss function (criterion) and optimizer
# get the class weights from the validation set weight=get_class_weights(train_dataset)
criterion = nn.CrossEntropyLoss(weight=get_class_weights(train_dataset)).cuda(GPU)

# optimizer = torch.optim.Adam(
#   model.parameters(), 
#   lr=CFG.LR, 
#   # momentum=MOMENTUM, 
#   weight_decay=CFG.WEIGHT_DECAY
#   )

# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.EPOCHS)

# optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.LEARNING_RATE)

optimizer = torch.optim.Adam(model.parameters(), lr=CFG.LEARNING_RATE)
# optimizer = torch.optim.SGD(
#     model.parameters(),
#     lr=CFG.LEARNING_RATE,
#     momentum=CFG.MOMENTUM,
#     weight_decay=CFG.WEIGHT_DECAY
#     )
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max = CFG.EPOCHS, eta_min = 1e-4)
    # T_0 = 8,# Number of iterations for the first restart
    # T_mult = 1, # A factor increases TiTi​ after a restart
    # eta_min = 1e-4) # Minimum learning rate

torch.set_float32_matmul_precision('high')

In [22]:
best_acc1 = 0

In [25]:
%%time
for epoch in tqdm(range(CFG.START_EPOCH, CFG.EPOCHS)):
    print('Epoch {}/{}'.format(epoch, CFG.EPOCHS - 1))
    print('-' * 20)
    
    wblogdict = {}
    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch)

    # evaluate on validation set
    acc1, acc5 = validate(val_loader, model, criterion)

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)

    save_checkpoint({
        'epoch': epoch + 1,
        'arch': CFG.ARCH,
        'state_dict': model.state_dict(),
        'best_acc1': best_acc1,
        'acc5': acc5,
        'optimizer' : optimizer.state_dict(),
    }, is_best)

    wblogger.log(wblogdict, step=epoch)

    scheduler.step()
    print('lr: ' + str(scheduler.get_last_lr()[0]))
    wandb.log({'lr': scheduler.get_last_lr()[0]})
    
    # scheduler.step(epoch_loss)
    # print('lr: ' + str(scheduler.get_last_lr()[0]))

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 0/19
--------------------


  0%|          | 0/20 [00:01<?, ?it/s]


AttributeError: 'int' object has no attribute 'cuda'