In [1]:
import warnings
warnings.filterwarnings(action='ignore') 


In [2]:
import math
import time
import pandas as pd

import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim

import gc


In [3]:
from dataloader import makeDatapathList,dataTransform,VOCDataset

root_path="./data/VOCdevkit/VOC2012/"

datapath_list=makeDatapathList(root_path)
train_img_list,train_anno_list=datapath_list('train')
val_img_list,val_anno_list=datapath_list('val')

color_mean=(0.485,0.456,0.406)
color_std=(0.29,0.224,0.225)

transform=dataTransform(475,color_mean,color_std)
train_dataset=VOCDataset(train_img_list,train_anno_list,phase="train",transform=transform)
val_dataset=VOCDataset(val_img_list,val_anno_list,phase="val",transform=transform)

batch_size=4
train_dataloader=data.DataLoader(train_dataset,batch_size,True)
val_dataloader=data.DataLoader(val_dataset,batch_size,False)

dataloaders_dict={"train":train_dataloader,"val":val_dataloader}

In [4]:
from pspnet import PSPNet

net=PSPNet(n_classes=150)

state_dict=torch.load("./weights/pspnet50_ADE20K.pth") #fine tuning
net.load_state_dict(state_dict, strict=False)

n_classes=21
net.decode_feature.classification=nn.Conv2d(in_channels=512,out_channels=n_classes,kernel_size=1,stride=1,padding=0)
net.aux.classification=nn.Conv2d(in_channels=256,out_channels=n_classes,kernel_size=1,stride=1,padding=0)

def weights_init(m):
    if isinstance(m,nn.Conv2d):
        nn.init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias,0.0)

net.decode_feature.classification.apply(weights_init)
net.aux.classification.apply(weights_init)


Conv2d(256, 21, kernel_size=(1, 1), stride=(1, 1))

In [5]:
class PSPLoss(nn.Module):
    def __init__(self,aux_weight=0.4):
        super(PSPLoss,self).__init__()
        self.aux_weight=aux_weight
    
    def forward(self,outputs,targets):
        loss=F.cross_entropy(outputs[0],targets,reduction='mean')
        loss_aux=F.cross_entropy(outputs[1],targets,reduction='mean')
        return loss+self.aux_weight*loss_aux

In [6]:
criterion=PSPLoss()

In [7]:
optimizer=optim.SGD([
    {'params':net.feature_conv.parameters(),'lr':1e-3},
    {'params':net.feature_res_1.parameters(),'lr':1e-3},
    {'params':net.feature_res_2.parameters(),'lr':1e-3},
    {'params':net.feature_dilated_res_1.parameters(),'lr':1e-3},
    {'params':net.feature_dilated_res_2.parameters(),'lr':1e-3},
    {'params':net.pyramid_pooling.parameters(),'lr':1e-3},
    {'params':net.decode_feature.parameters(),'lr':1e-3},
    {'params':net.aux.parameters(),'lr':1e-3},
],momentum=0.9,weight_decay=0.0001)

def lambda_epoch(epoch):
    max_epoch=30
    return math.pow((1-epoch/max_epoch),0.9)

scheduler=optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda_epoch)

In [8]:
def train_model(net,dataloaders_dict,criterion,scheduler,optimizer,num_epochs):
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    net.to(device)
    torch.backends.cudnn.benchmark=True

    num_train_imgs=len(dataloaders_dict["train"].dataset)
    num_val_imgs=len(dataloaders_dict["val"].dataset)
    
    batch_size=dataloaders_dict["train"].batch_size

    iteration=1
    logs=[]

    batch_multiplier=3

    for epoch in range(num_epochs):
        gc.collect()
        torch.cuda.empty_cache()
        t_epoch_start=time.time()
        t_iter_start=time.time()
        epoch_train_loss=0.0
        epoch_val_loss=0.0

        print(f"{'='*5}Epoch {epoch+1}/{num_epochs}{'='*5}")
        for phase in ["train","val"]:
            if phase=="train":
                net.train()
                scheduler.step()
                optimizer.zero_grad()
                print("[train]")
            else:
                if (epoch+1)%5==0:
                    net.eval()
                    print("[val]")
                else: 
                    continue
            count=0
            for imgs,anno_class_imgs in dataloaders_dict[phase]:
                if imgs.size()[0]==1:
                    continue
                imgs=imgs.to(device)
                anno_class_imgs=anno_class_imgs.to(device)

                if phase=='train' and count==0:
                    optimizer.step()
                    optimizer.zero_grad()
                    count=batch_multiplier
                    with torch.set_grad_enabled(phase=="train"):
                        outputs=net(imgs)
                        loss=criterion(outputs,anno_class_imgs.long())/batch_multiplier
                        if phase=="train":
                            loss.backward()
                            count-=1
                            if iteration%10==0:
                                t_iter_finish=time.time()
                                print("iter [{}] loss:{:.4f}   | {:.4f}sec".format(iteration,loss.item()/batch_size*batch_multiplier,t_iter_finish-t_iter_start))
                                t_iter_start=time.time()
                            epoch_train_loss+=loss.item()*batch_multiplier
                            iteration+=1
                        else:
                            epoch_val_loss+=loss.item()*batch_multiplier
            t_epoch_finish=time.time()
            print("epoch [{}] train loss:{:.4f}, val loss:{:.4f}  | {:.4f}sec".format(iteration,epoch_train_loss/num_train_imgs,epoch_val_loss/num_val_imgs,t_epoch_finish-t_epoch_start))
            t_epoch_start=time.time()

            log_epoch={'epoch':epoch+1,'train_loss':epoch_train_loss/num_train_imgs,'val_loss':epoch_val_loss/num_val_imgs}
            logs.append(log_epoch)
            df=pd.DataFrame(logs)
            df.to_csv('log_output.csv')
        torch.save(net.state_dict(),'weights/pspnet50_'+str(epoch+1)+'.pth')

In [9]:
num_epochs=30
train_model(net,dataloaders_dict,criterion,scheduler,optimizer,num_epochs)

cuda:0
=====Epoch 1/30=====
[train]
epoch [2] train loss:0.0032, val loss:0.0000  | 29.6834sec
=====Epoch 2/30=====
[train]
epoch [3] train loss:0.0032, val loss:0.0000  | 26.8674sec
=====Epoch 3/30=====
[train]
epoch [4] train loss:0.0033, val loss:0.0000  | 26.8150sec
=====Epoch 4/30=====
[train]
epoch [5] train loss:0.0032, val loss:0.0000  | 26.7496sec
=====Epoch 5/30=====
[train]
epoch [6] train loss:0.0034, val loss:0.0000  | 26.8550sec
[val]
epoch [6] train loss:0.0034, val loss:0.0000  | 14.2387sec
=====Epoch 6/30=====
[train]
epoch [7] train loss:0.0032, val loss:0.0000  | 26.7391sec
=====Epoch 7/30=====
[train]
epoch [8] train loss:0.0033, val loss:0.0000  | 26.7467sec
=====Epoch 8/30=====
[train]
epoch [9] train loss:0.0031, val loss:0.0000  | 26.6929sec
=====Epoch 9/30=====
[train]
epoch [10] train loss:0.0034, val loss:0.0000  | 26.7400sec
=====Epoch 10/30=====
[train]
iter [10] loss:1.2985   | 0.1271sec
epoch [11] train loss:0.0035, val loss:0.0000  | 26.9709sec
[val]
epo