In [1]:
from datetime import datetime
datetime.now().strftime("%Y-%m-%d_%H_%M_%S")

'2021-08-27_03_49_17'

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [3]:
from module.GlobalSeed import seed_everything
from module.DataLoader import MaskDataset#ImageDatasetPreTransform
from model.ResNet import ResNet20
from module.F1_score import F1_Loss


In [4]:
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter

In [5]:
# HyperParameter
SEED          = 777
DEVICE        = torch.device("cuda:0")
BATCH_SIZE    = 128
LEARNING_RATE = 0.0001
TOTAL_EPOCH   = 100
exp_num       = datetime.now().strftime("%Y-%m-%d_%H_%M_%S")

In [6]:
# set random seed
seed_everything(SEED)

In [7]:
dataset = MaskDataset(
    target         = "gender",
    realign        = True,
    csv_path       = '../../input/data/train/train.csv',
    images_path    = '../../input/data/train/images/',
    pre_transforms = transforms.Compose([
        lambda img : transforms.functional.crop(img, 80, 50, 320, 256),
        transforms.Resize((64,64)),
    ]),
    transforms = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor()
    ])
)

size = [int(4500 * 0.6 * 7 * 0.8),int(4500 * 0.6 * 7 * 0.2)]
train_set, val_set = torch.utils.data.random_split(dataset, size)

100%|██████████| 2700/2700 [01:08<00:00, 39.63it/s]


In [8]:
train_dataloader = DataLoader(
    train_set,
    batch_size  = BATCH_SIZE,
    shuffle     = True,
    sampler     = None,
    num_workers = 4,
    drop_last   = True
)

val_dataloader = DataLoader(
    val_set,
    batch_size  = BATCH_SIZE,
    shuffle     = None,
    sampler     = None,
    num_workers = 4,
    drop_last   = True
)

In [9]:
single_batch_X, single_batch_y = next(iter(train_dataloader))
print(single_batch_X.shape)
print(single_batch_y.shape)

torch.Size([128, 3, 64, 64])
torch.Size([128])


In [10]:
resnet20 = ResNet20(single_batch_X.shape, DEVICE)

In [11]:
#summary(resnet20, single_batch_X.shape[1:])

In [12]:
# Initialization

def weight_initialization(module):
    module_name = module.__class__.__name__
    try:
        if isinstance(module,nn.Conv2d): # init conv
            nn.init.kaiming_normal_(module.weight)
            nn.init.zeros_(module.bias)
        if isinstance(module,nn.BatchNorm2d): # init BN
            nn.init.constant_(module.weight,1)
            nn.init.constant_(module.bias,0)
        if isinstance(module,nn.Linear): # lnit dense
            nn.init.kaiming_normal_(module.weight)
            nn.init.zeros_(module.bias)
    except:
        print('has no attribute to update')


In [13]:
resnet20 = resnet20.apply(weight_initialization)

In [14]:
loss = torch.nn.CrossEntropyLoss()
f1_loss = F1_Loss(num_classes = 18).cuda()
opt = torch.optim.Adam(resnet20.parameters(), lr = LEARNING_RATE)

In [15]:
def func_eval(model,data_iter,device):
    with torch.no_grad():
        n_total,n_correct = 0,0
        model.eval() # evaluate (affects DropOut and BN)
        for batch_in,batch_out in data_iter:
            y_trgt = batch_out.to(device)
            model_pred = model(batch_in.to(device))
            _,y_pred = torch.max(model_pred.data,1)
            n_correct += (y_pred==y_trgt).sum().item()
            n_total += batch_in.size(0)
        val_accr = (n_correct/n_total)
        #model.train() # back to train mode 
    return val_accr
print ("Done")

Done


In [16]:
tr_writer = SummaryWriter('logs/exp_%s/tr'%exp_num)
val_writer = SummaryWriter('logs/exp_%s/val'%exp_num)

global_step = 0

for ep in range(TOTAL_EPOCH):
    #= Training phase =========
    tr_mean_loss, tr_mean_f1 = 0, 0
    for X, y in iter(train_dataloader):
        
        resnet20.train()
        predict = resnet20(X.to(DEVICE))
        loss_val = loss(predict, y.to(DEVICE))
        tr_mean_loss += loss_val
        f1_val = f1_loss(predict, y.to(DEVICE))
        tr_mean_f1 += f1_val
        
        p = 0.01
        loss_val = (1-p)*loss_val + p*f1_val
        
        opt.zero_grad()
        loss_val.backward()
        opt.step()
        
    #= Validation phase =============
    val_mean_loss, val_mean_f1 = 0, 0
    with torch.no_grad():
        for X, y in iter(val_dataloader):
            resnet20.eval()
            predict = resnet20(X.to(DEVICE))
            loss_val = loss(predict, y.to(DEVICE))
            val_mean_loss += loss_val
            val_mean_f1 += f1_loss(predict, y.to(DEVICE))
        
    #= Training writer =========
    tr_writer.add_scalar(
        'loss',
        tr_mean_loss / len(train_dataloader),
        ep
    )
    tr_writer.add_scalar(
        'score/acc',
        func_eval(resnet20,train_dataloader,DEVICE),
        ep
    )
    tr_writer.add_scalar(
        'score/F1',
        tr_mean_f1 / len(train_dataloader),
        ep
    )
    
    #= Validation writer =========
    val_writer.add_scalar(
        'loss',
        val_mean_loss / len(val_dataloader),
        ep
    )
    val_writer.add_scalar(
        'score/acc',
        func_eval(resnet20,val_dataloader,DEVICE),
        ep
    )
    val_writer.add_scalar(
        'score/F1',
        val_mean_f1 / len(train_dataloader),
        ep
    )
    
        
    print(ep)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


KeyboardInterrupt: 

In [17]:
torch.save(resnet20,'./saved_model/model_%s_ep_%d.pt'%(exp_num,ep))
torch.save(resnet20.state_dict(),'./saved_model/model_weights_%s_ep_%d.pt'%(exp_num,ep))