In [31]:
from imageio import imread
import matplotlib.pyplot as plt
from data_bunch import rio
from basedir import NUM_CLASSES

from collections import defaultdict, OrderedDict
from functools import partial
import json
from operator import itemgetter
import os

from catalyst.utils import get_one_hot
import cv2 as cv
import numpy as np
import PIL.Image
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader

In [5]:
def six(rec):
    return np.stack([imread(filename) for _, filename in rec['images']])

def rgb(img):
    if img.shape[0] == 6:
        img = img.transpose(1, 2, 0)
    return rio.convert_tensor_to_rgb(img)

def show_1(img, ax=None):
    if img.shape[0] == 6:
        img = rgb(img)
    if img.shape[0] == 3:
        img = img.transpose(1, 2, 0)
    elif img.shape[-1] != 3:
        raise ValueError(f'wrong image shape: {img.shape}')
    if ax is not None:
        ax.imshow(img)
    else:
        plt.imshow(img)
    
def show(img, *imgs, titles=None, ncols=4):
    from itertools import zip_longest
    imgs = [img] + list(imgs)
    titles = titles or []
    nrows = int(np.ceil(len(imgs) / ncols))
    sz = max(nrows, ncols)*2
    f, axes = plt.subplots(nrows, ncols, figsize=(sz, sz))
    for ax, img, t in zip_longest(axes.flat, imgs, titles):
        if img is not None:
            show_1(img, ax)
        ax.set_title(t)
        ax.set_aspect('equal')
        ax.axis('off')
    f.subplots_adjust(wspace=0, hspace=0.5)
    return f

def sample(records, count=1):
    n = len(records)
    idx = np.random.choice(n, count)
    return [records[i] for i in idx]

def visualize(records, n=25):
    samples = sample(records, n)
    titles = ["{}\n{}".format(s['cell_type'], s['well_type']) for s in samples]
    show(*map(six, samples), titles=titles)

In [6]:
def list_files(folder):
    dirname = os.path.expanduser(folder)
    return [os.path.join(dirname, x) for x in os.listdir(dirname)]

def load_data(filenames=None):
    if filenames is None:
        filenames = [f'{fn}.json' for fn in ('train', 'test')]
    content = []
    for filename in filenames:
        with open(filename) as f:
            content.append(json.load(f))
    return content

class RxRxDataset(Dataset):
    def __init__(self, items, onehot=True, label_smoothing=None,
                 features_key='images', targets_key='sirna',
                 channels_mode='six', drop_meta=True, 
                 num_classes=NUM_CLASSES, open_fn=PIL.Image.open, tr=None):
        
        assert channels_mode in ('six', 'rgb')
        super().__init__()      
        self.items = items
        self.onehot = onehot
        self.label_smoothing = label_smoothing
        self.features_key = features_key
        self.targets_key = targets_key
        self.channels_mode = channels_mode
        self.drop_meta = drop_meta
        self.num_classes = num_classes
        self.open_fn = open_fn
        self.tr = tr
        
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, index):
        item = self.items[index].copy()
        bunch = sorted(item.pop(self.features_key), key=itemgetter(0))
        channels = OrderedDict()
        if self.channels_mode == 'six':
            for i, filename in bunch:
                channels[f'chan_{i}'] = np.array(self.open_fn(filename))
        elif self.channels_mode == 'rgb':
            img = np.stack([self.open_fn(filename) for filename in bunch])
            img = rio.convert_tensor_to_rgb(img)
            for i in range(3):
                channes[f'chan_{i}'] = img[i,:,:]
        y = item[self.targets_key]
        if self.drop_meta:
            sample = item
            sample['features'] = channels
            sample['targets'] = y
        else:
            sample = dict(features=channels, targets=y)
        if self.onehot:
            y_enc = get_one_hot(
                y, smoothing=self.label_smoothing,
                num_classes=self.num_classes)
            sample['targets_one_hot'] = y_enc
        return sample
    
    def join_channels(self, x):
        return np.stack([channel for _, channel in x['features'].items()])

In [7]:
trn_rec, _ = load_data()

In [None]:
# visualize(trn_rec)

In [8]:
rxrx_ds = RxRxDataset(trn_rec, onehot=True)
x = rxrx_ds[9090]
img = rxrx_ds.join_channels(x)

In [49]:
import albumentations as T

In [50]:
transforms = T.Compose(
    additional_targets={f'image{i}': 'image' for i in range(5)},
    transforms=[
        T.VerticalFlip(p=0.5),
        T.HorizontalFlip(p=0.5),
        T.Rotate(p=0.25)
    ]
)

In [51]:
def augment(x, tr):
    main, *rest = list(x['features'].values())
    aug_input = dict(image=main)
    aug_input.update({f'image{i}': image for i, image in enumerate(rest)})
    augmented = tr(**aug_input)
    unpacked = [augmented['image']] + [augmented[f'image{i}'] for i in range(5)]
    sample = np.stack(unpacked)
    return sample

In [52]:
class Augmented(Dataset):
    def __init__(self, ds, tr):
        self.ds = ds
        self.tr = tr
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, index):
        return augment(self.ds[index], self.tr)

In [53]:
trn_rec[:1]

[{'images': [[1, '/home/ck/data/protein/train/HEPG2-01/Plate1/B02_s1_w1.png'],
   [2, '/home/ck/data/protein/train/HEPG2-01/Plate1/B02_s1_w2.png'],
   [3, '/home/ck/data/protein/train/HEPG2-01/Plate1/B02_s1_w3.png'],
   [4, '/home/ck/data/protein/train/HEPG2-01/Plate1/B02_s1_w4.png'],
   [5, '/home/ck/data/protein/train/HEPG2-01/Plate1/B02_s1_w5.png'],
   [6, '/home/ck/data/protein/train/HEPG2-01/Plate1/B02_s1_w6.png']],
  'sirna': 1138,
  'site': 1,
  'cell_type': 'HEPG2',
  'experiment': 'HEPG2-01',
  'well_type': 'negative_control',
  'plate': 1}]

In [34]:
cell_types = defaultdict(list)
for record in trn_rec:
    cell_types[record['cell_type']].append(record)    

In [None]:
for ct, records in cell_types.items():
    [r for r in records]

In [13]:
aug_ds = Augmented(rxrx_ds, transforms)

In [14]:
import torch
import torch.nn as nn 
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader
import torchvision.transforms as T
import pretrainedmodels
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm_notebook as tqdm

from catalyst.dl import SupervisedRunner
from catalyst.contrib.schedulers import OneCycleLR
from catalyst.contrib.modules import GlobalConcatPool2d
from catalyst.dl.callbacks import AccuracyCallback, AUCCallback

In [15]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [23]:
#export
from catalyst.dl.core import Callback
from visdom import Visdom
from pdb import set_trace


class VisdomCallback(Callback):
    
    def __init__(self, 
                 username='username', password='password', 
                 host='0.0.0.0', port=9090, use_env_creds=False):
        """
        Args:
            username (str): Visdom server username.
            password (str): Visdom server password.
            host (str): Visdom server address.
            port (int): Visdom server port.
            use_env_creds (bool): If True, then ignore credentials
                passed as __init__ parameters and use Visdom 
                environment variables instead.
        """
        
        super().__init__()

        if use_env_creds:
            username = os.environ['VISDOM_USERNAME']
            password = os.environ['VISDOM_PASSWORD']
            
        self.vis = Visdom(
            username=username, password=password,
            server=host, port=port)        

        
class BatchMetricsPlotCallback(VisdomCallback):

    def on_batch_end(self, state):
        for k, v in state.metrics.batch_values.items():
            self.vis.line(X=[state.step], Y=[v], win=k, name=k, 
                          update='append', opts=dict(title=k))
            
            
class EpochMetricsPlotCallback(VisdomCallback):
    
    def on_epoch_end(self, state):
        for k, v in state.metrics.batch_values.items():
            self.vis.line(X=[state.step], Y=[v], win=k, name=k,
                          update='append', opts=dict(title=k))

In [24]:
#export
def build_model():
    model = pretrainedmodels.resnet50()
    model.avgpool = GlobalConcatPool2d()
    model.last_linear = nn.Sequential(
        nn.Linear(4096, 2048),
        nn.Dropout(0.5),
        nn.ReLU(inplace=True),
        nn.Linear(2048, 1024),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(1024, 4)
    )
    conv1 = model.conv1
    new_conv = nn.Conv2d(6, 64, 7, 2, 3, bias=False)
    new_conv.weight.data[:,0:3,:] = conv1.weight.data.clone()
    new_conv.weight.data[:,3:6,:] = conv1.weight.data.clone()
    model.conv1 = new_conv
    del conv1
    return model
    
def freeze_all(model):
    for param in model.parameters():
        param.requires_grad = False
        
def unfreeze_head(model):
    for param in model.last_linear.parameters():
        param.requires_grad = True

In [25]:
#export
epochs = 3
model = build_model()
freeze_all(model)
unfreeze_head(model)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)
sched = CosineAnnealingWarmRestarts(opt, T_0=len(aug_ds), T_mult=2, eta_min=1e-6)
loss_fn = nn.CrossEntropyLoss()
runner = SupervisedRunner()

In [26]:
loaders = dict(train=aug_ds)

In [27]:
#export
runner.train(
    model=model,
    num_epochs=epochs,
    criterion=loss_fn,
    optimizer=opt,
    scheduler=sched,
    logdir='/tmp/cells_split/',
    loaders=loaders,
    callbacks=[
        AccuracyCallback(num_classes=4),
        BatchMetricsPlotCallback(use_env_creds=True),
        EpochMetricsPlotCallback(use_env_creds=True)
    ],
    verbose=True
)

W0906 23:50:23.888358 139970826073920 __init__.py:505] Setting up a new session...
W0906 23:50:23.996115 139970826073920 __init__.py:505] Setting up a new session...


AssertionError: 'valid' should be in provided loaders: ['train']