In [1]:
import numpy as np, pandas as pd
import os
import time

In [2]:
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))

In [3]:
import torch
import torchvision.models as models
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter

In [4]:
import random
seed = 34
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed) 
torch.cuda.manual_seed_all(seed) 
torch.backends.cudnn.deterministic=True

In [5]:
from helpers import Imagefolder_multilabel_inference as myImagefolder_val

In [6]:
def bnwd_optim_params(model, model_params, master_params):
    bn_params, remaining_params = split_bn_params(model, model_params, master_params)
    return [{'params':bn_params,'weight_decay':0}, {'params':remaining_params}]

def split_bn_params(model, model_params, master_params):
    def get_bn_params(module):
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): return module.parameters()
        accum = set()
        for child in module.children(): [accum.add(p) for p in get_bn_params(child)]
        return accum
    
    mod_bn_params = get_bn_params(model)
    zipped_params = list(zip(model_params, master_params))

    mas_bn_params = [p_mast for p_mod,p_mast in zipped_params if p_mod in mod_bn_params]
    mas_rem_params = [p_mast for p_mod,p_mast in zipped_params if p_mod not in mod_bn_params]
    return mas_bn_params, mas_rem_params

def save_checkpoint(state, is_best, filename):
    if is_best:
        print("-> Saving a new best ...")
        torch.save(state, filename)
    else:
        print("-> Validation accuracy did not improve ...")
        
def load_checkpoint(load_path, model, optimizer=None, warmup=False):
    if os.path.isfile(load_path):
        print("-> Loading checkpoint '{}'".format(load_path))
        checkpoint = torch.load(load_path)
        epoch = checkpoint['epoch'] if not warmup else -1
        acc_valid = checkpoint['acc_valid']
        acc_train = checkpoint['acc_train']
        loss_valid = checkpoint['loss_valid']
        loss_train = checkpoint['loss_train']
        state_dict = checkpoint['state_dict']
        itrn_chkpt = checkpoint['step'] if not warmup else 0
        
#         model.load_state_dict(state_dict)
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove 'module.' of dataparallel
            new_state_dict[name]=v
        model.load_state_dict(new_state_dict)
            
        print("-> Loaded checkpoint at epoch {} step {} ".format(epoch, itrn_chkpt))
        if warmup:
            return epoch, acc_valid, acc_train, loss_train, loss_valid, itrn_chkpt
        else:
            if optimizer != None:
                optimizer.load_state_dict(checkpoint['optimizer'])
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()
            return epoch-1, acc_valid, acc_train, loss_train, loss_valid, itrn_chkpt
    else:
        print("-> No checkpoint found at '{}'".format(load_path))
        return None

def labels_loop(list_class, label):
    for ilabel in label:
        for j in ilabel:
            list_class[j] += 1
    return list_class
    
def idx_to_desc(idx):
    return imgset_train.class_to_desc[imgset_train.idx_to_class[idx]]

class SoftF2Loss(torch.nn.Module):

    def __init__(self):
        super(SoftF2Loss,self).__init__()
        
    def forward(self, logits, labels):
        __small_value=1e-6
        beta = 2
        batch_size = logits.size()[0]
        p = torch.nn.functional.sigmoid(logits)
        l = labels
        num_pos = torch.sum(p, 1) + __small_value
        num_pos_hat = torch.sum(l, 1) + __small_value
        tp = torch.sum(l * p, 1)
        precise = tp / num_pos
        recall = tp / num_pos_hat
        fs = (1 + beta * beta) * precise * recall / (beta * beta * precise + recall + __small_value)
        loss = fs.sum() / batch_size
        return (1 - loss)
    
def add_to_result(result, valid, img, path, pred, labels_val):
    batch_size = pred.shape[0]
    for ii in range(batch_size):
        # result is dict, keys are classes, values are probs, pred is probs
        img.append(path[ii].split('/')[-1][:-4])
        for kk, p in enumerate(pred[ii]):
            result[kk].append(float(p)) # result is probs
        if labels_val is not None:
            for kk, l in enumerate(labels_val[ii]):
                valid[kk].append(int(l))

In [7]:
image_resize = (224,224)
batch_size_valid = 32
num_workers_valid = 12

chkpt = True
warmup = False
chkpt_file = 'accval-0.1413_lossval-56.2363_epoch-9_step-10440_checkpoint.pth'

In [8]:
img_transform_valid = transforms.Compose([transforms.Resize(image_resize),transforms.ToTensor()])

In [9]:
imgset_valid = myImagefolder_val.DatasetFolder(root='input/train_images/train_all/', label_file='input/label_hmn_mch_valid.csv', desc_file='input/label_hmn_mch_desc.csv', transform=img_transform_valid)
loader_valid = torch.utils.data.DataLoader(imgset_valid, batch_size=batch_size_valid, num_workers=num_workers_valid, shuffle=False)
n_class = len(imgset_valid.classes)

In [10]:
from senet import se_resnext101_32x4d
device = torch.device('cuda')
model = se_resnext101_32x4d(pretrained=None, num_classes=553 if (chkpt&warmup) else n_class, bn0=True)
model.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
# model = nn.DataParallel(model)

In [11]:
class_monitor = ['/m/01g317', '/m/09j2d', '/m/0dzct', '/m/07j7r', '/m/05s2s'] # 5 max
# criterion = nn.BCEWithLogitsLoss().to(device)
criterion = SoftF2Loss().to(device)

In [12]:
if chkpt:
    epoch, acc_valid, acc_train, loss_valid, loss_train, itrn_chkpt = load_checkpoint('chkpt/'+chkpt_file, 
                                                   model, None, warmup=warmup)
    if warmup:
        model.module.last_linear = nn.Linear(model.module.last_linear.in_features, n_class).to(device)
    print(epoch, acc_valid, acc_train, loss_valid, loss_train, itrn_chkpt)

-> Loading checkpoint 'chkpt/accval-0.1413_lossval-56.2363_epoch-9_step-10440_checkpoint.pth'
-> Loaded checkpoint at epoch 8 step 10440 
7 0.14126008805833443 0.1410138576230247 tensor(119.0210, device='cuda:0', requires_grad=True) tensor(56.2363, device='cuda:0') 10440


In [13]:
model = model.to(device)
writer = SummaryWriter()

# Train the model
total_step = len(loader_valid)
# Evaluate validation
with torch.no_grad():
    cutoffs = torch.ones(n_class)*0.05
    correct_val = 0
    total_val = 0
    loss_valid = 0
    icnt_val = 0
    correct_class_val = torch.zeros(n_class)
    correct_class_prob_val = torch.zeros(n_class)
    total_class_val = torch.zeros(n_class)
    result = {}
    valid = {}
    img = []
    for ii in range(n_class):
        result[ii] = []
        valid[ii] = []

    model.eval()
    t77 = time.time()
    for ival, (images_val, labels_val, path) in enumerate(loader_valid):
        print("Val_Step:{}".format(ival))
        t7 = time.time()
        print("Loader:{}".format(t7-t77))
        images_val = images_val.to(device)
        labels_val = labels_val.to(device)
        outputs_val = model(images_val)
        probs_val = torch.nn.Sigmoid()(outputs_val).cpu()
        iloss_val = criterion(outputs_val, labels_val.float())
        if ~torch.isnan(iloss_val):
            icnt_val += 1
            loss_valid += iloss_val
        labels_val = labels_val.cpu()
        predicted_val = [ torch.topk(outputs_val.data[ii], np.where(labels_val[ii])[0].shape[0])[1].cpu().sort()[0].numpy() for ii in range(labels_val.shape[0]) ]
        total_val += int(labels_val.sum())
        correct_val += np.sum([ (predicted_val[ii] == np.sort(np.where(labels_val[ii])[0])).sum() for ii in range(labels_val.shape[0]) ])
        total_class_val += labels_val.sum(0).cpu().float()
        correct_class_val = labels_loop(correct_class_val, [ np.where(labels_val[ii])[0][predicted_val[ii] == np.where(labels_val[ii])[0]] for ii in range(labels_val.shape[0]) ])
        correct_class_prob_val = np.sum( np.array([ np.array(labels_val[ii].float()*probs_val[ii]) for ii in range(labels_val.shape[0]) ]), axis=0 )
        t8 = time.time()
        t77 = t8
        print("Inference:{}".format(t8-t7))
        
        pred = probs_val.data

        add_to_result(result, valid, img, path, pred, labels_val)
        writer.add_scalar('loss/valid_per_batch', iloss_val.item(), ival)

    acc_valid = correct_val / total_val
    acc_class_valid = correct_class_val / total_class_val
    acc_class_valid[torch.isnan(acc_class_valid)] = 0.
    avg_class_prob_valid = correct_class_prob_val / correct_val
    avg_class_prob_valid[np.isnan(avg_class_prob_valid)] = 0.
    
    writer.add_scalar('loss/valid', loss_valid/icnt_val, 1 )
    writer.add_scalars('accuracy', { 'val_avg': acc_valid,
                                     'val_max0'+imgset_valid.class_to_desc[class_monitor[0]]: acc_class_valid[imgset_valid.class_to_idx[class_monitor[0]]], 
                                     'val_max1'+imgset_valid.class_to_desc[class_monitor[1]]: acc_class_valid[imgset_valid.class_to_idx[class_monitor[1]]], 
                                     'val_max2'+imgset_valid.class_to_desc[class_monitor[2]]: acc_class_valid[imgset_valid.class_to_idx[class_monitor[2]]], 
                                     'val_max3'+imgset_valid.class_to_desc[class_monitor[3]]: acc_class_valid[imgset_valid.class_to_idx[class_monitor[3]]], 
                                     'val_max4'+imgset_valid.class_to_desc[class_monitor[4]]: acc_class_valid[imgset_valid.class_to_idx[class_monitor[4]]] 
                                    },
                                     1 )

writer.close()

Val_Step:0
Loader:1.7382323741912842




Inference:0.2518174648284912
Val_Step:1
Loader:2.5158321857452393
Inference:0.13790082931518555
Val_Step:2
Loader:2.217898368835449
Inference:0.13523340225219727
Val_Step:3
Loader:2.228196620941162
Inference:0.12877345085144043
Val_Step:4
Loader:2.2754130363464355
Inference:0.13294148445129395
Val_Step:5
Loader:2.2647502422332764
Inference:0.13446879386901855
Val_Step:6
Loader:2.266200065612793
Inference:0.13393688201904297
Val_Step:7
Loader:2.3155713081359863
Inference:0.13216710090637207
Val_Step:8
Loader:2.338313579559326
Inference:0.13028717041015625
Val_Step:9
Loader:2.2750110626220703
Inference:0.13248085975646973
Val_Step:10
Loader:2.2235140800476074
Inference:0.14042139053344727
Val_Step:11
Loader:2.279700517654419
Inference:0.13158440589904785
Val_Step:12
Loader:2.3318958282470703
Inference:0.13856267929077148
Val_Step:13
Loader:2.3270578384399414
Inference:0.13584136962890625
Val_Step:14
Loader:2.2776334285736084
Inference:0.12972664833068848
Val_Step:15
Loader:2.327033281326

Val_Step:123
Loader:2.2444775104522705
Inference:0.1276555061340332
Val_Step:124
Loader:2.3295400142669678
Inference:0.1278529167175293
Val_Step:125
Loader:2.30521297454834
Inference:0.12796616554260254
Val_Step:126
Loader:2.3302507400512695
Inference:0.12740707397460938
Val_Step:127
Loader:2.372565984725952
Inference:0.1277458667755127
Val_Step:128
Loader:2.2641353607177734
Inference:0.12756943702697754
Val_Step:129
Loader:2.337062358856201
Inference:0.12817955017089844
Val_Step:130
Loader:2.348383903503418
Inference:0.1398906707763672
Val_Step:131
Loader:2.3334200382232666
Inference:0.12828445434570312
Val_Step:132
Loader:2.299520969390869
Inference:0.1282496452331543
Val_Step:133
Loader:2.4689226150512695
Inference:0.12945771217346191
Val_Step:134
Loader:2.2428150177001953
Inference:0.12850069999694824
Val_Step:135
Loader:2.3433690071105957
Inference:0.12762999534606934
Val_Step:136
Loader:2.317922830581665
Inference:0.1411445140838623
Val_Step:137
Loader:2.348388433456421
Inference

Val_Step:244
Loader:2.397791862487793
Inference:0.12806057929992676
Val_Step:245
Loader:2.3733677864074707
Inference:0.12889838218688965
Val_Step:246
Loader:2.317652463912964
Inference:0.12892651557922363
Val_Step:247
Loader:2.3856070041656494
Inference:0.12845945358276367
Val_Step:248
Loader:2.4097516536712646
Inference:0.12889504432678223
Val_Step:249
Loader:2.3627374172210693
Inference:0.1296093463897705
Val_Step:250
Loader:2.316096305847168
Inference:0.12725567817687988
Val_Step:251
Loader:2.3450238704681396
Inference:0.12951135635375977
Val_Step:252
Loader:2.3501689434051514
Inference:0.12587785720825195
Val_Step:253
Loader:2.245908498764038
Inference:0.12634944915771484
Val_Step:254
Loader:2.336360454559326
Inference:0.12917351722717285
Val_Step:255
Loader:2.354452610015869
Inference:0.1285536289215088
Val_Step:256
Loader:2.3590362071990967
Inference:0.1257617473602295
Val_Step:257
Loader:2.2936625480651855
Inference:0.1257493495941162
Val_Step:258
Loader:2.218602180480957
Infere

Val_Step:365
Loader:2.342832088470459
Inference:0.1289536952972412
Val_Step:366
Loader:2.3452041149139404
Inference:0.12710309028625488
Val_Step:367
Loader:2.237861394882202
Inference:0.12840485572814941
Val_Step:368
Loader:2.362537145614624
Inference:0.12865734100341797
Val_Step:369
Loader:2.3611958026885986
Inference:0.12833046913146973
Val_Step:370
Loader:2.3271796703338623
Inference:0.12923598289489746
Val_Step:371
Loader:2.3656084537506104
Inference:0.12671971321105957
Val_Step:372
Loader:2.3272831439971924
Inference:0.12866640090942383
Val_Step:373
Loader:2.3939576148986816
Inference:0.12840485572814941
Val_Step:374
Loader:2.3523988723754883
Inference:0.1286771297454834
Val_Step:375
Loader:2.4237756729125977
Inference:0.1286172866821289
Val_Step:376
Loader:2.351949691772461
Inference:0.1415095329284668
Val_Step:377
Loader:2.2749440670013428
Inference:0.1289505958557129
Val_Step:378
Loader:2.3960132598876953
Inference:0.12864375114440918
Val_Step:379
Loader:2.318366289138794
Infer

Val_Step:486
Loader:2.370054244995117
Inference:0.12909364700317383
Val_Step:487
Loader:2.359950542449951
Inference:0.12997865676879883
Val_Step:488
Loader:2.802483081817627
Inference:0.1301898956298828
Val_Step:489
Loader:2.2535362243652344
Inference:0.12697577476501465
Val_Step:490
Loader:2.258793592453003
Inference:0.12936162948608398
Val_Step:491
Loader:2.3240275382995605
Inference:0.12818622589111328
Val_Step:492
Loader:2.368767023086548
Inference:0.1289205551147461
Val_Step:493
Loader:2.387807607650757
Inference:0.12872815132141113
Val_Step:494
Loader:2.3193511962890625
Inference:0.12801313400268555
Val_Step:495
Loader:2.341111183166504
Inference:0.12858223915100098
Val_Step:496
Loader:2.2934255599975586
Inference:0.12667226791381836
Val_Step:497
Loader:2.192715644836426
Inference:0.12711501121520996
Val_Step:498
Loader:2.403306722640991
Inference:0.1289231777191162
Val_Step:499
Loader:2.3406596183776855
Inference:0.12642645835876465
Val_Step:500
Loader:2.2085464000701904
Inferen

Val_Step:607
Loader:2.3103859424591064
Inference:0.12864470481872559
Val_Step:608
Loader:2.39681339263916
Inference:0.12929391860961914
Val_Step:609
Loader:2.2746059894561768
Inference:0.1289222240447998
Val_Step:610
Loader:2.317007303237915
Inference:0.12838220596313477
Val_Step:611
Loader:2.2826333045959473
Inference:0.14212703704833984
Val_Step:612
Loader:2.408947467803955
Inference:0.1282637119293213
Val_Step:613
Loader:2.3557167053222656
Inference:0.12714695930480957
Val_Step:614
Loader:2.171302556991577
Inference:0.12662649154663086
Val_Step:615
Loader:2.3424978256225586
Inference:0.12902450561523438
Val_Step:616
Loader:2.4209938049316406
Inference:0.12909340858459473
Val_Step:617
Loader:2.3243772983551025
Inference:0.12894630432128906
Val_Step:618
Loader:2.8297464847564697
Inference:0.1414196491241455
Val_Step:619
Loader:2.3045475482940674
Inference:0.14048027992248535
Val_Step:620
Loader:2.3683154582977295
Inference:0.13195562362670898
Val_Step:621
Loader:2.3470935821533203
Inf

Val_Step:728
Loader:2.4077963829040527
Inference:0.1270465850830078
Val_Step:729
Loader:2.3096604347229004
Inference:0.12781643867492676
Val_Step:730
Loader:2.3880667686462402
Inference:0.12949800491333008
Val_Step:731
Loader:2.386444091796875
Inference:0.1267104148864746
Val_Step:732
Loader:2.3130273818969727
Inference:0.12848424911499023
Val_Step:733
Loader:2.3904366493225098
Inference:0.12804555892944336
Val_Step:734
Loader:2.376763105392456
Inference:0.12919068336486816
Val_Step:735
Loader:2.4068026542663574
Inference:0.12713027000427246
Val_Step:736
Loader:2.320298194885254
Inference:0.12992572784423828
Val_Step:737
Loader:2.396390199661255
Inference:0.12944793701171875
Val_Step:738
Loader:2.40722393989563
Inference:0.19600820541381836
Val_Step:739
Loader:2.4025790691375732
Inference:0.12695074081420898
Val_Step:740
Loader:2.3357183933258057
Inference:0.12783193588256836
Val_Step:741
Loader:2.4537699222564697
Inference:0.1270456314086914
Val_Step:742
Loader:2.3437843322753906
Infe

Val_Step:849
Loader:2.371845245361328
Inference:0.12926888465881348
Val_Step:850
Loader:2.339340925216675
Inference:0.12917685508728027
Val_Step:851
Loader:2.354607105255127
Inference:0.12904858589172363
Val_Step:852
Loader:2.4040579795837402
Inference:0.12832427024841309
Val_Step:853
Loader:2.40437388420105
Inference:0.12935209274291992
Val_Step:854
Loader:2.3756632804870605
Inference:0.1283247470855713
Val_Step:855
Loader:2.365267038345337
Inference:0.1286623477935791
Val_Step:856
Loader:2.4119465351104736
Inference:0.1282360553741455
Val_Step:857
Loader:2.3932852745056152
Inference:0.1279911994934082
Val_Step:858
Loader:2.3719239234924316
Inference:0.12883973121643066
Val_Step:859
Loader:2.4240589141845703
Inference:0.12824273109436035
Val_Step:860
Loader:2.3525784015655518
Inference:0.1289973258972168
Val_Step:861
Loader:2.3646936416625977
Inference:0.1293025016784668
Val_Step:862
Loader:2.4071826934814453
Inference:0.1285228729248047
Val_Step:863
Loader:2.3881232738494873
Inferenc

Val_Step:970
Loader:2.3767073154449463
Inference:0.12906360626220703
Val_Step:971
Loader:2.406345844268799
Inference:0.12748932838439941
Val_Step:972
Loader:2.314669370651245
Inference:0.12668776512145996
Val_Step:973
Loader:2.4171152114868164
Inference:0.12823963165283203
Val_Step:974
Loader:2.417536973953247
Inference:0.12795448303222656
Val_Step:975
Loader:2.4021260738372803
Inference:0.12795376777648926
Val_Step:976
Loader:2.4134654998779297
Inference:0.12845659255981445
Val_Step:977
Loader:2.3592681884765625
Inference:0.1265115737915039
Val_Step:978
Loader:2.374971389770508
Inference:0.26619553565979004
Val_Step:979
Loader:2.4072978496551514
Inference:0.12821078300476074
Val_Step:980
Loader:2.4109954833984375
Inference:0.15511488914489746
Val_Step:981
Loader:2.3825738430023193
Inference:0.12885522842407227
Val_Step:982
Loader:2.3444957733154297
Inference:0.1283423900604248
Val_Step:983
Loader:2.3638899326324463
Inference:0.1280515193939209
Val_Step:984
Loader:2.406264305114746
Inf

Val_Step:1089
Loader:2.3377318382263184
Inference:0.13125133514404297
Val_Step:1090
Loader:2.371525526046753
Inference:0.35184621810913086
Val_Step:1091
Loader:2.4031240940093994
Inference:0.127899169921875
Val_Step:1092
Loader:2.3323018550872803
Inference:0.1450953483581543
Val_Step:1093
Loader:2.349452257156372
Inference:0.12859296798706055
Val_Step:1094
Loader:2.448770761489868
Inference:0.1275780200958252
Val_Step:1095
Loader:2.3949577808380127
Inference:0.12672996520996094
Val_Step:1096
Loader:2.354865550994873
Inference:0.12799930572509766
Val_Step:1097
Loader:2.333188056945801
Inference:0.12795019149780273
Val_Step:1098
Loader:2.3141071796417236
Inference:0.12736177444458008
Val_Step:1099
Loader:2.344158887863159
Inference:0.12684178352355957
Val_Step:1100
Loader:2.364964485168457
Inference:0.30222034454345703
Val_Step:1101
Loader:2.3246119022369385
Inference:0.12810134887695312
Val_Step:1102
Loader:2.3620200157165527
Inference:0.12799572944641113
Val_Step:1103
Loader:2.41400074

Val_Step:1208
Loader:2.345193862915039
Inference:0.12873244285583496
Val_Step:1209
Loader:2.3352017402648926
Inference:0.1374659538269043
Val_Step:1210
Loader:2.345461845397949
Inference:0.13091278076171875
Val_Step:1211
Loader:2.3683576583862305
Inference:0.12975668907165527
Val_Step:1212
Loader:2.338261127471924
Inference:0.1277916431427002
Val_Step:1213
Loader:2.3264358043670654
Inference:0.12766027450561523
Val_Step:1214
Loader:2.3606209754943848
Inference:0.12907028198242188
Val_Step:1215
Loader:2.3464865684509277
Inference:0.12832188606262207
Val_Step:1216
Loader:2.383838653564453
Inference:0.129197359085083
Val_Step:1217
Loader:2.3678629398345947
Inference:0.12831974029541016
Val_Step:1218
Loader:2.266996383666992
Inference:0.12688207626342773
Val_Step:1219
Loader:2.3508076667785645
Inference:0.1268925666809082
Val_Step:1220
Loader:2.384944438934326
Inference:0.12755107879638672
Val_Step:1221
Loader:2.287734270095825
Inference:0.1269397735595703
Val_Step:1222
Loader:2.3374152183

Val_Step:1327
Loader:2.3342947959899902
Inference:0.12666606903076172
Val_Step:1328
Loader:2.2202394008636475
Inference:0.12711334228515625
Val_Step:1329
Loader:2.289886236190796
Inference:0.1269848346710205
Val_Step:1330
Loader:2.286111831665039
Inference:0.12805724143981934
Val_Step:1331
Loader:2.3295090198516846
Inference:0.12783408164978027
Val_Step:1332
Loader:2.3490142822265625
Inference:0.1280503273010254
Val_Step:1333
Loader:2.35601806640625
Inference:0.12860846519470215
Val_Step:1334
Loader:2.3700032234191895
Inference:0.12747788429260254
Val_Step:1335
Loader:2.2924561500549316
Inference:0.12871885299682617
Val_Step:1336
Loader:2.389253854751587
Inference:0.127913236618042
Val_Step:1337
Loader:2.2533037662506104
Inference:0.12757658958435059
Val_Step:1338
Loader:2.321063756942749
Inference:0.1271677017211914
Val_Step:1339
Loader:2.232940435409546
Inference:0.12662076950073242
Val_Step:1340
Loader:2.312784433364868
Inference:0.12783026695251465
Val_Step:1341
Loader:2.3389124870

Val_Step:1446
Loader:2.3023340702056885
Inference:0.12854266166687012
Val_Step:1447
Loader:2.3744056224823
Inference:0.12800812721252441
Val_Step:1448
Loader:2.3654816150665283
Inference:0.1285252571105957
Val_Step:1449
Loader:2.3407866954803467
Inference:0.12801218032836914
Val_Step:1450
Loader:2.297839879989624
Inference:0.12794828414916992
Val_Step:1451
Loader:2.3548543453216553
Inference:0.12795805931091309
Val_Step:1452
Loader:2.327221632003784
Inference:0.1277632713317871
Val_Step:1453
Loader:2.342017412185669
Inference:0.12803030014038086
Val_Step:1454
Loader:2.3863558769226074
Inference:0.12834620475769043
Val_Step:1455
Loader:2.3832526206970215
Inference:0.1278524398803711
Val_Step:1456
Loader:2.3991072177886963
Inference:0.12828969955444336
Val_Step:1457
Loader:2.3587124347686768
Inference:0.12854504585266113
Val_Step:1458
Loader:2.3691840171813965
Inference:0.1273813247680664
Val_Step:1459
Loader:2.3314085006713867
Inference:0.12767434120178223
Val_Step:1460
Loader:2.4154508

In [14]:
df_prob_val = pd.DataFrame(result)
df_label_val = pd.DataFrame(valid)

In [15]:
# df_prob_val.to_csv('val_probs.csv', index=False)
# df_label_val.to_csv('val_label.csv', index=False)

In [16]:
df_prob_val.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,7163,7164,7165,7166,7167,7168,7169,7170,7171,7172
0,9e-06,9e-06,1e-05,9e-06,9e-06,9e-06,9e-06,9e-06,9e-06,1e-05,...,9e-06,1.2e-05,9e-06,9e-06,9e-06,1e-05,9e-06,9e-06,9e-06,9e-06
1,1.6e-05,1.6e-05,1.6e-05,1.6e-05,1.6e-05,1.6e-05,1.6e-05,1.6e-05,1.6e-05,1.7e-05,...,1.6e-05,2.1e-05,1.6e-05,1.6e-05,1.6e-05,1.9e-05,1.6e-05,1.6e-05,1.6e-05,1.7e-05
2,1.7e-05,1.7e-05,1.8e-05,1.7e-05,1.7e-05,1.7e-05,1.7e-05,1.7e-05,1.7e-05,1.8e-05,...,1.7e-05,2.1e-05,1.7e-05,1.8e-05,1.7e-05,1.8e-05,1.7e-05,1.7e-05,1.7e-05,1.8e-05
3,7e-06,7e-06,7e-06,7e-06,7e-06,7e-06,7e-06,7e-06,7e-06,7e-06,...,7e-06,1e-05,7e-06,7e-06,7e-06,7e-06,7e-06,7e-06,7e-06,8e-06
4,6e-06,6e-06,6e-06,6e-06,6e-06,6e-06,6e-06,6e-06,6e-06,6e-06,...,6e-06,1e-05,6e-06,6e-06,6e-06,6e-06,6e-06,6e-06,6e-06,6e-06


In [17]:
df_bool_val = df_prob_val >=0.1
(df_bool_val.sum()!=0).sum()

148

In [18]:
from sklearn.metrics import fbeta_score

def fbeta(true_label, prediction):
   return fbeta_score(true_label, prediction, beta=2, average='samples')
   
def get_optimal_threshhold(true_label, prediction, iterations = 100):
    best_threshhold = [0.2]*n_class    
    for t in range(n_class):
        best_fbeta = 0
        temp_threshhold = [0.2]*n_class
        for i in range(iterations):
            temp_value = i / float(iterations)
            temp_threshhold[t] = temp_value
            temp_fbeta = fbeta(true_label, prediction >= temp_threshhold)
            if  temp_fbeta >= best_fbeta:
                best_fbeta = temp_fbeta
                best_threshhold[t] = temp_value
    return best_threshhold

In [19]:
# t1 = time.time()
# cutoffs = get_optimal_threshhold(df_label.loc[:2].values, df_prob.loc[:2].values, iterations=10)
# print(time.time()-t1)

In [20]:
t1 = time.time()
cutoffs = [0.85]*n_class
f2 = fbeta(df_label_val.values, df_prob_val.values >= cutoffs)
print("Time taken: {}".format(time.time()-t1))
print("F2 score: {}".format(f2))

Time taken: 41.91221356391907
F2 score: 0.4277894307631372


In [21]:
[[0.1, 0.3733349231333955],
 [0.2, 0.3791458752884131],
 [0.3, 0.3815689878019582],
 [0.4, 0.38292944161524584],
 [0.5, 0.38353335667175453],
 [0.6, 0.3836194407872708],
 [0.7, 0.38291952404157054],
 [0.8, 0.3811532538022081],
 [0.9, 0.3763843884836092]
]

[[0.1, 0.3733349231333955],
 [0.2, 0.3791458752884131],
 [0.3, 0.3815689878019582],
 [0.4, 0.38292944161524584],
 [0.5, 0.38353335667175453],
 [0.6, 0.3836194407872708],
 [0.7, 0.38291952404157054],
 [0.8, 0.3811532538022081],
 [0.9, 0.3763843884836092]]