In [1]:
import os
import torch
import pandas as pd
import numpy as np
import ipdb
from timeit import default_timer as timer
from torch.utils.data import DataLoader
from data import WhaleData, pic_show
from model import Net
from loss import FocalLossQb, bce_loss
from eval import do_valid, do_valid_arcFace, do_valid_arcFace_preload
from lr import CosineAnnealingLR_with_Restart

def time_to_str(t, mode='min'):
    from timeit import default_timer as timer
    if mode=='min':
        t  = int(t)/60
        hr = t//60
        min = t%60
        return '%2d hr %02d min'%(hr,min)
    elif mode=='sec':
        t   = int(t)
        min = t//60
        sec = t%60
        return '%2d min %02d sec'%(min,sec)
    else:
        raise NotImplementedError

In [2]:
out_dir = os.path.join('./models/', 'resnet101')
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
if not os.path.exists(os.path.join(out_dir,'checkpoint')):
    os.makedirs(os.path.join(out_dir,'checkpoint'))
if not os.path.exists(os.path.join(out_dir,'train')):
    os.makedirs(os.path.join(out_dir,'train'))
        
train_dataset = WhaleData(mode='train', augment=True)
valid_0 = WhaleData(mode='valid', fold_id=0)
valid_1 = WhaleData(mode='valid', fold_id=1)
valid_2 = WhaleData(mode='valid', fold_id=2)
valid_3 = WhaleData(mode='valid', fold_id=3)
valid_4 = WhaleData(mode='valid', fold_id=4)
valid_5 = WhaleData(mode='valid', fold_id=5)

In [3]:
batch_size=32*12
valid_loader0  = DataLoader(valid_0, batch_size=batch_size, drop_last=False, num_workers=6, pin_memory=True)
valid_loader1  = DataLoader(valid_1, batch_size=batch_size, drop_last=False, num_workers=6, pin_memory=True)
valid_loader2  = DataLoader(valid_2, batch_size=batch_size, drop_last=False, num_workers=6, pin_memory=True)
valid_loader3  = DataLoader(valid_3, batch_size=batch_size, drop_last=False, num_workers=6, pin_memory=True)
valid_loader4  = DataLoader(valid_4, batch_size=batch_size, drop_last=False, num_workers=6, pin_memory=True)
valid_loader5  = DataLoader(valid_4, batch_size=batch_size, drop_last=False, num_workers=6, pin_memory=True)

In [4]:
use_cuda = True
lr = 0.001
device = 'cuda'

net = Net(model_name='50', num_class=5005, arcFace=True, device=device)
for p in net.basemodel.layer0.parameters(): 
    p.requires_grad = False
    
for p in net.basemodel.layer1.parameters(): 
    p.requires_grad = False
    
for p in net.basemodel.layer2.parameters(): 
    p.requires_grad = False
    
net = torch.nn.DataParallel(net)
net = net.to(device)
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, 
                            net.parameters()), lr, 
                            weight_decay=0.0002, momentum=0.9)

scheduler = CosineAnnealingLR_with_Restart(optimizer, 
                                           T_max=4, 
                                           T_mult=1, 
                                           model=net, 
                                           out_dir='./', 
                                           take_snapshot=False)

In [5]:
log_file = './log_train4.txt'

In [6]:
32*12

384

In [None]:
EPOCH=100
i=0
        
with open(log_file, 'w+') as log:
    start = timer()
    for epoch in range(EPOCH):
        tmp_lr = scheduler.get_lr()
        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, drop_last=False, num_workers=6)
        for input, truth_ in train_loader:
            logit, loss = net(input.to(device), truth_.to(device))
#             batch_loss1 = FocalLossQb(gamma=2)(logit, truth_)
#             batch_loss2 = bce_loss(logit, truth_)
#             loss = batch_loss1 + batch_loss2
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()                
            i+=1
            
        # epoch eval
        scheduler.step()
        net.eval()
        train_loss1, train_loss2, train_loss, train_acc, label_5, pred_5 = do_valid_arcFace(net, train_loader, device=device) 
        valid_loss1, valid_loss2, valid_loss, valid_acc, label_5_val, pred_5_val = do_valid_arcFace(net, valid_loader0, device=device)
        net.train()
        
        print('---------------------------EPOCH:{}  LR:{}-----------------------------------------'.format(epoch, tmp_lr))
        print('train_loss1:{:.4f} || train_loss2:{:.4f} || train_loss:{:.4f} || train_acc:{:.4f} ||  use:{} ||'.format(train_loss1, train_loss2, train_loss, train_acc, time_to_str((timer() - start), 'min')))
        print('valid_loss1:{:.4f} || valid_loss2:{:.4f} || valid_loss:{:.4f} || valid_acc:{:.4f} ||'.format(valid_loss1, valid_loss2, valid_loss, valid_acc))
        print('five_sample_label:\n{}\nfive_sample_predict:\n{}\n'.format(label_5, pred_5))

        log.write('---------------------------EPOCH:{}  LR:{}-----------------------------------------'.format(epoch, tmp_lr))
        log.write('train_loss1:{:.4f} || train_loss2:{:.4f} || train_loss:{:.4f} || train_acc:{:.4f} ||  use:{} ||'.format(train_loss1, train_loss2, train_loss, train_acc, time_to_str((timer() - start), 'min')))
        log.write('valid_loss1:{:.4f} || valid_loss2:{:.4f} || valid_loss:{:.4f} || valid_acc:{:.4f} ||'.format(valid_loss1, valid_loss2, valid_loss, valid_acc))
        log.write('five_sample_label:\n{}\nfive_sample_predict:\n{}\n'.format(label_5, pred_5))
        log.write('five_sample_label:\n{}\nfive_sample_predict:\n{}\n'.format(label_5_val, pred_5_val))


---------------------------EPOCH:0  LR:[0.0008535533905932737]-----------------------------------------
train_loss1:0.0000 || train_loss2:0.0000 || train_loss:28.1108 || train_acc:0.0000 ||  use: 0 hr 09 min ||
valid_loss1:0.0000 || valid_loss2:0.0000 || valid_loss:28.0395 || valid_acc:0.0000 ||
five_sample_label:
[3703, 3874, 5004, 3803, 2970]
five_sample_predict:
[[5004, 1107, 4737, 456, 427], [5004, 4820, 4649, 1338, 2643], [456, 1000, 4195, 3242, 0], [5004, 2797, 2348, 3794, 4358], [5004, 792, 22, 2978, 912]]

---------------------------EPOCH:1  LR:[0.0005]-----------------------------------------
train_loss1:0.0000 || train_loss2:0.0000 || train_loss:28.0744 || train_acc:0.0026 ||  use: 0 hr 20 min ||
valid_loss1:0.0000 || valid_loss2:0.0000 || valid_loss:27.4171 || valid_acc:0.0028 ||
five_sample_label:
[3762, 1266, 540, 5004, 1747]
five_sample_predict:
[[5004, 1992, 4405, 824, 2772], [5004, 427, 1780, 62, 2493], [5004, 4019, 824, 4405, 2000], [427, 3604, 2671, 1115, 375], [5004,