In [None]:
import matplotlib
matplotlib.use('Agg')
%matplotlib inline
import matplotlib.pyplot as plt
import itertools

import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch

import numpy as np
import random
import copy
import time
from sklearn.model_selection import train_test_split

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"


# from utils import CustomDataset, GANLoss, Vgg19
# from generator import AttU_Net
# from discriminator import Discriminator

from PIL import Image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Read data

In [None]:
#img_path = "../OCR_data/doc_unet/original/1_1.jpg"

img_path = "/srv/data/hamin/OCR/doc_unet/original/1_1.jpg"

im = Image.open(img_path)
#im = im.resize((num_pixels_y, num_pixels_x))
im = np.array(im)

print(f"0: {im.shape[0] // 2}, 1: {im.shape[1] // 2}")


In [None]:
# orig_base_path = "../OCR_data/doc_unet/original/"
# crop_base_path = "../OCR_data/doc_unet/crop/"
# scan_base_path = "../OCR_data/doc_unet/scan/"


orig_base_path = "/srv/data/hamin/OCR/doc_unet/original/"
crop_base_path = "/srv/data/hamin/OCR/doc_unet/crop/"
scan_base_path = "/srv/data/hamin/OCR/doc_unet/scan/"


num_pixels_x, num_pixels_y, num_channels = 2016//16, 1512//16, 3

orig_images = np.zeros((len(os.listdir(orig_base_path)), num_pixels_x, num_pixels_y, num_channels), dtype=np.uint8)
crop_images = np.zeros((len(os.listdir(orig_base_path)), num_pixels_x, num_pixels_y, num_channels), dtype=np.uint8)
scan_images = np.zeros((len(os.listdir(orig_base_path)), num_pixels_x, num_pixels_y, num_channels), dtype=np.uint8)

In [None]:
from tqdm import tqdm
for count, im_path in enumerate(tqdm(sorted(os.listdir(orig_base_path)))):

    im = Image.open(f"{orig_base_path}/{im_path}")
    im = im.resize((num_pixels_y, num_pixels_x))
    im = np.array(im)

    orig_images[count] = im
    #print(i)

In [None]:
pip install tqdm

In [None]:
for count, im_path in enumerate(tqdm(sorted(os.listdir(crop_base_path)))):

    im = Image.open(f"{crop_base_path}/{im_path}")
    im = im.resize((num_pixels_y, num_pixels_x))
    im = np.array(im)

    crop_images[count] = im
    #print(i)

In [None]:
#Rename scan folder to make everything sorted
# for im_path in sorted(os.listdir(scan_base_path)):
#     os.rename(f"{scan_base_path}{im_path}", f"{scan_base_path}{im_path[:-4]}_1.png")

for count, im_path in enumerate(tqdm(sorted(os.listdir(scan_base_path)))):

    cur_count = count*2
    im = Image.open(f"{scan_base_path}/{im_path}")
    im = im.convert('RGB')
    im = im.resize((num_pixels_y, num_pixels_x))
    im = np.array(im)

    scan_images[cur_count] = im
    scan_images[cur_count+1] = im
    

In [None]:
f, axarr = plt.subplots(1,3, figsize=(10,7))

i = -2

cur_orig = orig_images[i]
cur_crop = crop_images[i]
cur_scan = scan_images[i]


axarr[0].imshow(cur_orig)
axarr[0].title.set_text('Original')

axarr[1].imshow(cur_crop)
axarr[1].title.set_text('Cropped')

axarr[2].imshow(cur_scan)
axarr[2].title.set_text('scanned')

In [None]:
base_path = "/srv/data/hamin/OCR/doc_unet/"
# with open(f"{base_path}/orig.npy", 'wb') as f:
#     np.save(f, orig_images)
    
# with open(f"{base_path}/crop.npy", 'wb') as f:
#     np.save(f, crop_images)
    
# with open(f"{base_path}/scan.npy", 'wb') as f:
#     np.save(f, scan_images)
    
    
# with open(f"{base_path}/orig.npy", 'rb') as f:
#     orig_images = np.load(f)
    
# with open(f"{base_path}/crop.npy", 'rb') as f:
#     crop_images = np.load(f)
    
# with open(f"{base_path}/scan.npy", 'rb') as f:
#     scan_images = np.load(f)    

# Phase 1: get cropped image

In [None]:
x_train, x_test, y_train, y_test = train_test_split(orig_images, crop_images, test_size=0.2, random_state=42)


In [None]:
f, axarr = plt.subplots(1,2, figsize=(10,7))

i = -6

cur_orig = x_train[i]
cur_crop = y_train[i]
#cur_scan = scan_images[i]


axarr[0].imshow(cur_orig)
axarr[0].title.set_text('Original')

axarr[1].imshow(cur_crop)
axarr[1].title.set_text('Cropped')

# axarr[2].imshow(cur_scan)
# axarr[2].title.set_text('scanned')

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from PIL import Image


class CustomDataset(Dataset):
    def __init__(self, x_data, y_data, transform_input=None, transform_output=None ):
        self.transform_input = transform_input
        self.transform_output = transform_output

        self.data = np.transpose(x_data, (0, 3, 1, 2))
        self.targets = np.transpose(y_data, (0, 3, 1, 2))

    def __len__(self):
        return self.targets.shape[0]

    def __getitem__(self, index):

        x = Image.fromarray(self.data[index].astype(np.uint8).transpose(1, 2, 0))
        x = self.transform_input(x)

        y = Image.fromarray(self.targets[index].astype(np.uint8).transpose(1, 2, 0))      
        y = self.transform_output(y)

        return x, y


In [None]:
transform_train = transforms.Compose([
    #transforms.RandomAffine(degrees=[-10,10], translate=[0.00,0.15], scale=[0.7,1.00], shear=1, fill=(0,0,0)),
    transforms.ToTensor(),
    #transforms.Normalize(mean = [ 0.5, 0.5, 0.5 ], std = [ 0.5, 0.5, 0.5 ]),
    #transforms.RandomErasing(scale=[0.05, 0.08], ratio=[0.02,0.05], p=0.5),
])

transform_output = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize(mean = [ 0.5, 0.5, 0.5 ], std = [ 0.5, 0.5, 0.5 ]),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize(mean = [ 0.5, 0.5, 0.5 ], std = [ 0.5, 0.5, 0.5 ]),

])

In [None]:
train_set = CustomDataset(x_train, y_train,
                          transform_input=transform_train, 
                          transform_output=transform_output, 
                         )
trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=16)

test_set = CustomDataset(x_test, y_test, 
                          transform_input=transform_test, 
                          transform_output=transform_output, 
                        )
testloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, num_workers=16)

In [None]:
#Sanity test
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor
    
    
for x,y in trainloader:
    break

    
unorm = UnNormalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
image_x = unorm(x)
image_y = unorm(y)

image_x = torch.permute(image_x, (0, 2,3, 1)).numpy()
image_y = torch.permute(image_y, (0, 2,3, 1)).numpy()


f, axarr = plt.subplots(1,2, figsize=(10,7))
i = 0
cur_orig = image_x[i]
cur_crop = image_y[i]
#cur_scan = scan_images[i]


axarr[0].imshow(cur_orig)
axarr[0].title.set_text('Original')

axarr[1].imshow(cur_crop)
axarr[1].title.set_text('Cropped')

In [None]:
from models.u2net import U2NETP, U2NET

In [None]:
generator = U2NETP()
# generator = U2NET()

generator = nn.DataParallel(generator)
generator.to(device)

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Discriminator(nn.Module):
    def __init__(self, input_nc=3):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x = x + (0.1**0.5)*(torch.randn(x.shape[0], x.shape[1], x.shape[2], x.shape[3])).cuda()
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

In [None]:

discriminator = Discriminator()

discriminator = nn.DataParallel(discriminator)
discriminator.to(device)

In [None]:
from torchvision.models import vgg19
from collections import namedtuple

class Vgg19(torch.nn.Module):
    def __init__(self):
        super(Vgg19, self).__init__()
        features = list(vgg19(pretrained = True).features)[:36]
        self.features = nn.ModuleList(features).eval()
        
    def forward(self, x):
        results = []
        for ii,model in enumerate(self.features):
            x = model(x)
            if ii in {2,7,12,21,30}:
                results.append(x)
        return results    

In [None]:
perceptual_model = Vgg19()

perceptual_model = nn.DataParallel(perceptual_model)
perceptual_model.to(device)

In [None]:
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(input)

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)    

In [None]:
lr=0.001


criterion_g = torch.nn.MSELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=lr,
                             betas=(0.5, 0.999)
                            )


criterion_d = GANLoss().to(device)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr,
                             betas=(0.5, 0.999)
                            )

criterion_p = torch.nn.MSELoss()
    

In [None]:
train_total = len(train_set)
train_batches = len(trainloader)

test_total = len(test_set)
test_baches = len(testloader)

In [None]:
num_epochs = 5000

patience = 0    # Bad epoch counter
best_loss = 1024


In [None]:
count = 0
for epoch in range(num_epochs):
    # Train
    
    
    generator.train()
    discriminator.train()

    train_loss = 0
    train_correct = 0

    start_time = time.time()
    for x, y in trainloader:
        x = x.to(device)
        y = y.to(device)

        #forward
        fake, _, _, _, _, _, _ = generator(x)
        
        
        #########################
        # (1) Calculate Perceptual Loss
        #########################

        fake_feature = perceptual_model(fake)
        real_feature = perceptual_model(y)

        loss_p = 0
        for i in range(5):
            loss_p += criterion_p(fake_feature[i], real_feature[i])        
        
        
        ##########################
        # (2) Update Discriminator
        ##########################
        
        #Train with fake
        optimizer_d.zero_grad()
        pred_fake = discriminator(fake)
        loss_d_fake = criterion_d(pred_fake, False)
        
        #Train with real
        pred_real = discriminator(y)
        loss_d_real = criterion_d(pred_real, True)
        
        #Average and update
        loss_d_total = (loss_d_fake + loss_d_real) * 0.5
        loss_d_total.backward(retain_graph=True)
        optimizer_d.step()
        
        
        ###########################
        # (3) Update Generator
        ###########################
        optimizer_g.zero_grad()
        
        #Get Discriminator Loss
        pred_fake = discriminator(fake)
        loss_gan = criterion_d(pred_fake, True)
        
        
        #Get Generator Loss
        loss_g = criterion_g(fake, y)
        
        loss_total = loss_gan + loss_g + loss_p
        loss_total.backward(retain_graph=True)
        optimizer_g.step()

        train_loss += loss_total.item()

    train_loss = train_loss / train_batches

    #scheduler.step(1.)

    end_time = time.time()
    
    print('[%2d / %d] train_loss: %.5f  ' % (epoch+1, num_epochs , train_loss), end = ' ')
        
    generator.eval()
    
    
    ### test acc ###
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data in testloader:
            x, y = data[0].to(device), data[1].to(device)
            outputs, _,_,_,_,_,_ = generator(x)
            
            loss = criterion_g(outputs, y)
            test_loss += loss.item()
            
            
        test_loss = test_loss / test_baches
        
        if(test_loss < best_loss):
            #torch.save(generator.state_dict(), path_checkpoint)
            #torch.save(discriminator.state_dict(), path_checkpoint2)
            best_loss = test_loss
        print('test_loss: %.5f -- Best loss: %.5f --- %.2f seconds' %(test_loss, best_loss, (end_time-start_time)))    
        
        
#     unorm = UnNormalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
#     x = unorm(x)
#     y_hat = unorm(outputs)
#     y = unorm(y)
    y_hat = outputs

    x = torch.permute(x.detach().cpu(), (0, 2,3, 1)).numpy()
    y_hat = torch.permute(y_hat.detach().cpu(), (0, 2,3, 1)).numpy()
    y = torch.permute(y.detach().cpu(), (0, 2,3, 1)).numpy()

    #for i in range(len(x)):
    

    if epoch % 5 == 0:
        f, axarr = plt.subplots(1,3, figsize=(10,7))
        i = 0
        cur_orig = x[i]
        cur_pred = y_hat[i]
        cur_crop = y[i]
        #cur_scan = scan_images[i]


        axarr[0].imshow(cur_orig)
        axarr[0].title.set_text('Original')

        axarr[1].imshow(cur_pred)
        axarr[1].title.set_text('Predicted')

        axarr[2].imshow(cur_crop)
        axarr[2].title.set_text('real')
        plt.show()

In [None]:
## Try doing CV2 algin first, then run the models?