In [None]:
! pip install -q kaggle
! mkdir ~/.kaggle
#Upload the token json file
from google.colab import files
files.upload()
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
! kaggle datasets list

In [None]:
! kaggle datasets download -d shravankumar9892/image-colorization

In [None]:
!unzip "image-colorization.zip"

In [None]:

import os
from pathlib import Path

# Import glob to get the files directories recursively
import glob

# Import Garbage collector interface
import gc

# Import OpenCV to transforme pictures
import cv2

# Import Time
import time

# import numpy for math calculations
import numpy as np

# Import pandas for data (csv) manipulation
import pandas as pd

# Import matplotlib for plotting
import matplotlib.pyplot as plt
import matplotlib
matplotlib.style.use('fivethirtyeight')
%matplotlib inline

import PIL
from PIL import Image
from skimage.color import rgb2lab, lab2rgb

#import pytorch_lightning as pl

# Import pytorch to build Deel Learling Models
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision import models
from torch.nn import functional as F
import torch.utils.data
from torchvision.models.inception import inception_v3
from scipy.stats import entropy

from torchsummary import summary

# Import tqdm to show a smart progress meter
from tqdm import tqdm

# Import warnings to hide the unnessairy warniings
import torchvision.utils as vutils

In [None]:
ab_path = "/content/ab/ab/ab1.npy"
l_path = "/content/l/gray_scale.npy"

In [None]:
ab_df = np.load(ab_path)[0:100]
L_df = np.load(l_path)[0:100]
dataset = (L_df,ab_df)
gc.collect()

In [None]:
def setup_input(self, data):
        self.L = data['L']
        self.ab = data['ab']

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision
import torchvision.models as models

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Load the pre-trained ResNet backbone
        backbone = models.resnet34(pretrained=True)
        self.label_emb = nn.Embedding(3, 3)

        # Modify the first layer to accept the desired input channel size
        backbone.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Retrieve the encoder and decoder blocks from the backbone
        self.encoder = nn.Sequential(*list(backbone.children())[:-2])
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ConvTranspose2d(64, 2, kernel_size=4, stride=4)
        )

    def forward(self, x):
        # Perform forward pass through the UNet model
        x = self.encoder(x)
        x = self.decoder(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()



        self.conv1 = nn.Conv2d(3, 64, 5, stride=1, padding=2, bias=False)
        self.conv2 = nn.Conv2d(64, 128, 5, stride=1, padding=2, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 5, stride=1, padding=2, bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, 5, stride=1, padding=2, bias=False)
        self.bn4 = nn.BatchNorm2d(512)
        self.conv5 = nn.Conv2d(512, 3, 5, stride=1, padding=2, bias=False)

    def forward(self, x, y):

        x_ = torch.cat([x, y], 1)

        # Forward pass
        x1 = nn.functional.leaky_relu(self.conv1(x_), 0.2)
        x2 = nn.functional.leaky_relu(self.bn2(self.conv2(x1)), 0.2)
        x3 = nn.functional.leaky_relu(self.bn3(self.conv3(x2)), 0.2)
        x4 = nn.functional.leaky_relu(self.bn4(self.conv4(x3)), 0.2)
        x5 = self.conv5(x4)
        out = torch.sigmoid(x5)

        # Output
        return out





In [None]:
class ImageColorizationDataset(Dataset):
    ''' Black and White (L) Images and corresponding A&B Colors'''
    def __init__(self, dataset, transform=None):
        '''
        :param dataset: Dataset name.
        :param data_dir: Directory with all the images.
        :param transform: Optional transform to be applied on sample
        '''
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset[0])

    def __getitem__(self, idx):
        L = np.array(dataset[0][idx]).reshape((224,224,1))
        L = transforms.ToTensor()(L)

        ab = np.array(dataset[1][idx])
        ab = transforms.ToTensor()(ab)

        return L, ab

In [None]:
torch.cuda.empty_cache()

In [None]:
def lab_to_rgb(L, ab):
    """
    Takes an image or a batch of images and converts from LAB space to RGB
    """
    L = L  * 100
    ab = (ab - 0.5) * 128 * 2
    Lab = torch.cat([L, ab], dim=1).numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"  # Set the desired value for max_split_size_mb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")




# Define the loss functions and optimizers

criterion = nn.BCELoss()

generator = UNet().to(device)
discriminator = Discriminator().to(device)


optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Load the dataset
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = ImageColorizationDataset(dataset)
train_loader = DataLoader(dataset=train_dataset, batch_size=1, shuffle = True, pin_memory = True)
img_list = []
# Train the generator and discriminator
for epoch in range(100):
    for i, (x, y) in enumerate(train_loader):

        x = x.to(device)
        y = y.to(device)

        input = torch.cat([x, torch.randn(x.size(0),1, 224, 224).to(device)], 1)


        y_fake = generator(input)
        # Train the discriminator
        optimizer_d.zero_grad()


        y_real = torch.ones(x.size(0), 3, 224, 224).to(device)

        y_fake_ = torch.zeros(x.size(0), 3, 224, 224).to(device)

        d_real = discriminator(x, y)

        d_fake = discriminator(x, y_fake)

        loss_d = criterion(d_real, y_real) + criterion(d_fake, y_fake_)
        loss_d.backward()



        optimizer_d.step()

        # Train the generator
        optimizer_g.zero_grad()
        input_1 = torch.cat([x, torch.randn(x.size(0),1, 224, 224).to(device)], 1)
        y_fake = generator(input_1)

        d_fake = discriminator(x, y_fake)
        loss_g = criterion(d_fake, y_real)
        loss_g.backward()

        optimizer_g.step()

        # Print the loss
        print('Epoch [{}/{}], Step [{}/{}], Loss D: {:.4f}, Loss G: {:.4f}'
              .format(epoch+1, 100, i+1, len(train_loader), loss_d.item(), loss_g.item()))
        img_list.append(torch.cat([x,y_fake], 1))


        torch.cuda.empty_cache()

# Save the model
torch.save(generator.state_dict(), 'generator.pth')

In [None]:
plt.figure(figsize=(30,30))
for i in range(1,16,2):
    plt.subplot(4,4,i)
    img = np.zeros((224,224,3))
    img[:,:,0] = L_df[i]
    print(L_df.shape)
    plt.title('B&W')
    plt.imshow(lab2rgb(img))

    plt.subplot(4,4,i+1)
    img[:,:,1:] = ab_df[i]
    img = img.astype('uint8')
    img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
    plt.title('Colored')
    plt.imshow(img)