In [2]:
import os
import time
import glob
import sys
import random
import warnings

import numpy as np
import pandas as pd
import cv2

import matplotlib.pyplot as plt

from tqdm import tqdm
from itertools import chain
import skimage
from PIL import Image
from skimage.io import imread, imshow, imread_collection, concatenate_images,imsave
from skimage.transform import resize
from skimage.util import crop, pad
from skimage.morphology import label
from skimage.color import rgb2gray, gray2rgb, rgb2lab, lab2rgb
from sklearn.model_selection import train_test_split

import keras
from keras.applications.xception import Xception
from keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input
from keras.models import Model, load_model,Sequential
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Input, Dense, UpSampling2D, RepeatVector, Reshape
from keras.layers.core import Dropout, Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras import backend as K

import tensorflow as tf
from tensorflow.python.client import device_lib
from tensorflow.python.ops.numpy_ops import np_config
        

warnings.filterwarnings('ignore', category=UserWarning, module='skimage')
seed = 42
random.seed = seed
np.random.seed = seed

In [3]:
from skimage import io
import os

# Create New Directories
for name in ['train', 'validation','test']:
    path = "dataset/" + name + "/"
    os.makedirs(path, exist_ok=True)


Extracting the files from the zip to the folder

In [4]:
filePath = '/content/drive/MyDrive/Datasets/GCScrapped/gc-dataset.zip'

In [5]:
!mkdir gcImages

In [None]:
!unzip /content/drive/MyDrive/Datasets/GCScrapped/gc-dataset.zip -d /content/gcImages

### --- 

In [7]:
dataset_path = '/content/gcImages/content/gc-dataset'
imgDirPath = '/content/gcImages/content/gc-dataset/train'
osSep = '/'
dataPath = imgDirPath + osSep + '*.jpg'

In [8]:
# Setting up Breakpoints
num_train = 1000
num_val = 200
num_test = 100

files = os.listdir(imgDirPath)
index = 0
for i,image in enumerate(files):

    test = io.imread(imgDirPath + osSep + image)
    if test.ndim != 3:
        continue

    # Pick what folder to place image into
    if i < num_train:
        os.rename(imgDirPath + osSep + image, "./dataset/train/" + image)
    elif i < (num_train + num_val):
        os.rename(imgDirPath + osSep + image, "./dataset/validation/" + image)
    elif i < (num_train + num_val + num_test):
        os.rename(imgDirPath + osSep + image, "./dataset/test/" + image)
    else:
        break
    index += 1


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [10]:
# Torch dataloader likes having a class instance that represents the data
# We preprocess images here
# ref: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
# ref: https://discuss.pytorch.org/t/how-to-load-images-without-using-imagefolder/59999

from skimage.transform import resize

class ModelData(torch.utils.data.Dataset):

    def __init__(self, base_dir):
        self.base_dir = base_dir
        self.all_imgs = os.listdir(base_dir)

    def __len__(self):
        return len(self.all_imgs)

    def __getitem__(self, idx):

        # Import Image and get LAB Image
        img_name = os.path.join(self.base_dir, self.all_imgs[idx])
        image = io.imread(img_name)
        image  = image / 255.0
        lab = rgb2lab(image)

        # L Channel
        L = np.expand_dims(lab[:,:,0], axis=2)
        L /= 50.0
        L -= 1.0

        # AB Channels
        AB = lab[:,:,1:].transpose(2,0,1).astype(np.float32)
        AB /= 128.0

        # Scale For Inception
        L_inc = resize(np.repeat(L,3,axis=2), (299, 299)).transpose(2,0,1).astype(np.float32)

        # Scale for Encoder
        L_enc = resize(np.repeat(L,3,axis=2), (256, 256)).transpose(2,0,1).astype(np.float32)

        # Build Sample Dict.
        sample = {"L":L.transpose(2,0,1).astype(np.float32),
                 "L_inc":L_inc, "L_enc":L_enc, "AB":AB,
                 "RGB": image}

        return sample


In [11]:
# PyTorch Loaders for Train and Validation

# Get Dataset Objects
train_images = ModelData(base_dir="./dataset/train")
val_images = ModelData(base_dir="./dataset/validation")
test_images = ModelData(base_dir="./dataset/test")

# Pass Into Loaders
train_loader = torch.utils.data.DataLoader(train_images, batch_size=40, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_images, batch_size=40, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_images, batch_size=40)


In [12]:
class ImgEncoder(nn.Module):

    def __init__(self):
        super(ImgEncoder, self).__init__()

        self.layers = nn.Sequential(
            # Conv1
            nn.Conv2d(3, 64, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),

            # Conv2
            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),

            # Conv3
            nn.Conv2d(128, 128, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),

            # Conv4
            nn.Conv2d(128, 256, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256),

            # Conv5
            nn.Conv2d(256, 256, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256),

            # Conv6
            nn.Conv2d(256, 512, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(512),

            # Conv7
            nn.Conv2d(512, 512, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(512),

            # Conv8
            nn.Conv2d(512, 256, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256),
        )

    def forward(self, x):
        return self.layers(x)


In [13]:
import torchvision.models as models

inception = models.inception_v3(pretrained=True)


Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

In [14]:
class ImgFusion(nn.Module):

    def __init__(self):
        super(ImgFusion, self).__init__()
        # In practice nothing here

    def forward(self, img1, img2):

        img2 = torch.stack([torch.stack([img2],dim=2)],dim=3)
        img2 = img2.repeat(1, 1, img1.shape[2], img1.shape[3])
        return torch.cat((img1, img2),1)


In [15]:
class ImgDecoder(nn.Module):

    def __init__(self):
        super(ImgDecoder, self).__init__()

        self.layers = nn.Sequential(

            # Conv1
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),

            # Upsample1
            nn.Upsample(scale_factor=2.0),

            # Conv2
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),

            # Conv3
            nn.Conv2d(64, 64, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),

            # Upsample2
            nn.Upsample(scale_factor=2.0),

            # Conv4
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),

            # Conv5
            nn.Conv2d(32, 2, 3, stride=1, padding=1),
            nn.Tanh(),

            # Upsample3
            nn.Upsample(scale_factor=2.0),
        )

    def forward(self, x):
        return self.layers(x)


In [16]:
class ColorNet(nn.Module):

    def __init__(self):
        super(ColorNet, self).__init__()
        self.encoder = ImgEncoder()
        self.fusion = ImgFusion()
        self.decoder = ImgDecoder()
        self.post_fuse = nn.Conv2d(1256, 256, 1, stride=1, padding=0)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, img1, img2):

        # Encoder Output
        out_enc = self.encoder(img1)

        # Fusion
        temp = self.fusion(out_enc, img2)
        temp = self.post_fuse(temp)
        temp = self.relu(temp)

        return self.decoder(temp)


In [17]:
model = ColorNet()

In [18]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=0.0012, weight_decay=1e-6)

In [19]:
criterion = nn.MSELoss()

In [20]:
train_batch_loss = []
train_epoch_loss = []

In [23]:
epochs = 50

for i in range(epochs):

    train_total_loss = 0
    val_total_loss = 0
    print("Epoch:", i+1)

    model.train()

    # Train
    for data in train_loader:

        # Move Data to GPU
        enc_in = data["L_enc"].cuda()
        inc_in = data["L_inc"].cuda()
        AB = data["AB"].cuda()

        # Init. Optim. Params.
        optimizer.zero_grad()

        # Forward Prop.
        # Get Inception Output
        out_incept = inception(inc_in)
        # Get Network AB
        net_AB = model(enc_in, out_incept)

        # Determine Loss
        loss = criterion(net_AB, AB)

        # Back Prop.
        loss.backward()

        # Update Weights
        optimizer.step()

        # Update Loss Saves
        train_batch_loss.append(loss.item())
        train_total_loss += loss.item()

    train_epoch_loss.append(train_total_loss)

    # Print Info Every Epoch
    print("Train Loss: ", train_total_loss)
    # print("Val. Loss: ", val_total_loss)


Epoch: 1


Process Process-5:
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Process Process-6:
  File "/usr/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 218, in _worker_loop
    random.seed(seed)
Traceback (most recent call last):
TypeError: 'int' object is not callable
  File "/usr/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/usr/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 218, in _worker_loop
    random.seed(seed)
TypeError: 'int' object is not callable


RuntimeError: ignored

In [None]:
checkpoint = {'model': ColorNet(),
          'state_dict': model.state_dict(),
          'optimizer' : optimizer.state_dict()}

torch.save(checkpoint, 'checkpoint.pth')


In [None]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False

    model.eval()
    return model

model = load_checkpoint('checkpoint.pth')
