# Imports

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init
import numpy as np

import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data import Dataset, DataLoader
import os
import glob

import time
import datetime
from torchvision import datasets, transforms
from src.models import train_model, model, vanilla_model

from src.features import utils
import pathlib
import PIL
from torchvision.transforms import ToTensor, ToPILImage

from imageaugment import augment
import PIL.Image
import pathlib
import json

# Load MNIST

In [2]:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../../DivNoising/examples/data', train=True, transform=transforms.ToTensor()), batch_size=128, shuffle=True)

In [3]:
val_loader = torch.utils.data.DataLoader(datasets.MNIST('../../DivNoising/examples/data', train=False, transform=transforms.ToTensor()), batch_size=32, shuffle=True)

## Get Mean & STD of the Data

In [15]:
dataset = datasets.MNIST('../../DivNoising/examples/data')
mean = 0.
std = 0.
for image, _ in dataset:
    mean += np.array(image).mean()
    std += np.array(image).std()

data_mean /= len(dataset)
data_std /= len(dataset)

# Load Simulated Patches

In [3]:
def load_patches(image_folder_path):
    image_path_list = sorted(image_folder_path.rglob("*.png"))
    images = []

    for _, image_path in enumerate(image_path_list):
        image = PIL.Image.open(image_path)
        image = image.resize((2400, 3500))
        image = image.convert("L")
        x = ToTensor()(image)

        kh, kw = 192, 128  # kernel size
        dh, dw = 192, 128 # stride
        # Pad to multiples of given number
        w_pad1 = (kw - (x.size(2)%kw)) // 2
        w_pad2 = (kw - (x.size(2)%kw)) - w_pad1
        h_pad1 = (kh - (x.size(1)%kh)) // 2
        h_pad2 = (kh - (x.size(1)%kh)) - h_pad1
        x = F.pad(x, (w_pad1, w_pad2, h_pad1, h_pad2), value=1)

        patches = x.unfold(1, kh, dh).unfold(2, kw, dw)
        #unfold_shape = patches.size()
        patches = patches.contiguous().view(-1, kh, kw)
        images.append(patches)

    patched_image_tensors = torch.stack(images)
    patched_images = patched_image_tensors.view(-1, 1, patched_image_tensors.size(2), patched_image_tensors.size(3))
    
    return patched_images

## Load Training Data

In [4]:
clean_image_folder_train = pathlib.Path("/home/fahad/training_data_with_bbox/train/documents/")
patched_training_clean_images = load_patches(image_folder_path=clean_image_folder_train)

noisy_image_folder_train = pathlib.Path("/home/fahad/master_thesis/data/simulated_noisy_templates/train/")
patched_training_noisy_images = load_patches(image_folder_path=noisy_image_folder_train)

patched_training_clean_images.shape
patched_training_noisy_images.shape

torch.Size([43320, 1, 192, 128])

## Load Test Data

In [5]:
clean_image_folder_test = pathlib.Path("/home/fahad/training_data_with_bbox/val/documents/")
patched_test_clean_images = load_patches(image_folder_path=clean_image_folder_test)

noisy_image_folder_test = pathlib.Path("/home/fahad/master_thesis/data/simulated_noisy_templates/val/")
patched_test_noisy_images = load_patches(image_folder_path=noisy_image_folder_test)

patched_test_clean_images.shape
patched_test_noisy_images.shape

## Check Dimensions

In [None]:
print(patched_training_clean_images.shape)
print(patched_training_noisy_images.shape)

# Create Noisy Templates

## Set Transform Parameters

In [8]:
image_transform = augment.get_random_faxify(
    gamma=(.8, 1.0),
    angle_final=(0, 3),
    angle_transient=(0, 3),
    shift=(.005, .01),
    scale=(1.0, 1.0),
    threshold=(.65, .80),
    brightness=(1.0, 1.3),
    ditherprob=0.0,
    flipprob=0.0,
    vlineprob=.5,
    maxvlines=2,
    linewidth=(0.001, 0.002),
    particledensity=(.001, .005),
    particlesize=(.0001, .001)
)

In [None]:
def save_faxified_templates(image_transform, image_folder_path, save_directory_path):
    image_path_list = sorted(image_folder_path.rglob("*.png"))
    for image_path in image_path_list:
        image = PIL.Image.open(image_path)
        faxified = image_transform(image)
        faxified.save(save_directory_path/ image_path.name)

## Training Data

In [4]:
image_folder_val = pathlib.Path("/home/fahad/training_data_with_bbox/train/documents")
val_image_directory_path = pathlib.Path("/home/fahad/master_thesis/data/simulated_noisy_templates/train/")

save_faxified_templates(
    image_transform=image_transform,
    image_folder_path=image_folder_val,
    save_directory_path=val_image_directory_path
)

## Validation Data

In [3]:
image_folder_val = pathlib.Path("/home/fahad/training_data_with_bbox/val/documents")
val_image_directory_path = pathlib.Path("/home/fahad/master_thesis/data/simulated_noisy_templates/val/")

save_faxified_templates(
    image_transform=image_transform,
    image_folder_path=image_folder_val,
    save_directory_path=val_image_directory_path
)

# Create Crops

In [None]:
def create_crops(image_folder_path, graph_annotation_folder_path, save_directory, clean_crop):
    images = sorted(image_folder_path.rglob("*.png"))
    graph_annotations = sorted(graph_annotation_folder_path.rglob("*.json"))
    index = 0
    for image_path, annotation_path in zip(images, graph_annotations):
        image = PIL.Image.open(image_path)
        with open(annotation_path) as f:
            annotations = json.load(f)
        for annotation in annotations["NODES"]:
            if annotation["category"]=="numeric":
                x_top_left = annotation['origin_x']
                y_top_left = annotation['origin_y']
                x_bottom_right = annotation['origin_x'] + annotation['width']
                y_bottom_right = annotation['origin_y'] + annotation['height']
                crop = image.crop((x_top_left, y_top_left, x_bottom_right, y_bottom_right))
                crop = crop.resize((150,100))
                if clean_crop:
                    crop = crop.convert('1')
                crop.save(save_directory + str(index) + ".png")
                index += 1

## Create Clean Crops

### Training Data

In [89]:
image_folder_path = pathlib.Path("/home/fahad/training_data_with_bbox/train/documents")
graph_annotation_folder_path = pathlib.Path("/home/fahad/training_data_with_bbox/train/graph_annotations/")
save_directory = "/home/fahad/master_thesis/data/crops/clean_crops/train/"
create_crops(
    image_folder_path=image_folder_path,
    graph_annotation_folder_path=graph_annotation_folder_path,
    save_directory=save_directory,
    clean_crop=True,
)

### Validation Data

In [9]:
image_folder_path = pathlib.Path("/home/fahad/training_data_with_bbox/val/documents")
graph_annotation_folder_path = pathlib.Path("/home/fahad/training_data_with_bbox/val/graph_annotations/")
save_directory = "/home/fahad/master_thesis/data/crops/clean_crops/val/"
create_crops(
    image_folder_path=image_folder_path,
    graph_annotation_folder_path=graph_annotation_folder_path,
    save_directory=save_directory,
    clean_crop=True,
)

## Create Noisy Crops

### Train

In [78]:
image_folder = pathlib.Path("/home/fahad/master_thesis/data/simulated_noisy_templates/train/")
graph_annotation_folder = pathlib.Path("/home/fahad/training_data_with_bbox/train/graph_annotations/")
save_directory = "/home/fahad/master_thesis/data/crops/noisy_crops/train/"
create_crops(
    image_folder_path=image_folder_path,
    graph_annotation_folder_path=graph_annotation_folder_path,
    save_directory=save_directory,
    clean_crop=False,
)

### Validation

In [10]:
image_folder = pathlib.Path("/home/fahad/master_thesis/data/simulated_noisy_templates/val/")
graph_annotation_folder = pathlib.Path("/home/fahad/training_data_with_bbox/val/graph_annotations/")
save_directory = "/home/fahad/master_thesis/data/crops/noisy_crops/val/"
create_crops(
    image_folder_path=image_folder_path,
    graph_annotation_folder_path=graph_annotation_folder_path,
    save_directory=save_directory,
    clean_crop=False,
)

## Load Crops

In [None]:
def load_crops(crops_folder_path):
    crops_path_list = sorted(crops_folder_path.rglob("*.png"))

    crops = [ToTensor()(PIL.Image.open(crop_path)) for crop_path in crops_path_list]

    crops_tensor = torch.stack(crops)
    crops_tensor = crops_tensor.view(-1, 1, crops_tensor.size(2), crops_tensor.size(3))
    
    return crops_tensor

### Load Clean and Noisy Crops as Training Data

In [8]:
clean_crops_folder_path = pathlib.Path("/home/fahad/master_thesis/data/crops/clean_crops/train/")
noisy_crops_folder_path = pathlib.Path("/home/fahad/master_thesis/data/crops/noisy_crops/train/")

training_clean_crops_data = load_crops(crops_folder_path=clean_crops_folder_path)
training_noisy_crops_data = load_crops(crops_folder_path=noisy_crops_folder_path)
print(training_clean_crops_data.shape)
print(training_noisy_crops_data.shape)

### Load Clean and Noisy Crops as Validation Data

In [None]:
clean_crops_folder_path = pathlib.Path("/home/fahad/master_thesis/data/crops/clean_crops/val/")
noisy_crops_folder_path = pathlib.Path("/home/fahad/master_thesis/data/crops/noisy_crops/val/")

validation_clean_crops_data = load_crops(crops_folder_path=clean_crops_folder_path)
validation_noisy_crops_data = load_crops(crops_folder_path=noisy_crops_folder_path)
print(validation_clean_crops_data.shape)
print(validation_noisy_crops_data.shape)

# Set Model Parameters

In [12]:
batch_size=361
directory_path = "/home/fahad/master_thesis/vanilla_vae/models/"
n_epochs = 50
lr=0.001
model_name = "templates"

device = torch.device("cuda")
#device = torch.device("cpu")

In [6]:
train_clean_image_loader = DataLoader(patched_training_clean_images, batch_size=batch_size, shuffle=False)
train_noisy_image_loader = DataLoader(patched_training_noisy_images, batch_size=batch_size, shuffle=False)

In [13]:
test_clean_image_loader = DataLoader(patched_test_clean_images, batch_size=batch_size, shuffle=False)
test_noisy_image_loader = DataLoader(patched_test_noisy_images, batch_size=batch_size, shuffle=False)

In [None]:
"""
Initialize the network and the Adam optimizer
"""
net = vanilla_model.VAE().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

"""
Training the network for a given number of epochs
The loss after every epoch is printed
"""
for epoch in range(n_epochs):
    for clean_images, noisy_images in zip(train_clean_image_loader, train_noisy_image_loader):

        clean_images = clean_images.to(device)
        noisy_images = noisy_images.to(device)

        # Feeding a batch of images into the network to obtain the output image, mu, and logVar
        out, mu, logVar = net(noisy_images)

        # The loss is the BCE loss combined with the KL divergence to ensure the distribution is learnt
        kl_divergence = 0.5 * torch.sum(1 + logVar - mu.pow(2) - logVar.exp())
        loss = F.binary_cross_entropy(out, clean_images, reduction='sum') - kl_divergence
        #loss = torch.mean((out - data)**2) - kl_divergence

        # Backpropagation based on the loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Epoch {}: Loss {}'.format(epoch, loss))

## Save Model

In [11]:
torch.save(net.state_dict(), "/home/fahad/master_thesis/vanilla_vae/models/net_final.pth")

## Load Model

In [None]:
net = vanilla_model.VAE().to(device)
net.load_state_dict(torch.load("/home/fahad/master_thesis/vanilla_vae/models/net_final.pth"))

## Test Model

In [22]:
import matplotlib.pyplot as plt
import numpy as np
import random
clean_image_list = []
noisy_image_list = []
net.eval()
with torch.no_grad():
    for data in random.sample(list(test_noisy_image_loader), 1):
        imgs = data
        for i in range(batch_size):
            imgs = imgs.to(device)
            img = np.transpose(imgs[i].cpu().numpy(), [1,2,0])
            noisy_image_list.append(np.squeeze(img))
            out, mu, logVAR = net(imgs)
            outimg = np.transpose(out[i].cpu().numpy(), [1,2,0])
            clean_image_list.append(np.squeeze(outimg))

In [None]:
n = 19
clean_image_patches = []
noisy_image_patches = []
for i in range(n):
    clean_image_patches.append(np.concatenate(clean_image_list[i*n:(i+1)n], axis=1))
    noisy_image_patches.append(np.concatenate(noisy_image_list[i*n:(i+1)n], axis=1))
full_clean_image = np.concatenate(clean_image_patches, axis=0)
full_noisy_image = np.concatenate(noisy_image_patches, axis=0)
plt.figure(figsize=(30,20))
plt.imshow(new_image, cmap="gray")

## Plot Input and Output

In [None]:
plt.figure(figsize=(30,20))
plt.imshow(full_noisy_image, cmap="gray")

In [None]:
plt.figure(figsize=(30,20))
plt.imshow(full_clean_image, cmap="gray")