In [None]:
import copy
import os
import json

In [None]:
import argparse
from torch.utils.data import Dataset, DataLoader
import torch

print(torch.cuda.is_available())

In [3]:
def get_hyper_parameters():
    _para_list = [{"optim_type": 'adam', 'lr': 0.0001, "weight_decay": 1e-5, "store_img_training": True}]
    _num_epoch = 40
    _patience = 10
    _device = 'cuda'
    return _para_list, _num_epoch, _patience, _device

In [11]:
model = StipplingReconstructionCNN(1)
para_list, num_epoch, patience, device_str = get_hyper_parameters()
# train loader
train_dataset = DataLoader_Stippling('./data/median_res/stipple/train', './data/median_res/grayscale/train')
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# val loader
val_dataset = DataLoader_Stippling('./data/median_res/stipple/val', './data/median_res/grayscale/val')
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

In [None]:
best_model, stats = train_model(model, train_loader, val_loader, num_epoch, para_list[0], patience, device_str)

In [None]:
save_checkpoint(best_model, 40, "./checkpoints", stats)
# Plotting
plt.figure(figsize=(10, 6)) # Sets the figure size
plt.plot(stats['loss_ind'], stats['loss'], marker='o', linestyle='-', color='b')
plt.title('Loss over Time') # Title of the plot
plt.xlabel('Index') # X-axis label
plt.ylabel('Loss') # Y-axis label
plt.grid(True) # Shows a grid
plt.show() # Displays the plot

In [None]:
# test loader
test_dataset = DataLoader_Stippling('./data/median_res/stipple/test', './data/median_res/grayscale/test')
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

device = torch.device('cuda')
best_model = StipplingReconstructionCNN(1)
checkpoint = torch.load("./checkpoints/final.pth.tar")
best_model.load_state_dict(checkpoint["state_dict"])


best_model.to(device)
best_model.eval()

print('------------------------ Start Testing ------------------------')

i = 0
for batch, data in enumerate(test_loader):
    print('Batch No. {0}'.format(batch))

    X = data['stippling'].to(device)
    y = data['grayscale'].to(device)

    reconstructed_img = best_model(X)

    cv2.imwrite("./test_res/" + str(i) + "_grayscale.png", y[0].cpu().squeeze().numpy() * 255.0)
    cv2.imwrite("./test_res/" + str(i) + "_stippling.png", X[0].cpu().squeeze().numpy() * 255.0)
    cv2.imwrite("./test_res/" + str(i) + "_reconstruction.png", reconstructed_img[0].cpu().squeeze().detach().numpy() * 255.0)
    i += 1

In [None]:
model, _, _ = restore_checkpoint(model, "./checkpoints", cuda=True, force=False, pretrain=False)

In [8]:
from utils import get_optimizer, reconstruct_image
from performance import calculate_loss, get_performance
import torch
import time
import copy
import cv2
def train_model(model, train_loader, val_loader, num_epoch, parameter, patience, device_str):
    device = torch.device(device_str if device_str == 'cuda' and torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.train()

    loss_all, loss_index = [], []
    num_itration = 0
    best_model, best_val_loss = model, float('inf')
    num_bad_epoch = 0

    optim = get_optimizer(model, parameter["optim_type"], parameter["lr"], parameter["weight_decay"])
    scheduler = torch.optim.lr_scheduler.LinearLR(optim, start_factor=1.0, end_factor=0, total_iters=num_epoch)

    print('------------------------ Start Training ------------------------')
    t_start = time.time()
    loss = 0
    for epoch in range(num_epoch):
        for batch, data in enumerate(train_loader):
            if (batch % 5 == 0):
                print('Batch No. {0}'.format(batch))

            X = data['stippling'].to(device)
            y = data['grayscale'].to(device)

            num_itration += 1

            model.train()
            optim.zero_grad()
            reconstructed_img = model(X)

            loss= calculate_loss(reconstructed_img, y)
            loss.requires_grad_().backward()
            optim.step()

            loss_all.append(loss.item())
            loss_index.append(num_itration)

        print('Epoch No. {0} -- loss = {1:.4f}'.format(
            epoch + 1,
            loss.item(),
        ))

        # Validation:
        print('Validation')
        val_loss = get_performance(model, val_loader, device_str)
        model.to(device)
        print("Validation loss: {:.4f}".format(val_loss))

        if val_loss < best_val_loss:
            best_model = copy.deepcopy(model)
            best_val_loss = val_loss
            num_bad_epoch = 0
        else:
            num_bad_epoch += 1

        # early stopping
        if num_bad_epoch >= patience:
            break

        # learning rate scheduler
        scheduler.step()

    t_end = time.time()
    print('Training lasted {0:.2f} minutes'.format((t_end - t_start) / 60))
    print('------------------------ Training Done ------------------------')
    stats = {'loss': loss_all,
             'loss_ind': loss_index,
             }

    return best_model, stats

In [7]:
from torch import nn
import numpy as np
import torch.optim as optimizer
from matplotlib import pyplot as plt
import cv2
import torch
import torch.nn.functional as F
from utils import get_optimizer, reconstruct_image

def calculate_loss(reconstructed_img, target_img):
    loss = F.mse_loss(reconstructed_img, target_img)

    return loss

def get_performance(model, val_loader, device_str):
    device = torch.device(device_str if device_str == 'cuda' and torch.cuda.is_available() else 'cpu')
    model.to(device)

    loss_all = []
    for _, data in enumerate(val_loader):
        X = data['stippling'].to(device)
        y = data['grayscale'].to(device)

        reconstructed_img = reconstruct_image(X, model, False)

        loss= calculate_loss(reconstructed_img, y)

        loss_all.append(loss.item())

    val_loss = sum(loss_all) / len(loss_all)
    return val_loss

In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import glob

class DataLoader_Stippling(Dataset):
    def __init__(self, stippling_dir, grayscale_dir):
        """
        Args:
            stippling_dir (string): Directory with all the stippling images.
            grayscale_dir (string): Directory with all the grayscale images.
        """
        self.stippling_files = sorted(glob.glob(stippling_dir + '/*.png'))
        self.grayscale_files = sorted(glob.glob(grayscale_dir + '/*.png'))

    def __len__(self):
        return len(self.stippling_files)

    def __getitem__(self, idx):
        # No need to check if idx is a tensor, since we're not using it in a tensor form
        stippling_img_path = self.stippling_files[idx]
        grayscale_img_path = self.grayscale_files[idx]

        # Load images using cv2
        stippling_img = torch.FloatTensor(cv2.resize(cv2.imread(stippling_img_path, cv2.IMREAD_GRAYSCALE), (800, 592))).unsqueeze(0) / 255.0
        grayscale_img = torch.FloatTensor(cv2.resize(cv2.imread(grayscale_img_path, cv2.IMREAD_GRAYSCALE), (800, 592))).unsqueeze(0) / 255.0


        sample = {'stippling': stippling_img, 'grayscale': grayscale_img}

        return sample

def get_data_loader(stippling_dir, grayscale_dir, batch_size):
    custom_dataset = DataLoader_Stippling(stippling_dir, grayscale_dir)
    return DataLoader(custom_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

In [10]:
import torch
import torch.nn as nn
from torchvision import models
from torch.nn.functional import relu

class StipplingReconstructionCNN(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        # Encoder
        # In the encoder, convolutional layers with the Conv2d function are used to extract features from the input image.
        # Each block in the encoder consists of two convolutional layers followed by a max-pooling layer, with the exception of the last block which does not include a max-pooling layer.
        # -------
        # input: 800x600x1
        self.e11 = nn.Conv2d(1, 64, kernel_size=3, padding=1) # output: 798x766x64
        self.e12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) # output: 796x764x64
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 398x382x64

        # input: 510x382x64
        self.e21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # output: 508x380x128
        self.e22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) # output: 394x378x128
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 197x189x128

        # input: 253x189x128
        self.e31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # output: 251x138x256
        self.e32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) # output: 193x136x256
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 97x68x256

        # input: 68x68x256
        self.e41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) # output: 66x66x512
        self.e42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) # output: 93x64x512
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) # output: 32x32x512

        # input: 32x32x512
        self.e51 = nn.Conv2d(512, 1024, kernel_size=3, padding=1) # output: 30x30x1024
        self.e52 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1) # output: 28x28x1024


        # Decoder
        self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.d11 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.d12 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.d21 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.d22 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.d31 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.d32 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.d41 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.d42 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        # self.median_blur = MedianBlur(kernel_size=5)
        # self.bilateral_filter = BilateralFilter(d=5, sigmaColor=0.5, sigmaSpace=2.0)

        # Output layer
        self.outconv = nn.Conv2d(64, n_class, kernel_size=1)

    def forward(self, x):
        # Encoder
        # print(x.size())
        xe11 = relu(self.e11(x))
        # print(xe11.size())
        xe12 = relu(self.e12(xe11))
        # print(xe12.size())
        xp1 = self.pool1(xe12)
        # print(xp1.size())

        xe21 = relu(self.e21(xp1))
        # print(xe21.size())
        xe22 = relu(self.e22(xe21))
        # print(xe22.size())
        xp2 = self.pool2(xe22)
        # print(xp2.size())

        xe31 = relu(self.e31(xp2))
        # print(xe31.size())
        xe32 = relu(self.e32(xe31))
        # print(xe32.size())
        xp3 = self.pool3(xe32)
        # print(xp3.size())

        xe41 = relu(self.e41(xp3))
        # print(xe41.size())
        xe42 = relu(self.e42(xe41))
        # print(xe42.size())
        xp4 = self.pool4(xe42)
        # print(xp4.size())

        xe51 = relu(self.e51(xp4))
        # print(xe51.size())
        xe52 = relu(self.e52(xe51))
        # print(xe52.size())

        # Decoder
        xu1 = self.upconv1(xe52)
        # print(xu1.size())
        xu11 = torch.cat([xu1, xe42], dim=1)
        # print(xu11.size())
        xd11 = relu(self.d11(xu11))
        # print(xd11.size())
        xd12 = relu(self.d12(xd11))
        # print(xd12.size())

        xu2 = self.upconv2(xd12)
        xu22 = torch.cat([xu2, xe32], dim=1)
        xd21 = relu(self.d21(xu22))
        xd22 = relu(self.d22(xd21))

        xu3 = self.upconv3(xd22)
        xu33 = torch.cat([xu3, xe22], dim=1)
        xd31 = relu(self.d31(xu33))
        xd32 = relu(self.d32(xd31))

        xu4 = self.upconv4(xd32)
        xu44 = torch.cat([xu4, xe12], dim=1)
        xd41 = relu(self.d41(xu44))
        xd42 = relu(self.d42(xd41))



        # Output layer
        out = self.outconv(xd42)
        print(out.size())

        return out

In [5]:
import numpy as np
import itertools
import os
import torch
from torch.nn.functional import softmax
from sklearn import metrics
import utils

def save_checkpoint(model, epoch, checkpoint_dir, stats):
    """Save a checkpoint file to `checkpoint_dir`."""
    state = {
        "epoch": epoch,
        "state_dict": model.state_dict(),
        "stats": stats,
    }

    filename = os.path.join(checkpoint_dir, "epoch={}.checkpoint.pth.tar".format(epoch))
    torch.save(state, filename)

def restore_checkpoint(model, checkpoint_dir, cuda=False, force=False, pretrain=False):
    """Restore model from checkpoint if it exists.

    Returns the model and the current epoch.
    """
    try:
        cp_files = [
            file_
            for file_ in os.listdir(checkpoint_dir)
            if file_.startswith("epoch=") and file_.endswith(".checkpoint.pth.tar")
        ]
    except FileNotFoundError:
        cp_files = None
        os.makedirs(checkpoint_dir)
    if not cp_files:
        print("No saved model parameters found")
        if force:
            raise Exception("Checkpoint not found")
        else:
            return model, 0, []

    # Find latest epoch
    for i in itertools.count(1):
        if "epoch={}.checkpoint.pth.tar".format(i) in cp_files:
            epoch = i
        else:
            break

    if not force:
        print(
            "Which epoch to load from? Choose in range [0, {}].".format(epoch),
            "Enter 0 to train from scratch.",
        )
        print(">> ", end="")
        inp_epoch = int(input())
        if inp_epoch not in range(epoch + 1):
            raise Exception("Invalid epoch number")
        if inp_epoch == 0:
            print("Checkpoint not loaded")
            clear_checkpoint(checkpoint_dir)
            return model, 0, []
    else:
        print("Which epoch to load from? Choose in range [1, {}].".format(epoch))
        inp_epoch = int(input())
        if inp_epoch not in range(1, epoch + 1):
            raise Exception("Invalid epoch number")

    filename = os.path.join(
        checkpoint_dir, "epoch={}.checkpoint.pth.tar".format(inp_epoch)
    )

    print("Loading from checkpoint {}?".format(filename))

    if cuda:
        checkpoint = torch.load(filename)
    else:
        # Load GPU model on CPU
        checkpoint = torch.load(filename, map_location=lambda storage, loc: storage)

    try:
        start_epoch = checkpoint["epoch"]
        stats = checkpoint["stats"]
        if pretrain:
            model.load_state_dict(checkpoint["state_dict"], strict=False)
        else:
            model.load_state_dict(checkpoint["state_dict"])
        print(
            "=> Successfully restored checkpoint (trained for {} epochs)".format(
                checkpoint["epoch"]
            )
        )
    except:
        print("=> Checkpoint not successfully restored")
        raise

    return model, inp_epoch, stats

def clear_checkpoint(checkpoint_dir):
    """Remove checkpoints in `checkpoint_dir`."""
    filelist = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth.tar")]
    for f in filelist:
        os.remove(os.path.join(checkpoint_dir, f))

    print("Checkpoint successfully removed")