## Graphical User Interface for Image Coloration

In [None]:
import tensorflow as tf
with tf.device('/CPU:0'):
    import tkinter as tk
    from tkinter import filedialog
    from PIL import Image, ImageTk
    import torch
    from torch import nn, optim
    import torchvision.transforms as transforms
    import glob
    import numpy as np
    from torch.utils.data import Dataset, DataLoader
    from PIL import Image
    from pathlib import Path
    from tqdm.notebook import tqdm
    import matplotlib.pyplot as plt
    from skimage.color import rgb2lab, lab2rgb
    SIZE = 256

In [None]:
class ColorizationDataset(Dataset):
    def __init__(self, paths):
        self.transforms = transforms.Resize((SIZE,SIZE), Image.BICUBIC)
        self.size = SIZE
        self.paths = paths
    def __getitem__(self,idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32")
        img_lab = transforms.ToTensor()(img_lab)
        l = img_lab[[0], ...]/50. - 1.
        ab = img_lab[[1,2], ...]/110.
        return {'L':l, 'ab':ab}
    def __len__(self):
        return len(self.paths)
def make_dataloaders(batch_size=1, **kwargs):
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    return dataloader
class UnetBlock(nn.Module):
  def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
              innermost=False, outermost=False):
      super().__init__()
      self.outermost = outermost
      if input_c is None: input_c = nf
      downconv = nn.Conv2d(input_c, ni, kernel_size=4,
                          stride=2, padding=1, bias=False)
      downrelu = nn.LeakyReLU(0.2, True)
      downnorm = nn.BatchNorm2d(ni)
      uprelu = nn.ReLU(True)
      upnorm = nn.BatchNorm2d(nf)

      if outermost:
          upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                      stride=2, padding=1)
          down = [downconv]
          up = [uprelu, upconv, nn.Tanh()]
          model = down + [submodule] + up
      elif innermost:
          upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                      stride=2, padding=1, bias=False)
          down = [downrelu, downconv]
          up = [uprelu, upconv, upnorm]
          model = down + up
      else:
          upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                      stride=2, padding=1, bias=False)
          down = [downrelu, downconv, downnorm]
          up = [uprelu, upconv, upnorm]
          if dropout: up += [nn.Dropout(0.5)]
          model = down + [submodule] + up
      self.model = nn.Sequential(*model)

  def forward(self, x):
      if self.outermost:
          return self.model(x)
      else:
          return torch.cat([x, self.model(x)], 1)

class Unet(nn.Module):
  def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
      super().__init__()
      unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
      for _ in range(n_down - 5):
          unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
      out_filters = num_filters * 8
      for _ in range(3):
          unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
          out_filters //= 2
      self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)

  def forward(self, x):
      return self.model(x)
def init_weights(net, init='norm', gain=0.02):
    
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)
            
    net.apply(init_func)
    return net
def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model
class PatchDiscriminator(nn.Module):
  def __init__(self, input_c, num_filters=64, n_down=3):
      super().__init__()
      model = [self.get_layers(input_c, num_filters, norm=False)]
      model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) 
                        for i in range(n_down)]
      model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)]
      self.model = nn.Sequential(*model)                                                   

  def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): 
      layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]          
      if norm: layers += [nn.BatchNorm2d(nf)]
      if act: layers += [nn.LeakyReLU(0.2, True)]
      return nn.Sequential(*layers)

  def forward(self, x):
      return self.model(x)
class GANLoss(nn.Module):
  def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
      super().__init__()
      self.register_buffer('real_label', torch.tensor(real_label))
      self.register_buffer('fake_label', torch.tensor(fake_label))
      if gan_mode == 'vanilla':
          self.loss = nn.BCEWithLogitsLoss()
      elif gan_mode == 'lsgan':
          self.loss = nn.MSELoss()

  def get_labels(self, preds, target_is_real):
      if target_is_real:
          labels = self.real_label
      else:
          labels = self.fake_label
      return labels.expand_as(preds)

  def __call__(self, preds, target_is_real):
      labels = self.get_labels(preds, target_is_real)
      loss = self.loss(preds, labels)
      return loss
class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=1e-4, lr_D=1e-4, 
                 beta1=0.40, beta2=0.9, lambda_L1=90.):
        super().__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        
        if net_G is None:
            self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
        else:
            self.net_G = net_G.to(self.device)
        self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
    
    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad
        
    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)
        
    def forward(self):
        self.fake_color = self.net_G(self.L)
    
    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()
    
    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()
    
    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()
        
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

def lab_to_rgb(L, ab):
  L = (L + 1.) * 50.
  ab = ab * 110.
  Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
  rgb_imgs = []
  for img in Lab:
      img_rgb = lab2rgb(img)
      rgb_imgs.append(img_rgb)
  return np.stack(rgb_imgs, axis=0)

In [None]:
device = torch.device('cpu')
model_state_dict = torch.load("C:\\Users\\dhanr\\Downloads\\coloration.pt", map_location=device)
model = MainModel()
model.load_state_dict(model_state_dict)
model.eval()

In [None]:
root = tk.Tk()
root.title("Image Converter")
root.geometry('800x800')
canvas = tk.Canvas(root, width=600, height=600)
canvas.pack()
def browse_file():
    file_path = filedialog.askopenfilename()
    input_image = Image.open(file_path)
    input_image = input_image.resize((256, 256))
    input_image_tk = ImageTk.PhotoImage(input_image)
    canvas.create_image(0, 0, anchor="nw", image=input_image_tk)
    canvas.input_image_tk = input_image_tk
    canvas.create_text(100,300,text=("B/W Image(Input)"),fill="black",font=("Helvetica 10 bold"))
    f_path = glob.glob(file_path)
    file_dl = make_dataloaders(paths=f_path)
    data = next(iter(file_dl))
    ls,ab = data['L'],data['ab']
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    fake_color = model.fake_color.detach()
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    fake_imgs[0] = ((fake_imgs[0] - fake_imgs[0].min()) / (fake_imgs[0].max() - fake_imgs[0].min())) * 255
    img = Image.fromarray(np.uint8(fake_imgs[0]))
    photo_img = ImageTk.PhotoImage(img)
    canvas.create_image(256, 0, anchor="nw", image=photo_img)
    canvas.photo_img = photo_img
    canvas.create_text(400,300,text=("Color Image(Output)"),fill="black",font=("Helvetica 10 bold"))

button = tk.Button(root, text="Browse input file", command=browse_file)
button.pack()
root.mainloop()