In [4]:
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms

import numpy as np
import glob
import os
import pickle
import sys

In [20]:
sys.path.insert(0, '/Users/narekgeghamyan/Classes/MLE_bootcamp/Capstone_Project')
from data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET

### Load data

In [7]:
os.chdir('/Users/narekgeghamyan/local_data/capstone_data')

In [8]:
filehandler = open('u2net_684_train_images.pkl', 'rb')
trn_images = pickle.load(filehandler)

In [9]:
filehandler = open('u2net_684_seg_images.pkl', 'rb')
seg_images = pickle.load(filehandler)

In [10]:
filehandler = open('u2net_684_img_name_list.pkl', 'rb')
img_name_list = pickle.load(filehandler)

In [11]:
len(trn_images)

684

### Set-up Data

In [None]:
train_num = len(trn_images)

In [25]:
batch_size_train = 12
batch_size_val = 1
train_num = 0
val_num = 0

In [None]:
salobj_dataset = SalObjDataset_modified(
    img_name_list=img_name_list,
    lbl_name_list=tra_lbl_name_list,
    transform=transforms.Compose([
        RescaleT(320),
        RandomCrop(288),
        ToTensorLab(flag=0)]))

salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)

### Define Loss Function

In [30]:
# ------- 1. define loss function --------

#bce_loss = nn.BCELoss(reduction='mean')
loss_fxn = nn.CrossEntropyLoss(reduction='mean')

def loss_fuction(d0, d1, d2, d3, d4, d5, d6, labels_v):

	loss0 = loss_fxn(d0,labels_v)
	loss1 = loss_fxn(d1,labels_v)
	loss2 = loss_fxn(d2,labels_v)
	loss3 = loss_fxn(d3,labels_v)
	loss4 = loss_fxn(d4,labels_v)
	loss5 = loss_fxn(d5,labels_v)
	loss6 = loss_fxn(d6,labels_v)

	loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
	print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item()))

	return loss0, loss

### Define model

In [24]:
# define the net
model_name = 'u2net'
if(model_name=='u2net'):
#     net = U2NET(3, 1)
    net = U2NET(in_ch=3, out_ch=4)

if torch.cuda.is_available():
    net.cuda()
    print('Cuda available..')

### Define optimizer

In [28]:
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

### Train Model

In [None]:
epoch_num = 3
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 3 # save every 2 iterations  

In [None]:
for epoch in range(0, epoch_num):
    net.train()

    for i, data in enumerate(salobj_dataloader):
        ite_num = ite_num + 1
        ite_num4val = ite_num4val + 1

        inputs, labels = data['image'], data['label']

        inputs = inputs.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)

        # wrap them in Variable
        if torch.cuda.is_available():
            inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
                                                                                        requires_grad=False)
        else:
            inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

        # y zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
        loss2, loss = loss_fuction(d0, d1, d2, d3, d4, d5, d6, labels_v)

        loss.backward()
        optimizer.step()

        # # print statistics
        running_loss += loss.data.item()
        running_tar_loss += loss2.data.item()

        # del temporary outputs and loss
        del d0, d1, d2, d3, d4, d5, d6, loss2, loss

        print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
        epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))

        if ite_num % save_frq == 0:

            torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
            running_loss = 0.0
            running_tar_loss = 0.0
            net.train()  # resume train
            ite_num4val = 0