# Imports

In [6]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
import os
%matplotlib inline
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2

from torch.nn.parameter import Parameter
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split

# Data Loader

1. Load images
2. Resize images to 256 x 256
3. Save images as tensors

<span style="color:red"> Why did we only choose to resize images rather than normalizing them? </span>

In [None]:
# Set image size to resize to
IMAGE_SIZE = [256,256]

# Load and resize images
class DataLoader(Dataset):

    def __init__(self, _path, transform=True):

        # Path of where data is located
        self._path = _path
        self.monets = os.listdir(_path + "/monet_jpg")
        self.photos = os.listdir(_path + "/photo_jpg")

        # Memorize path indices for later
        self.monet_indices = dict()
        self.photo_indices = dict()

        ## Add indices to dictionary
        for i, fl in enumerate(self.monets):
            self.monet_indices[i] = fl
        for i, fl in enumerate(self.photos):
            self.photo_indices[i] = fl

        # Resize images and save as a tensor
        if transform:
            # Default 0-1 norm
            self.transform = transforms.Compose((
                transforms.Resize(IMAGE_SIZE, antialias=False),
                transforms.ToTensor(),
                )
            )

    # Helper function to get length of dataset
    def __len__(self):
        return min(len(self.monets), len(self.photos))

    # Helper function to return example images
    def __getitem__(self, index):
        random_index = int(np.random.uniform(0, len(self.monet_indices.keys())))
        monet_src = Image.open(os.path.join(self._path, "monet_jpg", self.monet_indices[index % 300]))   
        photo_src = Image.open(os.path.join(self._path, "photo_jpg", self.photo_indices[random_index]))
        monet_src = self.transform(monet_src)
        photo_src = self.transform(photo_src)
        return photo_src, monet_src

In [8]:
# Load dataset
dataset = DataLoader("./data")

# Weight Normalization
[Weight Normalization: A Simple Reparameterization
to Accelerate Training of Deep Neural Networks
](https://arxiv.org/pdf/1602.07868)

A reparametrization method that separates the magnitude of the weight tensor from its direction $\to$ **speeds up convergence**!

<span style='color:red'> Where is the superclass ***WeightNormlization*** from? </span>

In [None]:
# Define class using pre-existing superclass?
class WeightNormalization(nn.Module):

    def __init__(self, in_channels, epsilon=1e-6):
        super(WeightNormalization, self).__init__()

        # Instantiate random weights
        self.weights = nn.Parameter(torch.randn(in_channels)) 

        # Define 
        self.scaling = nn.Parameter(torch.ones(in_channels)) 
        self.epsilon = epsilon

    def forward(self, x):
        unsqueezed = False
        if x.dim() == 3: 
            x = x.unsqueeze(0) 
            unsqueezed = True
        norm = torch.sqrt(torch.sum(x**2, dim=(2, 3), keepdim=True) + self.epsilon)
        scaled = x / norm
        scaled = scaled * self.scaling.view(
            1, -1, 1, 1
        ) 
        if unsqueezed:  
            scaled = scaled.squeeze(0)

        return scaled

# Model Training