In [2]:
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torchvision.transforms as standard_transforms
from skimage import io, transform
from sklearn.model_selection import train_test_split


from data_loader import TextSegDataset
import numpy as np
import glob
import os

from matplotlib import pyplot as plt
from u2net import U2NET

from PIL import Image

import time
import datetime

from PIL import ImageFile

In [11]:
images = []
masks = []

images_path = 'data/image/'
masks_path = 'data/semantic_label/'
img_size = 700

for path in glob.glob(images_path + '*'):
    # img_tensor = torch.from_numpy(io.imread(path))
    # img_resized = T.Resize((img_size, img_size))(img_tensor)
    # images.append(img_resized)
    images.append(io.imread(path))

for path in glob.glob(masks_path + '*'):
    # mask_tensor = torch.from_numpy(io.imread(path))
    # mask_resized = T.Resize((img_size, img_size))(mask_tensor)
    # masks.append(mask_resized)
    masks.append(io.imread(path))

images = np.array(images)
masks = np.array(masks)


  images = np.array(images)
  masks = np.array(masks)


In [3]:
images_path = 'data/image/'
masks_path = 'data/semantic_label/'

images_path_list=glob.glob(images_path + '*')
mask_path_list=glob.glob(masks_path + '*')

original_dataset = TextSegDataset(
    images_paths=images_path_list,
     masks_path=mask_path_list,
    transform=False
)

modified_dataset = TextSegDataset(
    images_paths=images_path_list,
    masks_path=mask_path_list,
    transform=True
)


augmented_dataset=original_dataset+modified_dataset

In [8]:

epoch_num = 1000
batch_size = 4
test_batch_size = 2
train_num = len(images_path_list)
validation_split = 0.15

train_dataset_size = int(2 * train_num * (1 - validation_split))
test_dataset_size = 2 * train_num - train_dataset_size

train_dataset, test_dataset = torch.utils.data.random_split(augmented_dataset, (train_dataset_size, test_dataset_size))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=1)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


net = U2NET(3, 1)
net.to(device)

model_name = f'u2net_{datetime.datetime.now().date()}'

checkpoint_name = 'u2net_2021-12-16_epoch_19_train_0.7557503520919565_test_0.7732233350837467.pth'
checkpoint_name = False
folder_name = 'saved_models_rf/'
if checkpoint_name:
    net.load_state_dict(torch.load(folder_name + checkpoint_name, map_location=torch.device(device)))


cuda


In [6]:
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
    loss0 =  nn.BCELoss(size_average=True)(d0, labels_v)
    loss1 =  nn.BCELoss(size_average=True)(d1, labels_v)
    loss2 =  nn.BCELoss(size_average=True)(d2, labels_v)
    loss3 =  nn.BCELoss(size_average=True)(d3, labels_v)
    loss4 =  nn.BCELoss(size_average=True)(d4, labels_v)
    loss5 =  nn.BCELoss(size_average=True)(d5, labels_v)
    loss6 =  nn.BCELoss(size_average=True)(d6, labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    # print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
    #     loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(),
    #     loss5.data.item(),
    #     loss6.data.item()))

    return loss0, loss

In [None]:
# optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
optimizer = optim.RMSprop(net.parameters(), lr=0.001, eps=1e-08, weight_decay=0)

for epoch in range(0, epoch_num):
    net.train()
    train_loss = 0
    test_loss = 0

    for i, train_data in enumerate(train_dataloader):
        start_time = time.time()

        train_inputs = train_data['image'].to(device)
        train_labels = train_data['mask'].to(device)

        # y zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        d0, d1, d2, d3, d4, d5, d6 = net(train_inputs)

        _, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, train_labels)

        loss.backward()

        optimizer.step()

        # # print statistics
        train_loss += loss.data.item()

        end_time = time.time()
        eta = (end_time - start_time) * (train_dataset_size - (i + 1) * batch_size) / batch_size
        print(
            f"epoch: {epoch + 1}/{epoch_num} eta:{int(eta)} s batch: {(i + 1)}/{int(train_dataset_size / batch_size)},"
            f" loss: {train_loss / (i + 1)} ")

    print('testing')
    with torch.no_grad():
        for j, test_data in enumerate(test_dataloader):
            test_inputs = test_data['image'].to(device)
            test_labels = test_data['label'].to(device)

            d0, d1, d2, d3, d4, d5, d6 = net(test_inputs)
            _, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, test_labels)

            # # print statistics
            test_loss += loss.data.item()

    print(
        f"epoch: {epoch + 1}/{epoch_num} loss: {train_loss * batch_size / train_dataset_size} test loss: {test_loss * batch_size / test_dataset_size} ")

    torch.save(net.state_dict(),
               folder_name + model_name + f"_epoch_{epoch}_train_{train_loss * batch_size / train_dataset_size}_test_{test_loss * batch_size / test_dataset_size}.pth")
    # train_loss = 0.0
    net.train()  # resume train
#   train_step_save = 0


epoch: 1/1000 eta:1205 s batch: 1/1710, loss: 3.8795864582061768 
epoch: 1/1000 eta:1187 s batch: 2/1710, loss: 19.837490677833557 
epoch: 1/1000 eta:1193 s batch: 3/1710, loss: 16.24286373456319 
epoch: 1/1000 eta:1194 s batch: 4/1710, loss: 13.126161396503448 
epoch: 1/1000 eta:1185 s batch: 5/1710, loss: 11.456288766860961 
epoch: 1/1000 eta:1191 s batch: 6/1710, loss: 10.242740750312805 
epoch: 1/1000 eta:1190 s batch: 7/1710, loss: 9.026218857084002 
epoch: 1/1000 eta:1187 s batch: 8/1710, loss: 8.675919145345688 
epoch: 1/1000 eta:1191 s batch: 9/1710, loss: 8.201924827363756 
epoch: 1/1000 eta:1188 s batch: 10/1710, loss: 7.615155959129334 
epoch: 1/1000 eta:1189 s batch: 11/1710, loss: 7.440383065830577 
epoch: 1/1000 eta:1180 s batch: 12/1710, loss: 6.937845607598622 
epoch: 1/1000 eta:1190 s batch: 13/1710, loss: 6.468071213135352 
epoch: 1/1000 eta:1206 s batch: 14/1710, loss: 6.033479814018522 
epoch: 1/1000 eta:1192 s batch: 15/1710, loss: 5.721175642808278 
epoch: 1/1000 