In [1]:
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms

import numpy as np
import glob

from data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import CenterCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import BASNet

import pytorch_ssim
import pytorch_iou

In [2]:
# ------- 1. define loss function --------

bce_loss = nn.BCELoss(size_average=True)
ssim_loss = pytorch_ssim.SSIM(window_size=11,size_average=True)
iou_loss = pytorch_iou.IOU(size_average=True)


reduction='mean'

def bce_ssim_loss(pred,target):

	bce_out = bce_loss(pred,target)
	ssim_out = 1 - ssim_loss(pred,target)
	iou_out = iou_loss(pred,target)

	loss = bce_out + ssim_out + iou_out

	return loss

def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7, labels_v):

	loss0 = bce_ssim_loss(d0,labels_v)
	loss1 = bce_ssim_loss(d1,labels_v)
	loss2 = bce_ssim_loss(d2,labels_v)
	loss3 = bce_ssim_loss(d3,labels_v)
	loss4 = bce_ssim_loss(d4,labels_v)
	loss5 = bce_ssim_loss(d5,labels_v)
	loss6 = bce_ssim_loss(d6,labels_v)
	loss7 = bce_ssim_loss(d7,labels_v)
	#ssim0 = 1 - ssim_loss(d0,labels_v)

	# iou0 = iou_loss(d0,labels_v)
	#loss = torch.pow(torch.mean(torch.abs(labels_v-d0)),2)*(5.0*loss0 + loss1 + loss2 + loss3 + loss4 + loss5) #+ 5.0*lossa
	loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7#+ 5.0*lossa
	#print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item(),loss5.item(),loss6.item()))	# print("BCE: l1:%3f, l2:%3f, l3:%3f, l4:%3f, l5:%3f, la:%3f, all:%3f\n"%(loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],lossa.data[0],loss.data[0]))

	return loss0, loss



In [3]:
# ------- 2. set the directory of training dataset --------

data_dir = './train_data/'
#tra_image_dir = 'DUTS/DUTS-TR/DUTS-TR/im_aug/'
#tra_image_dir = 'DUTS/image/'
tra_image_dir = 'NEW_DATA/image/'

#tra_label_dir = 'DUTS/DUTS-TR/DUTS-TR/gt_aug/'
#tra_label_dir = 'DUTS/mask/'
tra_label_dir = 'NEW_DATA/mask/'

image_ext = '.jpg'
label_ext = '.png'

model_dir = "./saved_models/"

PATH = "./saved_models/basnet.pth"

epoch_num = 200
batch_size_train = 8 
batch_size_val = 1
train_num = 0
val_num = 0

tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)


In [4]:
! ls ./saved_models/

basnet.pth  optimized_model.pth


In [5]:
tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)

tra_lbl_name_list = []
for img_path in tra_img_name_list:
	img_name = img_path.split("/")[-1]

	aaa = img_name.split(".")
	bbb = aaa[0:-1]
	imidx = bbb[0]
	for i in range(1,len(bbb)):
		imidx = imidx + "." + bbb[i]

	tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)

print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")

---
train images:  3310
train labels:  3310
---


In [6]:
valid_size = round(len(tra_img_name_list) * 0.2)
train_size = len(tra_img_name_list) - valid_size

In [7]:
[train_size, train_size]

[2648, 2648]

In [8]:
train_num = len(tra_img_name_list)

salobj_dataset = SalObjDataset(
    img_name_list=tra_img_name_list,
    lbl_name_list=tra_lbl_name_list,
    transform=transforms.Compose([
        RescaleT(256),
        RandomCrop(224),
        ToTensorLab(flag=0)]))

train_data, val_data = torch.utils.data.random_split(salobj_dataset, [train_size, valid_size])

train_loader = DataLoader(train_data, batch_size=batch_size_train, shuffle=True, num_workers=0)
valid_loader = DataLoader(val_data, batch_size=batch_size_train, shuffle=True, num_workers=0)


In [9]:
# ------- 3. define model --------
# define the net
net = BASNet(3, 1)
net.load_state_dict(torch.load(PATH))

if torch.cuda.is_available():
    net.cuda()

In [10]:
# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

# ------- 5. training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0


ite_num_valid = 0
running_loss_valid = 0.0

running_tar_loss_valid = 0.0
ite_num4val_valid = 0

valid_loss_min = np.Inf


---define optimizer...
---start training...


In [None]:
for epoch in range(0, epoch_num):

    net.train()
    for i, data in enumerate(train_loader):
        ite_num = ite_num + 1
        ite_num4val = ite_num4val + 1

        inputs, labels = data['image'], data['label']

        inputs = inputs.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)

        # wrap them in Variable
        if torch.cuda.is_available():
            inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
                                                                                        requires_grad=False)
        else:
            inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

        # y zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        d0, d1, d2, d3, d4, d5, d6, d7 = net(inputs_v)
        loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7, labels_v)

        loss.backward()
        optimizer.step()

        # # print statistics
        running_loss += loss.item()
        running_tar_loss += loss2.item()

        # del temporary outputs and loss
        del d0, d1, d2, d3, d4, d5, d6, d7, loss2, loss

        print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f ]" % (
        epoch + 1, epoch_num, (i + 1) * batch_size_train, train_size, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
        
        
        net.eval()
        for i, data in enumerate(valid_loader):
            
            ite_num_valid = ite_num + 1
            ite_num4val_valid = ite_num4val + 1

            inputs, labels = data['image'], data['label']
        
            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)

            # wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
                                                                                            requires_grad=False)
            else:
                inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

            # forward + backward + optimize
            d0, d1, d2, d3, d4, d5, d6, d7 = net(inputs_v)
            loss2_val, loss_val = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7, labels_v)
            
            
            # # print statistics
            running_loss_valid += loss_val.item()
            running_tar_loss_valid += loss2_val.item()

        
        # del temporary outputs and loss
        del d0, d1, d2, d3, d4, d5, d6, d7, loss2_val, loss_val
        
        valid_loss = running_tar_loss_valid / len(valid_loader.dataset)
        

        
        print("[Valid loss: %3f, tar: %3f ]" % (running_loss_valid / len(valid_loader.dataset),
                                               valid_loss))

        print('*************')
        
        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            valid_loss_min,
            valid_loss))
            torch.save(net.state_dict(), model_dir + 'optimized_model.pth')
            valid_loss_min = valid_loss
            
            running_loss = 0.0
            running_tar_loss = 0.0
            net.train()  # resume train
            ite_num4val = 0
            print('##########')
            
            

        
        #if ite_num % 2000 == 0:  # save model every 2000 iterations

         #   torch.save(net.state_dict(), model_dir + "basnet_bsi_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
         #   running_loss = 0.0
         #   running_tar_loss = 0.0
         #   net.train()  # resume train
         #   ite_num4val = 0
        
        
        #if ite_num % 2000 == 0:  # save model every 2000 iterations

            #torch.save(net.state_dict(), model_dir + "basnet_bsi_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
            #running_loss = 0.0
            #running_tar_loss = 0.0
           
            #ite_num4val = 0

print('-------------Congratulations! Training Done!!!-------------')


  "See the documentation of nn.Upsample for details.".format(mode))


[epoch:   1/200, batch:     8/ 2648, ite: 1] train loss: 8.998521, tar: 0.943130 ]
[Valid loss: 1.222583, tar: 0.111891 ]
*************
Validation loss decreased (inf --> 0.111891).  Saving model ...
##########
[epoch:   1/200, batch:    16/ 2648, ite: 2] train loss: 7.917761, tar: 0.845558 ]
[Valid loss: 2.381782, tar: 0.219096 ]
*************
[epoch:   1/200, batch:    24/ 2648, ite: 3] train loss: 11.055979, tar: 1.100870 ]
[Valid loss: 3.459262, tar: 0.315325 ]
*************
[epoch:   1/200, batch:    32/ 2648, ite: 4] train loss: 10.583622, tar: 1.054717 ]
[Valid loss: 4.449437, tar: 0.400819 ]
*************
[epoch:   1/200, batch:    40/ 2648, ite: 5] train loss: 9.086501, tar: 0.891874 ]
[Valid loss: 5.355346, tar: 0.480542 ]
*************
[epoch:   1/200, batch:    48/ 2648, ite: 6] train loss: 8.429366, tar: 0.815002 ]
[Valid loss: 6.215232, tar: 0.556442 ]
*************
[epoch:   1/200, batch:    56/ 2648, ite: 7] train loss: 8.349116, tar: 0.793718 ]
[Valid loss: 7.069030, t

In [None]:
64 * 27