In [None]:
# System tools
import scipy.misc
import random
import time
import sys
import os
import numpy as np

# Pytorch
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import segmentation_models_pytorch as smp

# Image I/O
import cv2
import PIL
from  matplotlib import pyplot as plt

# Analysis
from sklearn.metrics import confusion_matrix
import pandas as pd

%matplotlib inline

In [None]:
def check_and_create_folder(directory):
    try:
        os.stat(directory)
        print ('folder: ', directory, 'exits, do you want to remove it')
    except:
        os.mkdir(directory)
        print ('create ', directory)

# UNet train

### Hyper parameters

In [None]:
INPUT_IMG_SIZE = (480, 640)             # HEIGHT, WIDTH
BATCH_SIZE   = 5
NUM_EPOCHS   = 1000
NUM_WROKERS  = 4
LR           = 1e-3
MOMENTUM     = 0
WEIGHT_DECAY = 1e-5
STEP_SIZE    = 50
GAMMA        = 0.5
DATASET_ROOT = "./shoes_dataset_folder"
MODELS_ROOT  = "./models"
CLASSES = ["background", "right_shoes", "left_shoes"]    # classes with 'background' element

MASKS_DIR    = os.path.join(DATASET_ROOT, "masks")
LABELS_DIR   = os.path.join(DATASET_ROOT, "labels")
IMAGES_DIR   = os.path.join(DATASET_ROOT, "images")
if not os.path.isdir(MODELS_ROOT):
    os.mkdir(MODELS_ROOT)
NUM_CLASSES = len(CLASSES)

In [None]:
data_list = pd.read_csv(os.path.join(DATASET_ROOT, "train.csv"))
data_list

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataset_csv_file, phase):
        
        self.data_list = pd.read_csv(dataset_csv_file)
        
        print("********** Dataset Info start **********\n")
        print("Source: " + dataset_csv_file)
        print("Classes: {}".format(CLASSES))
        print("Amount of data: {}".format(len(self.data_list)))
        print("\n*********** Dataset Info end ***********\n")
        
        self.data_transform = transforms.Compose([ 
                                transforms.Resize(INPUT_IMG_SIZE), \
                                transforms.ToTensor(), \
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225])
                                ])
        
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        image_path   = self.data_list.iloc[index, 0]
        mask_path    = self.data_list.iloc[index, 1]
        
        # Read image
        image_raw = self.default_loader(os.path.join(DATASET_ROOT, image_path))
        input_image = self.data_transform(image_raw)
        # Read mask
        # mask_raw = PIL_Image.open(mask_path).convert('LA')
        mask_raw = cv2.imread(os.path.join(DATASET_ROOT, mask_path), cv2.IMREAD_GRAYSCALE)
        mask_raw = cv2.resize(mask_raw, (INPUT_IMG_SIZE[1], INPUT_IMG_SIZE[0]))
        mask_each_classes = torch.zeros(NUM_CLASSES, INPUT_IMG_SIZE[0], INPUT_IMG_SIZE[1])
        for i in range(NUM_CLASSES):
            mask_each_classes[i][mask_raw == i] = 1
        # batch = {'input': input_image, 'target': mask_each_classes, 'mask_raw':mask_raw, 'image_raw': image_raw}
        batch = {'input': input_image, 'target': mask_each_classes}
        return batch
    
    def pil_loader(self, path):
        with open(path, "rb") as f:
            with PIL.Image.open(f) as img:
                return img.convert("RGB")

    def accimage_loader(self, path):
        try:
            return accimage.Image(path)
        except IOError:
            # Potentially a decoding problem, fall back to PIL.Image
            return pil_loader(path)

    def default_loader(self, path):
        if torchvision.get_image_backend() == "accimage":
            return self.accimage_loader(path)
        else:
            return self.pil_loader(path)

In [None]:
# Training data loader
train_csv_path   = os.path.join(DATASET_ROOT, "train.csv")
train_dataset    = CustomDataset(dataset_csv_file=train_csv_path, phase="train")
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WROKERS)

# Testing data loader
test_csv_path   = os.path.join(DATASET_ROOT, "test.csv")
test_dataset    = CustomDataset(dataset_csv_file=test_csv_path, phase='test')
test_dataloader = DataLoader(test_dataset, batch_size=4, num_workers=1)
dataiter = iter(test_dataloader)

In [None]:
# Load model
model = smp.Unet('resnet18', classes=NUM_CLASSES, activation='softmax', encoder_weights='imagenet')
model.cuda()

# define loss function
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA) 

In [None]:
model

In [None]:
def train(model, optimizer, scheduler, loss_list, model_name):
    for epoch in range(NUM_EPOCHS):
        model.train()
        configs    = "FCN_{}_batch{}_epoch{}_RMSprop_lr{}".format(model_name, BATCH_SIZE, epoch, LR)
        model_path = os.path.join(MODELS_ROOT, configs)
        
        for index, batch in enumerate(train_dataloader):
            optimizer.zero_grad()

            if torch.cuda.is_available():
                inputs = batch['input'].cuda()
                targets = batch['target'].cuda()
            else:
                inputs, targets = batch['X'], batch['Y']

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        
            if index % 10 == 0:
                print("epoch{}, iter{}, loss: {}".format(epoch, index, loss))
                # print(loss)
  
        scheduler.step()
        loss_list.append(loss)
        print("==== Finish epoch {} ====".format(epoch))
        if index % 50 == 0:
            torch.save(model.state_dict(), model_path + '.pkl')
        # val(epoch)

In [None]:
loss_list = []
train(model, optimizer, scheduler, loss_list, model_name="resnet18")