### Data

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
from torchvision import datasets, transforms
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import glob
import random
import cv2
import imutils
from skimage import color
from skimage.util.dtype import convert
import import_ipynb #pip install import-ipynb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Another dataset class incase mflickr is not enough data, based on a subset of Imagenet
class ImageNet(Dataset):
    '''
    Constructor 
    Inputs:
        root -> root dir of images,
        ext -> optional extension of images, default jpg
        size -> optional, crops data size for faster testing
    '''
    def __init__(self, root, ext='.JPEG', size=None, transform=None):
        self.paths = glob.glob(f'{root}/*{ext}', recursive=True)
        self.size = size
        self.root = root
        self.transform = transform
        self.height = 256
        self.width = 256
        
        if size: #speed up testing
            self.paths = self.paths[:size]
    
    '''
    get the img at index i
    '''
    def __getitem__(self, i):
        s = self.paths[i]
        img = Image.open(s)
        if self.transform:
            img = img.resize((self.height, self.width))
            img = convert(img, np.float32)
            img = color.rgb2lab(img)
            img = self.transform(img)
            
        l = (img[0,:,:])[None,:,:]
        ab = img[1:,:,:]
        
        return l, ab
    
    '''
    return length of dataset
    '''
    def __len__(self):
        return(len(self.paths))
    
    '''
    Input:
        tens_orig_l -> [batch_size, 1, height, width]
        out_ab      -> [batch_size, 2, 256,    256]
    tens_orig_l is the original images L dimension that is obtained from pre processing
    out_ab is the 2-channel prediction outputted by the model
    '''
    def postprocess_batch(tens_orig_l, out_ab):
        out_shape = out_ab.shape[2:]
        orig_shape = tens_orig_l.shape[2:]
        
        # resize the model prediction to original image size using interpolate
        if(out_shape!=orig_shape):
            out_ab_orig = F.interpolate(out_ab, size=orig_shape, mode='bilinear')
        
        concat = torch.cat((tens_orig_l, out_ab_orig), dim=1) # [batch_size, 3, height, width]
        return concat 
    
    '''
    This function takes in an image and resizes it to 256 by 256
    '''
    def resize_img(self, img):
        return (img.resize((self.height, self.width)))
    
    '''
    Selects random start point in the dataset and prints 7 images
    Inputs:
        color-> 1 to print color images, 0 for grayscale images
    '''
    def print_samples(self):
        figure, axes = plt.subplots(1, 7, figsize=(18,10))
        axes = axes.flatten()
        max_idx = self.size if self.size!=None else len(self.paths)
        i = random.randint(0, max_idx)
        for axis in axes:
            l, ab = self[i]
            img = self.rebuild_image(l,ab)
            axis.imshow(img)
            axis.set_xlabel(i)
            i+=1
        plt.show()
    
    '''
    prints all three LAB channels of a given image index.
    also demonstrates how to rebuild original image from l,a,b channels
    '''
    def print_lab_channels(self,index):
        img_l, img_ab = self[index]
        
        img_lab = torch.concat((img_l,img_ab), dim=0).permute(1,2,0)
        l = torch.Tensor(img_l)
        a = torch.Tensor(img_ab[0,:,:])[None,:,:]
        b = torch.Tensor(img_ab[1,:,:])[None,:,:]
        
        figure, axes = plt.subplots(1, 5, figsize=(18,10))
        axes = axes.flatten()
        
        # Plot lab iamge and all 3 channels seperately
        axes[0].imshow(img_lab) # currently weird colors because img_lab contains negative values
        axes[0].set_xlabel("LAB image")
        axes[1].imshow(l.permute(1,2,0),cmap="gray")
        axes[1].set_xlabel("L Channel")
        axes[2].imshow(a.permute(1,2,0))
        axes[2].set_xlabel("A Channel")
        axes[3].imshow(b.permute(1,2,0))
        axes[3].set_xlabel("B Channel")
        
        orig = self.rebuild_image(img_l,img_ab)
        axes[4].imshow(orig)
        axes[4].set_xlabel("Rebuilt image")
        
        plt.tight_layout()
        plt.show()

    '''
    takes an index and prints the original plus the L channel image
    side by side
    '''
    def print_side_by_side(self,i):
        l,ab = self[i]
        
        img = self.rebuild_image(l,ab)
        l = l.permute(1,2,0) # original size l

        figure, axes = plt.subplots(1, 2, figsize=(18,10))
        axes = axes.flatten()
        axes[0].imshow(img)
        axes[1].imshow(l, cmap="gray")
        plt.tight_layout()
        plt.show()
    
    '''
    rebuilds original rgb image from l,ab inputs
    '''
    def rebuild_image(self, l,ab):
        return color.lab2rgb(torch.cat((l,ab), dim=0).permute(1,2,0))
    
    '''
    Calculates average width and height of all images in dataset
    '''
    def calc_average_dimension(self):
        totalw = 0
        totalh = 0
        for i in range(len(self.paths)):
            image = Image.open(self.paths[i])
            w, h = image.size
            totalw += w
            totalh += h
        avgw = totalw//len(self.paths)
        avgh = totalh//len(self.paths)
        return avgw, avgh

In [3]:
trans = transforms.Compose([transforms.ToTensor()])
ImageNet_train_dataset = ImageNet(root="./imagenet/train", transform=trans)
ImageNet_eval_dataset = ImageNet(root="./imagenet/val", transform=trans)

In [4]:
train_dataloader_imagenet = DataLoader(ImageNet_train_dataset, batch_size=32, shuffle=True)
eval_dataloader_imagenet = DataLoader(ImageNet_eval_dataset, batch_size=32, shuffle=True)

### Testing LAB conversion for ImageNet

In [None]:
temp = ImageNet(root="./imagenet/train", transform=trans, size=1000)

# Class for the mirFlickr dataset

In [None]:
# This code was based on the FlowerDataset class we did in Assignment 2
class VisionDataset(Dataset):
    '''
    Constructor 
    Inputs:
        root -> root dir of images,
        ext -> optional extension of images, default jpg
        size -> optional, crops data size for faster testing
    '''
    def __init__(self, root, ext='.jpg', size=None, transform=None):
        self.root = root
        self.paths = glob.glob(f'{root}/*{ext}', recursive=True)
        self.dataset = []         # array of tuples (black/white , original)
        self.size = size if size else len(self.paths)
        self.transform = transform
        

    '''
    return length of dataset
    '''
    def __len__(self):
        return(len(self.paths))
    
    '''
    get the img,label tuple corresponding to index i
    '''
    def __getitem__(self, i):
        s = (self.paths[i])
        img = Image.open(s)
        if self.transform:
            img = self.transform(img)
        return img
    
    
    '''
    Selects random start point in the dataset and prints 7 images
    Inputs:
        color-> 1 to print color images, 0 for grayscale images
    '''
    def print_samples(self, color:int):
        figure, axes = plt.subplots(1, 7, figsize=(18,10))
        axes = axes.flatten()
        i = random.randint(0, self.size-1000)
        print(i, self.size)

        for axis in axes:
            x = self[i]
            if color:
                axis.imshow(x)
            else:
                axis.imshow(x,cmap="gray")
            label = self.paths[i]
            axis.set_xlabel(label)
            i+=1
        plt.show()
    
    '''
    Calculates average width and height of all images in dataset
    '''
    def calc_average_dimension(self):
        totalw = 0
        totalh = 0
        for i in range(len(self.paths)):
            image = Image.open(self.paths[i])
            w, h = image.size
            totalw += w
            totalh += h
        avgw = totalw//len(self.paths)
        avgh = totalh//len(self.paths)
        return avgw, avgh

In [5]:
# we have two models
#add an import somewhere, we need import import_ipynb
from CUnet import CUNet
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CUNet()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1) #change as needed
loss_function = torch.nn.MSELoss() #change as needed
EPOCHS = 2

importing Jupyter notebook from CUnet.ipynb
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 128, 128])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 2, 256, 256])


In [None]:
from CUnet import CUNet

In [6]:
from eccv16 import eccv16
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = eccv16(device)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1) #change as needed
loss_function = torch.nn.MSELoss() #change as needed
EPOCHS = 20

importing Jupyter notebook from eccv16.ipynb


  validate(nb)


AttributeError: 'eccv16' object has no attribute 'model'

In [None]:
flickr_dataset = VisionDataset("./mirflickr", ext=".jpg")
train_dataloader_vision = DataLoader(flickr_dataset, batch_size=32, shuffle=True)
validation_dataloader = None



In [7]:
#filename = model/type(model)_epoch_avg_validation_loss.pth'


def name_to_epoch(name):
    _, a = name.split("/")
    _, epoch, _ = name.split("_") #name, epoch, avg_l.pth
    return(int(epoch))

paths = glob.glob(f'model/{type(model).__name__}_*.pth')
#only find models that have the same type as the model being trained

paths.sort(key = name_to_epoch) #most recent epoch
start = 0
if len(paths) > 0:
    target = paths[-1]
    start = name_to_epoch(target)
    model.load_state_dict(torch.load(target)) 


model = model.to(device)

for epoch in range(start, EPOCHS):
    model.train()
    train_loss = 0
    for batch_index,(inputs, expected) in enumerate(train_dataloader_imagenet):
        optimizer.zero_grad() 
        inputs = inputs.to(device)
        expected = expected.to(device)
        outputs = model(inputs)
        loss = loss_function(outputs, expected)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*len(expected)
    
    model.eval()
    val_loss = 0
    for batch_index,(inputs, expected) in enumerate(eval_dataloader_imagenet):
        # inputs = torchvision.transforms.ToTensor(train_dataloader_imagenet.resize_image(lightness))
        # expected = torchvision.transforms.ToTensor(train_dataloader_imagenet.resize_image(colors))
        # inputs.type(torch.float32)
        # expected.type(torch.float32)
        print(type(inputs))
        inputs = inputs.to(device)
        expected = expected.to(device)
        outputs = model(inputs)
        loss = loss_function(outputs, expected)
        val_loss += loss.item()*len(expected)
    avg_val_loss = val_loss/len(eval_dataloader_imagenet)
    filename = f'model/{type(model).__name__}_{epoch}_{avg_val_loss}.pth'
    torch.save(model.state_dict(), filename)

  img = convert(img, np.float32)


<class 'torch.Tensor'>
<class 'torch.Tensor'>
