In [1]:
from pdb import set_trace

import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from basedir import ROOT

In [2]:
f1 = pd.read_csv('stacked/train/densenet121_long_training.csv')
f2 = pd.read_csv('stacked/train/densenet121_with_tricks.csv')
f3 = pd.read_csv('stacked/train/densenet169.csv')

In [3]:
trn_df = pd.read_csv(ROOT/'train.csv')

In [4]:
X = np.column_stack([f.drop(columns='id_code').values for f in (f1, f2, f3)])
y = trn_df['sirna'].values

In [5]:
X.shape, y.shape

((36515, 3324), (36515,))

In [6]:
X = torch.tensor(X).float()
y = torch.tensor(y)
print(X.shape, X.type(), y.shape, y.type())

torch.Size([36515, 3324]) torch.FloatTensor torch.Size([36515]) torch.LongTensor


In [7]:
dataset = TensorDataset(X, y)

In [13]:
class LogisticClassifier(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.fc1 = nn.Linear(3324, 2048)
        self.fc2 = nn.Linear(2048, 1024)
        self.logreg = nn.Linear(1024, n_classes)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.logreg(x)
        return x

In [15]:
device = torch.device('cuda:0')
logreg = LogisticClassifier(1108)
logreg.to(device)
opt = torch.optim.AdamW(logreg.parameters())

loss_fn = nn.CrossEntropyLoss()

for i in range(10):
    for j, (xb, yb) in enumerate(DataLoader(dataset, shuffle=True, batch_size=1024)):
        opt.zero_grad()
        xb, yb = xb.to(device), yb.to(device)
        out = logreg(xb)
        loss = loss_fn(out, yb)
        loss.backward()
        opt.step()
        print(f'[{i}][{j}] Loss: {loss.item():.4f}')

[0][0] Loss: 7.0111
[0][1] Loss: 7.0039
[0][2] Loss: 6.9970
[0][3] Loss: 6.9851
[0][4] Loss: 6.9763
[0][5] Loss: 6.9659
[0][6] Loss: 6.9517
[0][7] Loss: 6.9359
[0][8] Loss: 6.9148
[0][9] Loss: 6.9015
[0][10] Loss: 6.8792
[0][11] Loss: 6.8568
[0][12] Loss: 6.8362
[0][13] Loss: 6.8100
[0][14] Loss: 6.7754
[0][15] Loss: 6.7296
[0][16] Loss: 6.7049
[0][17] Loss: 6.6776
[0][18] Loss: 6.6192
[0][19] Loss: 6.5915
[0][20] Loss: 6.5343
[0][21] Loss: 6.4677
[0][22] Loss: 6.3976
[0][23] Loss: 6.3470
[0][24] Loss: 6.2673
[0][25] Loss: 6.1761
[0][26] Loss: 6.1197
[0][27] Loss: 6.0255
[0][28] Loss: 5.9288
[0][29] Loss: 5.8081
[0][30] Loss: 5.6807
[0][31] Loss: 5.5962
[0][32] Loss: 5.4307
[0][33] Loss: 5.2785
[0][34] Loss: 5.1475
[0][35] Loss: 5.0174
[1][0] Loss: 4.5792
[1][1] Loss: 4.4304
[1][2] Loss: 4.1691
[1][3] Loss: 3.9974
[1][4] Loss: 3.6870
[1][5] Loss: 3.5483
[1][6] Loss: 3.3617
[1][7] Loss: 3.1556
[1][8] Loss: 2.8102
[1][9] Loss: 2.5237
[1][10] Loss: 2.2964
[1][11] Loss: 1.9804
[1][12] Loss

In [16]:
f1 = pd.read_csv('stacked/test/densenet121_long_training.csv')
f2 = pd.read_csv('stacked/test/densenet121_with_tricks.csv')
f3 = pd.read_csv('stacked/test/densenet169.csv')
X_test = np.column_stack([f.drop(columns='id_code').values for f in (f1, f2, f3)])

In [18]:
X_test = torch.tensor(X_test).to(device).float()

In [27]:
with torch.no_grad():
    logreg.eval()
    probs = []
    for xb in tqdm(DataLoader(X_test, batch_size=128, shuffle=False)):
        out = logreg(xb)
        out = out.softmax(dim=-1).cpu().numpy()
        probs.extend(out)

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




In [50]:
stacked.shape

(19897, 1108)

In [28]:
stacked = np.row_stack(preds)

In [31]:
sub = pd.DataFrame({'id_code': f1.id_code, 'sirna': stacked.argmax(axis=-1)})

In [33]:
trn_csv = pd.read_csv(ROOT/'train.csv')
tst_csv = pd.read_csv(ROOT/'test.csv')

In [34]:
plate_groups = np.zeros((1108,4), int)
for sirna in range(1108):
    grp = trn_csv.loc[trn_csv.sirna==sirna,:].plate.value_counts().index.values
    assert len(grp) == 3
    plate_groups[sirna,0:3] = grp
    plate_groups[sirna,3] = 10 - grp.sum()

In [35]:
all_test_exp = tst_csv.experiment.unique()

In [36]:
group_plate_probs = np.zeros((len(all_test_exp),4))
for idx in range(len(all_test_exp)):
    preds = sub.loc[tst_csv.experiment == all_test_exp[idx],'sirna'].values
    pp_mult = np.zeros((len(preds),1108))
    pp_mult[range(len(preds)),preds] = 1
    
    sub_test = tst_csv.loc[tst_csv.experiment == all_test_exp[idx],:]
    assert len(pp_mult) == len(sub_test)
    
    for j in range(4):
        mask = np.repeat(plate_groups[np.newaxis, :, j], len(pp_mult), axis=0) == \
               np.repeat(sub_test.plate.values[:, np.newaxis], 1108, axis=1)
        
        group_plate_probs[idx,j] = np.array(pp_mult)[mask].sum()/len(pp_mult)

In [37]:
pd.DataFrame(group_plate_probs, index=all_test_exp)

Unnamed: 0,0,1,2,3
HEPG2-08,0.084914,0.093044,0.112918,0.709124
HEPG2-09,0.16065,0.488267,0.181408,0.169675
HEPG2-10,0.769856,0.072202,0.074007,0.083935
HEPG2-11,0.82821,0.04792,0.055154,0.068716
HUVEC-17,0.830325,0.050542,0.057762,0.061372
HUVEC-18,0.65131,0.130985,0.105691,0.112014
HUVEC-19,0.069495,0.079422,0.76444,0.086643
HUVEC-20,0.025271,0.031588,0.912455,0.030686
HUVEC-21,0.080325,0.080325,0.086643,0.752708
HUVEC-22,0.843863,0.041516,0.065884,0.048736


In [38]:
exp_to_group = group_plate_probs.argmax(1)
print(exp_to_group)

[3 1 0 0 0 0 2 2 3 0 0 3 1 0 0 0 2 3]


In [39]:
def select_plate_group(pp_mult, idx):
    sub_test = tst_csv.loc[tst_csv.experiment == all_test_exp[idx],:]
    assert len(pp_mult) == len(sub_test)
    mask = np.repeat(plate_groups[np.newaxis, :, exp_to_group[idx]], len(pp_mult), axis=0) != \
           np.repeat(sub_test.plate.values[:, np.newaxis], 1108, axis=1)
    pp_mult[mask] = 0
    return pp_mult

In [40]:
sub = sub.set_index('id_code')

In [41]:
for idx in range(len(all_test_exp)):
    indexes = tst_csv.experiment == all_test_exp[idx]
    preds = stacked[indexes, :].copy()
    preds = select_plate_group(preds, idx)
    sub.loc[tst_csv.id_code[indexes], 'sirna'] = preds.argmax(1)

In [42]:
sub = sub.reset_index()

In [45]:
from IPython.display import FileLink
sub.to_csv('stack_leak.csv', index=False, columns=['id_code', 'sirna'])
FileLink('stack_leak.csv')