# Imports

In [1]:
# imports
import numpy as np

import os 
import torch
import pandas as pd
from skimage import io, transform, color
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, utils
from PIL import Image

import dataLoader
from dataLoader import PhotoDataset, Rescale, Rotate, ToTensor, ToGreyNormalize, ColorJitter
from torch.utils.data.sampler import SubsetRandomSampler
from torchsummary import summary

#optimizer
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import cv2
import torchvision.models as models



# Prep Data

In [2]:
#gather data set

#van-gogh from: https://www.kaggle.com/ipythonx/van-gogh-paintings

image_names =  []
path = "training_data/van_gogh/"
for file in os.listdir(path):
    if file.endswith(".jpg"):
        image_names.append(file)
image_names = np.asarray(image_names)    

In [None]:
path = "training_data/van_gogh/"
transformed_dataset_train = PhotoDataset(image_names = image_names,
                                           root_dir=path,
                                           transform= [#ToGreyNormalize(),
                                                       Rescale((512,512)),
                                                       ToTensor()])

In [3]:
path = "content_images/"
content_names = ['nutmeg.jpg']
content_dataset = PhotoDataset(image_names = content_names,
                                           root_dir=path,
                                           transform= [#ToGreyNormalize(),
                                                       transforms.Resize((512,512)),
                                                       transforms.ToTensor(),
                                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                                             [0.229, 0.224, 0.225])])

In [10]:
transform = transforms.Compose([transforms.Resize((512,512)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],  
                                 [0.229, 0.224, 0.225])])

In [13]:
#open content img
content_img = Image.open("content_images/nutmeg.jpg")
content_image = transform(content_img).unsqueeze(0)

In [None]:
#read the images
style_img = Image.open("training_data/van_gogh/")

In [None]:
# Sampled image from your dataloader visualized with ground-truth keypoints.
random_indx = [2, 36, 47]
for i in random_indx:
    sample = transformed_dataset_train[i]
    im = sample['image']
    image = im.data
    image = image.numpy()
    s =  image.shape
    print(s)
    image = np.reshape(image, (s[1], s[2], 3))
    #have to have this to save the image 
    implot = plt.imshow(image)
    plt.show()

In [25]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
        # load the vgg model's features
        self.vgg = models.vgg19(pretrained=True).features
    
    def get_content_activations(self, x: torch.Tensor) -> torch.Tensor:
        """
            Extracts the features for the content loss from the block4_conv2 of VGG19
            Args:
                x: torch.Tensor - input image we want to extract the features of
            Returns:
                features: torch.Tensor - the activation maps of the block4_conv2 layer
        """
        features = self.vgg[:23](x)
        return features
    
    def get_style_activations(self, x) -> torch.Tensor:
        """
            Extracts the features for the style loss from the block1_conv1, 
                block2_conv1, block3_conv1, block4_conv1, block5_conv1 of VGG19
            Args:
                x: torch.Tensor - input image we want to extract the features of
            Returns:
                features: list - the list of activation maps of the block1_conv1, 
                    block2_conv1, block3_conv1, block4_conv1, block5_conv1 layers
        """
        features = [self.vgg[:4](x)] + [self.vgg[:7](x)] + [self.vgg[:12](x)] + [self.vgg[:21](x)] + [self.vgg[:30](x)] 
        return features
    
    def forward(self, x):
        return self.vgg(x)

# Model

In [14]:
#load initial model
vgg19 = models.vgg19(pretrained = True)

In [15]:
#save avgpool layer
avgPool = vgg19.avgpool

#only keep feature space
vgg19 = vgg19.features

In [16]:
#change max pool layers to avg pool layers
for i, child in vgg19.named_children():
    if isinstance(child, nn.MaxPool2d):
        vgg19[int(i)] = avgPool

In [None]:
sx = (s[2], s[0], s[1])

In [None]:
summary(vgg19, sx)

In [18]:
# stationary feature extractor
for param in vgg19.parameters():
    param.requires_grad  = False

In [19]:
#extract content features (conv4_2)
content_feature = vgg19[:22]

#extract style features(conv1_1, conv2_1, conv3_1, conv4_1, conv5_1)
style_features = [vgg19[:0]]+[vgg19[:5]]+[vgg19[:10]]+[vgg19[:19]]+[vgg19[:28]]

In [None]:
content_feature

In [None]:
content_im.unsqueeze(0).shape

In [None]:
content_dataset[0]

In [22]:

content_act = vgg19[:22](content_image)
content_act = content_act.view(512, -1)

In [23]:
content_act

tensor([[-0.4922, -1.3637, -1.9718,  ...,  1.2022, -2.1436, -3.2753],
        [ 0.0407, -0.7288, -1.4203,  ..., -3.3852, -2.1745, -0.3340],
        [ 0.1702, -3.6247, -5.0771,  ..., -1.0185,  0.8282,  1.0640],
        ...,
        [ 4.0950,  3.9521,  2.3959,  ..., -2.3100, -1.6659,  0.9146],
        [ 0.4880, -1.8757, -2.7719,  ..., -0.7039,  0.2067, -0.9594],
        [-3.5972, -5.5392, -1.6151,  ..., -3.5265, -3.4930, -0.2222]])

In [None]:
# apply changes to model
relu = torch.nn.modules.activation.ReLU
maxpool = torch.nn.modules.pooling.MaxPool2d
new_features = []
for feature in vgg19.features:
    
    # get rid of RELU layers
    if type(feature) != relu:
        
        # change max_pool layers to avgpool
        if  type(feature) == maxpool:
            new_features.append(vgg19.avgpool)
        else:
            new_features.append(feature)
            
vgg19 = nn.Sequential(*new_features)          

In [None]:
vgg19 = models.vgg19(pretrained=True)

In [None]:
#prediction loss
criterion = nn.MSELoss() # mean squared error loss (torch.nn.MSELoss)

In [None]:
# create your optimizer
optimizer = optim.Adam(vgg19.parameters(), lr=0.001)

In [None]:
# Training

In [None]:
len(image_names)

In [None]:
#split into train and validation
indices = list(range(len(transformed_dataset_train)))
train_indices, val_indices = indices[:350], indices[350:]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

#Create Loaders
train_loader = torch.utils.data.DataLoader(transformed_dataset_train, batch_size=1, 
                                           sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(transformed_dataset_train, batch_size=1,
                                                sampler=valid_sampler)

In [None]:
# Training
#train
training_loss = []
validation_loss = []

for epoch in range(25):  # loop over the dataset multiple times
    running_train_loss = 0.0
    running_val_loss = 0.0
    
    
    #TRAINING
    for i, sample in enumerate(train_loader):
        # get the inputs; data is a list of [inputs, labels]
        image = sample['image']
        image = image.type(torch.FloatTensor)
        print(image.shape)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = vgg19(image)
        print(outputs.shape)
        
        #print(outputs)
        loss = criterion(outputs)
        
        #back propagation
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_train_loss += loss.item()

     
        
    #store avg loss
    loss_t = running_train_loss/len(train_loader)
   # loss_v = running_val_loss/len(val_loader)
    
    training_loss.append(loss_t)
   # validation_loss.append(loss_v)
    print("Epoch: " + str(epoch) + " Training_Loss: " + str(loss_t))# + " Val_Loss: " + str(loss_v))
        
print('Finished Training')

In [None]:
vgg19
for i, child

In [None]:
vgg19[:21]

In [None]:
#set content layers [conv4_1] [4 = block]
content_layers = vgg[]

In [None]:
for name, child in vgg19.named_children():
    print(name, child)

In [None]:
#remove fully connected layers(ie only feature space)
#run the pre-trained model as a fixed feature extractor
#and then use the resulting features to train a new classifier.

In [None]:
#only want the feature space 
vgg19 = models.vgg19(pretrained=True).features

In [None]:
vgg19.summary()

In [None]:
summary(vgg19, (3, 600, 600))


In [None]:
# Tweak the model
#  - avgpool instead of max pool
#  - #content reconstruction : conv1 1, conv2 1, conv3 1, conv4 1, conv5 1
#  - #style reconstruction

In [None]:
# Higher layers in the network capture the high-level content
# feature responses in higher layers of the network is the content representation

In [None]:
# reconstructions from the lower layers simply reproduce the exact pixel values 
# of the original image 

In [None]:
#create white noise image

In [None]:
# gram matrix

In [None]:
#for style representation
# use gradient descent from a white noise image to find another image that 
# matches the style representation of the original image


# minimising the mean-squared distance between the entries of the 
# Gram matrix from the original image 
# and the Gram matrix of the image to be generated

In [None]:
# mix the content of a photograph with the style
# jointly minimise the distance of a white noise 
# image from the content representation of the photograph in one layer of the network 
# and the style representation of the painting in a number of layers of the CNN

# content reconstructions

In [None]:
# perform gradient descent on a white noise image to 
# - > find another image that matches the feature responses of the original image
