In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import jupytools.syspath
import numpy as np
import pandas as pd
from basedir import ROOT, NUM_CLASSES

In [3]:
jupytools.syspath.add('/home/ck/code/tasks/protein_project')

In [4]:
trn_csv = pd.read_csv(ROOT/'train.csv')
tst_csv = pd.read_csv(ROOT/'test.csv')

In [5]:
plate_groups = np.zeros((1108,4), int)
for sirna in range(1108):
    grp = trn_csv.loc[trn_csv.sirna==sirna,:].plate.value_counts().index.values
    assert len(grp) == 3
    plate_groups[sirna,0:3] = grp
    plate_groups[sirna,3] = 10 - grp.sum()

In [24]:
# sub = pd.read_csv('densenet121_two_way_512.csv')
sub = pd.read_csv('densenet121_long_training_e29.csv')

In [7]:
all_test_exp = tst_csv.experiment.unique()

group_plate_probs = np.zeros((len(all_test_exp),4))
for idx in range(len(all_test_exp)):
    preds = sub.loc[tst_csv.experiment == all_test_exp[idx],'sirna'].values
    pp_mult = np.zeros((len(preds),1108))
    pp_mult[range(len(preds)),preds] = 1
    
    sub_test = tst_csv.loc[tst_csv.experiment == all_test_exp[idx],:]
    assert len(pp_mult) == len(sub_test)
    
    for j in range(4):
        mask = np.repeat(plate_groups[np.newaxis, :, j], len(pp_mult), axis=0) == \
               np.repeat(sub_test.plate.values[:, np.newaxis], 1108, axis=1)
        
        group_plate_probs[idx,j] = np.array(pp_mult)[mask].sum()/len(pp_mult)

In [8]:
pd.DataFrame(group_plate_probs, index=all_test_exp)

Unnamed: 0,0,1,2,3
HEPG2-08,0.109304,0.112918,0.122855,0.654923
HEPG2-09,0.147112,0.481949,0.199458,0.17148
HEPG2-10,0.711191,0.094765,0.093863,0.100181
HEPG2-11,0.765823,0.072333,0.079566,0.082278
HUVEC-17,0.791516,0.064079,0.07491,0.069495
HUVEC-18,0.620596,0.144535,0.109304,0.125565
HUVEC-19,0.08935,0.086643,0.725632,0.098375
HUVEC-20,0.037004,0.035199,0.882671,0.045126
HUVEC-21,0.08574,0.09296,0.110108,0.711191
HUVEC-22,0.812274,0.055054,0.075812,0.056859


In [9]:
# reference
# exp_to_group = group_plate_probs.argmax(1)
# print(exp_to_group)

In [10]:
exp_to_group = group_plate_probs.argmax(1)
print(exp_to_group)

[3 1 0 0 0 0 2 2 3 0 0 3 1 0 0 0 2 3]


In [11]:
from augmentation import JoinChannels, SwapChannels, Resize, ToFloat, Rescale
from augmentation import VerticalFlip, HorizontalFlip, PixelStatsNorm, composer
from augmentation import AugmentedImages, bernoulli
from imageio import imread
from torch.utils.data import Dataset

default_open_fn = imread  # PIL.Image.open

class RxRxImages(Dataset):
    def __init__(self, meta_df, img_dir, site=1, channels=(1, 2, 3, 4, 5, 6), 
                 open_image=default_open_fn, n_classes=NUM_CLASSES, train=True):
        
        self.records = meta_df.to_records(index=False)
        self.img_dir = img_dir
        self.site = site
        self.channels = channels
        self.n = len(self.records)
        self.open_image = open_image
        self.n_classes = n_classes
        self.train = train
        
    def _get_image_path(self, index, channel):
        r = self.records[index]
        exp, plate, well = r.experiment, r.plate, r.well
        subdir = 'train' if self.train else 'test'
        path = f'{self.img_dir}/{subdir}/{exp}/Plate{plate}/{well}_s{self.site}_w{channel}.png'
        return path
    
    def __getitem__(self, index):
        paths = [self._get_image_path(index, ch) for ch in self.channels]
        images = [self.open_image(p) for p in paths]

        try:
            img = np.stack(images)
        except (TypeError, ValueError) as e:
            print(f'Warning: cannot concatenate images! {e.__class__.__name__}: {e}')
            for filename, image in zip(paths, images):
                print(f'\tpath={filename}, size={image.size}')
            index = (index + 1) % len(self)
            print(f'Skipping instance {index} and trying another one...')
            return self[index]
        finally:
            for image in images:
                if hasattr(image, 'close'):
                    image.close()
            
        img = img.astype(np.float32)
        img = img.transpose(1, 2, 0)
        r = self.records[index]
        if self.train:
            sirna = r.sirna
            target = int(sirna)
            onehot = get_one_hot(target, num_classes=self.n_classes)
            return {'features': img, 'targets': target, 
                    'targets_one_hot': onehot, 'id_code': r.id_code,
                    'site': self.site}
        else:
            id_code = r.id_code
            return {'features': img, 'id_code': id_code, 'site': self.site}
    
    def __len__(self): 
        return self.n
    
class TwoSiteImages(Dataset):
    def __init__(self, ds1, ds2, swap=0.0):
        assert len(ds1) == len(ds2)
        self.ds1, self.ds2 = ds1, ds2
        self.swap = swap
        self.size = len(ds1)
        
    def __getitem__(self, index):
        s1, s2 = self.ds1[index], self.ds2[index]
        if self.swap and bernoulli(self.swap) == 1:
            s1, s2 = s2, s1
        return {'site1': s1, 'site2': s2}
    
    def __len__(self):
        return self.size

In [12]:
from catalyst.contrib.modules import GlobalConcatPool2d
import pretrainedmodels
import torch.nn as nn

def densenet(name='densenet121', n_classes=NUM_CLASSES):
    model_fn = pretrainedmodels.__dict__[name]
    model = model_fn(num_classes=1000, pretrained='imagenet')
    new_conv = nn.Conv2d(6, 64, 7, 2, 3, bias=False)
    conv0 = model.features.conv0.weight
    with torch.no_grad():
        new_conv.weight[:, :] = torch.stack([torch.mean(conv0, 1)]*6, dim=1)
    model.features.conv0 = new_conv
    return model

class DenseNet_TwoSites(nn.Module):
    def __init__(self, name, n_classes=NUM_CLASSES):
        super().__init__()
        
        base = densenet(name=name, n_classes=n_classes)
        feat_dim = base.last_linear.in_features
        
        self.base = base 
        self.pool = GlobalConcatPool2d()
        self.head = nn.Sequential(
            nn.Linear(feat_dim * 2, feat_dim * 2),
            nn.BatchNorm1d(feat_dim * 2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
            nn.Linear(feat_dim * 2, n_classes)
        )
        
    def forward(self, s1, s2):
        f1 = self.base.features(s1)
        f2 = self.base.features(s2)
        f_merged = self.pool(f1 + f2)
        out = self.head(f_merged.squeeze())
        return out
    
def freeze_all(model):
    for name, child in model.named_children():
        print('Freezing layer:', name)
        for param in child.parameters():
            param.requires_grad = False
            
def unfreeze_all(model):
    for name, child in model.named_children():
        print('Un-freezing layer:', name)
        for param in child.parameters():
            param.requires_grad = True
            
def unfreeze_layers(model, names):
    for name, child in model.named_children():
        if name not in names:
            continue
        print('Un-freezing layer:', name)
        for param in child.parameters():
            param.requires_grad = True

In [13]:
from dataset import build_stats_index

sz = 512

tst_df = pd.read_csv(ROOT/'test.csv')

stats = build_stats_index(ROOT/'pixel_stats.csv')

tst_ds = TwoSiteImages(
    ds1=AugmentedImages(ds=RxRxImages(tst_df, ROOT, site=1, train=False), tr=composer([
        PixelStatsNorm(stats, channels_first=False)
    ], resize=sz, rescale=False)),
    ds2=AugmentedImages(ds=RxRxImages(tst_df, ROOT, site=2, train=False), tr=composer([
        PixelStatsNorm(stats, channels_first=False)
    ], resize=sz, rescale=False))
)

In [14]:
import torch

In [15]:
device = torch.device('cuda:1')
model = DenseNet_TwoSites('densenet121')
freeze_all(model)

model.to(device)
# state = torch.load('densenet121_15_cw/train.14.pth', map_location=lambda loc, storage: loc)
# model.load_state_dict(state)
state = torch.load('densenet121_long_training/train.29.pth', 
                   map_location=lambda loc, storage: loc)
model.load_state_dict(state['model'])
_ = model.eval()

tst_df = pd.read_csv(ROOT/'test.csv')

stats = build_stats_index(ROOT/'pixel_stats.csv')

tst_ds = TwoSiteImages(
    ds1=AugmentedImages(ds=RxRxImages(tst_df, ROOT, site=1, train=False), tr=composer([
        PixelStatsNorm(stats, channels_first=False)
    ], resize=sz, rescale=False)),
    ds2=AugmentedImages(ds=RxRxImages(tst_df, ROOT, site=2, train=False), tr=composer([
        PixelStatsNorm(stats, channels_first=False)
    ], resize=sz, rescale=False))
)

Freezing layer: base
Freezing layer: pool
Freezing layer: head


In [16]:
from torch.utils.data import DataLoader

In [17]:
def new_loader(ds, bs, drop_last=False, shuffle=True, num_workers=12):
    return DataLoader(ds, batch_size=bs, drop_last=drop_last, 
                      shuffle=shuffle, num_workers=num_workers)

In [18]:
from tqdm import tqdm_notebook as tqdm

In [19]:
with torch.no_grad():
    test_dl = new_loader(tst_ds, bs=64, shuffle=False)
    preds = []
    for batch in tqdm(test_dl):
        s1 = batch['site1']['features']
        s2 = batch['site2']['features']
        out = model(s1.to(device), s2.to(device))
        probs = out.softmax(dim=-1).cpu().numpy()
        preds.append(probs)

HBox(children=(IntProgress(value=0, max=311), HTML(value='')))




In [20]:
stacked = np.row_stack(preds).squeeze()

In [21]:
stacked.shape

(19897, 1108)

In [25]:
def select_plate_group(pp_mult, idx):
    sub_test = tst_csv.loc[tst_csv.experiment == all_test_exp[idx],:]
    assert len(pp_mult) == len(sub_test)
    mask = np.repeat(plate_groups[np.newaxis, :, exp_to_group[idx]], len(pp_mult), axis=0) != \
           np.repeat(sub_test.plate.values[:, np.newaxis], 1108, axis=1)
    pp_mult[mask] = 0
    return pp_mult

In [26]:
sub = sub.set_index('id_code')

In [27]:
for idx in range(len(all_test_exp)):
    indexes = tst_csv.experiment == all_test_exp[idx]
    preds = stacked[indexes, :].copy()
    preds = select_plate_group(preds, idx)
    sub.loc[tst_csv.id_code[indexes], 'sirna'] = preds.argmax(1)

In [28]:
sub = sub.reset_index()

In [29]:
(sub.sirna == pd.read_csv('densenet121_two_way_512.csv').sirna).mean()

0.46027039252148566

In [30]:
from IPython.display import FileLink
sub.to_csv('densenet121_long_training_e29_leak.csv', index=False, columns=['id_code', 'sirna'])
FileLink('densenet121_long_training_e29_leak.csv')