# Setup

In [1]:
#GPU runtime required, should give CUDA version
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0


In [2]:
!pip install faiss-gpu
!pip install fuzzy-c-means



In [3]:
from google.colab import drive

drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [4]:
!cp -r /content/gdrive/MyDrive/Explainable_Wound_Classification/CMSF/self_supervised/* /content
root_path = 'gdrive/MyDrive/Explainable_Wound_Classification/'

In [5]:
import builtins
import os
import sys
import time
import argparse
import random

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets, models

from PIL import ImageFilter, Image
from util import adjust_learning_rate, AverageMeter as AvgMeter, subset_classes
import models.resnet as resnet

import pdb
import faiss
from fcmeans import FCM

import numpy as np
import pandas as pd
import re
from collections import namedtuple

import shutil
import warnings

import torch.nn.parallel
import torch.optim
import torch.utils.data
from torch.utils.data import DataLoader
import torch.nn.functional as F

from tools import *

#CMSF-KM

### Misc Functions

In [6]:
def get_mlp(inp_dim, hidden_dim, out_dim):
    mlp = nn.Sequential(
        nn.Linear(inp_dim, hidden_dim),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, out_dim),
    )
    return mlp

In [7]:
def faiss_kmeans(feats, nmb_clusters):
    feats = feats.numpy()
    d = feats.shape[-1]
    clus = faiss.Clustering(d, nmb_clusters)
    clus.niter = 20
    clus.max_points_per_centroid = 10000000

    index = faiss.IndexFlatL2(d)
    co = faiss.GpuMultipleClonerOptions()
    co.useFloat16 = True
    co.shard = True
    index = faiss.index_cpu_to_all_gpus(index, co)

    # perform the training
    clus.train(feats, index)
    _, train_a = index.search(feats, 1)

    return list(train_a[:, 0])

In [8]:
def fuzzy_c_means(feats, nmb_clusters):
    feats = feats.numpy()
    
    fcm = FCM(n_clusters=nmb_clusters)
    fcm.fit(feats)

    return fcm.predict(feats)

In [9]:
def get_shuffle_ids(bsz):
    """generate shuffle ids for ShuffleBN"""
    forward_inds = torch.randperm(bsz).long().cuda()
    backward_inds = torch.zeros(bsz).long().cuda()
    value = torch.arange(bsz).long().cuda()
    backward_inds.index_copy_(0, forward_inds, value)
    return forward_inds, backward_inds


### Model Architecture

In [10]:
class ConstrainedMeanShiftKM(nn.Module):
    def __init__(self, arch, m=0.99, mem_bank_size=128000, topk=5, dataset_size=100, num_clusters=50000):
        super(ConstrainedMeanShiftKM, self).__init__()

        # save parameters
        self.m = m
        self.mem_bank_size = mem_bank_size
        self.topk = topk
        self.dataset_size = dataset_size
        self.num_clusters = num_clusters

        # create encoders and projection layers
        # both encoders should have same arch
        if 'resnet' in arch:
            self.encoder_q = resnet.__dict__[arch]()
            self.encoder_t = resnet.__dict__[arch]()

        # save output embedding dimensions
        # assuming that both encoders have same dim
        feat_dim = self.encoder_q.fc.in_features
        hidden_dim = feat_dim * 2
        proj_dim = feat_dim // 4

        # projection layers
        self.encoder_t.fc = get_mlp(feat_dim, hidden_dim, proj_dim)
        self.encoder_q.fc = get_mlp(feat_dim, hidden_dim, proj_dim)

        # prediction layer
        self.predict_q = get_mlp(proj_dim, hidden_dim, proj_dim)

        # copy query encoder weights to target encoder
        for param_q, param_t in zip(self.encoder_q.parameters(), self.encoder_t.parameters()):
            param_t.data.copy_(param_q.data)
            param_t.requires_grad = False

        print("using mem-bank size {}".format(self.mem_bank_size))
        # setup queue (For Storing Random Targets)
        self.register_buffer('queue', torch.randn(self.mem_bank_size, proj_dim))
        self.register_buffer('pool', torch.randn(self.dataset_size, proj_dim))
        self.register_buffer('pseudo_labels', 0*torch.ones(self.dataset_size).long())
        # normalize the queue embeddings
        self.queue = nn.functional.normalize(self.queue, dim=1)
        # initialize the labels queue (For Purity measurement)
        self.register_buffer('labels', -1*torch.ones(self.mem_bank_size).long())
        self.register_buffer('index_queue', -1 * torch.ones(self.mem_bank_size).long())
        # setup the queue pointer
        self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_target_encoder(self):
        for param_q, param_t in zip(self.encoder_q.parameters(), self.encoder_t.parameters()):
            param_t.data = param_t.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def data_parallel(self):
        self.encoder_q = torch.nn.DataParallel(self.encoder_q)
        self.encoder_t = torch.nn.DataParallel(self.encoder_t)
        self.predict_q = torch.nn.DataParallel(self.predict_q)

    @torch.no_grad()
    def cluster(self):
        print('start clustering ... num clusters: {}'.format(self.num_clusters))
        cluster_assignment = fuzzy_c_means(self.pool.clone().cpu(), self.num_clusters)
        self.pseudo_labels = torch.tensor(cluster_assignment).cuda()

    @torch.no_grad()
    def _dequeue_and_enqueue(self, targets, labels, indices):
        batch_size = targets.shape[0]

        ptr = int(self.queue_ptr)
        assert self.mem_bank_size % batch_size == 0 

        # replace the targets at ptr (dequeue and enqueue)
        self.pool[indices, :] = targets
        self.queue[ptr:ptr + batch_size] = targets
        self.labels[ptr:ptr + batch_size] = labels
        self.index_queue[ptr:ptr + batch_size] = indices
        ptr = (ptr + batch_size) % self.mem_bank_size  # move pointer

        self.queue_ptr[0] = ptr

    def forward(self, im_q, im_t, labels, indices):
        # compute query features
        feat_q = self.encoder_q(im_q)
        # compute predictions for instance level regression loss
        query = self.predict_q(feat_q)
        query = nn.functional.normalize(query, dim=1)

        # compute target features
        with torch.no_grad():
            # update the target encoder
            self._momentum_update_target_encoder()

            # shuffle targets
            shuffle_ids, reverse_ids = get_shuffle_ids(im_t.shape[0])
            im_t = im_t[shuffle_ids]

            # forward through the target encoder
            current_target = self.encoder_t(im_t)
            current_target = nn.functional.normalize(current_target, dim=1)

            # undo shuffle
            current_target = current_target[reverse_ids].detach()

            # update the memory-bank
            self._dequeue_and_enqueue(current_target, labels, indices)

        targets = self.queue.clone().detach()

        # get pseudo of target and memory bank samples
        current_target_pseudo_labels = self.pseudo_labels[indices]
        targets_pseudo_labels = self.pseudo_labels[self.index_queue]

        # create a mask to constrain the search space
        b = current_target_pseudo_labels.shape[0]
        m = targets_pseudo_labels.shape[0]
        lx = current_target_pseudo_labels.unsqueeze(1).expand((b, m))
        lm = targets_pseudo_labels.unsqueeze(0).expand((b, m))
        msk = lx != lm

        # calculate distances between vectors
        dist_t = 2 - 2 * torch.einsum('bc,kc->bk', [current_target, targets])
        dist_q = 2 - 2 * torch.einsum('bc,kc->bk', [query, targets])

        # select the k nearest neighbors [with smallest distance (largest=False)] based on current target
        _, unconstrained_nn_index = dist_t.topk(self.topk, dim=1, largest=False)

        # select the k nearest neighbors based on constrained memory bank
        dist_t[torch.where(msk)] = 5.0
        _, constrained_nn_index = dist_t.topk(self.topk, dim=1, largest=False)

        # calculate mean shift regression loss
        nn_dist_q_constrained = torch.gather(dist_q, 1, constrained_nn_index)
        nn_dist_q_unconstrained = torch.gather(dist_q, 1, unconstrained_nn_index)

        # purity based on memory bank
        labels = labels.unsqueeze(1).expand(nn_dist_q_unconstrained.shape[0], nn_dist_q_unconstrained.shape[1])
        labels_queue = self.labels.clone().detach()
        labels_queue = labels_queue.unsqueeze(0).expand((nn_dist_q_unconstrained.shape[0], self.mem_bank_size))
        labels_queue = torch.gather(labels_queue, dim=1, index=unconstrained_nn_index)
# TODO: Change matches here
        matches = (labels_queue == labels).float()
        purity = (matches.sum(dim=1) / self.topk).mean()

        loss = ((nn_dist_q_constrained.sum(dim=1) / self.topk).mean()
                + (nn_dist_q_unconstrained.sum(dim=1) / self.topk).mean()) / 2.0

        return loss, purity

### Transformations/Data Loading

In [11]:
class TwoCropsTransform:
    """Take two random crops of one image as the query and target."""
    def __init__(self, weak_transform, strong_transform):
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform
        print(self.weak_transform)
        print(self.strong_transform)

    def __call__(self, x):
        q = self.strong_transform(x)
        t = self.weak_transform(x)
        return [q, t]


class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

In [12]:
class Image_Dataset(torch.utils.data.Dataset):

    def __init__(self, root_dir, label_fn, transform=None):
        """
        Image dataset. Returns tensorized images and labels with index
        Args:
            root_dir: path to a cropped mouse image dataset.
            label_fn: function that returns the correct label given an image name
        """
        self.root_dir = root_dir
        self.label_fn = label_fn
        self.transform = transform

        samples = []
        targets = []
        for f in os.listdir(root_dir):
            samples.append(os.path.join(root_dir, f))
            targets.append(label_fn(f))
        
        self.samples = samples
        self.targets = targets

    def pil_loader(self, path):
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def __getitem__(self, index: int):
            """
            Returns index, tensor data, and tensorized label.
            """
            img = self.pil_loader(self.samples[index])
            target = self.targets[index]

            if self.transform:
                img = self.transform(img)

            return index, img, torch.tensor(target)

    def __len__(self):
        return len(self.samples)

    def __str__(self):
        return "Image_Dataset:\n" + "Found " + str(len(self)) + " images in " + self.root_dir + "\n"

In [13]:
# Create train loader
def get_train_loader(datapath, label_fn, batch_size, num_workers, weak_strong=True):
    traindir = os.path.join(datapath, 'train')
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    augmentation_strong = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]

    augmentation_weak = [
        #transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]

    if weak_strong:
        train_dataset = Image_Dataset(
            traindir, label_fn, 
            TwoCropsTransform(transforms.Compose(augmentation_weak), transforms.Compose(augmentation_strong))
        )
    else:
        train_dataset = Image_Dataset(
            traindir, label_fn, 
            TwoCropsTransform(transforms.Compose(augmentation_weak), transforms.Compose(augmentation_weak))
        )

    print('==> train dataset')
    print(train_dataset)

    # NOTE: remove drop_last
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True, drop_last=True)

    return train_loader

### Training Functions

In [14]:
def train(epoch, train_loader, mean_shift, optimizer, print_freq):
    """
    one epoch training
    """
    mean_shift.train()

    batch_time = AvgMeter()
    data_time = AvgMeter()
    loss_meter = AvgMeter()
    purity_meter = AvgMeter()

    end = time.time()
    for idx, (indices, (im_q, im_t), labels) in enumerate(train_loader):
        data_time.update(time.time() - end)
        im_q = im_q.cuda(non_blocking=True)
        im_t = im_t.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)

        # ===================forward=====================
        loss, purity = mean_shift(im_q=im_q, im_t=im_t, labels=labels, indices=indices)

        # ===================backward=====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ===================meters=====================
        loss_meter.update(loss.item(), im_q.size(0))
        purity_meter.update(purity.item(), im_q.size(0))

        torch.cuda.synchronize()
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'purity {purity.val:.3f} ({purity.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})\t'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time,
                   purity=purity_meter,
                   loss=loss_meter))
            sys.stdout.flush()
            sys.stdout.flush()

    return loss_meter.avg

In [15]:
#workaround struct to pass args to util script
LRArgs = namedtuple('LRArgs', ['cos', 'learning_rate', 'epochs', 'lr_decay_rate'])

In [16]:
def cmsf_km_main(data_path, checkpoint_path, label_fn, batch_size=16, num_workers=2, 
                 epochs=200, print_freq=10, save_freq=10, weak_strong=True, 
                 debug=False, arch='resnet50', momentum=0.99, mem_bank_size=128000, 
                 topk=5, num_clusters=4, learning_rate=0.05, sgd_momentum=0.9, 
                 weight_decay=1e-4, weights=None, resume=None, cos=True, 
                 lr_decay_rate=0.2):
    """
    todo: docstring
    """
    opt = locals()
    del opt['label_fn']

    os.makedirs(checkpoint_path, exist_ok=True)

    train_loader = get_train_loader(data_path, label_fn, batch_size, num_workers, weak_strong)

    mean_shift = ConstrainedMeanShiftKM(
        arch,
        m=momentum,
        mem_bank_size=mem_bank_size,
        topk=topk,
        dataset_size=len(train_loader.dataset),
        num_clusters=num_clusters
    )

    mean_shift.data_parallel()
    mean_shift = mean_shift.cuda()
    print(mean_shift)

    params = [p for p in mean_shift.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params,
                                lr=learning_rate,
                                momentum=sgd_momentum,
                                weight_decay=weight_decay)

    cudnn.benchmark = True
    start_epoch = 1

    if weights:
        print('==> load weights from checkpoint: {}'.format(weights))
        ckpt = torch.load(weights)
        print('==> resume from epoch: {}'.format(ckpt['epoch']))
        if 'model' in ckpt:
            sd = ckpt['model']
        else:
            sd = ckpt['state_dict']
        msg = mean_shift.load_state_dict(sd, strict=False)
        optimizer.load_state_dict(ckpt['optimizer'])
        start_epoch = ckpt['epoch'] + 1
        print(msg)

    if resume:
        print('==> resume from checkpoint: {}'.format(resume))
        ckpt = torch.load(resume, map_location='cpu')
        # sd = ckpt['state_dict']
        # sd = {k.replace('module.', ''): v for k, v in sd.items()}
        print('==> resume from epoch: {}'.format(ckpt['epoch']))
        mean_shift.load_state_dict(ckpt['state_dict'], strict=True)
        optimizer.load_state_dict(ckpt['optimizer'])
        start_epoch = ckpt['epoch'] + 1

    lr_args = LRArgs(cos, learning_rate, epochs, lr_decay_rate)

    for epoch in range(start_epoch, epochs + 1):

        adjust_learning_rate(epoch, lr_args, optimizer)
        print("==> training...")

        time1 = time.time()

        train(epoch, train_loader, mean_shift, optimizer, print_freq)
        mean_shift.cluster()
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # saving the model
        if epoch % save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': opt,
                'state_dict': mean_shift.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }

            save_file = os.path.join(checkpoint_path, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)

            # help release GPU memory
            del state
            torch.cuda.empty_cache()

# CMSF-KM Training

In [17]:
cmsf_km_main(
    data_path=root_path + 'Split_images', 
    checkpoint_path=root_path + 'outputs/fuzzy_c_1', 
    label_fn=lambda x: np.array(0),
    batch_size=16,
    num_workers=2,
    epochs=100,
    arch='resnet50',
    topk=6,
    num_clusters=4,
    learning_rate=0.05,
    mem_bank_size=128000,
    weak_strong=False)

Compose(
    Resize(size=224, interpolation=bilinear, max_size=None, antialias=None)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
Compose(
    Resize(size=224, interpolation=bilinear, max_size=None, antialias=None)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
==> train dataset
Image_Dataset:
Found 191 images in gdrive/MyDrive/Explainable_Wound_Classification/Split_images/train

using mem-bank size 128000
ConstrainedMeanShiftKM(
  (encoder_q): DataParallel(
    (module): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
        

# Linear Evaluation

### Misc Setup

In [18]:
def load_weights(model, wts_path):
    wts = torch.load(wts_path)
    # pdb.set_trace()
    if 'state_dict' in wts:
        ckpt = wts['state_dict']
    elif 'model' in wts:
        ckpt = wts['model']
    else:
        ckpt = wts

    ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
    ckpt = {k: v for k, v in ckpt.items() if 'encoder_t' not in k}
    ckpt = {k.replace('encoder_q.', ''): v for k, v in ckpt.items()}
    state_dict = {}

    for m_key, m_val in model.state_dict().items():
        if m_key in ckpt:
            state_dict[m_key] = ckpt[m_key]
        else:
            state_dict[m_key] = m_val
            print('not copied => ' + m_key)

    model.load_state_dict(state_dict)

In [19]:
def get_model(arch, wts_path):
    if arch == 'alexnet':
        model = AlexNet()
        model.fc = nn.Sequential()
        load_weights(model, wts_path)
    elif arch == 'pt_alexnet':
        model = models.alexnet()
        classif = list(model.classifier.children())[:5]
        model.classifier = nn.Sequential(*classif)
        load_weights(model, wts_path)
    elif arch == 'mobilenet':
        model = MobileNetV2()
        model.fc = nn.Sequential()
        load_weights(model, wts_path)
    elif 'sup_resnet' in arch:
        model = models.__dict__[arch.replace('sup_', '')](pretrained=True)
        model.fc = nn.Sequential()
    elif 'resnet' in arch:
        model = models.__dict__[arch]()
        model.fc = nn.Sequential()
        load_weights(model, wts_path)
    else:
        raise ValueError('arch not found: ' + arch)

    for p in model.parameters():
        p.requires_grad = False

    return model

In [20]:
class Normalize(nn.Module):
    def forward(self, x):
        return x / x.norm(2, dim=1, keepdim=True)


class FullBatchNorm(nn.Module):
    def __init__(self, var, mean):
        super(FullBatchNorm, self).__init__()
        self.register_buffer('inv_std', (1.0 / torch.sqrt(var + 1e-5)))
        self.register_buffer('mean', mean)

    def forward(self, x):
        return (x - self.mean) * self.inv_std

In [21]:
def get_channels(arch):
    if arch == 'alexnet':
        c = 4096
    elif arch == 'pt_alexnet':
        c = 4096
    elif 'resnet50' in arch:
        c = 2048
    elif arch == 'resnet18':
        c = 512
    elif arch == 'mobilenet':
        c = 1280
    else:
        raise ValueError('arch not found: ' + arch)
    return c

In [22]:
def normalize(x):
    return x / x.norm(2, dim=1, keepdim=True)

In [23]:
def get_feats(loader, model, print_freq, logger):
    batch_time = AverageMeter('Time', ':6.3f')
    progress = ProgressMeter(
        len(loader),
        [batch_time],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    feats, labels, ptr = None, None, 0

    with torch.no_grad():
        end = time.time()
        for i, (indices, images, target) in enumerate(loader):
            images = images.cuda(non_blocking=True)
            cur_targets = target.cpu()
            cur_feats = normalize(model(images)).cpu()
            B, D = cur_feats.shape
            inds = torch.arange(B) + ptr

            if not ptr:
                feats = torch.zeros((len(loader.dataset), D)).float()
                labels = torch.zeros(len(loader.dataset)).long()

            feats.index_copy_(0, inds, cur_feats)
            labels.index_copy_(0, inds, cur_targets.argmax(axis=1))
            ptr += B

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                logger.info(progress.display(i))

    return feats, labels

### Training Functions

In [24]:
def train(train_loader, backbone, linear, optimizer, epoch, print_freq, logger):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top2 = AverageMeter('Acc@2', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top2],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    backbone.eval()
    linear.train()

    end = time.time()
    for i, (indices, images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        with torch.no_grad():
            output = backbone(images)
        output = linear(output)
        loss = F.binary_cross_entropy_with_logits(output, target)

        # measure accuracy and record loss
        acc1, acc2 = accuracy(output, target.argmax(axis=1), topk=(1, 2))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top2.update(acc2[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            logger.info(progress.display(i))

In [25]:
def validate(val_loader, backbone, linear, print_freq, logger):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top2 = AverageMeter('Acc@2', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top2],
        prefix='Test: ')

    backbone.eval()
    linear.eval()

    with torch.no_grad():
        end = time.time()
        for i, (indices, images, target) in enumerate(val_loader):
            images = images.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            # compute output
            output = backbone(images)
            output = linear(output)
            loss = F.binary_cross_entropy_with_logits(output, target)

            # measure accuracy and record loss
            
            acc1, acc2 = accuracy(output, target.argmax(axis=1), topk=(1, 2))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top2.update(acc2[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                logger.info(progress.display(i))

        # TODO: this should also be done with the ProgressMeter
        logger.info(' * Acc@1 {top1.avg:.3f} Acc@2 {top2.avg:.3f}'
              .format(top1=top1, top2=top2))

    return top1.avg

In [26]:
def main_worker(data, label_fn, weights, save, logger, batch_size=16, workers=2, 
                epochs=40, arch='resnet50', print_freq=10, mlp=True, lr=0.01, 
                momentum=0.9, weight_decay=1e-4, lr_schedule='15,30,40', 
                resume=None, evaluate=False, n_classes=4):
    best_acc1 = 0

    # Data loading code
    traindir = os.path.join(data, 'train')
    valdir = os.path.join(data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
        #transforms.RandomResizedCrop(224),
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        #transforms.Resize(256),
        #transforms.CenterCrop(224),
        transforms.Resize(224),
        transforms.ToTensor(),
        normalize,
    ])

    train_dataset = Image_Dataset(traindir, label_fn, train_transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size, shuffle=True,
        num_workers=workers, pin_memory=True,
    )

    val_loader = torch.utils.data.DataLoader(
        Image_Dataset(valdir, label_fn, val_transform),
        batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True,
    )

    train_val_loader = torch.utils.data.DataLoader(
        Image_Dataset(traindir, label_fn, val_transform),
        batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True,
    )

    backbone = get_model(arch, weights)
    backbone = nn.DataParallel(backbone).cuda()
    backbone.eval()


    cached_feats = '%s/var_mean.pth.tar' % save
    if not os.path.exists(cached_feats):
        train_feats, _ = get_feats(train_val_loader, backbone, print_freq, logger)
        train_var, train_mean = torch.var_mean(train_feats, dim=0)
        torch.save((train_var, train_mean), cached_feats)
    else:
        train_var, train_mean = torch.load(cached_feats)
    if mlp:
        c = get_channels(arch)
        linear = nn.Sequential(
            Normalize(),
            FullBatchNorm(train_var, train_mean),
            nn.Linear(c, c),
            nn.BatchNorm1d(c),
            nn.ReLU(inplace=True),
            nn.Linear(c, n_classes),
            nn.Softmax()
        )
    else:
        linear = nn.Sequential(
            Normalize(),
            FullBatchNorm(train_var, train_mean),
            nn.Linear(get_channels(arch), n_classes),
        )

    print(backbone)
    print(linear)

    linear = linear.cuda()

    optimizer = torch.optim.SGD(linear.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

    sched = [int(x) for x in lr_schedule.split(',')]
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=sched
    )

    start_epoch = 0
    # optionally resume from a checkpoint
    if resume:
        if os.path.isfile(resume):
            logger.info("=> loading checkpoint '{}'".format(resume))
            checkpoint = torch.load(resume)
            start_epoch = checkpoint['epoch']
            linear.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            logger.info("=> loaded checkpoint '{}' (epoch {})"
                  .format(resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(resume))

    cudnn.benchmark = True

    if evaluate:
        validate(val_loader, backbone, linear, print_freq, logger)
        return

    for epoch in range(start_epoch, epochs):
        # train for one epoch
        train(train_loader, backbone, linear, optimizer, epoch, print_freq, logger)

        # evaluate on validation set
        acc1 = validate(val_loader, backbone, linear, print_freq, logger)

        # modify lr
        lr_scheduler.step()
        # logger.info('LR: {:f}'.format(lr_scheduler.get_last_lr()[-1]))

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': linear.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
        }, is_best, save)

    return backbone, linear

In [27]:
!touch logger_init
def main_linear_eval(data, label_fn, weights, save, batch_size=16, workers=2, 
                     epochs=40, arch='resnet50', print_freq=10, mlp=True, lr=0.01, 
                     momentum=0.9, weight_decay=1e-4, lr_schedule='15,30,40', 
                     resume=None, evaluate=False, seed=None, n_classes=4):
    args = locals()
    del args['label_fn']

    makedirs(save)
    logger = get_logger(logpath=os.path.join(save, 'logs'), filepath='logger_init')
    logger.info(args)

    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    return main_worker(data, label_fn, weights, save, logger, batch_size, workers, 
                epochs, arch, print_freq, mlp, lr, momentum, weight_decay, 
                lr_schedule, resume, evaluate)

# Linear Evaluation Training

In [28]:
labels_df = pd.read_csv(root_path + 'Cropped_Images_Wound_Stage_Probabilities.csv', index_col='Image')
labels_df.head()

Unnamed: 0_level_0,hemostasis,inflammatory,proliferative,maturation
Image,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Day 8_A8-4-L.png,0.181818,0.090909,0.545455,0.181818
Day 4_A8-3-R.png,0.090909,0.909091,0.0,0.0
Day 14_Y8-4-L.png,0.0,0.0,0.090909,0.909091
Day 7_Y8-4-L.png,0.0,0.0,0.454545,0.545455
Day 2_A8-1-L.png,0.181818,0.727273,0.090909,0.0


In [29]:
labels_df.loc['Day 8_A8-4-L.png'].to_numpy()

array([0.18181818, 0.09090909, 0.54545455, 0.18181818])

In [30]:
backbone, linear = main_linear_eval(
    data='/content/gdrive/MyDrive/Explainable_Wound_Classification/Split_images', 
    label_fn=lambda x: labels_df.loc[x].to_numpy(), 
    weights='/content/gdrive/MyDrive/Explainable_Wound_Classification/outputs/fuzzy_c_1/ckpt_epoch_100.pth',
    save='/content/gdrive/MyDrive/Explainable_Wound_Classification/outputs/fuzzy_c_1/eval_100/', 
    batch_size=16, 
    workers=2, 
    epochs=40, 
    arch='resnet50', 
    print_freq=10, 
    mlp=True, 
    lr=0.01)

logger_init

{'data': '/content/gdrive/MyDrive/Explainable_Wound_Classification/Split_images', 'weights': '/content/gdrive/MyDrive/Explainable_Wound_Classification/outputs/fuzzy_c_1/ckpt_epoch_100.pth', 'save': '/content/gdrive/MyDrive/Explainable_Wound_Classification/outputs/fuzzy_c_1/eval_100/', 'batch_size': 16, 'workers': 2, 'epochs': 40, 'arch': 'resnet50', 'print_freq': 10, 'mlp': True, 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0.0001, 'lr_schedule': '15,30,40', 'resume': None, 'evaluate': False, 'seed': None, 'n_classes': 4}


DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
      

  input = module(input)
Epoch: [0][ 0/12]	Time  0.446 ( 0.446)	Data  0.365 ( 0.365)	Loss 7.6223e-01 (7.6223e-01)	Acc@1  12.50 ( 12.50)	Acc@2  43.75 ( 43.75)
Epoch: [0][10/12]	Time  0.059 ( 0.098)	Data  0.014 ( 0.043)	Loss 7.3550e-01 (7.5555e-01)	Acc@1  56.25 ( 36.36)	Acc@2  87.50 ( 63.07)
Test: [0/2]	Time  0.309 ( 0.309)	Loss 7.5002e-01 (7.5002e-01)	Acc@1  56.25 ( 56.25)	Acc@2  62.50 ( 62.50)
 * Acc@1 53.125 Acc@2 68.750
Epoch: [1][ 0/12]	Time  0.307 ( 0.307)	Data  0.261 ( 0.261)	Loss 7.4393e-01 (7.4393e-01)	Acc@1  50.00 ( 50.00)	Acc@2  68.75 ( 68.75)
Epoch: [1][10/12]	Time  0.063 ( 0.086)	Data  0.018 ( 0.039)	Loss 7.3785e-01 (7.3593e-01)	Acc@1  43.75 ( 53.98)	Acc@2  68.75 ( 74.43)
Test: [0/2]	Time  0.292 ( 0.292)	Loss 7.5363e-01 (7.5363e-01)	Acc@1  50.00 ( 50.00)	Acc@2  62.50 ( 62.50)
 * Acc@1 50.000 Acc@2 71.875
Epoch: [2][ 0/12]	Time  0.325 ( 0.325)	Data  0.279 ( 0.279)	Loss 7.0074e-01 (7.0074e-01)	Acc@1  68.75 ( 68.75)	Acc@2  81.25 ( 81.25)
Epoch: [2][10/12]	Time  0.109 ( 0.094)	Da

# Test Set

In [31]:
test_dataset = Image_Dataset(
    '/content/gdrive/MyDrive/Explainable_Wound_Classification/Split_images/test', 
    lambda x: labels_df.loc[x].to_numpy(), 
    transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor(),
         normalize])
)

In [32]:
backbone.eval()
linear.eval()

Sequential(
  (0): Normalize()
  (1): FullBatchNorm()
  (2): Linear(in_features=2048, out_features=2048, bias=True)
  (3): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (4): ReLU(inplace=True)
  (5): Linear(in_features=2048, out_features=4, bias=True)
  (6): Softmax(dim=None)
)

In [33]:
with torch.no_grad():
    results = {}
    for i in range(len(test_dataset)):
        img_name = test_dataset.samples[i].split('/')[-1]
        results[img_name] = {} 
        results[img_name]['target'] = test_dataset.targets[i]
        results[img_name]['pred'] = linear(backbone(test_dataset[i][1].expand(1, 3, 224, 224))).cpu().numpy()

  input = module(input)


In [34]:
df = pd.DataFrame(results).T
props = df.index.map(lambda x: re.match('^Day (\d+)_(Y|A)8-(\d)-(L|R)', x).groups())
df['Day'] = props.map(lambda x: int(x[0]))
df['Age'] = props.map(lambda x: x[1])
df['Mouse'] = props.map(lambda x: int(x[2]))
df['Side'] = props.map(lambda x: x[3])
df.head()

Unnamed: 0,target,pred,Day,Age,Mouse,Side
Day 14_Y8-4-L.png,"[0.0, 0.0, 0.0909090909090909, 0.9090909090909...","[[0.00077371113, 0.0022693924, 0.98832035, 0.0...",14,Y,4,L
Day 7_Y8-4-L.png,"[0.0, 0.0, 0.4545454545454545, 0.5454545454545...","[[0.00076753687, 0.0022484525, 0.9883583, 0.00...",7,Y,4,L
Day 9_A8-1-R.png,"[0.0909090909090909, 0.4545454545454545, 0.454...","[[0.0007661236, 0.002261453, 0.98830634, 0.008...",9,A,1,R
Day 4_A8-1-R.png,"[0.3636363636363636, 0.6363636363636364, 0.0, ...","[[0.0007682852, 0.0022748422, 0.9882572, 0.008...",4,A,1,R
Day 12_A8-1-R.png,"[0.0, 0.2, 0.5, 0.3]","[[0.0007665131, 0.002246504, 0.98840904, 0.008...",12,A,1,R


In [35]:
young_df = df[df.Age == 'Y']
young_df.sort_values('Day')

Unnamed: 0,target,pred,Day,Age,Mouse,Side
Day 0_Y8-4-L.png,"[1.0, 0.0, 0.0, 0.0]","[[0.00077199383, 0.0023037447, 0.9880725, 0.00...",0,Y,4,L
Day 1_Y8-4-L.png,"[0.3, 0.4, 0.3, 0.0]","[[0.0007662006, 0.0022762865, 0.988245, 0.0087...",1,Y,4,L
Day 2_Y8-4-L.png,"[0.8, 0.2, 0.0, 0.0]","[[0.00076841813, 0.0022748837, 0.9882653, 0.00...",2,Y,4,L
Day 3_Y8-4-L.png,"[0.2, 0.8, 0.0, 0.0]","[[0.00076816865, 0.0022626803, 0.98832417, 0.0...",3,Y,4,L
Day 4_Y8-4-L.png,"[0.4, 0.5, 0.1, 0.0]","[[0.0007702087, 0.0022798984, 0.9881912, 0.008...",4,Y,4,L
Day 5_Y8-4-L.png,"[0.3, 0.5, 0.2, 0.0]","[[0.0007680087, 0.0022626277, 0.98828226, 0.00...",5,Y,4,L
Day 6_Y8-4-L.png,"[0.1, 0.3, 0.6, 0.0]","[[0.00076703215, 0.002252201, 0.9883349, 0.008...",6,Y,4,L
Day 7_Y8-4-L.png,"[0.0, 0.0, 0.4545454545454545, 0.5454545454545...","[[0.00076753687, 0.0022484525, 0.9883583, 0.00...",7,Y,4,L
Day 8_Y8-4-L.png,"[0.0, 0.0, 0.5, 0.5]","[[0.0007683486, 0.0022554114, 0.9883284, 0.008...",8,Y,4,L
Day 9_Y8-4-L.png,"[0.0, 0.0, 0.6, 0.4]","[[0.000767051, 0.0022412816, 0.9884106, 0.0085...",9,Y,4,L


In [36]:
aged_df = df[df.Age == 'A']
aged_df.sort_values('Day')

Unnamed: 0,target,pred,Day,Age,Mouse,Side
Day 0_A8-1-R.png,"[0.9, 0.1, 0.0, 0.0]","[[0.0007662976, 0.002264746, 0.9882834, 0.0086...",0,A,1,R
Day 1_A8-1-R.png,"[0.1, 0.8, 0.1, 0.0]","[[0.0007634308, 0.0022500332, 0.988376, 0.0086...",1,A,1,R
Day 2_A8-1-R.png,"[0.4, 0.5, 0.1, 0.0]","[[0.0007654989, 0.0022612421, 0.9883159, 0.008...",2,A,1,R
Day 3_A8-1-R.png,"[0.4, 0.5, 0.1, 0.0]","[[0.0007688016, 0.002287011, 0.9882036, 0.0087...",3,A,1,R
Day 4_A8-1-R.png,"[0.3636363636363636, 0.6363636363636364, 0.0, ...","[[0.0007682852, 0.0022748422, 0.9882572, 0.008...",4,A,1,R
Day 5_A8-1-R.png,"[0.2, 0.5, 0.3, 0.0]","[[0.0007663971, 0.0022515245, 0.98835796, 0.00...",5,A,1,R
Day 6_A8-1-R.png,"[0.2, 0.7, 0.1, 0.0]","[[0.0007645951, 0.002251848, 0.9883652, 0.0086...",6,A,1,R
Day 7_A8-1-R.png,"[0.1, 0.7, 0.1, 0.1]","[[0.00076561735, 0.0022552274, 0.9883435, 0.00...",7,A,1,R
Day 8_A8-1-R.png,"[0.2, 0.6, 0.2, 0.0]","[[0.0007667092, 0.0022641725, 0.98829377, 0.00...",8,A,1,R
Day 9_A8-1-R.png,"[0.0909090909090909, 0.4545454545454545, 0.454...","[[0.0007661236, 0.002261453, 0.98830634, 0.008...",9,A,1,R
