In [1]:
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms

from sklearn.model_selection import train_test_split

import numpy as np
import glob
import os

from data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from u2net import U2NET

import time
import datetime
import gdown


In [33]:
bce_loss = nn.BCELoss(size_average=True)


def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
    loss0 = bce_loss(d0, labels_v)
    loss1 = bce_loss(d1, labels_v)
    loss2 = bce_loss(d2, labels_v)
    loss3 = bce_loss(d3, labels_v)
    loss4 = bce_loss(d4, labels_v)
    loss5 = bce_loss(d5, labels_v)
    loss6 = bce_loss(d6, labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    # print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
    #     loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(),
    #     loss5.data.item(),
    #     loss6.data.item()))

    return loss0, loss

In [24]:
model_name = f'u2net_{datetime.datetime.now().date()}'

images_path = 'textseg/image/'
masks_path = 'textseg/semantic_label/'

model_dir = 'saved_models/' + model_name



In [35]:
tra_img_name_list = glob.glob(images_path + '*')
tra_lbl_name_list = glob.glob(masks_path + '*')

salobj_dataset = SalObjDataset(
    img_name_list=tra_img_name_list,
    lbl_name_list=tra_lbl_name_list,
    transform=transforms.Compose([
        RescaleT(320),
        RandomCrop(288),
        ToTensorLab(flag=0)]))



In [36]:
net = U2NET(3, 1)
if torch.cuda.is_available():
    net.cuda()

In [38]:

ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 2000  # save the model every 2000 iterations

epoch_num = 1000
batch_size_train = 24
train_num = len(tra_img_name_list)
batch_size_val = 1
val_num = 0

salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)

optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

for epoch in range(0, epoch_num):
    net.train()

    for i, data in enumerate(salobj_dataloader):
        ite_num = ite_num + 1
        ite_num4val = ite_num4val + 1

        inputs, labels = data['image'], data['label']

        inputs = inputs.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)

        # wrap them in Variable
        if torch.cuda.is_available():
            inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
                                                                                        requires_grad=False)
        else:
            inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

        # y zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
        loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)

        loss.backward()
        optimizer.step()

        # # print statistics
        running_loss += loss.data.item()

        # del temporary outputs and loss
        del d0, d1, d2, d3, d4, d5, d6, loss2, loss

        print(
            f"epoch: {epoch + 1}/{epoch_num}, batch: {(i + 1) * batch_size_train}/{train_num}, ite: {ite_num} train loss: {running_loss / ite_num4val} ")

        if ite_num % save_frq == 0:
            torch.save(net.state_dict(), model_dir + model_name + "_bce_itr_%d_train_%3f_tar_%3f.pth" % (
                ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
            running_loss = 0.0
            running_tar_loss = 0.0
            net.train()  # resume train
            ite_num4val = 0


epoch: 1/1000, batch: 24/4024, ite: 1 train loss: 2.136038303375244 
epoch: 1/1000, batch: 48/4024, ite: 2 train loss: 2.7362329959869385 
epoch: 1/1000, batch: 72/4024, ite: 3 train loss: 2.462747891743978 
epoch: 1/1000, batch: 96/4024, ite: 4 train loss: 2.3425833880901337 
epoch: 1/1000, batch: 120/4024, ite: 5 train loss: 2.2762691736221314 
epoch: 1/1000, batch: 144/4024, ite: 6 train loss: 2.1904390255610147 
epoch: 1/1000, batch: 168/4024, ite: 7 train loss: 2.130296298435756 
epoch: 1/1000, batch: 192/4024, ite: 8 train loss: 2.1250045597553253 
epoch: 1/1000, batch: 216/4024, ite: 9 train loss: 2.1028853787316217 
epoch: 1/1000, batch: 240/4024, ite: 10 train loss: 2.0659117221832277 
epoch: 1/1000, batch: 264/4024, ite: 11 train loss: 2.0404987226833 
epoch: 1/1000, batch: 288/4024, ite: 12 train loss: 2.0281540354092917 
epoch: 1/1000, batch: 312/4024, ite: 13 train loss: 1.999640024625338 
epoch: 1/1000, batch: 336/4024, ite: 14 train loss: 1.9748997858592443 
epoch: 1/100

KeyboardInterrupt: 