In [1]:
!python -m jupytools export -nb "16a_densenet_tricks_pretrained.ipynb" -o .
!mv densenet_tricks_pretrained.py densenet121_aftertrain_holdout.py

Exported: 16a_densenet_tricks_pretrained.ipynb -> densenet_tricks_pretrained.py
1 notebook(s) exported into folder: .


In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
# https://www.kaggle.com/tanlikesmath/rcic-fastai-starter

## Imports 

In [3]:
#export
from collections import OrderedDict
import json
import os
from os.path import dirname, join
from functools import reduce
from pdb import set_trace

import cv2 as cv
import jupytools
import jupytools.syspath
import numpy as np
import pandas as pd
import PIL.Image
import matplotlib.pyplot as plt

from catalyst.utils import get_one_hot
from imageio import imread
import pretrainedmodels
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from visdom import Visdom

jupytools.syspath.add(join(dirname(os.getcwd()), 'protein_project'))
jupytools.syspath.add('rxrx1-utils')
if jupytools.is_notebook():
    from tqdm import tqdm_notebook as tqdm
else:
    from tqdm import tqdm as tdqm
    
from basedir import ROOT, NUM_CLASSES
from dataset import build_stats_index
from lookahead import Lookahead

In [4]:
#export
torch.set_default_tensor_type(torch.FloatTensor)

## Dataset Reader

In [5]:
#export
from augmentation import JoinChannels, SwapChannels, Resize, ToFloat, Rescale
from augmentation import VerticalFlip, HorizontalFlip, PixelStatsNorm, composer
from augmentation import UnconditionalTransform
from augmentation import AugmentedImages, bernoulli

In [6]:
#export
default_open_fn = imread  # PIL.Image.open

In [7]:
#export
class RxRxDataset(Dataset):
    
    def __init__(self, meta_df, img_dir, sites=(1, 2), channels=(1, 2, 3, 4, 5, 6),
                open_image=default_open_fn, n_classes=NUM_CLASSES, train=True,
                flip_v=0.1, flip_h=0.1, resize=512, norm=True, rescale=False,
                crop=False, hat=False, sigma_clip=False, label_smoothing=0.0):
        
        # data
        self.records = meta_df.to_records(index=False)
        self.img_dir = img_dir
        self.sites = sites
        self.channels = channels
        self.n = len(self.records)
        self.open_image = open_image
        self.n_classes = n_classes
        self.train = train
        
        # options
        self.flip_v = flip_v
        self.flip_h = flip_h
        self.resize = resize
        self.norm = norm
        self.rescale = rescale
        self.crop = crop
        self.hat = hat
        self.sigma_clip = sigma_clip
        self.label_smoothing = label_smoothing
    
    def __getitem__(self, index):
        samples = {f'site{site}': self._get_site_image(index, site) for site in self.sites}
        return samples
    
    def __len__(self):
        return self.n
    
    def _get_image_path(self, index, channel, site):
        r = self.records[index]
        exp, plate, well = r.experiment, r.plate, r.well
        subdir = 'train' if self.train else 'test'
        path = f'{self.img_dir}/{subdir}/{exp}/Plate{plate}/{well}_s{site}_w{channel}.png'
        return path
        
    def _get_site_image(self, index, site):
        paths = [self._get_image_path(index, ch, site) for ch in self.channels]
        images = [self.open_image(p) for p in paths]
        image = self._concat(images)
        image = self._augment(image)
        sample = self._wrap_with_meta(image, self.records[index])
        sample['site'] = site
        return sample
        
    def _wrap_with_meta(self, image, meta):
        if self.train:
            sirna = meta.sirna
            target = int(sirna)
            onehot = get_one_hot(target, num_classes=self.n_classes,
                                 smoothing=self.label_smoothing)
            return {'features': image, 'targets': target, 'targets_one_hot': onehot,
                    'id_code': meta.id_code}
        else:
            return {'features': image, 'id_code': meta.id_code}
        
    def _concat(self, images):
        try:
            img = np.stack(images)
        except (TypeError, ValueError) as e:
            print(f'Warning: cannot concatenate images! {e.__class__.__name__}: {e}')
            for filename, image in zip(paths, images):
                print(f'\tpath={filename}, size={image.size}')
            index = (index + 1) % len(self)
            print(f'Skipping instance {index} and trying another one...')
            return self[index]
        finally:
            for image in images:
                if hasattr(image, 'close'):
                    image.close()
        return img
                    
    def _augment(self, image):
        # OpenCV channels ordering
        image = image.transpose(1, 2, 0)  # W x H x C
        if self.resize:
            image = cv.resize(image, (self.resize, self.resize))
        if self.crop:
            assert isinstance(self.crop, int), 'If crop provided, it should be integer'
            assert self.crop < self.resize, 'Crop should be smaller than image size'
            shift = np.random.randint(0, self.resize - self.crop - 1)
            image = image[shift:(shift+self.crop), shift:(shift+self.crop), :]
        if self.flip_v:
            if bernoulli(self.flip_v) == 1:
                image = cv.flip(image, 0)
        if self.flip_h:
            if bernoulli(self.flip_h) == 1:
                image = cv.flip(image, 1)
        if self.hat:
            kernel = np.ones(self.hat, np.uint8)
            image = cv.morphologyEx(image, cv.MORPH_TOPHAT, self.hat)
        image = image.astype(np.float32)
        if self.rescale:
            image /= 255
        if self.norm:
            mean, std = image.mean(axis=(0, 1)), image.std(axis=(0, 1))
            image = (image - mean)/(std + 1e-8)
        if self.sigma_clip:
            image = np.clip(image, -self.sigma_clip, self.sigma_clip)
        image = image.transpose(2, 0, 1)  # C x W x H
        return image

## Pipeline Preparation

In [8]:
# from split import StratifiedSplit
# splitter = StratifiedSplit()
# trn_df, val_df = splitter(pd.read_csv(ROOT/'train.csv')) 
# tst_df = pd.read_csv(ROOT/'test.csv')

In [9]:
#export
from collections import defaultdict

def get_cell_type(x): return x.split('-')[0]

class HoldoutSplit:
    def __call__(self, data_frame):
        return self.apply(data_frame)
    def apply(self, data_frame):
        experiments = data_frame.experiment.unique()
        cell_types = defaultdict(list)
        for exp in experiments:
            cell_type, _ = exp.split('-')
            cell_types[cell_type].append(exp)
        training, holdout = [], []
        for cell_type, subset in cell_types.items():
            valid = np.random.choice(subset)
            subset.remove(valid)
            training.extend(subset)
            holdout.append(valid)
        training_df = data_frame[data_frame.experiment.isin(training)].copy()
        valid_df = data_frame[data_frame.experiment.isin(holdout)].copy()
        return training_df, valid_df

sub_df = pd.read_csv(ROOT/'train.csv')
tst_df = pd.read_csv(ROOT/'test.csv')

splitter = HoldoutSplit()
trn_df, val_df = splitter(sub_df)
print(trn_df.shape, val_df.shape, len(val_df)/len(trn_df))

(32085, 5) (4430, 5) 0.13807074957145082


In [10]:
trn_df.experiment.unique()

array(['HEPG2-01', 'HEPG2-02', 'HEPG2-03', 'HEPG2-04', 'HEPG2-05',
       'HEPG2-07', 'HUVEC-02', 'HUVEC-03', 'HUVEC-04', 'HUVEC-05',
       'HUVEC-06', 'HUVEC-07', 'HUVEC-08', 'HUVEC-09', 'HUVEC-10',
       'HUVEC-11', 'HUVEC-12', 'HUVEC-13', 'HUVEC-14', 'HUVEC-15',
       'HUVEC-16', 'RPE-01', 'RPE-02', 'RPE-03', 'RPE-04', 'RPE-05',
       'RPE-07', 'U2OS-01', 'U2OS-02'], dtype=object)

In [11]:
val_df.experiment.unique()

array(['HEPG2-06', 'HUVEC-01', 'RPE-06', 'U2OS-03'], dtype=object)

In [12]:
# trn_df = trn_df.head(100)
# val_df = val_df.head(100)

In [13]:
#export
sz = 512

trn_ds = RxRxDataset(
    trn_df, ROOT,
    crop=384, resize=sz, hat=(128, 128), norm=True, sigma_clip=5.0, label_smoothing=0.1)

val_ds = RxRxDataset(
    val_df, ROOT,
    crop=384, resize=sz, hat=(128, 128), norm=True, sigma_clip=5.0, label_smoothing=None)

In [14]:
#export
def new_loader(ds, bs, drop_last=False, shuffle=True, num_workers=12):
    return DataLoader(ds, batch_size=bs, drop_last=drop_last, 
                      shuffle=shuffle, num_workers=num_workers)

## Model

In [15]:
#export
def densenet(name='densenet121', n_classes=NUM_CLASSES):
    model_fn = pretrainedmodels.__dict__[name]
    model = model_fn(num_classes=1000, pretrained='imagenet')
    new_conv = nn.Conv2d(6, 64, 7, 2, 3, bias=False)
    conv0 = model.features.conv0.weight
    with torch.no_grad():
        new_conv.weight[:, :] = torch.stack([torch.mean(conv0, 1)]*6, dim=1)
    model.features.conv0 = new_conv
    return model

In [16]:
#export
from catalyst.contrib.modules import GlobalConcatPool2d
class DenseNet_TwoSites(nn.Module):
    def __init__(self, name, n_classes=NUM_CLASSES):
        super().__init__()
        
        base = densenet(name=name, n_classes=n_classes)
        feat_dim = base.last_linear.in_features
        
        self.base = base 
        self.pool = GlobalConcatPool2d()
        self.head = nn.Sequential(
            nn.Linear(feat_dim * 2, feat_dim * 2),
            nn.BatchNorm1d(feat_dim * 2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
            nn.Linear(feat_dim * 2, n_classes)
        )
        
    def forward(self, s1, s2):
        f1 = self.base.features(s1)
        f2 = self.base.features(s2)
        f_merged = self.pool(f1 + f2)
        out = self.head(f_merged.squeeze())
        return out

In [17]:
#export
def freeze_all(model):
    for name, child in model.named_children():
        print('Freezing layer:', name)
        for param in child.parameters():
            param.requires_grad = False

In [18]:
#export
def unfreeze_all(model):
    for name, child in model.named_children():
        print('Un-freezing layer:', name)
        for param in child.parameters():
            param.requires_grad = True

In [19]:
#export
def unfreeze_layers(model, names):
    for name, child in model.named_children():
        if name not in names:
            continue
        print('Un-freezing layer:', name)
        for param in child.parameters():
            param.requires_grad = True

In [20]:
#export
from torch.optim.lr_scheduler import _LRScheduler
class CosineDecay(_LRScheduler):
    def __init__(self, optimizer, total_steps,
                 linear_start=0,
                 linear_frac=0.1, min_lr=1e-6,
                 last_epoch=-1):
        
        self.optimizer = optimizer
        self.total_steps = total_steps
        self.linear_start = linear_start
        self.linear_frac = linear_frac
        self.min_lr = min_lr
        self.linear_steps = total_steps * linear_frac
        self.cosine_steps = total_steps - self.linear_steps
        super().__init__(optimizer, last_epoch)
        
    def get_lr(self):
        step = self.last_epoch
        if step <= self.linear_steps:
            b = self.linear_start
            return [(step/self.linear_steps) * (base_lr - b) + b for base_lr in self.base_lrs]
        else:
            t = self.last_epoch - self.linear_steps
            T = self.cosine_steps
            return [self.min_lr + (base_lr - self.min_lr)*(1 + np.cos(t*np.pi/T))/2
                    for base_lr in self.base_lrs]

## Train

In [None]:
#export
model = DenseNet_TwoSites('densenet121')
state = torch.load('densenet121_long_training/train.29.pth', map_location=lambda l, s: l)
model.load_state_dict(state['model'])
unfreeze_all(model)

In [None]:
#export
from visdom import Visdom

In [None]:
#export
class RollingLoss:
    def __init__(self, smooth=0.98):
        self.smooth = smooth
        self.prev = 0
    def __call__(self, curr, batch_no):
        a = self.smooth
        avg_loss = a*self.prev + (1 - a)*curr
        debias_loss = avg_loss/(1 - a**batch_no)
        self.prev = avg_loss
        return debias_loss

In [None]:
#export
def create_loaders(batch_size, drop_last=False):
    trn_dl = new_loader(trn_ds, bs=batch_size, drop_last=drop_last, shuffle=True)
    val_dl = new_loader(val_ds, bs=batch_size, drop_last=drop_last, shuffle=False)
    return OrderedDict([('train', trn_dl), ('valid', val_dl)])

In [None]:
#export
class Checkpoint:
    def __init__(self, output_dir):
        if os.path.exists(output_dir):
            print('Warning! Output folder already exists.')
        os.makedirs(output_dir, exist_ok=True)
        self.output_dir = output_dir
    
    def __call__(self, epoch, **objects):
        filename = os.path.join(self.output_dir, f'train.{epoch}.pth')
        checkpoint = {}
        for k, v in objects.items():
            if hasattr(v, 'state_dict'):
                v = v.state_dict()
            checkpoint[k] = v
        torch.save(checkpoint, filename)
        return filename

In [None]:
#export
class LabelSmoothingLoss(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, preds, one_hot_target):
        preds = preds.log_softmax(dim=self.dim)
        return torch.mean(torch.sum(-one_hot_target * preds, dim=self.dim))

In [None]:
#export
# loss_fn = nn.CrossEntropyLoss()
loss_fn = LabelSmoothingLoss()
device = torch.device('cuda:1')

In [None]:
#export
epochs = 50
patience = 15
base_lr = 3e-4

opt = Lookahead(torch.optim.AdamW(params=[
    {'params': model.head.parameters(),                      'lr': base_lr     },
    {'params': model.base.features.denseblock4.parameters(), 'lr': base_lr / 3 },
    {'params': model.base.features.denseblock3.parameters(), 'lr': base_lr / 5 },
    {'params': model.base.features.denseblock2.parameters(), 'lr': base_lr / 5 },
    {'params': model.base.features.denseblock1.parameters(), 'lr': base_lr / 10},
]))
opt = Lookahead(opt)
loaders = create_loaders(batch_size=12, drop_last=True)
sched = CosineDecay(
    optimizer=opt,
    total_steps=len(loaders['train']) * epochs,
    linear_start=base_lr / 100,
    linear_frac=0.1, 
    min_lr=base_lr / 300)
model = model.to(device)

vis = Visdom(server='0.0.0.0', port=9090,
             username=os.environ['VISDOM_USERNAME'],
             password=os.environ['VISDOM_PASSWORD'])

trials = 0
best_metric = -np.inf
history = []
stop = False
rolling_loss = dict(train=RollingLoss(), valid=RollingLoss())
steps = dict(train=0, valid=0)
checkpoint = Checkpoint('densenet121_pretrained')

for epoch in range(1, epochs+1):
    print(f'Epoch [{epoch}/{epochs}]')

    iteration = dict(epoch=epoch, train_loss=list(), valid_loss=list())
    
    for name, loader in loaders.items():
        is_training = name == 'train'
        count = 0
        metric = 0.0
        
        with torch.set_grad_enabled(is_training):
            for batch_no, batch in enumerate(loader):
                steps[name] += 1
                opt.zero_grad()

                y = batch['site1']['targets_one_hot'].to(device)
                
                out = model(
                    batch['site1']['features'].to(device),
                    batch['site2']['features'].to(device)
                )
                
                if is_training:
                    loss = loss_fn(out, y)
                    loss.backward()
                    opt.step()
                    sched.step()
                    
                    curr_lr = opt.param_groups[0]['lr']
                    vis.line(X=[steps[name]], Y=[curr_lr], win='lr', name='lr', update='append')    
                
                avg_loss = rolling_loss[name](loss.item(), steps[name])
                iteration[f'{name}_loss'].append(avg_loss)
                y_pred = out.softmax(dim=1).argmax(dim=1)
                y_true = batch['site1']['targets'].to(device)
                acc = (y_pred == y_true).float().mean().item()
                metric += acc
                count += len(batch)
                vis.line(X=[steps[name]], Y=[avg_loss], name=f'{name}_loss', 
                         win=f'{name}_loss', update='append', 
                         opts=dict(title=f'Running Loss [{name}]'))
        
        metric /= count
        iteration[f'{name}_acc'] = metric
        vis.line(X=[epoch], Y=[avg_loss], name=f'{name}', win='avg_loss',
                 update='append', opts=dict(title='Average Epoch Loss'))
        vis.line(X=[epoch], Y=[metric], name=f'{name}', win='accuracy', 
                 update='append', opts=dict(title=f'Accuracy'))
        
        last_loss = iteration[f'{name}_loss'][-1]
        
        print(f'{name} metrics: accuracy={metric:2.3%}, loss={last_loss:.4f}')
          
        if is_training:
            pass
          
        else:
            if metric > best_metric:
                trials = 0
                best_metric = metric
                print('Score improved!')
                checkpoint(epoch, model=model, opt=opt)

            else:
                trials += 1
                if trials >= patience:
                    stop = True
                    break
    
    history.append(iteration)
    
    print('-' * 80)
    
    if stop:
        print(f'Early stopping on epoch: {epoch}')
        break

torch.save(history, f'{checkpoint.output_dir}/history.pth')

## Test

In [21]:
device = torch.device('cuda:0')
model = DenseNet_TwoSites('densenet121')
model = model.to(device)
# state = torch.load('densenet121_15_cw/train.14.pth', map_location=lambda loc, storage: loc)
state = torch.load('densenet121_pretrained/train.37.pth',
                   map_location=lambda loc, storage: loc)
model.load_state_dict(state['model'])
freeze_all(model)
_ = model.eval()

Freezing layer: base
Freezing layer: pool
Freezing layer: head


In [54]:
sz = 512

tst_ds = RxRxDataset(
    tst_df, ROOT,
    crop=384, resize=sz, hat=(128, 128), norm=True, sigma_clip=5.0, 
    label_smoothing=None, train=False, flip_v=False, flip_h=False)

with torch.no_grad():
    test_dl = new_loader(tst_ds, shuffle=False, bs=64)
    probs = {}
    for batch in tqdm(test_dl):
        s1 = batch['site1']['features']
        s2 = batch['site2']['features']
        out = model(s1.to(device), s2.to(device))
        y_prob = out.softmax(dim=-1).cpu().numpy()
        probs.update(dict(zip(batch['site1']['id_code'], y_prob)))

probs_df = pd.DataFrame([
    {'id_code': id_code, 'prob_sirna': prob_sirna} 
    for id_code, prob_sirna in probs.items()])

HBox(children=(IntProgress(value=0, max=311), HTML(value='')))




In [55]:
preds_df = pd.DataFrame()
preds_df['id_code'] = probs_df.id_code
preds_df['sirna'] = probs_df.prob_sirna.map(lambda x: x.argmax())
preds_df.to_csv('densenet121_after_holdout.csv', index=False)

In [None]:
# preds_df = pd.DataFrame([
#     {'id_code': id_code, 'sirna': sirna} 
#     for id_code, sirna in preds.items()])
# preds_df.head(5)

In [None]:
# len(preds_df)

In [None]:
# filename = 'densenet121_with_tricks.csv'
# preds_df.to_csv(filename, index=False, columns=['id_code', 'sirna'])
# from IPython.display import FileLink
# FileLink(filename)

## Leak

In [56]:
trn_csv = pd.read_csv(ROOT/'train.csv')
tst_csv = pd.read_csv(ROOT/'test.csv')

plate_groups = np.zeros((1108,4), int)
for sirna in range(1108):
    grp = trn_csv.loc[trn_csv.sirna==sirna,:].plate.value_counts().index.values
    assert len(grp) == 3
    plate_groups[sirna,0:3] = grp
    plate_groups[sirna,3] = 10 - grp.sum()

subfile = 'densenet121_after_holdout'

sub = pd.read_csv(f'{subfile}.csv')

all_test_exp = tst_csv.experiment.unique()

group_plate_probs = np.zeros((len(all_test_exp),4))

for idx in range(len(all_test_exp)):
    preds = sub.loc[tst_csv.experiment == all_test_exp[idx],'sirna'].values
    pp_mult = np.zeros((len(preds),1108))
    pp_mult[range(len(preds)),preds] = 1
    
    sub_test = tst_csv.loc[tst_csv.experiment == all_test_exp[idx],:]
    assert len(pp_mult) == len(sub_test)
    
    for j in range(4):
        mask = np.repeat(plate_groups[np.newaxis, :, j], len(pp_mult), axis=0) == \
               np.repeat(sub_test.plate.values[:, np.newaxis], 1108, axis=1)
        
        group_plate_probs[idx,j] = np.array(pp_mult)[mask].sum()/len(pp_mult)

In [57]:
pd.DataFrame(group_plate_probs, index=all_test_exp)

Unnamed: 0,0,1,2,3
HEPG2-08,0.227642,0.202349,0.200542,0.369467
HEPG2-09,0.202166,0.304152,0.279783,0.213899
HEPG2-10,0.412455,0.185018,0.188628,0.213899
HEPG2-11,0.406872,0.169078,0.209765,0.214286
HUVEC-17,0.376354,0.203971,0.208484,0.211191
HUVEC-18,0.31346,0.239386,0.224932,0.222222
HUVEC-19,0.222022,0.196751,0.381769,0.199458
HUVEC-20,0.158845,0.200361,0.450361,0.190433
HUVEC-21,0.208484,0.232852,0.236462,0.322202
HUVEC-22,0.392599,0.183213,0.232852,0.191336


In [58]:
exp_to_group = group_plate_probs.argmax(1)
print(exp_to_group)

[3 1 0 0 0 0 2 2 3 0 0 3 1 0 0 0 2 3]


In [59]:
stacked = np.row_stack(probs_df.prob_sirna.values)

In [60]:
stacked.shape

(19897, 1108)

In [61]:
def select_plate_group(pp_mult, idx):
    sub_test = tst_csv.loc[tst_csv.experiment == all_test_exp[idx],:]
    assert len(pp_mult) == len(sub_test)
    mask = np.repeat(plate_groups[np.newaxis, :, exp_to_group[idx]], len(pp_mult), axis=0) != \
           np.repeat(sub_test.plate.values[:, np.newaxis], 1108, axis=1)
    pp_mult[mask] = 0
    return pp_mult

In [62]:
sub = sub.set_index('id_code')

In [63]:
for idx in range(len(all_test_exp)):
    indexes = tst_csv.experiment == all_test_exp[idx]
    preds = stacked[indexes, :].copy()
    preds = select_plate_group(preds, idx)
    sub.loc[tst_csv.id_code[indexes], 'sirna'] = preds.argmax(1)

In [64]:
sub = sub.reset_index()

In [65]:
(sub.sirna == pd.read_csv(f'{subfile}.csv').sirna).mean()

0.35638538473136655

In [66]:
from IPython.display import FileLink
sub.to_csv(f'{subfile}_same_size_leak.csv', index=False, columns=['id_code', 'sirna'])
FileLink(f'{subfile}_same_size_leak.csv')