In [None]:
%env CUDA_VISIBLE_DEVICES=0,1,2,3
%matplotlib notebook 

# Import resources
import argparse
import os
import random
import shutil
import time
import warnings

import cv2
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

import pretrainedmodels
from efficientnet_pytorch import EfficientNet
import segmentation_models_pytorch as smp
import albumentations as A


In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# dot dict for args
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

## Data Loaders and Augmentation

In [None]:
#https://machinelearningmastery.com/best-practices-for-preparing-and-augmenting-image-data-for-convolutional-neural-networks/
def get_training_augmentation(input_size=224):
    train_transform = [
        
        # 1. Get a random square crop of the image
        # and rescale (preserve aspect ratio) to input size with random interpolation method
        # REF: GoogLeNet (Inception) - Going Deeper with Convolutions, 2014
        A.OneOf([
            A.RandomResizedCrop(height=input_size, width=input_size, ratio=(1.0,1.0), interpolation=cv2.INTER_NEAREST),
            A.RandomResizedCrop(height=input_size, width=input_size, ratio=(1.0,1.0), interpolation=cv2.INTER_LINEAR),
            A.RandomResizedCrop(height=input_size, width=input_size, ratio=(1.0,1.0), interpolation=cv2.INTER_CUBIC),
            A.RandomResizedCrop(height=input_size, width=input_size, ratio=(1.0,1.0), interpolation=cv2.INTER_AREA),
            ], 
            p=1
        ),
        
        # 2. random flipping / rotating
        A.RandomRotate90(p=1),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
                
        # 3. photometric distortions ()
        # REF: GoogLeNet (Inception) - Going Deeper with Convolutions, 2014
        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightness(p=1),
                A.RandomGamma(p=1),
            ],
            p=0.5,
        ),
        A.OneOf(
            [
                A.RandomContrast(p=1),
                A.HueSaturationValue(p=1),
            ],
            p=0.5,
        ),
        
        # 4. Add noise
        A.IAAAdditiveGaussianNoise(p=0.33),
        A.OneOf(
            [
                A.IAASharpen(p=1),
                A.Blur(blur_limit=3, p=1),
                A.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.33,
        ),
                       
    ]

    return A.Compose(train_transform)


def get_validation_augmentation(input_size=224):
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        A.SmallestMaxSize(256), 
        A.CenterCrop(256,256),
        A.Resize(input_size,input_size),
    ]
    return A.Compose(test_transform)


def to_tensor(x, **kwargs):
    return torch.tensor(x.transpose(2, 0, 1).astype('float32'))


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: Amentations.Compose
    
    """
    
    _transform = [
        A.Lambda(image=preprocessing_fn),
        A.Lambda(image=to_tensor),
    ]
    return A.Compose(_transform)

In [None]:
class Dataset(datasets.folder.ImageFolder):
    """A generic data loader where the images are arranged in this way: ::
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png
        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png
    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
    
    def __init__(self, root, transform=None, target_transform=None, is_valid_file=None, preprocessing=None):
        super(datasets.folder.ImageFolder, self).__init__(root, datasets.folder.default_loader,
                                                          datasets.folder.IMG_EXTENSIONS,
                                          transform=transform,
                                          target_transform=target_transform,
                                          is_valid_file=is_valid_file)
        self.imgs = self.samples
        self.preprocessing = preprocessing
        self.transform = transform
    
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = cv2.imread(path)
        sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)
        
        # apply augmentations
        if self.transform is not None:
            sample = self.transform(image = sample)['image']
            
        # apply preprocessing
        if self.preprocessing is not None:
            sample = self.preprocessing(image=sample)['image']
        
            
        return sample, target
        

In [None]:
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        
        target = torch.tensor(target)
        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
        
    return losses.avg


def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
                
            target = torch.tensor(target)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg, top5.avg, losses.avg

In [None]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [None]:
def main(args):

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()
    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args)

In [None]:
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        if args.arch in list(models.__dict__.keys()):
            model = models.__dict__[args.arch](pretrained=True)
        elif args.arch in pretrainedmodels.model_names:
            try:
                model = pretrainedmodels.__dict__[args.arch](pretrained='imagenet')
            except KeyError:
                model = pretrainedmodels.__dict__[args.arch](pretrained='imagenet+5k')
                
        elif args.arch.startswith('efficientnet'):
            model = EfficientNet.from_pretrained(args.arch)
        else:
            print('\n\nWARNING, CANNOT LOAD', args.arch, '\n\n')
            
    else:
        print("=> creating model '{}'".format(args.arch))
        if args.arch in list(models.__dict__.keys()):
            model = models.__dict__[args.arch]()
        elif args.arch in pretrainedmodels.model_names:
            model = pretrainedmodels.__dict__[args.arch]()
        elif args.arch.startswith('efficientnet'):
            model = EfficientNet.from_name(args.arch)
        else:
            print('\n\nWARNING, CANNOT LOAD', args.arch, '\n\n')

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    input_size = 299 if args.arch.startswith("inception") else 224
    try:
        preprocessing_fn = smp.encoders.get_preprocessing_fn(args.arch)  
    except ValueError:
        preprocessing_fn = smp.encoders.get_preprocessing_fn(args.arch, pretrained='imagenet+5k') 
    
    train_dataset = Dataset(
        traindir,
        transform=get_training_augmentation(input_size), 
        preprocessing=get_preprocessing(preprocessing_fn)
    )

    valid_dataset = Dataset(
        valdir,
        transform=get_validation_augmentation(input_size), 
        preprocessing=get_preprocessing(preprocessing_fn)
    )


    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None
        
    

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)


    val_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)


    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    train_losses = []
    val_losses = []
    patience = 0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer, epoch, args)
        train_losses.append(train_loss)

        # evaluate on validation set
        acc1, acc5, val_loss = validate(val_loader, model, criterion, args)
        val_losses.append(val_loss)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'acc5': acc5,
                'optimizer' : optimizer.state_dict(),
                'val_loss': val_losses,
                'train_loss': train_losses
            }, is_best)
            
        # Early stopping
        if args.patience is not None:
            patience = patience + 1 if not is_best else 0
            if patience >= args.patience:
                print('\n\nEarly stopping.  No improvement in %d epochs.\n\n' % patience)
                break

## Train All

In [None]:
models512 = ['resnet101', 'dpn92', 'se_resnet101', 'se_resnext50_32x4d', 'inceptionv4',
             'efficientnet-b1', 'efficientnet-b2',
             'vgg13_bn', 'vgg16_bn', 'densenet121', 'xception'] # more than 26M params
models256 = ['resnet152', 'resnext101_32x8d', 'dpn98', 'dpn107', 'dpn131', 'se_resnet152', 'efficientnet-b3', 
             'efficientnet-b4', 'se_resnext101_32x4d' ,
             'se_resnext101_32x4d', 'densenet169', 'densenet201', 'densenet161'] # more than 50M params
models128 = ['senet154', 'inceptionresnetv2', 'efficientnet-b6', 'efficientnet-b7', 'efficientnet-b5'] # more than 100M params
# Don't run, more than 200M params or no pretrained version
models_no_run = ['resnext101_32x32d', 'resnext101_32x48d', 'efficientnet-b6', 
                 'efficientnet-b7', 'resnext101_32x16d', 'resnext101_32x32d',
                'resnext101_32x48d', 'vgg16', 'vgg19', 'vgg19_bn', 'vgg13', 'vgg11'] 

In [None]:
# Training params
data_dir = r'./data/microscopy/data_v4/'
architecture = 'se_resnet50'
pretrained = True
resume = ''
learning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
batch_size = 1024
epochs = 90
start_epoch = 0
num_workers = 16

args = dotdict({'data': data_dir,
                'arch': architecture,
               'workers': num_workers,
               'epochs': epochs,
               'start_epoch': start_epoch,
               'batch_size': batch_size,
               'lr': learning_rate,
               'learning_rate': learning_rate,
               'momentum': momentum,
               'weight_decay': weight_decay,
               'print_freq': 10,
               'resume': resume,
               'pretrained': pretrained,
               'evaluate': False,
               'world_size': -1,
               'rank': -1,
               'gpu': None,
               'multiprocessing_distributed': False,
               'patience': None})

for arch in smp.encoders.get_encoder_names()[45:]:
    t0 = time.time()
    best_acc1 = 0 # reset the best accuracy
    args.arch = arch
    
    # set the batch size based on the number of model params
    if arch in models512:
        args.batch_size = 512
    elif arch in models256:
        args.batch_size = 256
    elif arch in models128:
        args.batch_size = 128
    elif arch in models_no_run: #skip this architecture
        print('\n\nSkipping', arch, '\n\n')
        continue
    else:
        args.batch_size = 1024
        
    print('\n\nTraining', arch, 'with batch size', args.batch_size, '\n\n')
    # train the model
    main(args)
    
    #save the model
    best_path = r'./model_best.pth.tar'
    bm = torch.load(best_path)
    dest_path = r'./{}_microscopy_epochs_{}_acc1_{:.3f}_acc5_{:.3f}.pth.tar'.format(
        bm['arch'], bm['epoch'], bm['best_acc1'].item(), bm['acc5'].item())
    shutil.copy(best_path, dest_path) 
    print(arch, 'took', round((time.time()-t0)/60, 2), 'minutes to train.')
    

## Continue Training with Early Stopping

In [None]:
# Training params
data_dir = r'./data/microscopy/data_v4/'
architecture = 'se_resnet50'
pretrained = True
resume = ''
learning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
batch_size = 1024
epochs = 200
start_epoch = 0
num_workers = 16
patience = 30 # early stopping patience

args = dotdict({'data': data_dir,
                'arch': architecture,
               'workers': num_workers,
               'epochs': epochs,
               'start_epoch': start_epoch,
               'batch_size': batch_size,
               'lr': learning_rate,
               'learning_rate': learning_rate,
               'momentum': momentum,
               'weight_decay': weight_decay,
               'print_freq': 10,
               'resume': resume,
               'pretrained': pretrained,
               'evaluate': False,
               'world_size': -1,
               'rank': -1,
               'gpu': None,
               'multiprocessing_distributed': False,
               'patience': patience})


trained_models = [f for f in os.listdir() if f.endswith(".tar")]
for f in trained_models:
    if f.startswith('checkpoint') or f.startswith('model_best') or 'fromscratch' in f:
        print('skipping {}'.format(f))
        continue
    epochs = int(f.split('epochs_')[1].split('_acc1_')[0])
    if epochs < 81 or epochs > 90:
        print('skipping {}'.format(f))
        continue
        
    args.resume = f
    
    t0 = time.time()
    best_acc1 = 0 # reset the best accuracy
    arch = f.split('_microscopy_')[0]
    args.arch = arch
    
    
    # set the batch size based on the number of model params
    if arch in models512:
        args.batch_size = 512
    elif arch in models256:
        args.batch_size = 256
    elif arch in models128:
        args.batch_size = 128
    elif arch in models_no_run: #skip this architecture
        print('\n\nSkipping', arch, '\n\n')
        continue
    else:
        args.batch_size = 1024
        
    print('\n\nTraining', arch, 'with batch size', args.batch_size, '\n\n')
    # train the model
    main(args)
    
    #save the model
    best_path = r'./model_best.pth.tar'
    bm = torch.load(best_path)
    dest_path = r'./{}_microscopy_epochs_{}_acc1_{:.3f}_acc5_{:.3f}.pth.tar'.format(
        bm['arch'], bm['epoch'], bm['best_acc1'].item(), bm['acc5'].item())
    shutil.copy(best_path, dest_path) 
    print(arch, 'took', round((time.time()-t0)/60, 2), 'minutes to train.')    

## Train them all without pretraining

In [None]:
from_scratch_models = ['se_resnet50', 'se_resnext50_32x4d', 'efficientnet-b3', 'inceptionresnetv2', 'inceptionv4']

In [None]:
# Training params
data_dir = r'./data/microscopy/data_v4/'
architecture = 'se_resnet50'
pretrained = True
resume = ''
learning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
batch_size = 1024
epochs = 500
start_epoch = 0
num_workers = 16
patience = 30

args = dotdict({'data': data_dir,
                'arch': architecture,
               'workers': num_workers,
               'epochs': epochs,
               'start_epoch': start_epoch,
               'batch_size': batch_size,
               'lr': learning_rate,
               'learning_rate': learning_rate,
               'momentum': momentum,
               'weight_decay': weight_decay,
               'print_freq': 10,
               'resume': resume,
               'pretrained': False,
               'evaluate': False,
               'world_size': -1,
               'rank': -1,
               'gpu': None,
               'multiprocessing_distributed': False,
               'patience': patience})

for arch in from_scratch_models[0:]:
    t0 = time.time()
    best_acc1 = 0 # reset the best accuracy
    args.arch = arch
    
    # set the batch size based on the number of model params
    if arch in models512:
        args.batch_size = 512
    elif arch in models256:
        args.batch_size = 256
    elif arch in models128:
        args.batch_size = 128
    elif arch in models_no_run: #skip this architecture
        print('\n\nSkipping', arch, '\n\n')
        continue
    else:
        args.batch_size = 1024
        
    print('\n\nTraining', arch, 'with batch size', args.batch_size, '\n\n')
    # train the model
    main(args)
    
    #save the model
    best_path = r'./model_best.pth.tar'
    bm = torch.load(best_path)
    dest_path = r'./{}_microscopy_fromscratch_epochs_{}_acc1_{:.3f}_acc5_{:.3f}.pth.tar'.format(
        bm['arch'], bm['epoch'], bm['best_acc1'].item(), bm['acc5'].item())
    shutil.copy(best_path, dest_path) 
    print(arch, 'took', round((time.time()-t0)/60, 2), 'minutes to train.')

## Evaluate Performance

In [None]:
from collections import OrderedDict

In [None]:
best_path = r'./resnet18_microscopy_epochs_75_acc1_81.185_acc5_96.926.pth.tar'
bm = torch.load(best_path)
plt.plot(range(len(bm['train_loss'])), bm['train_loss'])
plt.plot(range(len(bm['val_loss'])), bm['val_loss'])

In [None]:
# load model
best_path = r'./model_best.pth.tar'
model = pretrainedmodels.__dict__[args.arch]()
state_dict = torch.load(best_path)['state_dict']

if list(state_dict.keys())[0].startswith('module'):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    state_dict = new_state_dict

model.load_state_dict(state_dict)
model.eval();

In [None]:
preprocessing_fn = smp.encoders.get_preprocessing_fn(architecture)  
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
valid_dataset = Dataset(
        valdir,
        transform=get_validation_augmentation(), 
        preprocessing=get_preprocessing(preprocessing_fn)
    )

In [None]:
@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    for batch in loader:
        images, labels = batch

        preds = model(images)
        all_preds = torch.cat(
            (all_preds, preds)
            ,dim=0
        )
    return all_preds

In [None]:
with torch.no_grad():
    prediction_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1000)
    valid_preds = get_all_preds(model, prediction_loader)

In [None]:
%matplotlib notebook 
from sklearn.metrics import confusion_matrix
import itertools



In [None]:
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)
    plt.figure(figsize = (20,20))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)
    plt.xlim((len(classes), -0.5))
    plt.ylim((-0.5, len(classes)))

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig(architecture+'.jpg')
    #plt.show()

In [None]:
cm = confusion_matrix(valid_dataset.targets, valid_preds.argmax(dim=1))

In [None]:
plot_confusion_matrix(cm, valid_dataset.classes)

In [None]:
plt.show()