# Imports

In [None]:
import os
import sys

import numpy as np
import random
import time

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

import PIL
from PIL import Image
from IPython import display

import cv2
from google.colab.patches import cv2_imshow

import torch
import torchvision.transforms as transforms

from copy import deepcopy

from portraitsegmenter import PortraitSegmenter
from datasets_portraitseg import PortraitSegDatasetAug
from segment_trainer import SegmentTrainer

# Set-up

In [None]:
x_train = np.load("data/img_uint8.npy")
y_train = np.load("data/msk_uint8.npy")
x_test = np.load("data/test_xtrain.npy")
y_test = np.load("data/test_ytrain.npy")

In [None]:
datavals = PortraitSegDatasetAug(x_train,
                                 y_train,
                                 angle_range=30,
                                 zoom=0.5,
                                 noise_scale=10.0)
valvals = PortraitSegDatasetAug(x_test, y_test, aug=False)
port_seg = PortraitSegmenter(down_depth=[1, 2, 2, 2],
                             up_depth=[1, 1, 1],
                             filters=[16, 24, 32, 48])
trainer = SegmentTrainer(port_seg)
iiii = 0

In [None]:
iiii += 1
x, y, z, w = valvals[iiii]
cv2_imshow(np.moveaxis(((x + 1) * 127.5), 0, -1)[:, :, ::-1])
cv2_imshow(np.expand_dims(z * 255., axis=2))
cv2_imshow(np.expand_dims(w * 255., axis=2))
with torch.no_grad():
    a1, a2 = port_seg(torch.tensor(x).unsqueeze(0).to(torch.device("cuda")))
    print(a1.shape)
    print(a2.shape)
thresh = 1.64872
print(
    trainer.calcIOU(torch.tensor(w), torch.tensor(a1.to(torch.device("cpu")))))
a1[a1 < thresh] = 0
a1[a1 >= thresh] = 1
a2[a2 < thresh] = 0
a2[a2 >= thresh] = 1
cv2_imshow(a2.detach().squeeze().to(torch.device("cpu")).numpy() * 255.)
cv2_imshow(a1.detach().squeeze().to(torch.device("cpu")).numpy() * 255.)

# Train

In [None]:
history = trainer.train(datavals,
                        valvals,
                        batch_size=128,
                        epochs=50,
                        lr=0.001,
                        es_patience=30,
                        mask_weight=10,
                        mask_loss='CE',
                        edge_loss=None)
trainer.segmenter.load_state_dict(torch.load("best.pth"))
torch.save(trainer.segmenter.state_dict(), "portraitCE.pth")