In [None]:
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torchvision.transforms as standard_transforms
from skimage import io, transform
from sklearn.model_selection import train_test_split

from data_loader import TextSegDataset
import numpy as np
import glob
import os

from matplotlib import pyplot as plt
from u2net import U2NET

from PIL import Image

import time
import datetime

from PIL import ImageFile

In [None]:
images = []
masks = []

images_path = 'data/image/'
masks_path = 'data/semantic_label/'
img_size = 500

for path in glob.glob(images_path + '*'):
    # img_tensor = torch.from_numpy(io.imread(path))
    # img_resized = T.Resize((img_size, img_size))(img_tensor)
    # images.append(img_resized)
    images.append(io.imread(path))

for path in glob.glob(masks_path + '*'):
    # mask_tensor = torch.from_numpy(io.imread(path))
    # mask_resized = T.Resize((img_size, img_size))(mask_tensor)
    # masks.append(mask_resized)
    masks.append(io.imread(path))

images = np.array(images)
masks = np.array(masks)


In [None]:
images_path = 'data/image/'
masks_path = 'data/semantic_label/'

images_path_list = glob.glob(images_path + '*')
mask_path_list = glob.glob(masks_path + '*')

original_dataset = TextSegDataset(
    images_paths=images_path_list,
    masks_path=mask_path_list,
    transform=False
)

modified_dataset = TextSegDataset(
    images_paths=images_path_list,
    masks_path=mask_path_list,
    transform=True
)

augmented_dataset = original_dataset + modified_dataset

In [None]:

epoch_num = 1000
batch_size = 10
test_batch_size = 8
train_num = len(images_path_list)
validation_split = 0.15

train_dataset_size = int(2 * train_num * (1 - validation_split))
test_dataset_size = 2 * train_num - train_dataset_size

train_dataset, test_dataset = torch.utils.data.random_split(augmented_dataset, (train_dataset_size, test_dataset_size),
                                                            generator=torch.Generator().manual_seed(721))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=1)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

net = U2NET(3, 1)
net.to(device)

model_name = f'u2net_{datetime.datetime.now().date()}'

checkpoint_name = 'u2net_2021-12-23_epoch_0_train_0.5615894327386778_test_0.4500891561835807.pth'
# checkpoint_name = False
folder_name = 'saved_models_rg/'
if checkpoint_name:
    net.load_state_dict(torch.load(folder_name + checkpoint_name, map_location=torch.device(device)))


In [None]:
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
    loss0 = nn.BCELoss(size_average=True)(d0, labels_v)
    loss1 = nn.BCELoss(size_average=True)(d1, labels_v)
    loss2 = nn.BCELoss(size_average=True)(d2, labels_v)
    loss3 = nn.BCELoss(size_average=True)(d3, labels_v)
    loss4 = nn.BCELoss(size_average=True)(d4, labels_v)
    loss5 = nn.BCELoss(size_average=True)(d5, labels_v)
    loss6 = nn.BCELoss(size_average=True)(d6, labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6

    return loss0, loss

In [None]:
#optimizer = optim.Adagrad(net.parameters(), lr=0.00001, eps=1e-08, weight_decay=1e-4)
optimizer = optim.Adam(net.parameters(), lr=0.0001, eps=1e-08, weight_decay=1e-6,amsgrad=True)
#optimizer = optim.RMSprop(net.parameters(), lr=0.00001, eps=1e-08, weight_decay=1e-6, centered=True)

for epoch in range(0, epoch_num):
    net.train()
    train_loss = 0
    test_loss = 0

    for i, train_data in enumerate(train_dataloader):
        start_time = time.time()

        train_inputs = train_data['image'].to(device)
        train_labels = train_data['mask'].to(device)

        # y zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        d0, d1, d2, d3, d4, d5, d6 = net(train_inputs)

        _, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, train_labels)

        loss.backward()

        optimizer.step()

        # # print statistics
        train_loss += loss.data.item()

        end_time = time.time()
        eta = (end_time - start_time) * (train_dataset_size - (i + 1) * batch_size) / batch_size
        print(
            f"epoch: {epoch + 1}/{epoch_num} eta:{int(eta)} s batch: {(i + 1)}/{int(train_dataset_size / batch_size)},"
            f" loss: {train_loss / (i + 1)} ")

    print('testing')
    with torch.no_grad():
        for j, test_data in enumerate(test_dataloader):
            test_inputs = test_data['image'].to(device)
            test_labels = test_data['mask'].to(device)

            d0, d1, d2, d3, d4, d5, d6 = net(test_inputs)
            _, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, test_labels)

            # # print statistics
            test_loss += loss.data.item()

    print(
        f"epoch: {epoch + 1}/{epoch_num} loss: {train_loss * batch_size / train_dataset_size} test loss: {test_loss * batch_size / test_dataset_size} ")

    torch.save(net.state_dict(),
               folder_name + model_name + f"_epoch_{epoch}_train_{train_loss * batch_size / train_dataset_size}_test_{test_loss * batch_size / test_dataset_size}.pth")
    # train_loss = 0.0
    net.train()  # resume train
#   train_step_save = 0

In [None]:
def save_output(img, mask):
    mask = mask.squeeze()

    image=Image.fromarray(img)
    # Конвертируем диапазон (0.0, 1.0) -> (0, 255)
    mask = mask.transpose(0,1).cpu().data.numpy() * 255
    # Перегоняем в pillow
    mask = Image.fromarray(mask).convert("L")
    # Увеличиваем до размера исходного изображения
    mask = mask.resize(image.size, resample=Image.BILINEAR)
    # Оригинал перегоняем в RGBA
    image.convert('RGBA')
    # Заполняем альфа канал
    image.putalpha(mask)
    # Сохраняем
    return mask,image

In [None]:
creatives_folder = "creatives/"
image_size=500

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = U2NET(3, 1)
net.to(device)

checkpoint_name = 'u2net_2021-12-23_epoch_0_train_0.5615894327386778_test_0.4500891561835807.pth'
# checkpoint_name = False
folder_name = 'saved_models_rg/'

net.load_state_dict(torch.load(folder_name + checkpoint_name, map_location=torch.device(device)))

net.eval()

for path in glob.glob(creatives_folder+'*'):
    img=io.imread(path)
    image_tensor = torch.from_numpy(img.astype(np.float32) / 255).transpose( 0, 2).to(device)
    image_tensor = T.Resize((image_size, image_size))(image_tensor).to(device)

    with torch.no_grad():
        d1, d2, d3, d4, d5, d6, d7 = net(image_tensor[None,:,:,:])

    # Забираем из d1 маску
    # В остальных d# тоже маски но хуже качеством
    mask = d1[:, 0, :, :]

    # сохраняем результат
    mask,image=save_output(img, mask)

    save_name=path[len(creatives_folder)+2:-3] + "result.png"

    fig,(ax1,ax2)=plt.subplots(1,2,figsize=(15,15))

    ax1.imshow(image)
    ax1.set_title('image')

    ax2.imshow(mask,cmap='gray')
    ax2.set_title('mask')

    plt.show()

    del d1, d2, d3, d4, d5, d6, d7

In [None]:
glob.glob('creatives/'+'*')