In [1]:
!trash-empty

In [2]:
from radam import RAdam, PlainRAdam, AdamW
from am_softmax import AMSoftmaxLoss, AngleSimpleLinear

In [3]:
import os
import gc
import numpy as np 
import pandas as pd
from PIL import Image
from tqdm import tqdm_notebook
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.utils.data as D
from torch.optim.lr_scheduler import ExponentialLR
from torchvision import models, transforms as T
import torch.nn.functional as F

from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
from ignite.metrics import Loss, Accuracy
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from ignite.handlers import  EarlyStopping, ModelCheckpoint

# from am_softmax import AMSoftmaxLoss

import warnings
warnings.filterwarnings('ignore')

torch.cuda.empty_cache()

In [4]:
config = {
    'SEED': 42,
    'CLASSES': 1108,
    'PATH_DATA': '/home/tienen/kaggle_dataset_drugs/',
    'DEVICE': 'cuda',
    'BATCH_SIZE': 8,
    'VAL_SIZE': 0.05,
    'MODEL_NAME': 'DenseNet201_AMSLoss',
    'USE_ANGULAR': True,
    'USE_BN': True,
    'LR': 1e-4,
    'LR_STR': '1e-4',
    'TURN_OFF_ON_N_EPOCHS': 1,
}

best_epoch = 8

In [5]:
def seed_torch(seed=42):
    import random; import os
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_torch(config['SEED'])

## Phase 2: train on each category

In [7]:
class ImagesDS(D.Dataset):
    def __init__(self, df, img_dir, mode='train', site=1, channels=[1,2,3,4,5,6]):
        self.records = df.to_records(index=False)
        self.channels = channels
        self.site = site
        self.mode = mode
        self.img_dir = img_dir
        self.len = df.shape[0]
        
    @staticmethod
    def _load_img_as_tensor(file_name):
        with Image.open(file_name) as img:
            return T.ToTensor()(img)

    def _get_img_path(self, index, channel):
        experiment, well, plate = self.records[index].experiment, self.records[index].well, self.records[index].plate
        return '/'.join([self.img_dir,self.mode,experiment,f'Plate{plate}',f'{well}_s{self.site}_w{channel}.png'])
        
    def __getitem__(self, index):
        paths = [self._get_img_path(index, ch) for ch in self.channels]
        img = torch.cat([self._load_img_as_tensor(img_path) for img_path in paths])
        if self.mode == 'train':
            return img, int(self.records[index].sirna)
        else:
            return img, self.records[index].id_code

    def __len__(self):
        return self.len

In [8]:
class DenseNet(nn.Module):
    def __init__(self, num_classes=1000, num_channels=6, use_bn=False, use_angular=False):
        super().__init__()
        self.use_angular = use_angular
        self.use_bn = use_bn
        if self.use_bn:
            self.bn = nn.BatchNorm2d(6)
        
        preloaded = models.densenet201(pretrained=True)
        # print(preloaded)
        self.features = preloaded.features
        
        new_conv = nn.Conv2d(num_channels, 64, 7, 2, 3, bias=False)
        trained_kernel = self.features.conv0.weight
        with torch.no_grad():
            new_conv.weight[:,:] = torch.stack([torch.mean(trained_kernel, 1)]*6, dim=1)
        
        self.features.conv0 = new_conv
        
        if self.use_angular:
            self.fc_angular = AngleSimpleLinear(1920, num_classes)
        else:
            self.fc = nn.Linear(in_features=1920, out_features=num_classes, bias=True)
            
        del preloaded
        
    def forward(self, x):
        if self.use_bn:
            x = self.bn(x)
        x = self.features(x)
        x = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1)
        
        if self.use_angular:
            y = self.fc_angular(x)
            return y
        else:
            x = self.fc(x)
            return x

In [9]:
df = pd.read_csv(config['PATH_DATA']+'/train.csv')
df_test = pd.read_csv(config['PATH_DATA']+'/test.csv')

df['category'] = df['experiment'].apply(lambda x: x.split('-')[0])
df_test['category'] = df_test['experiment'].apply(lambda x: x.split('-')[0])

## Training

In [10]:
categories = df['category'].unique()
# 23.08.19 (ES by LOSS + RAdam)
for category in categories:
    category_df = df[df['category'] == category]
    cat_test_df = df_test[df_test['category'] == category].copy()
    category_df_train, category_df_val = train_test_split(category_df,
                                                          test_size=config['VAL_SIZE'],
                                                          # stratify=category_df.sirna,
                                                          random_state=config['SEED'])

    print('\n' + '=' * 40)
    print("CURRENT CATEGORY:", category)
    print('-' * 40)

    # LOAD MODEL
    model = DenseNet(num_classes=config['CLASSES'], use_bn=config['USE_BN'], use_angular=config['USE_ANGULAR'])
    checkpoint = torch.load('{}/all_exps_{}_lr{}_{}.pth'.format(config['MODEL_NAME'], config['MODEL_NAME'],
                                                                config['LR_STR'], best_epoch))
    model.load_state_dict(checkpoint)
    model.to(config['DEVICE'])
    model.train()

    criterion = AMSoftmaxLoss(margin_type='cos')
    optimizer = RAdam(model.parameters(), lr=config['LR'])
    metrics = {'loss': Loss(criterion), 'accuracy': Accuracy()}

    trainer = create_supervised_trainer(model, optimizer, criterion, device=config['DEVICE'])
    val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=config['DEVICE'])

    # HELPERS
    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_and_display_val_metrics(engine):
        epoch = engine.state.epoch
        metrics = val_evaluator.run(val_loader).metrics
        print("Validation Results - Epoch: {}  Average Loss: {:.4f} | Accuracy: {:.4f} "
              .format(engine.state.epoch, 
                          metrics['loss'], 
                          metrics['accuracy']))

    lr_scheduler = ExponentialLR(optimizer, gamma=0.95)
    @trainer.on(Events.EPOCH_COMPLETED)
    def update_lr_scheduler(engine):
        lr_scheduler.step()
        lr = float(optimizer.param_groups[0]['lr'])
        print("Learning rate: {}".format(lr))

    handler = EarlyStopping(patience=4, score_function=lambda engine: - engine.state.metrics['loss'], trainer=trainer)
    val_evaluator.add_event_handler(Events.COMPLETED, handler)

    checkpoints = ModelCheckpoint(config['MODEL_NAME'], category,
                                  save_interval=1, n_saved=10, create_dir=True, require_empty=False)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoints,
                              {config['MODEL_NAME']+'_lr{}'.format(config['LR_STR']): model})

    pbar = ProgressBar(bar_format='')
    pbar.attach(trainer, output_transform=lambda x: {'loss': x})

    if not 'KAGGLE_WORKING_DIR' in os.environ:  #  If we are not on kaggle server
        from ignite.contrib.handlers.tensorboard_logger import *
        tb_logger = TensorboardLogger("board/"+config['MODEL_NAME']+'_'+category)

        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", output_transform=lambda loss: {'loss': loss}),
                         event_name=Events.ITERATION_COMPLETED)

        tb_logger.attach(val_evaluator, log_handler=OutputHandler(tag="validation", metric_names=["accuracy", "loss"],
                         another_engine=trainer),event_name=Events.EPOCH_COMPLETED)
        tb_logger.close()

    # DATA
    ds_1 = ImagesDS(category_df_train, config['PATH_DATA'], site=1, mode='train')
    ds_2 = ImagesDS(category_df_train, config['PATH_DATA'], site=2, mode='train')
    ds = D.ConcatDataset([ds_1, ds_2])

    ds_val_1 = ImagesDS(category_df_val, config['PATH_DATA'], site=1, mode='train')
    ds_val_2 = ImagesDS(category_df_val, config['PATH_DATA'], site=2, mode='train')
    ds_val = D.ConcatDataset([ds_val_1, ds_val_2])

    #ds_test_1 = ImagesDS(df_test, path_data, site=1, mode='test')
    #ds_test_2 = ImagesDS(df_test, path_data, site=2, mode='test')

    train_loader = D.DataLoader(ds, batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=4)
    val_loader = D.DataLoader(ds_val, batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=4)

    #test_loader_1 = D.DataLoader(ds_test_1, batch_size=1, shuffle=False, num_workers=4)
    #test_loader_2 = D.DataLoader(ds_test_2, batch_size=1, shuffle=False, num_workers=4)

    # TRAINING
    trainer.run(train_loader, max_epochs=15)

    del model, trainer, ds, ds_val
    gc.collect()
    torch.cuda.empty_cache()


CURRENT CATEGORY: HEPG2
----------------------------------------


HBox(children=(IntProgress(value=0, max=1841), HTML(value='')))

Validation Results - Epoch: 1  Average Loss: 13.6390 | Accuracy: 0.8389 
Learning rate: 9.5e-05


HBox(children=(IntProgress(value=0, max=1841), HTML(value='')))

Validation Results - Epoch: 2  Average Loss: 13.7600 | Accuracy: 0.8041 
Learning rate: 9.025e-05


HBox(children=(IntProgress(value=0, max=1841), HTML(value='')))

Validation Results - Epoch: 3  Average Loss: 13.3609 | Accuracy: 0.8106 
Learning rate: 8.573749999999999e-05


HBox(children=(IntProgress(value=0, max=1841), HTML(value='')))

Validation Results - Epoch: 4  Average Loss: 13.5271 | Accuracy: 0.7668 
Learning rate: 8.145062499999998e-05


HBox(children=(IntProgress(value=0, max=1841), HTML(value='')))

Validation Results - Epoch: 5  Average Loss: 13.4008 | Accuracy: 0.7912 
Learning rate: 7.737809374999998e-05


HBox(children=(IntProgress(value=0, max=1841), HTML(value='')))

Validation Results - Epoch: 6  Average Loss: 13.2512 | Accuracy: 0.7822 
Learning rate: 7.350918906249998e-05


HBox(children=(IntProgress(value=0, max=1841), HTML(value='')))

Validation Results - Epoch: 7  Average Loss: 13.6991 | Accuracy: 0.7358 
Learning rate: 6.983372960937497e-05


HBox(children=(IntProgress(value=0, max=1841), HTML(value='')))

Validation Results - Epoch: 8  Average Loss: 13.4104 | Accuracy: 0.7732 
Learning rate: 6.634204312890622e-05


HBox(children=(IntProgress(value=0, max=1841), HTML(value='')))

Validation Results - Epoch: 9  Average Loss: 13.6742 | Accuracy: 0.7461 
Learning rate: 6.30249409724609e-05


HBox(children=(IntProgress(value=0, max=1841), HTML(value='')))

Validation Results - Epoch: 10  Average Loss: 13.4362 | Accuracy: 0.7680 
Learning rate: 5.987369392383786e-05

CURRENT CATEGORY: HUVEC
----------------------------------------


HBox(children=(IntProgress(value=0, max=4201), HTML(value='')))

Validation Results - Epoch: 1  Average Loss: 9.6993 | Accuracy: 0.9345 
Learning rate: 9.5e-05


HBox(children=(IntProgress(value=0, max=4201), HTML(value='')))

Validation Results - Epoch: 2  Average Loss: 9.1346 | Accuracy: 0.9328 
Learning rate: 9.025e-05


HBox(children=(IntProgress(value=0, max=4201), HTML(value='')))

Validation Results - Epoch: 3  Average Loss: 9.1990 | Accuracy: 0.9260 
Learning rate: 8.573749999999999e-05


HBox(children=(IntProgress(value=0, max=4201), HTML(value='')))

Validation Results - Epoch: 4  Average Loss: 9.0000 | Accuracy: 0.9288 
Learning rate: 8.145062499999998e-05


HBox(children=(IntProgress(value=0, max=4201), HTML(value='')))

Validation Results - Epoch: 5  Average Loss: 8.8905 | Accuracy: 0.9209 
Learning rate: 7.737809374999998e-05


HBox(children=(IntProgress(value=0, max=4201), HTML(value='')))

Validation Results - Epoch: 6  Average Loss: 8.8218 | Accuracy: 0.9266 
Learning rate: 7.350918906249998e-05


HBox(children=(IntProgress(value=0, max=4201), HTML(value='')))

Validation Results - Epoch: 7  Average Loss: 9.2237 | Accuracy: 0.9079 
Learning rate: 6.983372960937497e-05


HBox(children=(IntProgress(value=0, max=4201), HTML(value='')))

Validation Results - Epoch: 8  Average Loss: 9.1535 | Accuracy: 0.9023 
Learning rate: 6.634204312890622e-05


HBox(children=(IntProgress(value=0, max=4201), HTML(value='')))

Validation Results - Epoch: 9  Average Loss: 8.9864 | Accuracy: 0.9073 
Learning rate: 6.30249409724609e-05


HBox(children=(IntProgress(value=0, max=4201), HTML(value='')))

Validation Results - Epoch: 10  Average Loss: 9.1156 | Accuracy: 0.9056 
Learning rate: 5.987369392383786e-05

CURRENT CATEGORY: RPE
----------------------------------------


HBox(children=(IntProgress(value=0, max=1842), HTML(value='')))

Validation Results - Epoch: 1  Average Loss: 11.9584 | Accuracy: 0.9291 
Learning rate: 9.5e-05


HBox(children=(IntProgress(value=0, max=1842), HTML(value='')))

Validation Results - Epoch: 2  Average Loss: 12.1056 | Accuracy: 0.9008 
Learning rate: 9.025e-05


HBox(children=(IntProgress(value=0, max=1842), HTML(value='')))

Validation Results - Epoch: 3  Average Loss: 11.8402 | Accuracy: 0.8982 
Learning rate: 8.573749999999999e-05


HBox(children=(IntProgress(value=0, max=1842), HTML(value='')))

Validation Results - Epoch: 4  Average Loss: 11.8767 | Accuracy: 0.8763 
Learning rate: 8.145062499999998e-05


HBox(children=(IntProgress(value=0, max=1842), HTML(value='')))

Validation Results - Epoch: 5  Average Loss: 11.8704 | Accuracy: 0.8582 
Learning rate: 7.737809374999998e-05


HBox(children=(IntProgress(value=0, max=1842), HTML(value='')))

Validation Results - Epoch: 6  Average Loss: 11.7897 | Accuracy: 0.8673 
Learning rate: 7.350918906249998e-05


HBox(children=(IntProgress(value=0, max=1842), HTML(value='')))

Validation Results - Epoch: 7  Average Loss: 12.0652 | Accuracy: 0.8531 
Learning rate: 6.983372960937497e-05


HBox(children=(IntProgress(value=0, max=1842), HTML(value='')))

Validation Results - Epoch: 8  Average Loss: 12.1486 | Accuracy: 0.8325 
Learning rate: 6.634204312890622e-05


HBox(children=(IntProgress(value=0, max=1842), HTML(value='')))

Validation Results - Epoch: 9  Average Loss: 12.3095 | Accuracy: 0.8312 
Learning rate: 6.30249409724609e-05


HBox(children=(IntProgress(value=0, max=1842), HTML(value='')))

Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/opt/anaconda3/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    se

KeyboardInterrupt: 

In [10]:
categories = df['category'].unique()
# 23.08.19 (ES by LOSS + RAdam) (U2OS)
for category in categories:
    if category == 'U2OS':
        category_df = df[df['category'] == category]
        cat_test_df = df_test[df_test['category'] == category].copy()
        category_df_train, category_df_val = train_test_split(category_df,
                                                              test_size=config['VAL_SIZE'],
                                                              # stratify=category_df.sirna,
                                                              random_state=config['SEED'])

        print('\n' + '=' * 40)
        print("CURRENT CATEGORY:", category)
        print('-' * 40)

        # LOAD MODEL
        model = DenseNet(num_classes=config['CLASSES'], use_bn=config['USE_BN'], use_angular=config['USE_ANGULAR'])
        checkpoint = torch.load('{}/all_exps_{}_lr{}_{}.pth'.format(config['MODEL_NAME'], config['MODEL_NAME'],
                                                                    config['LR_STR'], best_epoch))
        model.load_state_dict(checkpoint)
        model.to(config['DEVICE'])
        model.train()

        criterion = AMSoftmaxLoss(margin_type='cos')
        optimizer = RAdam(model.parameters(), lr=config['LR'])
        metrics = {'loss': Loss(criterion), 'accuracy': Accuracy()}

        trainer = create_supervised_trainer(model, optimizer, criterion, device=config['DEVICE'])
        val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=config['DEVICE'])

        # HELPERS
        @trainer.on(Events.EPOCH_COMPLETED)
        def compute_and_display_val_metrics(engine):
            epoch = engine.state.epoch
            metrics = val_evaluator.run(val_loader).metrics
            print("Validation Results - Epoch: {}  Average Loss: {:.4f} | Accuracy: {:.4f} "
                  .format(engine.state.epoch, 
                              metrics['loss'], 
                              metrics['accuracy']))

        lr_scheduler = ExponentialLR(optimizer, gamma=0.95)
        @trainer.on(Events.EPOCH_COMPLETED)
        def update_lr_scheduler(engine):
            lr_scheduler.step()
            lr = float(optimizer.param_groups[0]['lr'])
            print("Learning rate: {}".format(lr))

        handler = EarlyStopping(patience=4, score_function=lambda engine: - engine.state.metrics['loss'], trainer=trainer)
        val_evaluator.add_event_handler(Events.COMPLETED, handler)

        checkpoints = ModelCheckpoint(config['MODEL_NAME'], category,
                                      save_interval=1, n_saved=10, create_dir=True, require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoints,
                                  {config['MODEL_NAME']+'_lr{}'.format(config['LR_STR']): model})

        pbar = ProgressBar(bar_format='')
        pbar.attach(trainer, output_transform=lambda x: {'loss': x})

        if not 'KAGGLE_WORKING_DIR' in os.environ:  #  If we are not on kaggle server
            from ignite.contrib.handlers.tensorboard_logger import *
            tb_logger = TensorboardLogger("board/"+config['MODEL_NAME']+'_'+category)

            tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", output_transform=lambda loss: {'loss': loss}),
                             event_name=Events.ITERATION_COMPLETED)

            tb_logger.attach(val_evaluator, log_handler=OutputHandler(tag="validation", metric_names=["accuracy", "loss"],
                             another_engine=trainer),event_name=Events.EPOCH_COMPLETED)
            tb_logger.close()

        # DATA
        ds_1 = ImagesDS(category_df_train, config['PATH_DATA'], site=1, mode='train')
        ds_2 = ImagesDS(category_df_train, config['PATH_DATA'], site=2, mode='train')
        ds = D.ConcatDataset([ds_1, ds_2])

        ds_val_1 = ImagesDS(category_df_val, config['PATH_DATA'], site=1, mode='train')
        ds_val_2 = ImagesDS(category_df_val, config['PATH_DATA'], site=2, mode='train')
        ds_val = D.ConcatDataset([ds_val_1, ds_val_2])

        #ds_test_1 = ImagesDS(df_test, path_data, site=1, mode='test')
        #ds_test_2 = ImagesDS(df_test, path_data, site=2, mode='test')

        train_loader = D.DataLoader(ds, batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=4)
        val_loader = D.DataLoader(ds_val, batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=4)

        #test_loader_1 = D.DataLoader(ds_test_1, batch_size=1, shuffle=False, num_workers=4)
        #test_loader_2 = D.DataLoader(ds_test_2, batch_size=1, shuffle=False, num_workers=4)

        # TRAINING
        trainer.run(train_loader, max_epochs=15)

        del model, trainer, ds, ds_val
        gc.collect()
        torch.cuda.empty_cache()


CURRENT CATEGORY: U2OS
----------------------------------------


HBox(children=(IntProgress(value=0, max=790), HTML(value='')))

Validation Results - Epoch: 1  Average Loss: 14.0493 | Accuracy: 0.8144 
Learning rate: 9.5e-05


HBox(children=(IntProgress(value=0, max=790), HTML(value='')))

Validation Results - Epoch: 2  Average Loss: 14.1675 | Accuracy: 0.7934 
Learning rate: 9.025e-05


HBox(children=(IntProgress(value=0, max=790), HTML(value='')))

Validation Results - Epoch: 3  Average Loss: 14.1393 | Accuracy: 0.7695 
Learning rate: 8.573749999999999e-05


HBox(children=(IntProgress(value=0, max=790), HTML(value='')))

Validation Results - Epoch: 4  Average Loss: 14.3611 | Accuracy: 0.7814 
Learning rate: 8.145062499999998e-05


HBox(children=(IntProgress(value=0, max=790), HTML(value='')))

Validation Results - Epoch: 5  Average Loss: 14.3295 | Accuracy: 0.7754 
Learning rate: 7.737809374999998e-05


## Prediction for test

In [11]:
categories = df['category'].unique()
output_df = []
output_predicted = []

for category in categories:
    cat_test_df = df_test[df_test['category'] == category].copy()

    print('\n' + '=' * 40)
    print("CURRENT CATEGORY:", category)
    print('-' * 40)

    # LOAD MODEL
    model = DenseNet(num_classes=config['CLASSES'], use_bn=config['USE_BN'], use_angular=config['USE_ANGULAR'])
    if category == 'HEPG2':
        best_epoch_on_category = 6
    if category == 'HUVEC':
        best_epoch_on_category = 6
    if category == 'RPE':
        best_epoch_on_category = 6
    if category == 'U2OS':
        best_epoch_on_category = 1
        
    checkpoint = torch.load('{0}/{2}_{0}_lr{1}_{3}.pth'.format(config['MODEL_NAME'], config['LR_STR'], category, best_epoch_on_category))
    model.load_state_dict(checkpoint)
    model.to(config['DEVICE'])
    model.eval();

    # DATA
    ds_test_1 = ImagesDS(cat_test_df, config['PATH_DATA'], site=1, mode='test')
    ds_test_2 = ImagesDS(cat_test_df, config['PATH_DATA'], site=2, mode='test')
    test_loader_1 = D.DataLoader(ds_test_1, batch_size=1, shuffle=False, num_workers=4)
    test_loader_2 = D.DataLoader(ds_test_2, batch_size=1, shuffle=False, num_workers=4)

    # PREDICTION
    with torch.no_grad():
        predicted = []  # predicted = np.empty(0)
        for (x1, id1), (x2, id2) in tqdm_notebook(zip(test_loader_1, test_loader_2)):
            assert id1 == id2

            x1 = x1.to(config['DEVICE'])
            output1 = model(x1)

            x2 = x2.to(config['DEVICE'])
            output2 = model(x2)

            result = 0.5*(output1 + output2)
            predicted.append(result.cpu().numpy())
            
    predicted = np.stack(predicted).squeeze()
    cat_test_df['sirna'] = np.argmax(predicted, axis=1).astype(int)
    output_df.append(cat_test_df[['id_code', 'experiment', 'sirna']])
    output_predicted.append(predicted)

    del model, ds_test_1, ds_test_2
    gc.collect()
    torch.cuda.empty_cache()


CURRENT CATEGORY: HEPG2
----------------------------------------


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))



CURRENT CATEGORY: HUVEC
----------------------------------------


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))



CURRENT CATEGORY: RPE
----------------------------------------


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))



CURRENT CATEGORY: U2OS
----------------------------------------


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [12]:
output_df = pd.concat(output_df)
submission = output_df[['id_code', 'sirna']]
submission.to_csv('submits/{}_each_exps_best_epoch_lr{}.csv'.format(config['MODEL_NAME'], config['LR_STR']),
                  index=False, columns=['id_code','sirna'])

In [13]:
submission.head()

Unnamed: 0,id_code,sirna
0,HEPG2-08_1_B03,855
1,HEPG2-08_1_B04,979
2,HEPG2-08_1_B05,836
3,HEPG2-08_1_B06,584
4,HEPG2-08_1_B07,878


## Use plates leak

In [14]:
output_predicted = np.concatenate(output_predicted)
np.save('predictions/{}_each_exps_best_epoch_by_loss_lr{}'.format(config['MODEL_NAME'], config['LR_STR']),
        output_predicted)

In [15]:
plate_groups = np.zeros((1108,4), int)
for sirna in range(1108):
    grp = df.loc[df.sirna==sirna,:].plate.value_counts().index.values
    assert len(grp) == 3
    plate_groups[sirna,0:3] = grp
    plate_groups[sirna,3] = 10 - grp.sum()
    
print(plate_groups[:10,:])

[[4 2 3 1]
 [1 3 4 2]
 [2 4 1 3]
 [1 3 4 2]
 [3 1 2 4]
 [1 3 4 2]
 [1 3 4 2]
 [2 4 1 3]
 [1 3 4 2]
 [4 2 3 1]]


In [16]:
all_test_exp = df_test.experiment.unique()

group_plate_probs = np.zeros((len(all_test_exp),4))
for idx in range(len(all_test_exp)):
    preds = submission.loc[df_test.experiment == all_test_exp[idx],'sirna'].values
    pp_mult = np.zeros((len(preds),1108))
    pp_mult[range(len(preds)),preds] = 1
    
    sub_test = df_test.loc[df_test.experiment == all_test_exp[idx],:]
    assert len(pp_mult) == len(sub_test)
    
    for j in range(4):
        mask = np.repeat(plate_groups[np.newaxis, :, j], len(pp_mult), axis=0) == \
               np.repeat(sub_test.plate.values[:, np.newaxis], 1108, axis=1)
        
        group_plate_probs[idx,j] = np.array(pp_mult)[mask].sum()/len(pp_mult)

In [17]:
exp_to_group = group_plate_probs.argmax(1)
print(exp_to_group)

[3 1 0 0 0 0 2 2 3 0 0 3 1 0 0 0 2 3]


In [18]:
# this is the function that sets 75% of the sirnas to zero according to the selected assignment

def select_plate_group(pp_mult, idx):
    sub_test = df_test.loc[df_test.experiment == all_test_exp[idx],:]
    assert len(pp_mult) == len(sub_test)
    mask = np.repeat(plate_groups[np.newaxis, :, exp_to_group[idx]], len(pp_mult), axis=0) != \
           np.repeat(sub_test.plate.values[:, np.newaxis], 1108, axis=1)
    pp_mult[mask] = 0
    return pp_mult

In [19]:
output_df.head()

Unnamed: 0,id_code,experiment,sirna
0,HEPG2-08_1_B03,HEPG2-08,855
1,HEPG2-08_1_B04,HEPG2-08,979
2,HEPG2-08_1_B05,HEPG2-08,836
3,HEPG2-08_1_B06,HEPG2-08,584
4,HEPG2-08_1_B07,HEPG2-08,878


In [20]:
sub = submission.copy()

for idx in range(len(all_test_exp)):
    indices = (output_df.experiment == all_test_exp[idx])
    preds = output_predicted[indices, :].copy()
    
    preds = select_plate_group(preds, idx)
    sub.loc[indices,'sirna'] = preds.argmax(1)

In [21]:
sub.to_csv('submits/{}_each_exps_best_epoch_lr{}_plates_leak.csv'.format(config['MODEL_NAME'], config['LR_STR']),
           index=False, columns=['id_code','sirna'])

In [22]:
print((sub.sirna == submission.sirna).mean())

0.7248831482132985


In [23]:
len(submission.sirna.unique()), len(sub.sirna.unique())

(1105, 1107)

In [24]:
sub.head()

Unnamed: 0,id_code,sirna
0,HEPG2-08_1_B03,855
1,HEPG2-08_1_B04,979
2,HEPG2-08_1_B05,836
3,HEPG2-08_1_B06,609
4,HEPG2-08_1_B07,878
