In [1]:
from se_resnet import se_resnet152
from torchvision.models.resnet import resnet50
import torch
import torch.nn as nn
from training.training import Trainer
import os.path as osp
import cv2
import numpy as np

In [2]:
import torch.nn.functional as F

In [3]:
from albumentations import (
    VerticalFlip,
    HorizontalFlip,
    Compose,
    RandomRotate90,
    ElasticTransform,
    GridDistortion,
    OpticalDistortion,
    OneOf,
    CLAHE,
    RandomContrast,
    RandomGamma,
    RandomBrightness,
    Resize)
torch.manual_seed(42)
np.random.seed(42)

2018-12-05 10:52:08,340 [MainThread  ] [DEBUG]  $HOME=/root
2018-12-05 10:52:08,341 [MainThread  ] [DEBUG]  matplotlib data path /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data
2018-12-05 10:52:08,347 [MainThread  ] [DEBUG]  loaded rc file /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/matplotlibrc
2018-12-05 10:52:08,349 [MainThread  ] [DEBUG]  matplotlib version 2.2.2
2018-12-05 10:52:08,350 [MainThread  ] [DEBUG]  interactive is False
2018-12-05 10:52:08,351 [MainThread  ] [DEBUG]  platform is linux


In [4]:
# from fastai.conv_learner import *
# from fastai.dataset import *

import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import scipy.optimize as opt

In [5]:
name_label_dict = {
0:  'Nucleoplasm',
1:  'Nuclear membrane',
2:  'Nucleoli',   
3:  'Nucleoli fibrillar center',
4:  'Nuclear speckles',
5:  'Nuclear bodies',
6:  'Endoplasmic reticulum',   
7:  'Golgi apparatus',
8:  'Peroxisomes',
9:  'Endosomes',
10:  'Lysosomes',
11:  'Intermediate filaments',
12:  'Actin filaments',
13:  'Focal adhesion sites',   
14:  'Microtubules',
15:  'Microtubule ends',  
16:  'Cytokinetic bridge',   
17:  'Mitotic spindle',
18:  'Microtubule organizing center',  
19:  'Centrosome',
20:  'Lipid droplets',
21:  'Plasma membrane',   
22:  'Cell junctions', 
23:  'Mitochondria',
24:  'Aggresome',
25:  'Cytosol',
26:  'Cytoplasmic bodies',   
27:  'Rods & rings' }

In [6]:
PATH = './'
TRAIN = '/root/data/protein/train/'
TEST = '/root/data/protein/test/'
LABELS = '/root/data/protein/train.csv'
SAMPLE = '/root/data/protein/sample_submission.csv'


In [7]:
train_names = list({f[:36] for f in os.listdir(TRAIN)})
test_names = list({f[:36] for f in os.listdir(TEST)})
tr_n, val_n = train_test_split(train_names, test_size=0.1, random_state=42)


In [8]:
def open_rgby(path,id): #a function that reads RGBY image
    colors = ['red','green','blue','yellow']
    flags = cv2.IMREAD_GRAYSCALE
    img = [cv2.imread(os.path.join(path, id+'_'+color+'.png'), flags).astype(np.float32)/255
           for color in colors]
    return np.stack(img, axis=-1)

In [9]:
TARGET_SIZE=512

In [10]:
aug = Compose([
    HorizontalFlip(p=0.7),
    RandomGamma(p=0.7),
    #GridDistortion(p=0.6),
    #OpticalDistortion(p=0.6),
    #ElasticTransform(p=0.6),
    Resize(height=TARGET_SIZE, width=TARGET_SIZE)
])

val_aug=Resize(height=TARGET_SIZE, width=TARGET_SIZE)

In [11]:
class ProteinDataset:
    def __init__(self, names, path,aug=aug):
        self.names=names
        self.aug=aug
        self.path=path
        self.labels = pd.read_csv(LABELS).set_index('Id')
        self.labels['Target'] = [[int(i) for i in s.split()] for s in self.labels['Target']]
        
    def __len__(self):
        return len(self.names)
    
    def __getitem__(self, idx):
        
        if(self.path == TEST): label= np.zeros(len(name_label_dict),dtype=np.int)
        else:
            labels = self.labels.loc[self.names[idx]]['Target']
            label=np.eye(len(name_label_dict),dtype=np.float)[labels].sum(axis=0)
        
        img = open_rgby(self.path, self.names[idx])
        img = aug(image=img)['image']
        
        return torch.from_numpy(
            img
        ).permute([2,0,1]), torch.from_numpy(label).float()

In [12]:
train_names, val_names = train_test_split(train_names)

In [13]:

def get_resnet152():    
    model = resnet152(pretrained=True)
    w = model.conv1.weight
    model.conv1 = nn.Conv2d(4,64,kernel_size=(7,7),stride=(2,2),padding=(3, 3), bias=False)
    model.conv1.weight = torch.nn.Parameter(torch.cat((w,torch.mean(w,dim=1).unsqueeze(1)),dim=1))

    model.avgpool = nn.Sequential(
        nn.MaxPool2d(kernel_size=6, stride=2,padding=0),
        nn.AvgPool2d(kernel_size=5, stride=2,padding=0)
    )
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 28))

    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load('resnet152_best.pth.tar'))
    return model

def get_se_resnet152():
    model = se_resnet152(num_classes=1000)
    w = model.conv1.weight
    model.conv1 = nn.Conv2d(4,64,kernel_size=(7,7),stride=(2,2),padding=(3, 3), bias=False)
    model.conv1.weight = torch.nn.Parameter(torch.cat((w,torch.mean(w,dim=1).unsqueeze(1)),dim=1))

    model.avgpool = nn.Sequential(
        nn.MaxPool2d(kernel_size=6, stride=2,padding=0),
        nn.AvgPool2d(kernel_size=5, stride=2,padding=0)
    )
    model.fc = nn.Linear(model.fc.in_features, 28)
    
    model = nn.DataParallel(model)
    #model.load_state_dict(torch.load('se_resnet152_best.pth.tar'))

    return model

def get_model(name):
    if name == 'resnet152':
        return get_resnet152()
    elif name == 'se_resnet152':
        return get_se_resnet152()
    else: raise Exception('not supported model')


MODEL_NAME='se_resnet152'
BATCH_SIZE=10
DEVICE=0
EPOCHS=100

model = get_model(MODEL_NAME)

### train stage

In [18]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
        
    def forward(self, input, target):
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(target.size(), input.size()))

        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
            ((-max_val).exp() + (-input - max_val).exp()).log()

        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        
        return loss.sum(dim=1).mean()


In [19]:
THRESHOLD=0.0

In [20]:
loss = FocalLoss()

In [21]:
def mymetric(pred, target):
    preds = (pred > THRESHOLD).int()
    targs = target.int()
    return (preds==targs).float().mean()

def myloss(pred, target):
    return loss(pred, target)



In [20]:
train_ds = ProteinDataset(train_names, TRAIN)
val_ds = ProteinDataset(val_names, TRAIN, val_aug)

In [21]:

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE,shuffle=True)
val_loader = torch.utils.data.DataLoader(val_ds,batch_size=BATCH_SIZE)


In [23]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
trainer = Trainer(myloss, mymetric, optimizer, MODEL_NAME, None, DEVICE)


2018-11-12 06:16:33,809 [Visdom-Socke] [INFO ]  Visdom successfully connected to server


In [24]:
trainer.output_watcher = None

In [25]:
ct = 0
for child in model.children():
    ct += 1
    if ct < 7:
        for param in child.parameters():
            param.requires_grad = False

In [56]:
model.conv1.weight.shape

torch.Size([64, 4, 7, 7])

In [63]:
torch.mean(model.conv1.weight,dim=1).shape

torch.Size([64, 7, 7])

In [64]:
torch.mean(model.conv1.weight,dim=1).unsqueeze(1).shape

torch.Size([64, 1, 7, 7])

In [70]:
w = model.conv1.weight

In [71]:
w.shape

torch.Size([64, 4, 7, 7])

Parameter containing:
tensor([[[[ 3.8840e-02,  1.4477e-01,  8.1565e-02,  ..., -8.3733e-02,
           -5.0097e-02, -9.8249e-02],
          [-5.0880e-02,  1.8462e-02, -1.0346e-02,  ..., -3.6859e-02,
           -1.1928e-01, -7.1972e-02],
          [-1.7249e-01,  6.2043e-02, -3.6353e-02,  ..., -5.0005e-02,
           -8.3304e-02,  3.0776e-02],
          ...,
          [-1.7829e-02, -5.5378e-03,  5.7718e-02,  ..., -7.3080e-03,
           -4.5801e-02, -7.6981e-02],
          [ 7.4269e-02,  2.7448e-02,  2.6845e-02,  ...,  2.1313e-02,
           -7.8128e-02,  3.0925e-02],
          [ 1.0649e-02,  1.5577e-02, -1.8829e-02,  ...,  5.4947e-02,
            3.3654e-02,  9.3044e-02]],

         [[-8.3563e-02,  9.4394e-02,  2.3049e-02,  ..., -1.1455e-01,
           -6.8636e-02, -1.2724e-01],
          [-8.5396e-02,  2.5232e-02, -9.5578e-02,  ...,  1.2756e-01,
           -4.5774e-02,  3.9743e-02],
          [-4.9588e-02,  1.0899e-01, -1.4501e-01,  ...,  1.6635e-01,
            1.3625e-01,  2.0005e-01]

In [26]:
model.to(DEVICE)

ResNet(
  (conv1): Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=F

In [27]:

for i in range(1):
    trainer.train(train_loader, model, i)
    trainer.validate(val_loader, model)

train loss:1.3131424188613892, train metric: tensor(0.9375, device='cuda:0'): : 236it [11:56,  3.04s/it]

KeyboardInterrupt: 

In [None]:

for child in model.children():
    for param in child.parameters():
        param.requires_grad = True

In [None]:

for i in range(1):
    trainer.train(train_loader, model, i)
    trainer.validate(val_loader, model)

### inference stage

In [14]:
from tqdm import *

In [15]:
model.to(DEVICE)

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): SEBottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (se): SELayer(
          (avg_pool): Ada

In [16]:
subm = pd.read_csv(SAMPLE)
test_ds = ProteinDataset(subm.Id.values, TEST, val_aug)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=10)

In [23]:
result,target_class = [],[]
for batch_idx, (input, target) in tqdm(enumerate(val_loader)):
    target_class.append(target)
    input = input.to(DEVICE)
    result.append(model(input).detach().cpu())
    


0it [00:00, ?it/s][A
1it [00:00,  1.01it/s][A
2it [00:01,  1.02it/s][A
3it [00:02,  1.00it/s][A
4it [00:03,  1.05it/s][A
5it [00:04,  1.00it/s][A
6it [00:05,  1.01s/it][A
7it [00:06,  1.01s/it][A
8it [00:07,  1.06it/s][A
9it [00:08,  1.07it/s][A
10it [00:09,  1.07it/s][A
11it [00:10,  1.01it/s][A
12it [00:12,  1.09s/it][A
13it [00:13,  1.04s/it][A
14it [00:13,  1.01it/s][A
15it [00:14,  1.01s/it][A
16it [00:15,  1.07it/s][A
17it [00:16,  1.02it/s][A
18it [00:17,  1.04it/s][A
19it [00:18,  1.12it/s][A
20it [00:19,  1.13it/s][A
21it [00:20,  1.13it/s][A
22it [00:21,  1.15it/s][A
23it [00:21,  1.12it/s][A
24it [00:23,  1.04it/s][A
25it [00:23,  1.06it/s][A
26it [00:24,  1.06it/s][A
27it [00:25,  1.06it/s][A
28it [00:26,  1.02it/s][A
29it [00:28,  1.03s/it][A
30it [00:28,  1.01it/s][A
31it [00:29,  1.07it/s][A
32it [00:30,  1.04it/s][A
33it [00:31,  1.02it/s][A
34it [00:32,  1.04it/s][A
35it [00:33,  1.05it/s][A
36it [00:34,  1.07it/s][A
37it [00:35,  

296it [04:38,  1.16it/s][A
297it [04:39,  1.14it/s][A
298it [04:40,  1.10it/s][A
299it [04:41,  1.10it/s][A
300it [04:42,  1.06it/s][A
301it [04:43,  1.06it/s][A
302it [04:44,  1.10it/s][A
303it [04:45,  1.06it/s][A
304it [04:46,  1.03it/s][A
305it [04:47,  1.06it/s][A
306it [04:48,  1.02it/s][A
307it [04:49,  1.03it/s][A
308it [04:50,  1.06it/s][A
309it [04:50,  1.08it/s][A
310it [04:51,  1.09it/s][A
311it [04:52,  1.00it/s][A
312it [04:54,  1.04s/it][A
313it [04:55,  1.01s/it][A
314it [04:56,  1.02s/it][A
315it [04:56,  1.06it/s][A
316it [04:57,  1.03it/s][A
317it [04:58,  1.04it/s][A
318it [04:59,  1.11it/s][A
319it [05:00,  1.13it/s][A
320it [05:01,  1.15it/s][A
321it [05:02,  1.18it/s][A
322it [05:02,  1.19it/s][A
323it [05:03,  1.10it/s][A
324it [05:05,  1.03it/s][A
325it [05:06,  1.03it/s][A
326it [05:07,  1.01s/it][A
327it [05:07,  1.10it/s][A
328it [05:08,  1.11it/s][A
329it [05:09,  1.07it/s][A
330it [05:10,  1.07it/s][A
331it [05:11,  1.07i

588it [09:07,  1.05it/s][A
589it [09:08,  1.03it/s][A
590it [09:08,  1.11it/s][A
591it [09:10,  1.04it/s][A
592it [09:10,  1.11it/s][A
593it [09:11,  1.06it/s][A
594it [09:12,  1.09it/s][A
595it [09:13,  1.14it/s][A
596it [09:14,  1.14it/s][A
597it [09:15,  1.05it/s][A
598it [09:16,  1.05it/s][A
599it [09:17,  1.11it/s][A
600it [09:18,  1.11it/s][A
601it [09:19,  1.11it/s][A
602it [09:19,  1.09it/s][A
603it [09:20,  1.10it/s][A
604it [09:21,  1.13it/s][A
605it [09:22,  1.16it/s][A
606it [09:23,  1.17it/s][A
607it [09:24,  1.23it/s][A
608it [09:24,  1.19it/s][A
609it [09:25,  1.18it/s][A
610it [09:26,  1.07it/s][A
611it [09:27,  1.07it/s][A
612it [09:28,  1.13it/s][A
613it [09:29,  1.13it/s][A
614it [09:30,  1.18it/s][A
615it [09:31,  1.18it/s][A
616it [09:32,  1.13it/s][A
617it [09:33,  1.12it/s][A
618it [09:33,  1.18it/s][A
619it [09:34,  1.13it/s][A
620it [09:35,  1.27it/s][A
621it [09:36,  1.26it/s][A
622it [09:37,  1.22it/s][A
623it [09:38,  1.08i

In [26]:
preds, ans = [],[]
for n,r in enumerate(result):
    for nu,t in enumerate(r):
        preds.append(torch.sigmoid(t).numpy())
        ans.append(target_class[n][nu].numpy())

In [28]:
from sklearn.metrics import f1_score as off1
rng = np.arange(0, 1, 0.01)
f1s = np.zeros((rng.shape[0], 28))
for j,t in enumerate(tqdm(rng)):
    for i in range(28):
        p = np.array(np.asarray(preds)[:,i]>t, dtype=np.int8)
        scoref1 = off1(np.asarray(ans)[:,i], p, average='binary')
        f1s[j,i] = scoref1



  'recall', 'true', average, warn_for)


  1%|          | 1/100 [00:00<00:28,  3.48it/s][A[A

  2%|▏         | 2/100 [00:00<00:28,  3.45it/s][A[A

  3%|▎         | 3/100 [00:00<00:27,  3.56it/s][A[A

  4%|▍         | 4/100 [00:01<00:25,  3.78it/s][A[A

  5%|▌         | 5/100 [00:01<00:23,  4.11it/s][A[A

  6%|▌         | 6/100 [00:01<00:21,  4.40it/s][A[A

  7%|▋         | 7/100 [00:01<00:20,  4.64it/s][A[A

  8%|▊         | 8/100 [00:01<00:19,  4.78it/s][A[A

  9%|▉         | 9/100 [00:02<00:18,  4.92it/s][A[A

 10%|█         | 10/100 [00:02<00:17,  5.03it/s][A[A

 11%|█         | 11/100 [00:02<00:17,  5.11it/s][A[A

 12%|█▏        | 12/100 [00:02<00:17,  5.15it/s][A[A

 13%|█▎        | 13/100 [00:02<00:16,  5.20it/s][A[A

 14%|█▍        | 14/100 [00:02<00:16,  5.23it/s][A[A

 15%|█▌        | 15/100 [00:03<00:16,  5.25it/s][A[A

 16%|█▌        | 16/100 [00:03<00:15,  5.26it/s][A[A

 17%|█▋        | 17/100 [00:03<00:15,  5.27it/s][A[A

 18%|█▊       

In [101]:
import matplotlib.pyplot as plt

2018-12-05 14:42:23,260 [MainThread  ] [DEBUG]  CACHEDIR=/root/.cache/matplotlib
2018-12-05 14:42:23,262 [MainThread  ] [INFO ]  font search path ['/opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
2018-12-05 14:42:23,537 [MainThread  ] [DEBUG]  trying fontname /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneral.ttf
2018-12-05 14:42:23,539 [MainThread  ] [DEBUG]  trying fontname /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerifDisplay.ttf
2018-12-05 14:42:23,540 [MainThread  ] [DEBUG]  trying fontname /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralItalic.ttf
2018-12-05 14:42:23,542 [MainThread  ] [DEBUG]  trying fontname /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-BoldItalic.ttf
2018-12-

2018-12-05 14:42:23,608 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-BoldOblique.ttf
2018-12-05 14:42:23,610 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/STIXSizFourSymReg.ttf
2018-12-05 14:42:23,611 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/STIXSizThreeSymReg.ttf
2018-12-05 14:42:23,613 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/STIXSizOneSymReg.ttf
2018-12-05 14:42:23,615 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-Oblique.ttf
2018-12-05 14:42:23,617 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/STIXSizTwoSymReg.ttf
2018-12-05 14:42:23,619 [MainThread  ] [DEBUG]  createFontDic

2018-12-05 14:42:23,810 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/afm/phvr8an.afm
2018-12-05 14:42:23,816 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/afm/putri8a.afm
2018-12-05 14:42:23,821 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/afm/phvro8an.afm
2018-12-05 14:42:23,827 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts/Times-BoldItalic.afm
2018-12-05 14:42:23,837 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/afm/pplr8a.afm
2018-12-05 14:42:23,841 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/afm/pncb8a.afm
2018-12-05 14:42:23,846 [MainThread  ] [DEBUG]  createFontDict: /opt/conda/lib/python3.6/site-packages/matplotlib

In [118]:
thresh = np.array([np.max(f1s[:,i]) for i in range(28)])

In [130]:
def F1_soft(preds,targs,th=0.5,d=50.0):
    preds = d*(preds - th)
    targs = targs.astype(np.float)
    score = 2.0*(preds*targs).sum(axis=0)/((preds+targs).sum(axis=0) + 1e-6)
    return score

In [131]:
def fit_val(x,y):
    params = 0.5*np.ones(len(name_label_dict))
    wd = 1e-5
    error = lambda p: np.concatenate((F1_soft(x,y,p) - 1.0,
                                      wd*(p - 0.5)), axis=None)
    p, success = opt.leastsq(error, params)
    return p

In [121]:
print('Individual F1-scores for each class:')
print(thresh)
print('Macro F1-score CV =', np.mean(np.max(f1s, axis=0)))

Individual F1-scores for each class:
[0.58364482 0.11037344 0.22123785 0.0918904  0.12696042 0.1487791
 0.07876231 0.16317348 0.00378843 0.00261746 0.00702988 0.07583643
 0.05728728 0.05785124 0.11614731 0.00169827 0.05032823 0.02766355
 0.05351759 0.08865202 0.01304225 0.21375978 0.05001761 0.17606944
 0.02487797 0.41337697 0.02474281 0.        ]
Macro F1-score CV = 0.1065402254837565


In [None]:
[0.54841 0.59038 0.58134 0.55174 0.58924 0.61733 0.50403 0.57952 0.50263 0.44694 0.37854 0.60698 0.58683
 0.57907 0.5511  0.50602,0.48044,0.48919,0.51825,0.5246,0.40525,0.50145,0.51568,0.57143,0.68486,0.51439,0.49912,0.50548]

In [60]:
result = []
for batch_idx, (input, target) in tqdm(enumerate(test_loader)):
    input = input.to(DEVICE)
    result.append(model(input).detach().cpu())

1171it [17:20,  1.34it/s]


In [63]:
preds = []
for r in result:
    for t in r:
        preds.append(torch.sigmoid(t).numpy())

In [90]:
THRESHOLD=0.95

In [91]:
thresh[thresh < THRESHOLD] = THRESHOLD

In [122]:
for i in tqdm_notebook(range(subm.Id.size)):
    subm.iloc[i,1] = ' '.join(np.where(preds[i] > thresh)[0].astype(str))

HBox(children=(IntProgress(value=0, max=11702), HTML(value='')))

In [123]:
np.array([len(k) for k in subm.Predicted.str.split()]).mean()

26.15262348316527

In [95]:
subm

Unnamed: 0,Id,Predicted
0,00008af0-bad0-11e8-b2b8-ac1f6b6435d0,2 15 19 26
1,0000a892-bacf-11e8-b2b8-ac1f6b6435d0,
2,0006faa6-bac7-11e8-b2b7-ac1f6b6435d0,2 4 14 15 19 26
3,0008baca-bad7-11e8-b2b9-ac1f6b6435d0,
4,000cce7e-bad4-11e8-b2b8-ac1f6b6435d0,
5,00109f6a-bac8-11e8-b2b7-ac1f6b6435d0,
6,001765de-bacd-11e8-b2b8-ac1f6b6435d0,2 4 13 14 15 16 19 26
7,0018641a-bac9-11e8-b2b8-ac1f6b6435d0,
8,00200f22-bad7-11e8-b2b9-ac1f6b6435d0,
9,0026f154-bac6-11e8-b2b7-ac1f6b6435d0,


In [94]:
subm.to_csv('my_subm.csv',index=False)