In [1]:
import os
import random
import math
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

In [2]:
seed = 1234
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# FineTuning

In [3]:
from utils.dataloader import make_datapath_list, DataTransform, VOCDataset

# file path list の作成
rootpath = './data/VOCdevkit/VOC2012/'
train_img_list, train_anno_list, val_img_list, val_anno_list = make_datapath_list(rootpath=rootpath)

# Dataset作成
color_mean = (0.485, 0.456, 0.406)
color_std = (0.229, 0.224, 0.225)

train_dataset = VOCDataset(train_img_list, train_anno_list, phase='train',
                          transform=DataTransform(input_size=475, color_mean=color_mean, color_std=color_std))

val_dataset = VOCDataset(val_img_list, val_anno_list, phase='val', 
                        transform=DataTransform(input_size=475, color_mean=color_mean, color_std=color_std))

batch_size = 8

train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

dataloaders_dict = {'train': train_dataloader, 'val': val_dataloader}

In [4]:
from utils.pspnet import PSPNet

# ADE20Kのpretrainを読み込み
net = PSPNet(n_classes=150)

state_dict = torch.load('./weights/pspnet50_ADE20K.pth')
net.load_state_dict(state_dict)

# 分類用の畳み込み層を21chのものに付け替え
n_classes = 21
net.decode_feature.classification = nn.Conv2d(in_channels=512, out_channels=n_classes, kernel_size=1, stride=1, padding=0)
net.aux.classification = nn.Conv2d(in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

# 付け替えた畳み込み層を初期化する。活性化関数がシグモイド関数なのでXavierを使用する (Heじゃないよ)
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
            
net.decode_feature.classification.apply(weights_init)
net.aux.classification.apply(weights_init)

Conv2d(256, 21, kernel_size=(1, 1), stride=(1, 1))

In [5]:
# 損失関数の設定
class PSPLoss(nn.Module):
    def __init__(self, aux_weight=0.4):
        super(PSPLoss, self).__init__()
        self.aux_weight = aux_weight
        
    def forward(self, outputs, targets):
        '''
        Parameters
        ----------
        outputs: PSPNetの出力 (tuple)
            (output=torch.Size([num_batch, 21, 475, 475]), 
             output_aux=torch.Size([num_batch, 21, 475, 475]))
             
        targets: [num_batch, 475, 475]
        
        Returns
        -------
        loss: テンソル 損失の値
        '''
        loss = F.cross_entropy(outputs[0], targets, reduction='mean')
        loss_aux = F.cross_entropy(outputs[1], targets, reduction='mean')
        
        return loss+self.aux_weight*loss_aux

criterion = PSPLoss(aux_weight=0.4)
        

スケジューラを利用したepochごとの学習率の変更

In [6]:
optimizer = optim.SGD([
    {'params': net.feature_conv.parameters(), 'lr': 1e-3},
    {'params': net.feature_res_1.parameters(), 'lr': 1e-3},
    {'params': net.feature_res_2.parameters(), 'lr': 1e-3},
    {'params': net.feature_dilated_res_1.parameters(), 'lr': 1e-3},
    {'params': net.feature_dilated_res_2.parameters(), 'lr': 1e-3},
    {'params': net.pyramid_pooling.parameters(), 'lr': 1e-3},
    {'params': net.decode_feature.parameters(), 'lr': 1e-2},
    {'params': net.aux.parameters(), 'lr': 1e-2},
], momentum=0.9, weight_decay=0.0001)

def lambda_epoch(epoch):
    max_epoch = 30
    return math.pow((1-epoch/max_epoch), 0.9)
    
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_epoch)

In [7]:
# モデルを学習させる関数
def train_model(net, dataloaders_dict, criterion, scheduler, optimizer, num_epochs):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    net.to(device)
    
    torch.backends.cudnn.benchmark = True
    
    num_train_imgs = len(dataloaders_dict['train'].dataset)
    num_val_imgs = len(dataloaders_dict['val'].dataset)
    batch_size = dataloaders_dict['train'].batch_size
    
    iteration = 1
    logs = []
    
    batch_multiplier = 3
    
    for epoch in range(num_epochs):
        t_epoch_start = time.time()
        t_iter_start = time.time()
        epoch_train_loss = 0.0
        epoch_val_loss = 0.0
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()
                scheduler.step()
                optimizer.zero_grad()
                print(' (train) ')
                
            else:
                if ((epoch+1) % 5 == 0):
                    net.eval()
                    print('---------')
                    print(' (val) ')
                else:
                    continue
                    
            count = 0
            for imgs, anno_class_imgs in dataloaders_dict[phase]:
                if imgs.size()[0] == 1:
                    continue
                    
                imgs = imgs.to(device)
                anno_class_imgs = anno_class_imgs.to(device)
                
                # multiple minibatchでのパラメタ更新
                if (phase == 'train') and (count == 0):
                    optimizer.step()
                    optimizer.zero_grad()
                    count = batch_multiplier
                    
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(imgs)
                    loss = criterion(outputs, anno_class_imgs.long()) / batch_multiplier
                    
                    # 訓練時はbackpropagation
                    if phase == 'train':
                        loss.backward()
                        count -=1
                        
                        if (iteration % 10 == 0):
                            t_iter_finish = time.time()
                            duration = t_iter_finish - t_iter_start
                            print(f'iteration {iteration} || \
                            Loss: {loss.item()/batch_size*batch_multiplier:.4f} || \
                            10iter: {duration:.4f}sec.')
                        
                        epoch_train_loss += loss.item() * batch_multiplier
                        iteration += 1
                    
                    else:
                        epoch_val_loss += loss.item() * batch_multiplier
                        
        t_epoch_finish = time.time()
        print('------------')
        print(f'epoch {epoch+1} || epoch_train_loss: {epoch_train_loss/num_train_imgs:.4f} \
        || epoch_val_loss: {epoch_val_loss/num_val_imgs:.4f}')
        t_epoch_start = time.time()
        
        log_epoch = {'epoch': epoch+1, 'train_loss': epoch_train_loss/num_train_imgs,
                    'val_loss': epoch_val_loss/num_val_imgs}
        logs.append(log_epoch)
        df = pd.DataFrame(logs)
        df.to_csv('log_output.csv')
        
    torch.save(net.state_dict(), 'weights/pspnet50_' + str(epoch+1) + '.pth')
    

In [None]:
num_epochs = 30
train_model(net, dataloaders_dict, criterion, scheduler, optimizer, num_epochs=num_epochs)

Epoch 1/30
 (train) 
iteration 10 ||                             Loss: 0.3835 ||                             10iter: 14.5618sec.
iteration 20 ||                             Loss: 0.2189 ||                             10iter: 23.2726sec.
iteration 30 ||                             Loss: 0.1510 ||                             10iter: 31.8795sec.
iteration 40 ||                             Loss: 0.1658 ||                             10iter: 40.5575sec.
iteration 50 ||                             Loss: 0.0886 ||                             10iter: 49.2204sec.
iteration 60 ||                             Loss: 0.0729 ||                             10iter: 57.9064sec.
iteration 70 ||                             Loss: 0.1165 ||                             10iter: 66.6232sec.
iteration 80 ||                             Loss: 0.1351 ||                             10iter: 75.3311sec.
iteration 90 ||                             Loss: 0.2174 ||                             10iter: 84.0095sec.
iterati