# init

In [1]:
import os
import random
import shutil
import time
import warnings
from datetime import date
from collections import Counter
import numpy as np
import gc
import wandb
import uuid
import tempfile

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim

import torch.utils.data
import torch.distributed as dist
# import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# import torchvision.models as models
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast

import timm
from tqdm import tqdm, trange

In [None]:
def setup(rank, world_size):
  os.environ['MASTER_ADDR'] = 'localhost'
  os.environ['MASTER_PORT'] = '12355'
  dist.init_process_group("nccl", rank=rank, world_size=world_size)

In [2]:
GPU=0
SEED=1

random.seed(SEED)
torch.manual_seed(SEED)
cudnn.deterministic = True

In [3]:
# Print num GPUs available
print(f"GPU(s) available: {torch.cuda.device_count()}") 
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

GPU(s) available: 1
Device: cuda:0


## Config

In [4]:
class CFG:
  ARCH = 'maxvit_small_224'
  START_EPOCH = 0
  EPOCHS = 20
  LR = 0.1
  MOMENTUM = 0.9
  WEIGHT_DECAY = 1e-4
  ADAM_EPSILON = 1e-7
  PRINT_FREQ = 50
  TRAIN_BATCH = 64
  VAL_BATCH = 64
  WORKERS = 2
  DATADIR = "/data/home/ec2-user/broad/training_images/BBBC037/"
  TRAINDIR = DATADIR+"train"
  VALDIR = DATADIR+"val"
  TESTDIR = DATADIR+"test"

  PRETRAINED = False
  IMAGE_SIZE = 224
  IN_CHANS = 5
  NUM_CLASSES = 47

  RANDOM_SEED = 42

  OUTPUT_DIR = '/home/ubuntu' + '/saved_models/' + str(date.today())
  CHECKPOINT_LAST = OUTPUT_DIR + '/' + ARCH + '/checkpoint-last'
  CHECKPOINT_BEST = OUTPUT_DIR + '/' + ARCH + '/checkpoint-best'

  WANDB_NOTEBOOK_NAME = str(date.today()) + '_' + ARCH + '_cjdonahoe'


## W&B

In [8]:
os.environ['WANDB_API_KEY']='e2b77d7240d4c1ceee8264dbfbea27d2f30d5331'

class WandBLogger(object):
    def __init__(self, variant, project, prefix=''):
      """
      Args:
        variant: dictionary of hyperparameters
        project: name of project
      """
      log_dir = tempfile.mkdtemp()
      if prefix != '':
          project = '{}--{}'.format(prefix, project)

      wandb.init(
          config=variant,
          project=project,
          dir=log_dir,
          id=uuid.uuid4().hex,
      )

    def log(self, *args, **kwargs):
      wandb.log(*args, **kwargs)

wblogger = WandBLogger(
    variant={
      'initial_learning_rate': CFG.LR,
      'adam_epsilon': CFG.ADAM_EPSILON,
      'num_epochs': CFG.EPOCHS,
      'batch_size': CFG.TRAIN_BATCH,
      'weight_decay': CFG.WEIGHT_DECAY,
      'architecture': CFG.ARCH,
    },
    project=f'cellvit',
    prefix='cjdonahoe'
)


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668506549982944, max=1.0â€¦

# Functions

In [9]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [10]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

In [11]:
class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [12]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [13]:
def get_class_weights(dataset):
    ''' Get class weights for a dataset
    Args:
        dataset: torch.utils.data.Dataset
    Returns:
        class_weights: torch.FloatTensor
    '''
    
    class_counts = Counter(dataset.targets)
    total_count = len(dataset.targets)
    class_weights = list({class_id: class_counts/total_count for class_id, class_counts in class_counts.items()}.values())
    class_weights = torch.FloatTensor(class_weights).cuda()
    return class_weights

In [14]:
class SplitTensorToFiveChannels(object):
    """Convert images in Pytorch Dataset to Tensors with one channel
    for each discrete fluerecent image in a Cell Painting sample."""

    def __call__(self, img):
        # select the first channel since the image is grayscale
        img = img[0,:,:]
        # split the image into the 6 channels and remove the last channel
        img = torch.tensor_split(img,6,dim=1)[:-1]
        # concatenate the 5 channels into a single tensor
        img = torch.stack(img, dim=0)
        return img


## MaxVitClassifier

In [15]:
class MaxVitClassifier(nn.Module):
    def __init__(self, checkpoint=None):
        super().__init__()
        self.model_name = CFG.ARCH
        self.model = timm.create_model(
            CFG.ARCH,
            in_chans=CFG.IN_CHANS,
            pretrained=CFG.PRETRAINED, 
            num_classes=CFG.NUM_CLASSES)
        # n_features = self.model.head.in_features
        # self.model.head = nn.Linear(n_features, num_classes)
        # self.model.fc = nn.Linear(n_features, num_classes)
        if checkpoint:
          self.model.load_state_dict(torch.load(checkpoint), strict=False)

    def forward(self, x):
        x = self.model(x)
        return x
    
    def freeze(self):
        # To freeze the residual layers
        for param in self.model.parameters():
            param.requires_grad = False

        for param in self.model.head.parameters():
            param.requires_grad = True
    
    def unfreeze(self):
        # Unfreeze all layers
        for param in self.model.parameters():
            param.requires_grad = True

# Train & Validation Functions

In [16]:
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # Grad Scaler
    scaler = GradScaler()
    # switch to train mode
    model.train()

    running_loss = 0.0
    running_corrects = 0

    end = time.time()
    for i, (images, target) in tqdm(enumerate(train_loader)):
        # measure data loading time
        data_time.update(time.time() - end)

        # move data to GPU
        if GPU is not None:
            images = images.cuda(GPU, non_blocking=True)
        if torch.cuda.is_available():
            target = target.cuda(GPU, non_blocking=True)
        
        optimizer.zero_grad()

        # compute output
        with torch.cuda.amp.autocast(dtype=torch.float16):
            output = model(images)
            loss = criterion(output, target)

            _, preds = torch.max(output, 1)
        
        running_loss += loss.item() * images.size(0)
        running_corrects += torch.sum(preds == target.data)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))
        
        # compute gradient and do SGD step
        # use the scaler
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        # wandb.log({"Loss/train": loss, 'acc1/train': top1.avg, 'acc5/train': top5.avg})

        if i % CFG.PRINT_FREQ == 0:
            progress.display(i)

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = running_corrects.double() / len(train_loader.dataset)
    
    print('{} Loss: {:.4f} Acc: {:.4f}'.format("Training:", epoch_loss, epoch_acc))

    wblogdict[f'{"train"}/loss'] = np.round(epoch_loss, 4)
    wblogdict[f'{"train"}/acc'] = np.round(epoch_acc.cpu(), 3)

    wblogdict['train/learning_rate'] = CFG.learning_rate

    torch.cuda.empty_cache()
    gc.collect()

In [17]:
def validate(val_loader, model, criterion):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Validation: ')

    # switch to evaluate mode
    model.eval()

    running_loss = 0.0
    running_corrects = 0

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in tqdm(enumerate(val_loader)):
            if GPU is not None:
                images = images.cuda(GPU, non_blocking=True)
            if torch.cuda.is_available():
                target = target.cuda(GPU, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            _, preds = torch.max(output, 1)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % CFG.PRINT_FREQ == 0:
                progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))
        
        running_loss += loss.item() * images.size(0)
        running_corrects += torch.sum(preds == target.data)

    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_acc = running_corrects.double() / len(val_loader.dataset)
    
    print('{} Loss: {:.4f} Acc: {:.4f}'.format("Validation", epoch_loss, epoch_acc))

    wblogdict[f'{"val"}/loss'] = np.round(epoch_loss, 4)
    wblogdict[f'{"val"}/acc'] = np.round(epoch_acc.cpu(), 3)

    # wandb.log({"Loss/val": losses.avg, 'acc1/val': top1.avg, 'acc5/val': top5.avg})
    return top1.avg, top5.avg

# Transformations, Datasets, & Dataloaders

In [18]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(CFG.IMAGE_SIZE),
    transforms.ToTensor(),
    SplitTensorToFiveChannels(),
])

transform_val = transforms.Compose([
    transforms.Resize(CFG.IMAGE_SIZE),
    transforms.ToTensor(),
    SplitTensorToFiveChannels(),
])


In [19]:
train_dataset = datasets.ImageFolder(
    CFG.TRAINDIR, transform=transform_train)

val_dataset = datasets.ImageFolder(
    CFG.VALDIR, transform=transform_val)

In [20]:
train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=CFG.TRAIN_BATCH, shuffle=True,
        num_workers=CFG.WORKERS, pin_memory=True, sampler=None)

In [21]:
val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=CFG.VAL_BATCH, shuffle=False,
        num_workers=CFG.WORKERS, pin_memory=True, sampler=None) 

# Training Loop

In [22]:
model = MaxVitClassifier()
model = model.cuda(GPU)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [23]:
# define loss function (criterion) and optimizer
# get the class weights from the validation set
criterion = nn.CrossEntropyLoss(weight=get_class_weights(val_dataset)).cuda(GPU)

optimizer = torch.optim.Adam(
  model.parameters(), 
  lr=CFG.LR, 
  # momentum=MOMENTUM, 
  weight_decay=CFG.WEIGHT_DECAY
  )

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.EPOCHS)

In [24]:
best_acc1 = 0

In [25]:
for epoch in range(CFG.START_EPOCH, CFG.EPOCHS):
    print('Epoch {}/{}'.format(epoch, CFG.EPOCHS - 1))
    print('-' * 10)
    
    wblogdict = {}
    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch)

    # evaluate on validation set
    acc1, acc5 = validate(val_loader, model, criterion)

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)

    save_checkpoint({
        'epoch': epoch + 1,
        'arch': CFG.ARCH,
        'state_dict': model.state_dict(),
        'best_acc1': best_acc1,
        'acc5': acc5,
        'optimizer' : optimizer.state_dict(),
    }, is_best)

    wblogger.log(wblogdict, step=epoch)
    
    scheduler.step()
    print('lr: ' + str(scheduler.get_last_lr()[0]))
    # wandb.log({'lr': scheduler.get_last_lr()[0]})

Epoch 0/19
----------


1it [00:03,  3.10s/it]

Epoch: [0][   0/2809]	Time  3.187 ( 3.187)	Data  0.654 ( 0.654)	Loss 4.8018e+00 (4.8018e+00)	Acc@1   1.56 (  1.56)	Acc@5   6.25 (  6.25)


51it [00:56,  1.06s/it]

Epoch: [0][  50/2809]	Time  1.058 ( 1.100)	Data  0.001 ( 0.014)	Loss nan (nan)	Acc@1   0.00 (  2.14)	Acc@5   1.56 (  6.59)


101it [01:48,  1.06s/it]

Epoch: [0][ 100/2809]	Time  1.058 ( 1.080)	Data  0.002 ( 0.008)	Loss nan (nan)	Acc@1   0.00 (  1.58)	Acc@5   3.12 (  5.86)


151it [02:41,  1.06s/it]

Epoch: [0][ 150/2809]	Time  1.055 ( 1.072)	Data  0.001 ( 0.006)	Loss nan (nan)	Acc@1   1.56 (  1.40)	Acc@5   3.12 (  5.84)


201it [03:34,  1.06s/it]

Epoch: [0][ 200/2809]	Time  1.060 ( 1.069)	Data  0.001 ( 0.005)	Loss nan (nan)	Acc@1   0.00 (  1.27)	Acc@5   7.81 (  5.60)


251it [04:27,  1.06s/it]

Epoch: [0][ 250/2809]	Time  1.060 ( 1.067)	Data  0.001 ( 0.004)	Loss nan (nan)	Acc@1   1.56 (  1.26)	Acc@5   6.25 (  5.61)


301it [05:20,  1.06s/it]

Epoch: [0][ 300/2809]	Time  1.060 ( 1.066)	Data  0.001 ( 0.003)	Loss nan (nan)	Acc@1   0.00 (  1.23)	Acc@5   9.38 (  5.72)


351it [06:13,  1.06s/it]

Epoch: [0][ 350/2809]	Time  1.059 ( 1.065)	Data  0.001 ( 0.003)	Loss nan (nan)	Acc@1   3.12 (  1.20)	Acc@5   6.25 (  5.65)


401it [07:06,  1.06s/it]

Epoch: [0][ 400/2809]	Time  1.060 ( 1.064)	Data  0.001 ( 0.003)	Loss nan (nan)	Acc@1   0.00 (  1.15)	Acc@5   6.25 (  5.65)


451it [07:59,  1.06s/it]

Epoch: [0][ 450/2809]	Time  1.059 ( 1.064)	Data  0.001 ( 0.003)	Loss nan (nan)	Acc@1   0.00 (  1.08)	Acc@5   6.25 (  5.55)


501it [08:52,  1.06s/it]

Epoch: [0][ 500/2809]	Time  1.059 ( 1.063)	Data  0.002 ( 0.003)	Loss nan (nan)	Acc@1   0.00 (  1.06)	Acc@5   7.81 (  5.55)


551it [09:45,  1.06s/it]

Epoch: [0][ 550/2809]	Time  1.060 ( 1.063)	Data  0.001 ( 0.003)	Loss nan (nan)	Acc@1   1.56 (  1.08)	Acc@5   4.69 (  5.56)


601it [10:38,  1.06s/it]

Epoch: [0][ 600/2809]	Time  1.059 ( 1.063)	Data  0.001 ( 0.002)	Loss nan (nan)	Acc@1   0.00 (  1.08)	Acc@5   3.12 (  5.52)


651it [11:31,  1.06s/it]

Epoch: [0][ 650/2809]	Time  1.059 ( 1.062)	Data  0.001 ( 0.002)	Loss nan (nan)	Acc@1   0.00 (  1.07)	Acc@5   6.25 (  5.53)


701it [12:24,  1.06s/it]

Epoch: [0][ 700/2809]	Time  1.058 ( 1.062)	Data  0.001 ( 0.002)	Loss nan (nan)	Acc@1   0.00 (  1.06)	Acc@5   6.25 (  5.50)


751it [13:17,  1.06s/it]

Epoch: [0][ 750/2809]	Time  1.060 ( 1.062)	Data  0.001 ( 0.002)	Loss nan (nan)	Acc@1   0.00 (  1.04)	Acc@5   9.38 (  5.46)


801it [14:10,  1.06s/it]

Epoch: [0][ 800/2809]	Time  1.060 ( 1.062)	Data  0.001 ( 0.002)	Loss nan (nan)	Acc@1   1.56 (  1.04)	Acc@5   7.81 (  5.48)


851it [15:03,  1.06s/it]

Epoch: [0][ 850/2809]	Time  1.058 ( 1.062)	Data  0.001 ( 0.002)	Loss nan (nan)	Acc@1   1.56 (  1.05)	Acc@5   1.56 (  5.45)


901it [15:56,  1.06s/it]

Epoch: [0][ 900/2809]	Time  1.058 ( 1.062)	Data  0.001 ( 0.002)	Loss nan (nan)	Acc@1   1.56 (  1.06)	Acc@5   6.25 (  5.41)


929it [16:26,  1.06s/it]