In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import pandas as pd

import ssl
import certifi
ssl._create_default_https_context = ssl._create_unverified_context

In [2]:
device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device = " + device)
if device == 'cpu':
    print("WARNING: Using CPU will cause slower train times")

Using device = mps


# Variables

In [59]:
image_size = 256
content_filename = 'Tuebingen_Neckarfront.jpg'
style_filename = 'Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg'

image_save_folder = 'gatys_original_images_2'

Neural_Style_Layer_List = ['0', '5', '10', '19', '28']

Normalization_Method = 'None' # 'None' or 'imagenet'
loss_method = 'Perceptual' # 'Base', 'Perceptual', 'Wasserstein'

image_save_folder = 'gatys_original_images_2' + '_' + loss_method

total_steps = 3100
save_steps = 100
learning_rate = 0.0001 # loss for base 0.01
alpha = 1 
beta = 1 # beta for base 0.01

In [60]:
# if image_save_folder does not exist, create it
import os
if not os.path.exists('final_project_gen_images/' + image_save_folder):
    os.makedirs('final_project_gen_images/' + image_save_folder)
    

# Neural Style Transfer Network

In [61]:
class VGG19(nn.Module):
    def __init__(self, layer_list):
        super(VGG19, self).__init__()
        self.chosen_features = layer_list
        self.model = models.vgg19(pretrained=True).features[:29]

    def forward(self, x):
        features = []
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.chosen_features:
                features.append(x)
        return features

In [62]:
ns_model = VGG19(Neural_Style_Layer_List).to(device).eval()



# perceptual Neural Network

In [63]:
per_model = models.vgg16(pretrained=True).features.to(device).eval()
# per_model = nn.Sequential(*list(per_model.children())[:-1])  # Remove the classification layers
# per_model.eval()  # Set to evaluation mode



# Load Image Function

In [64]:
def load_image(image_name, image_size=256):
    if Normalization_Method == 'None':
        loader = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
    ])
        
    elif Normalization_Method == 'imagenet':
        loader = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [65]:
# Load pre-trained VGG16 model for perceptual loss
class VGG16Features(nn.Module):
    def __init__(self):
        super(VGG16Features, self).__init__()
        vgg16_model = models.vgg16(pretrained=True).features
        self.layers = nn.Sequential(
            vgg16_model[0], vgg16_model[1], vgg16_model[2], vgg16_model[3],
            vgg16_model[4], vgg16_model[5], vgg16_model[6], vgg16_model[7]
        )
        
    def forward(self, x):
        return self.layers(x)

# Function to extract features using VGG16
def extract_features(img, model, device):
    model = model.to(device)
    img = img.to(device)
    with torch.no_grad():
        features = model(img)
    return features

# Neural Style Transfer Code

In [66]:
content = load_image("gatys_original_images/" + content_filename, image_size)
style = load_image("gatys_original_images/" + style_filename, image_size)

In [67]:
# Extract features for perceptual loss
per_model = VGG16Features().to(device).eval()
content_features_per = extract_features(content, per_model, device)
style_features_per = extract_features(style, per_model, device)



In [68]:
generated_image = content.clone().requires_grad_(True).to(device)
optimizer = optim.Adam([generated_image], lr=learning_rate, betas=[0.5, 0.999])

In [69]:
loss_values = []
step_list = []
for step in range(total_steps):
    generated_features = ns_model(generated_image)
    content_features = ns_model(content)
    style_features = ns_model(style)
    
    style_loss = 0
    content_loss = 0
    
    for gen_feature, content_feature, style_feature in zip(generated_features,content_features, style_features):
        batch_size, channel, height, width = gen_feature.shape
        
    
        # Compute Gram Matrix
        G_gen = gen_feature.view(channel, height*width).mm(gen_feature.view(channel, height*width).t())
        G_style = style_feature.view(channel, height*width).mm(style_feature.view(channel, height*width).t())
        
        # Compute Loss
        if loss_method == 'Base':
            content_loss += torch.mean((gen_feature - content_feature)**2)
            style_loss += torch.mean((G_gen - G_style)**2)
        elif loss_method == 'Perceptual':
            gen_features_per = per_model(generated_image)
            content_loss += 100 * torch.mean((gen_features_per - content_features_per) ** 2)
            style_loss += 100 * torch.mean((gen_features_per - style_features_per) ** 2)
        elif loss_method == 'Wasserstein':
            content_loss += torch.mean(generated_image) - torch.mean(content)
            style_loss += torch.mean(generated_image) - torch.mean(style)
        elif loss_method == 'total_variation':
            content_loss += torch.mean(torch.abs(generated_image[:, :, :, 1:] - generated_image[:, :, :, :-1])) + torch.mean(torch.abs(generated_image[:, :, 1:, :] - generated_image[:, :, :-1, :]))
            style_loss += torch.mean(torch.abs(generated_image[:, :, :, 1:] - generated_image[:, :, :, :-1])) + torch.mean(torch.abs(generated_image[:, :, 1:, :] - generated_image[:, :, :-1, :]))
            
       
        
    total_loss = alpha*content_loss + beta*style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if step %  save_steps == 0:
        loss_value = total_loss.item()
        loss_values.append(loss_value)
        step_list.append(step)
        print(f'step {step}: {loss_value}')
        image_name = "final_project_gen_images/" + image_save_folder + "/" + str(step) + ".png"
        save_image(generated_image, image_name)
        
# Save loss values
print('saving loss values')
df = pd.DataFrame(list(zip(step_list, loss_values)), columns =['Step', 'Loss'])
df.to_csv("final_project_gen_images/" + image_save_folder + "/loss_values.csv", index=False)

step 0: 3726.54248046875
step 100: 3110.81884765625
step 200: 2788.247314453125
step 300: 2602.488037109375
step 400: 2479.765380859375
step 500: 2393.5654296875
step 600: 2329.72705078125
step 700: 2280.6640625
step 800: 2242.005126953125
step 900: 2210.60302734375
step 1000: 2184.590087890625
step 1100: 2162.80322265625
step 1200: 2144.36962890625
step 1300: 2128.5390625
step 1400: 2114.6953125
step 1500: 2102.5888671875
step 1600: 2091.94091796875
step 1700: 2082.481689453125
step 1800: 2073.969970703125
step 1900: 2066.3115234375
step 2000: 2059.37109375
step 2100: 2053.06201171875
step 2200: 2047.291259765625
step 2300: 2042.024658203125
step 2400: 2037.2115478515625
step 2500: 2032.771484375
step 2600: 2028.681396484375
step 2700: 2024.907958984375
step 2800: 2021.41748046875
step 2900: 2018.19140625
step 3000: 2015.191162109375
saving loss values
