In [None]:
!python -m jupytools export -nb "03_cell_type_refactor.ipynb" -o .

In [1]:
#export
from collections import Counter, defaultdict, OrderedDict
from functools import partial
import json
from operator import itemgetter
import os
from pdb import set_trace

import cv2
from imageio import imread
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL.Image
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import Dataset, DataLoader

import albumentations as T
from catalyst.contrib.schedulers import OneCycleLR
from catalyst.contrib.modules import GlobalConcatPool2d
from catalyst.dl import SupervisedRunner
from catalyst.dl.callbacks import AccuracyCallback
from catalyst.dl.core import Callback
from catalyst.utils import get_one_hot
import pretrainedmodels
from visdom import Visdom

from data_bunch import rio

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
W0908 17:27:18.908218 139690980988736 compression.py:14] lz4 not available, disabling compression. To install lz4, run `pip install lz4`.


In [2]:
#export
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
def six(rec):
    return np.stack([imread(filename) for _, filename in rec['images']])

def rgb(img):
    if img.shape[0] == 6:
        img = img.transpose(1, 2, 0)
    return rio.convert_tensor_to_rgb(img)

def show_1(img, ax=None):
    if img.shape[0] == 6:
        img = rgb(img)
    if img.shape[0] == 3:
        img = img.transpose(1, 2, 0)
    elif img.shape[-1] != 3:
        raise ValueError(f'wrong image shape: {img.shape}')
    if ax is not None:
        ax.imshow(img)
    else:
        plt.imshow(img)
    
def show(img, *imgs, titles=None, ncols=4):
    from itertools import zip_longest
    imgs = [img] + list(imgs)
    titles = titles or []
    nrows = int(np.ceil(len(imgs) / ncols))
    sz = max(nrows, ncols)*2
    f, axes = plt.subplots(nrows, ncols, figsize=(sz, sz))
    for ax, img, t in zip_longest(axes.flat, imgs, titles):
        if img is not None:
            show_1(img, ax)
        ax.set_title(t)
        ax.set_aspect('equal')
        ax.axis('off')
    f.subplots_adjust(wspace=0, hspace=0.5)
    return f

def sample(records, count=1):
    n = len(records)
    idx = np.random.choice(n, count)
    return [records[i] for i in idx]

def visualize(records, n=25):
    samples = sample(records, n)
    titles = ["{}\n{}".format(s['cell_type'], s['well_type']) for s in samples]
    show(*map(six, samples), titles=titles)

In [3]:
#export
def list_files(folder):
    dirname = os.path.expanduser(folder)
    return [os.path.join(dirname, x) for x in os.listdir(dirname)]

def load_data(filenames=None):

    def treatment_only(records): 
        return [r for r in records if r['well_type'] == 'treatment']
    
    def re_encode(records):
        enc = LabelEncoder()
        labels = [r['sirna'] for r in records]
        encoded = enc.fit_transform(labels)
        for enc, r in zip(encoded, records):
            r['enc_sirna'] = enc
        return enc

    if filenames is None:
        filenames = [f'{fn}.json' for fn in ('train', 'test')]
    
    data, encoders = [], []
    for filename in filenames:
        with open(filename) as f:
            records = treatment_only(json.load(f))
        encoder = re_encode(records)
        data.append(records)
        encoders.append(encoder)
    num_classes =  max([r['enc_sirna'] for r in data[0]]) + 1
    return {'data': data, 'encoders': encoders, 'num_classes': num_classes}

class RxRxDataset(Dataset):
    def __init__(self, items, num_classes, onehot=True, label_smoothing=None,
                 features_key='images', targets_key='enc_sirna',
                 channels_mode='six', drop_meta=True, open_fn=PIL.Image.open, tr=None):
        
        assert channels_mode in ('six', 'rgb')
        super().__init__()      
        self.items = items
        self.onehot = onehot
        self.label_smoothing = label_smoothing
        self.features_key = features_key
        self.targets_key = targets_key
        self.channels_mode = channels_mode
        self.drop_meta = drop_meta
        self.num_classes = num_classes
        self.open_fn = open_fn
        self.tr = tr
        
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, index):
        item = self.items[index].copy()
        bunch = sorted(item.pop(self.features_key), key=itemgetter(0))
        channels = OrderedDict()
        if self.channels_mode == 'six':
            for i, filename in bunch:
                channels[f'chan_{i}'] = np.array(self.open_fn(filename))
        elif self.channels_mode == 'rgb':
            img = np.stack([self.open_fn(filename) for filename in bunch])
            img = rio.convert_tensor_to_rgb(img)
            for i in range(3):
                channes[f'chan_{i}'] = img[i,:,:]
        y = item[self.targets_key]
        if self.drop_meta:
            sample = item
            sample['features'] = channels
            sample['targets'] = y
        else:
            sample = dict(features=channels, targets=y)
        if self.onehot:
            y_enc = get_one_hot(
                y, smoothing=self.label_smoothing,
                num_classes=self.num_classes)
            sample['targets_one_hot'] = y_enc
        return sample
    
    def join_channels(self, x):
        return np.stack([channel for _, channel in x['features'].items()])

In [4]:
#export
def augment(x, tr):
    main, *rest = list(x['features'].values())
    aug_input = dict(image=main)
    aug_input.update({f'image{i}': image for i, image in enumerate(rest)})
    augmented = tr(**aug_input)
    unpacked = [augmented['image']] + [augmented[f'image{i}'] for i in range(5)]
    sample = np.stack(unpacked)
    return sample

class Augmented(Dataset):
    def __init__(self, ds, tr):
        self.ds = ds
        self.tr = tr
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, index):
        sample = self.ds[index]
        aug_img = augment(sample, self.tr)
        sample['features'] = aug_img
        return sample

In [5]:
#export
class VisdomCallback(Callback):
    
    def __init__(self, 
                 username='username', password='password', 
                 host='0.0.0.0', port=9090, use_env_creds=False):
        """
        Args:
            username (str): Visdom server username.
            password (str): Visdom server password.
            host (str): Visdom server address.
            port (int): Visdom server port.
            use_env_creds (bool): If True, then ignore credentials
                passed as __init__ parameters and use Visdom 
                environment variables instead.
        """
        
        super().__init__()

        if use_env_creds:
            username = os.environ['VISDOM_USERNAME']
            password = os.environ['VISDOM_PASSWORD']
            
        self.vis = Visdom(
            username=username, password=password,
            server=host, port=port)        
        
class BatchMetricsPlotCallback(VisdomCallback):
    def on_batch_end(self, state):
        for k, v in state.metrics.batch_values.items():
            self.vis.line(X=[state.step], Y=[v], win=k, name=k, 
                          update='append', opts=dict(title=k))       

class EpochMetricsPlotCallback(VisdomCallback):
    def on_epoch_end(self, state):
        for k, v in state.metrics.batch_values.items():
            self.vis.line(X=[state.step], Y=[v], win=k, name=k,
                          update='append', opts=dict(title=k))

## Model

In [6]:
def get_model(model_name, num_classes, pretrained='imagenet'):
    model_fn = pretrainedmodels.__dict__[model_name]
    model = model_fn(num_classes=1000, pretrained=pretrained)
    dim_feats = model.last_linear.in_features
    model.last_linear = nn.Sequential()
    conv1 = model.conv1
    new_conv = nn.Conv2d(6, 64, 7, 2, 3, bias=False)
    new_conv.weight.data[:,0:3,:] = conv1.weight.data.clone()
    new_conv.weight.data[:,3:6,:] = conv1.weight.data.clone()
    model.conv1 = new_conv
    del conv1
    return model
    
def freeze_all(model):
    for param in model.parameters():
        param.requires_grad = False
        
def unfreeze_head(model):
    for param in model.last_linear.parameters():
        param.requires_grad = True
        
def get_layer(model, key):
    """Gets model layer using a key.
    
    The key could be hierarchical, like first.second.third where
    each dot separates hierarchy level.
    """
    parts = key.split('.')
    block = model
    for part in parts:
        block = getattr(block, part)
    return block

def unfreeze_layers(model, names):
    for name in names:
        layer = get_layer(model, name)
        print(f'Unfreezing layer {name}')
        for param in layer.parameters():
            param.requires_grad = True

## Data

In [8]:
#export
data_dict = load_data()
trn_rec, tst_rec = data_dict['data']
num_classes = data_dict['num_classes']

In [None]:
num_classes

In [None]:
#export
cell_types = defaultdict(list)
for record in trn_rec:
    cell_types[record['cell_type']].append(record)

train, valid = [], []
for ct, records in cell_types.items():
    print(f'Splitting train/test for type: {ct}')
    labels = np.array([r['enc_sirna'] for r in records])
    ct_train, ct_valid = train_test_split(records, stratify=labels, test_size=0.2)
    train.extend(ct_train)
    valid.extend(ct_valid)
    print(f'... split counts: {len(ct_train)}/{len(ct_valid)} [total: {len(records)}]')

transforms = T.Compose(
    additional_targets={f'image{i}': 'image' for i in range(5)},
    transforms=[
        T.Resize(224, 224), 
        T.VerticalFlip(p=0.25),
        T.HorizontalFlip(p=0.25),
        T.Normalize(mean=(0.5,), std=(0.5,))
    ]
)

batch_size = 650
trn_ds = Augmented(RxRxDataset(train, num_classes, onehot=True), tr=transforms)
val_ds = Augmented(RxRxDataset(valid, num_classes, onehot=True), tr=transforms)
loaders = dict(
    train=DataLoader(trn_ds, batch_size=batch_size, num_workers=12, shuffle=True),
    valid=DataLoader(val_ds, batch_size=batch_size, num_workers=12, shuffle=False)
)

## Loss

In [7]:
#export
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
        
    def forward(self, x, y):
        y_onehot = torch.eye(x.size(-1))[y.detach().clone()]
        logits = F.softmax(x.cpu(), dim=1).clamp(self.eps, 1 - self.eps)
        ce_loss = -1 * y_onehot * torch.log(logits)
        focal_loss = ce_loss * (1 - logits)**self.gamma
        return focal_loss.to(x.device).sum(dim=1).mean()

## Train

In [None]:
#export
epochs = 50
lrs = 1e-5, 1e-4, 1e-3
model = get_model('resnet50', num_classes)
freeze_all(model)
unfreeze_layers(model, ['layer4', 'last_linear'])
conv_lr, layer_lr, head_lr = lrs
opt = torch.optim.AdamW(
    params=[
        # {'params': model.conv1.parameters(), 'lr': conv_lr},
        {'params': model.layer4.parameters(), 'lr': layer_lr},
        {'params': model.last_linear.parameters(), 'lr': head_lr}
    ],
    weight_decay=0.01
)
sched = CosineAnnealingWarmRestarts(opt, T_0=len(loaders['train']), T_mult=2, eta_min=1e-6)
loss_fn = nn.CrossEntropyLoss()
runner = SupervisedRunner()

In [None]:
#export
runner.train(
    model=model,
    num_epochs=epochs,
    criterion=loss_fn,
    optimizer=opt,
    scheduler=sched,
    logdir='/tmp/cells_split/',
    loaders=loaders,
    callbacks=[
        AccuracyCallback(num_classes=num_classes),
        BatchMetricsPlotCallback(use_env_creds=True),
        EpochMetricsPlotCallback(use_env_creds=True)
    ],
    verbose=True
)

## Test

In [7]:
from basedir import NUM_CLASSES
logdir = os.path.expanduser('~/logs/protein/runs')
os.listdir(logdir)

['resnet34_224_ce_1linear_adamw_coswr_difflr_e25_bs650',
 'resnet34_224_focal_1linear_adamw_coswr_difflr_e25_bs650',
 'resnet101_224_ce_1linear_adamw_coswr_difflr_e25_bs650']

In [8]:
checkpoint = torch.load(os.path.join(
    logdir,
    # 'resnet34_224_ce_1linear_adamw_coswr_difflr_e25_bs650',
    'resnet34_224_focal_1linear_adamw_coswr_difflr_e25_bs650',
    'checkpoints',
    'best.pth'))

In [9]:
def get_model_with_simple_custom_head(model_name, num_classes, pretrained='imagenet'):
    model_fn = pretrainedmodels.__dict__[model_name]
    model = model_fn(num_classes=1000, pretrained=pretrained)
    dim_feats = model.last_linear.in_features
    model.last_linear = nn.Linear(dim_feats, num_classes)
    conv1 = model.conv1
    new_conv = nn.Conv2d(6, 64, 7, 2, 3, bias=False)
    new_conv.weight.data[:, 0:3, :] = conv1.weight.data.clone()
    new_conv.weight.data[:, 3:6, :] = conv1.weight.data.clone()
    model.conv1 = new_conv
    del conv1
    return model

In [10]:
dev = torch.device('cuda:1')
model = get_model_with_simple_custom_head('resnet34', NUM_CLASSES)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(dev)
for param in model.parameters():
    param.requires_grad = False

In [11]:
data_dict = load_data()
_, tst_rec = data_dict['data']
tst_ds = Augmented(
    ds=RxRxDataset(tst_rec, NUM_CLASSES, onehot=False), 
    tr=T.Compose(
        additional_targets={f'image{i}': 'image' for i in range(5)},
        transforms=[
            T.Resize(224, 224),
            T.Normalize((0.5,), (0.5),)
        ]
    )
)
tst_dl = DataLoader(tst_ds, batch_size=1024, shuffle=False, num_workers=12)

In [12]:
preds = []
for batch in tqdm(tst_dl):
    out = model(batch['features'].to(dev))
    y = out.softmax(dim=1)
    preds.extend(y.tolist())

HBox(children=(IntProgress(value=0, max=39), HTML(value='')))




In [None]:
model = build_model(num_classes)
checkpoints = []
best = '/tmp/cells_split/checkpoints/train.2.pth'
checkpoint = torch.load(best)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model.to('cuda:0')
for param in model.parameters():
    param.requires_grad = False

In [12]:
data_dict = load_data()
_, tst_rec = data_dict['data']
tst_ds = Augmented(RxRxDataset(tst_rec, NUM_CLASSES, onehot=False), tr=T.Compose([
    T.Normalize((0.5,), (0.5),)
]))
tst_dl = DataLoader(tst_ds, batch_size=512, shuffle=False, num_workers=12)

In [13]:
preds = []
for batch in tqdm(tst_dl):
    out = model(batch['features'].to(dev))
    y = out.softmax(dim=1)
    preds.extend(y.tolist())

HBox(children=(IntProgress(value=0, max=311), HTML(value='')))

Process Process-10:
Process Process-11:
Process Process-5:
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f1ca4ed5d08>
Traceback (most recent call last):
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 926, in __del__
    self._shutdown_workers()
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 906, in _shutdown_workers
    w.join()
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/popen_fork.py", line 48, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/popen_fork.py", line 28, in poll
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt: 


KeyboardInterrupt: 

Process Process-9:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/util.py", line 325, in _exit_function
    _run_finalizers()
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/util.py", line 325, in _exit_function
    _run_finalizers()
  File "/home/ck/anaconda3/envs/fasta

## Submit

In [13]:
def include_into_test(filename):
    basename, _ = os.path.splitext(os.path.basename(filename))
    sirna = int(basename.split('_')[-1])
    return sirna == 0

filenames = sorted(list_files('~/data/protein/tmp/test'))
filenames = [fn for fn in filenames if include_into_test(fn)]

# odd
site1 = []
for filename, pred in list(zip(filenames, preds))[::2]:
    basename, _ = os.path.splitext(os.path.basename(filename))
    sirna = int(basename.split('_')[-1])
    if sirna != 0: 
        continue
    site1.append(pred)
    
# even
site2 = []
for filename, pred in list(zip(filenames, preds))[1::2]:
    basename, _ = os.path.splitext(os.path.basename(filename))
    sirna = int(basename.split('_')[-1])
    if sirna != 0: 
        continue
    site2.append(pred)

In [14]:
t1 = torch.tensor(site1)
t2 = torch.tensor(site2)
avg_pred = ((t1 + t2)/2).argmax(dim=1)
print(avg_pred.shape)

torch.Size([19897])


In [15]:
sample = pd.read_csv('/home/ck/data/protein/sample_submission.csv')
sample['sirna'] = avg_pred.tolist()
sample.to_csv('submit.csv', index=False)
from IPython.display import FileLink
FileLink('submit.csv')