In [1]:
import numpy as np, pandas as pd
import os
import time

In [2]:
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))

In [3]:
import torch
import torchvision.models as models
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter

In [4]:
import random
seed = 100
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed) 
torch.cuda.manual_seed_all(seed) 
torch.backends.cudnn.deterministic=True

In [5]:
from helpers import Imagefolder_multilabel as myImagefolder

In [6]:
import Augmentor

In [7]:
image_resize = (224,224)
batch_size_train = 64*8
num_workers_train = 12*8
batch_size_valid = 64*8
num_workers_valid = 12*8

In [8]:
p = Augmentor.Pipeline()
# p.skew(probability=0.45, magnitude=0.3)
p.rotate_random_90(probability=0.45)
p.rotate(probability=0.45, max_left_rotation=10, max_right_rotation=10)
p.flip_random(probability=0.45)
# p.crop_random(probability=0.45, percentage_area=0.5)
img_transform = transforms.Compose([p.torch_transform(), 
                                    transforms.Resize(image_resize), transforms.ToTensor()])

In [9]:
imgset_train = myImagefolder.DatasetFolder(root='input/train_images/train_all/', label_file='input/label_hmn_mch_train.csv', desc_file='input/label_hmn_mch_desc.csv', transform=img_transform)
sampler_train = torch.utils.data.sampler.WeightedRandomSampler(weights=imgset_train.weights, num_samples=int(imgset_train.num_samples), replacement=True)
loader_train = torch.utils.data.DataLoader(imgset_train, batch_size=batch_size_train, num_workers=num_workers_train, sampler=sampler_train)
# loader_train = torch.utils.data.DataLoader(imgset_train, batch_size=batch_size_train, num_workers=num_workers_train, shuffle=True)

In [10]:
img_transform_valid = transforms.Compose([transforms.Resize(image_resize),transforms.ToTensor()])

In [11]:
imgset_valid = myImagefolder.DatasetFolder(root='input/train_images/train_all/', label_file='input/label_hmn_mch_valid.csv', desc_file='input/label_hmn_mch_desc.csv', transform=img_transform_valid)
loader_valid = torch.utils.data.DataLoader(imgset_valid, batch_size=batch_size_valid, num_workers=num_workers_valid, shuffle=False)

In [12]:
def bnwd_optim_params(model, model_params, master_params):
    bn_params, remaining_params = split_bn_params(model, model_params, master_params)
    return [{'params':bn_params,'weight_decay':0}, {'params':remaining_params}]

def split_bn_params(model, model_params, master_params):
    def get_bn_params(module):
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): return module.parameters()
        accum = set()
        for child in module.children(): [accum.add(p) for p in get_bn_params(child)]
        return accum
    
    mod_bn_params = get_bn_params(model)
    zipped_params = list(zip(model_params, master_params))

    mas_bn_params = [p_mast for p_mod,p_mast in zipped_params if p_mod in mod_bn_params]
    mas_rem_params = [p_mast for p_mod,p_mast in zipped_params if p_mod not in mod_bn_params]
    return mas_bn_params, mas_rem_params

def save_checkpoint(state, is_best, filename):
    if is_best:
        print("-> Saving a new best ...")
        torch.save(state, filename)
    else:
        print("-> Validation accuracy did not improve ...")
        
def load_checkpoint(load_path, model, optimizer=None, warmup=False):
    if os.path.isfile(load_path):
        print("-> Loading checkpoint '{}'".format(load_path))
        checkpoint = torch.load(load_path)
        epoch = checkpoint['epoch'] if not warmup else -1
        acc_valid = checkpoint['acc_valid']
        acc_train = checkpoint['acc_train']
        loss_valid = checkpoint['loss_valid']
        loss_train = checkpoint['loss_train']
        state_dict = checkpoint['state_dict']
        itrn_chkpt = checkpoint['step'] if not warmup else 0
        
        model.load_state_dict(state_dict)
#         from collections import OrderedDict
#         new_state_dict = OrderedDict()
#         for k, v in state_dict.items():
#             name = k[7:] # remove 'module.' of dataparallel
#             new_state_dict[name]=v
#         model.load_state_dict(new_state_dict)
#         model = nn.DataParallel(model)
            
        print("-> Loaded checkpoint at epoch {} step {} ".format(epoch, itrn_chkpt))
        if warmup:
            return epoch, acc_valid, acc_train, loss_train, loss_valid, itrn_chkpt
        else:
            if optimizer != None:
                optimizer.load_state_dict(checkpoint['optimizer'])
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()
            return epoch-1, acc_valid, acc_train, loss_train, loss_valid, itrn_chkpt
    else:
        print("-> No checkpoint found at '{}'".format(load_path))
        return None
    
def update_lr(optimizer, lr):    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
def update_wd(optimizer, wd):
    optimizer.param_groups[1]['weight_decay'] = wd

def labels_loop(list_class, label):
    for ilabel in label:
        for j in ilabel:
            list_class[j] += 1
    return list_class
    
def idx_to_desc(idx):
    return imgset_train.class_to_desc[imgset_train.idx_to_class[idx]]

class SoftF2Loss(torch.nn.Module):

    def __init__(self):
        super(SoftF2Loss,self).__init__()
        
    def forward(self, logits, labels):
        __small_value=1e-6
        beta = 2
        batch_size = logits.size()[0]
        p = torch.nn.functional.sigmoid(logits)
        l = labels
        num_pos = torch.sum(p, 1) + __small_value
        num_pos_hat = torch.sum(l, 1) + __small_value
        tp = torch.sum(l * p, 1)
        precise = tp / num_pos
        recall = tp / num_pos_hat
        fs = (1 + beta * beta) * precise * recall / (beta * beta * precise + recall + __small_value)
        loss = fs.sum() / batch_size
        return (1 - loss)

In [13]:
device = torch.device('cuda')
num_epochs = 30
lr = 5e-6
wd = 1e-4

In [14]:
n_class = len(imgset_train.classes)

In [20]:
chkpt = True
warmup = False
chkpt_file = 'accval-0.1413_lossval-56.2363_epoch-9_step-10440_checkpoint.pth'

In [21]:
from senet import se_resnext101_32x4d

model = se_resnext101_32x4d(pretrained=None, num_classes=553 if (chkpt&warmup) else n_class, bn0=True)
model.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)

In [22]:
# from resnet import resnet50

# model = resnet50(pretrained=False, num_classes=553 if (chkpt&warmup) else n_class, bn0=True).to(device)

In [23]:
model = nn.DataParallel(model)
class_monitor = ['/m/01g317', '/m/09j2d', '/m/0dzct', '/m/07j7r', '/m/05s2s'] # 5 max
# criterion = nn.CrossEntropyLoss(weight=torch.Tensor(imgset_train.weights_).to(device))
# criterion = nn.BCEWithLogitsLoss().to(device)
criterion = SoftF2Loss().to(device)
optim_params = bnwd_optim_params(model, model.parameters(), model.parameters())
# optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = torch.optim.Adam(optim_params, lr=lr, weight_decay=wd)
# optimizer = torch.optim.SGD(optim_params, lr=lr, momentum=0.9, weight_decay=wd)
epoch = -1
itrn_chkpt = 0

In [24]:
if chkpt:
    epoch, acc_valid, acc_train, loss_valid, loss_train, itrn_chkpt = load_checkpoint('chkpt/'+chkpt_file, 
                                                   model, optimizer, warmup=warmup)
#                                                    model, optimizer, warmup=warmup)
    if warmup:
        model.module.last_linear = nn.Linear(model.module.last_linear.in_features, n_class).to(device)
    print(epoch, acc_valid, acc_train, loss_valid, loss_train, itrn_chkpt)
    
update_lr(optimizer, lr)
update_wd(optimizer, wd)

-> Loading checkpoint 'chkpt/accval-0.1413_lossval-56.2363_epoch-9_step-10440_checkpoint.pth'
-> Loaded checkpoint at epoch 8 step 10440 
7 0.14126008805833443 0.1410138576230247 tensor(119.0210, device='cuda:0', requires_grad=True) tensor(56.2363, device='cuda:0') 10440


In [25]:
log_freq = 1800000//3//batch_size_train

In [26]:
model = model.to(device)
writer = SummaryWriter()

# Train the model
total_step = len(loader_train)
for iepoch in range(epoch+1, num_epochs):
    correct = 0
    total = 0
    loss = 0.
    icnt = 0
    correct_class = torch.zeros(n_class)
    total_class = torch.zeros(n_class)
    
    model.train()
    t00 = time.time()
    for itrn, (images, labels) in enumerate(loader_train, itrn_chkpt+1):
        print("Trn_Step:{}".format(iepoch*total_step+itrn))
        t0 = time.time()
        print("Loader:{}".format(t0-t00))
        images = images.to(device)
        t1 = time.time()
        print("Image transfer:{}".format(t1-t0))
        labels = labels.to(device)
        t2 = time.time()
        print("Label transfer:{}".format(t2-t1))
        
        # Forward pass
        outputs = model(images)
        t3 = time.time()
        print("Out:{}".format(t3-t2))
        iloss_trn = criterion(outputs, labels.float())
        loss += iloss_trn
        icnt += 1
        t4 = time.time()
        print("Loss:{}".format(t4-t3))
        
        # evaluate training
        labels = labels.cpu()
        predicted = [ torch.topk(outputs.data[ii], np.where(labels[ii])[0].shape[0])[1].cpu().sort()[0].numpy() for ii in range(labels.shape[0]) ]
        total += int(labels.sum())
        correct += np.sum([ (predicted[ii] == np.sort(np.where(labels[ii])[0])).sum() for ii in range(labels.shape[0]) ])
        total_class += labels.sum(0).cpu().float()
        correct_class = labels_loop(correct_class, [ np.where(labels[ii])[0][predicted[ii] == np.where(labels[ii])[0]] for ii in range(labels.shape[0]) ])
        t5 = time.time()
        print("Eval:{}".format(t5-t4))
        
        # Backward and optimize
        optimizer.zero_grad()
        iloss_trn.backward()
        optimizer.step()
        t6 = time.time()
        t00 = t6
        print("Opt:{}".format(t6-t5))
        
        writer.add_scalar('loss/train_per_batch', iloss_trn.item(), iepoch*total_step+itrn)
        writer.add_scalar('lr', lr, iepoch+(1+itrn)/log_freq )
        
        if (iepoch*total_step+itrn+1) % log_freq == 0:
            # Evaluate validation
            with torch.no_grad():
                correct_val = 0
                total_val = 0
                loss_valid = 0
                icnt_val = 0
                correct_class_val = torch.zeros(n_class)
                total_class_val = torch.zeros(n_class)

                model.eval()
                t77 = time.time()
                for ival, (images_val, labels_val) in enumerate(loader_valid):
                    print("Val_Step:{}".format(iepoch*total_step+ival))
                    t7 = time.time()
                    print("Loader:{}".format(t7-t77))
                    images_val = images_val.to(device)
                    labels_val = labels_val.to(device)
                    outputs_val = model(images_val)
                    iloss_val = criterion(outputs_val, labels_val.float())
                    if ~torch.isnan(iloss_val):
                        icnt_val += 1
                        loss_valid += iloss_val
                    labels_val = labels_val.cpu()
                    predicted_val = [ torch.topk(outputs_val.data[ii], np.where(labels_val[ii])[0].shape[0])[1].cpu().sort()[0].numpy() for ii in range(labels_val.shape[0]) ]
                    total_val += int(labels_val.sum())
                    correct_val += np.sum([ (predicted_val[ii] == np.sort(np.where(labels_val[ii])[0])).sum() for ii in range(labels_val.shape[0]) ])
                    total_class_val += labels_val.sum(0).cpu().float()
                    correct_class_val = labels_loop(correct_class_val, [ np.where(labels_val[ii])[0][predicted_val[ii] == np.where(labels_val[ii])[0]] for ii in range(labels_val.shape[0]) ])

                    t8 = time.time()
                    t77 = t8
                    print("Inference:{}".format(t8-t7))
                    
                    writer.add_scalar('loss/valid_per_batch', iloss_val.item(), iepoch*total_step+ival)

                acc_train = correct / total
                acc_class_train = correct_class / total_class
                acc_class_train[torch.isnan(acc_class_train)] = 0.
                acc_valid = correct_val / total_val
                acc_class_valid = correct_class_val / total_class_val
                acc_class_valid[torch.isnan(acc_class_valid)] = 0.

                writer.add_scalars('loss', {'valid': loss_valid/icnt_val,
                                            'train': loss/icnt }, iepoch+(1+itrn)/log_freq )
                writer.add_scalars('accuracy', { 'val_avg': acc_valid,
                                                 'trn_avg': acc_train,
                                                 'trn_max0'+imgset_train.class_to_desc[class_monitor[0]]: acc_class_train[imgset_train.class_to_idx[class_monitor[0]]], 
                                                 'trn_max1'+imgset_train.class_to_desc[class_monitor[1]]: acc_class_train[imgset_train.class_to_idx[class_monitor[1]]], 
                                                 'trn_max2'+imgset_train.class_to_desc[class_monitor[2]]: acc_class_train[imgset_train.class_to_idx[class_monitor[2]]], 
                                                 'trn_max3'+imgset_train.class_to_desc[class_monitor[3]]: acc_class_train[imgset_train.class_to_idx[class_monitor[3]]], 
                                                 'trn_max4'+imgset_train.class_to_desc[class_monitor[4]]: acc_class_train[imgset_train.class_to_idx[class_monitor[4]]], 
                                                 'val_max0'+imgset_train.class_to_desc[class_monitor[0]]: acc_class_valid[imgset_train.class_to_idx[class_monitor[0]]], 
                                                 'val_max1'+imgset_train.class_to_desc[class_monitor[1]]: acc_class_valid[imgset_train.class_to_idx[class_monitor[1]]], 
                                                 'val_max2'+imgset_train.class_to_desc[class_monitor[2]]: acc_class_valid[imgset_train.class_to_idx[class_monitor[2]]], 
                                                 'val_max3'+imgset_train.class_to_desc[class_monitor[3]]: acc_class_valid[imgset_train.class_to_idx[class_monitor[3]]], 
                                                 'val_max4'+imgset_train.class_to_desc[class_monitor[4]]: acc_class_valid[imgset_train.class_to_idx[class_monitor[4]]] 
                                                },
                                                 iepoch+(1+itrn)/log_freq )
                
            
            save_checkpoint({'epoch': iepoch,
                         'step': itrn,
                         'state_dict': model.state_dict(),
                         'acc_valid': acc_valid,
                         'acc_train': acc_train,
                         'loss_train': loss,
                         'loss_valid': loss_valid,
                         'optimizer': optimizer.state_dict()},
                          1, 'chkpt/accval-{:.4f}_lossval-{:.4f}_epoch-{}_step-{}_checkpoint.pth'.format(acc_valid, loss_valid, iepoch+1, itrn))
            model.train()

writer.close()

Trn_Step:51001
Loader:85.2647533416748
Image transfer:0.10731649398803711
Label transfer:0.010593175888061523
Out:50.772974729537964
Loss:0.02975320816040039




Eval:0.5344877243041992
Opt:1.8852379322052002
Trn_Step:51002
Loader:0.07039713859558105
Image transfer:0.11089920997619629
Label transfer:0.01075601577758789
Out:1.3796422481536865
Loss:0.023973941802978516
Eval:0.48139500617980957
Opt:0.7518131732940674
Trn_Step:51003
Loader:0.08706450462341309
Image transfer:0.09698033332824707
Label transfer:0.009490013122558594
Out:0.6214425563812256
Loss:0.012287616729736328
Eval:0.39489054679870605
Opt:0.6409358978271484
Trn_Step:51004
Loader:0.07942557334899902
Image transfer:0.10358071327209473
Label transfer:0.01395869255065918
Out:0.6310863494873047
Loss:0.012705326080322266
Eval:0.39255857467651367
Opt:0.6552886962890625
Trn_Step:51005
Loader:0.07832932472229004
Image transfer:0.09982681274414062
Label transfer:0.010576009750366211
Out:0.6385772228240967
Loss:0.009493112564086914
Eval:0.42600488662719727
Opt:0.6373403072357178
Trn_Step:51006
Loader:0.08339548110961914
Image transfer:0.10341548919677734
Label transfer:0.01023411750793457
Out

Opt:0.5455746650695801
Trn_Step:51041
Loader:0.05199003219604492
Image transfer:0.0870978832244873
Label transfer:0.008691787719726562
Out:0.4547402858734131
Loss:0.014782190322875977
Eval:0.26967835426330566
Opt:0.5406460762023926
Trn_Step:51042
Loader:0.05165672302246094
Image transfer:0.07536959648132324
Label transfer:0.007944822311401367
Out:0.5361886024475098
Loss:0.007927417755126953
Eval:0.2653226852416992
Opt:0.5420255661010742
Trn_Step:51043
Loader:0.049569129943847656
Image transfer:0.09025216102600098
Label transfer:0.008633852005004883
Out:0.5019402503967285
Loss:0.010627031326293945
Eval:0.2744567394256592
Opt:0.5403985977172852
Trn_Step:51044
Loader:0.04927182197570801
Image transfer:0.0863945484161377
Label transfer:0.008651256561279297
Out:0.501814603805542
Loss:0.010933399200439453
Eval:0.2452411651611328
Opt:0.5415980815887451
Trn_Step:51045
Loader:0.04907059669494629
Image transfer:0.0871591567993164
Label transfer:0.008599519729614258
Out:0.4973301887512207
Loss:0.

Out:0.5081386566162109
Loss:0.009374618530273438
Eval:0.27668213844299316
Opt:0.5441262722015381
Trn_Step:51081
Loader:0.05025339126586914
Image transfer:0.0866849422454834
Label transfer:0.008595943450927734
Out:0.5061280727386475
Loss:0.011879205703735352
Eval:0.2859158515930176
Opt:0.5409743785858154
Trn_Step:51082
Loader:0.054053544998168945
Image transfer:0.07344865798950195
Label transfer:0.007393836975097656
Out:0.5063905715942383
Loss:0.014100074768066406
Eval:0.2941579818725586
Opt:0.5986731052398682
Trn_Step:51083
Loader:0.05028033256530762
Image transfer:0.0865945816040039
Label transfer:0.008720159530639648
Out:0.49507951736450195
Loss:0.009280681610107422
Eval:0.29289698600769043
Opt:0.5410847663879395
Trn_Step:51084
Loader:0.04863572120666504
Image transfer:0.08768296241760254
Label transfer:0.008589744567871094
Out:0.49907946586608887
Loss:0.009081125259399414
Eval:0.2816474437713623
Opt:0.5466516017913818
Trn_Step:51085
Loader:0.05571699142456055
Image transfer:0.075111

Eval:0.26893162727355957
Opt:0.5418522357940674
Trn_Step:51120
Loader:0.05066227912902832
Image transfer:0.07349824905395508
Label transfer:0.007903814315795898
Out:0.5210204124450684
Loss:0.017026424407958984
Eval:0.28292107582092285
Opt:0.5414080619812012
Trn_Step:51121
Loader:0.046087026596069336
Image transfer:0.08357024192810059
Label transfer:0.008486270904541016
Out:0.5039632320404053
Loss:0.011416196823120117
Eval:0.2657206058502197
Opt:0.5392322540283203
Trn_Step:51122
Loader:0.041265010833740234
Image transfer:0.08421063423156738
Label transfer:0.008502960205078125
Out:0.49332690238952637
Loss:0.00915837287902832
Eval:0.27028489112854004
Opt:0.5404348373413086
Trn_Step:51123
Loader:0.044999122619628906
Image transfer:0.08588886260986328
Label transfer:0.008579254150390625
Out:0.5047385692596436
Loss:0.01223301887512207
Eval:0.25926995277404785
Opt:0.541292667388916
Trn_Step:51124
Loader:0.04743027687072754
Image transfer:0.08507680892944336
Label transfer:0.008597373962402344

Out:0.5200724601745605
Loss:0.012041091918945312
Eval:0.26018643379211426
Opt:0.552391767501831
Trn_Step:51160
Loader:0.05231285095214844
Image transfer:0.07410359382629395
Label transfer:0.007311344146728516
Out:0.5094823837280273
Loss:0.011958599090576172
Eval:0.26084089279174805
Opt:0.5414650440216064
Trn_Step:51161
Loader:0.04552412033081055
Image transfer:0.08569812774658203
Label transfer:0.008357763290405273
Out:0.5056822299957275
Loss:0.009184837341308594
Eval:0.2552037239074707
Opt:0.540830135345459
Trn_Step:51162
Loader:0.04642009735107422
Image transfer:0.08706045150756836
Label transfer:0.008451700210571289
Out:0.5163924694061279
Loss:0.012386798858642578
Eval:0.2585911750793457
Opt:0.5461826324462891
Trn_Step:51163
Loader:0.051328182220458984
Image transfer:0.07185697555541992
Label transfer:0.00737309455871582
Out:0.5267982482910156
Loss:0.01243901252746582
Eval:0.26670360565185547
Opt:0.5465691089630127
Trn_Step:51164
Loader:0.050462961196899414
Image transfer:0.08664464

KeyboardInterrupt: 