### Data

In [148]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
from torchvision import datasets, transforms

import albumentations as A
from albumentations.pytorch import ToTensor
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import glob
import random
import cv2
import imutils
from skimage import color
from skimage.util.dtype import convert
import import_ipynb #pip install import-ipynb

In [158]:
# Another dataset class incase mflickr is not enough data, based on a subset of Imagenet
class ImageNet(Dataset):
    '''
    Constructor 
    Inputs:
        root -> root dir of images,
        ext -> optional extension of images, default jpg
        size -> optional, crops data size for faster testing
    '''
    def __init__(self, root, ext='.JPEG', size=None, transform=None):
        self.paths = glob.glob(f'{root}/*{ext}', recursive=True)
        self.size = size
        self.root = root
        self.transform = transform
        self.height = 256
        self.width = 256
        
        if size: #speed up testing
            self.paths = self.paths[:size]
    
    '''
    get the (l,ab) pair for image i
    '''
    def __getitem__(self, i):
        s = self.paths[i]
        img = (Image.open(s)).convert('RGB')
        if self.transform:
            img = img.resize((self.height, self.width))
            img = convert(img, np.float32) # [256,256,3]
            img = color.rgb2lab(img)
            img = self.transform(img) # [3, 256, 256]
            
        l = (img[0,:,:])[None,:,:]
        ab = img[1:,:,:]
        return l, ab
    
    '''
    get the l,ab pair for image i, no transforms
    '''
    def get_original(self, i):
        s = self.paths[i]
        img = (Image.open(s)).convert('RGB')
        img = color.rgb2lab(img)
        img = self.transform(img) # [3, 256, 256]
            
        l = (img[0,:,:])[None,:,:]
        ab = img[1:,:,:]
        return l, ab
    
    '''
    return length of dataset
    '''
    def __len__(self):
        return(len(self.paths))
    
    '''
    Input:
        tens_orig_l -> [batch_size, 1, height, width]
        out_ab      -> [batch_size, 2, 256,    256]
    tens_orig_l is the original images L dimension that is obtained from pre processing
    out_ab is the 2-channel prediction outputted by the model
    '''
    def postprocess_batch(tens_orig_l, out_ab):
        out_shape = out_ab.shape[2:]
        orig_shape = tens_orig_l.shape[2:]
        
        # resize the model prediction to original image size using interpolate
        if(out_shape!=orig_shape):
            out_ab_orig = F.interpolate(out_ab, size=orig_shape, mode='bilinear')
        
        concat = torch.cat((tens_orig_l, out_ab_orig), dim=1) # [batch_size, 3, height, width]
        return concat 
    
    '''
    This function takes in an image and resizes it to 256 by 256
    '''
    def resize_img(self, img):
        return (img.resize((self.height, self.width)))
    
    '''
    take a random index i, get the original image, and the predicted image
    print them side by side
    '''
    def check_output(self, i, model):
        l_org, ab_org = self.get_original[i]
        img_original = self.rebuild_image(l_org, ab_org)
        
        l,ab = self[i]
        pred_ab = model(l)
        img_pred = self.rebuild_iamge(l,pred_ab) # needs to be upsized with post_process function still
        return img_original, img_pred
        
    '''
    Selects random start point in the dataset and prints 7 images
    '''
    def print_samples(self):
        figure, axes = plt.subplots(1, 7, figsize=(18,10))
        axes = axes.flatten()
        max_idx = self.size if self.size!=None else len(self.paths)
        i = random.randint(0, max_idx)
        for axis in axes:
            l, ab = self[i]
            img = self.rebuild_image(l,ab)
            axis.imshow(img)
            axis.set_xlabel(i)
            i+=1
        plt.show()
    
    '''
    prints all three LAB channels of a given image index.
    also demonstrates how to rebuild original image from l,a,b channels
    '''
    def print_lab_channels(self,index):
        img_l, img_ab = self[index]
        
        img_lab = torch.concat((img_l,img_ab), dim=0).permute(1,2,0)
        l = torch.Tensor(img_l)
        a = torch.Tensor(img_ab[0,:,:])[None,:,:]
        b = torch.Tensor(img_ab[1,:,:])[None,:,:]
        
        figure, axes = plt.subplots(1, 5, figsize=(18,10))
        axes = axes.flatten()
        
        # Plot lab iamge and all 3 channels seperately
        axes[0].imshow(img_lab) # currently weird colors because img_lab contains negative values
        axes[0].set_xlabel("LAB image")
        axes[1].imshow(l.permute(1,2,0),cmap="gray")
        axes[1].set_xlabel("L Channel")
        axes[2].imshow(a.permute(1,2,0))
        axes[2].set_xlabel("A Channel")
        axes[3].imshow(b.permute(1,2,0))
        axes[3].set_xlabel("B Channel")
        
        orig = self.rebuild_image(img_l,img_ab)
        axes[4].imshow(orig)
        axes[4].set_xlabel("Rebuilt image")
        
        plt.tight_layout()
        plt.show()

    '''
    takes an index and prints the original plus the L channel image
    side by side
    '''
    def print_side_by_side(self,i):
        l,ab = self[i]
        
        img = self.rebuild_image(l,ab)
        l = l.permute(1,2,0) # original size l

        figure, axes = plt.subplots(1, 2, figsize=(18,10))
        axes = axes.flatten()
        axes[0].imshow(img)
        axes[1].imshow(l, cmap="gray")
        plt.tight_layout()
        plt.show()
    
    '''
    rebuilds original rgb image from l,ab inputs
    '''
    def rebuild_image(self, l, ab):
        return color.lab2rgb(torch.cat((l,ab), dim=0).permute(1,2,0))
    
    '''
    Calculates average width and height of all images in dataset
    '''
    def calc_average_dimension(self):
        totalw = 0
        totalh = 0
        for i in range(len(self.paths)):
            image = Image.open(self.paths[i])
            w, h = image.size
            totalw += w
            totalh += h
        avgw = totalw//len(self.paths)
        avgh = totalh//len(self.paths)
        return avgw, avgh

In [159]:
trans = transforms.Compose([transforms.ToTensor()])
ImageNet_train_dataset = ImageNet(root="./imagenet/train", transform=trans)
ImageNet_eval_dataset = ImageNet(root="./imagenet/val", transform=trans)

In [160]:
train_dataloader_imagenet = DataLoader(ImageNet_train_dataset, batch_size=32, shuffle=True)
eval_dataloader_imagenet = DataLoader(ImageNet_eval_dataset, batch_size=32, shuffle=True)

### Testing LAB conversion for ImageNet

In [161]:
temp_train = ImageNet(root="./imagenet/train", transform=trans, size=20)
temp_eval = ImageNet(root="./imagenet/val", transform=trans, size=20)
train_dataloader_imagenet2 = DataLoader(temp_train, batch_size=8, shuffle=True)
eval_dataloader_imagenet2 = DataLoader(temp_eval, batch_size=8, shuffle=True)

# Class for the mirFlickr dataset

In [162]:
# This code was based on the FlowerDataset class we did in Assignment 2
class VisionDataset(Dataset):
    '''
    Constructor 
    Inputs:
        root -> root dir of images,
        ext -> optional extension of images, default jpg
        size -> optional, crops data size for faster testing
    '''
    def __init__(self, root, ext='.jpg', size=None, transform=None):
        self.root = root
        self.paths = glob.glob(f'{root}/*{ext}', recursive=True)
        self.dataset = []         # array of tuples (black/white , original)
        self.size = size if size else len(self.paths)
        self.transform = transform
        

    '''
    return length of dataset
    '''
    def __len__(self):
        return(len(self.paths))
    
    '''
    get the img,label tuple corresponding to index i
    '''
    def __getitem__(self, i):
        s = (self.paths[i])
        img = Image.open(s)
        if self.transform:
            img = self.transform(img)
        return img
    
    
    '''
    Selects random start point in the dataset and prints 7 images
    Inputs:
        color-> 1 to print color images, 0 for grayscale images
    '''
    def print_samples(self, color:int):
        figure, axes = plt.subplots(1, 7, figsize=(18,10))
        axes = axes.flatten()
        i = random.randint(0, self.size-1000)
        print(i, self.size)

        for axis in axes:
            x = self[i]
            if color:
                axis.imshow(x)
            else:
                axis.imshow(x,cmap="gray")
            label = self.paths[i]
            axis.set_xlabel(label)
            i+=1
        plt.show()
    
    '''
    Calculates average width and height of all images in dataset
    '''
    def calc_average_dimension(self):
        totalw = 0
        totalh = 0
        for i in range(len(self.paths)):
            image = Image.open(self.paths[i])
            w, h = image.size
            totalw += w
            totalh += h
        avgw = totalw//len(self.paths)
        avgh = totalh//len(self.paths)
        return avgw, avgh

In [163]:
flickr_dataset = VisionDataset("./mirflickr", ext=".jpg")
train_dataloader_vision = DataLoader(flickr_dataset, batch_size=32, shuffle=True)
validation_dataloader = None