In [70]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import pandas as pd

import ssl
import certifi
ssl._create_default_https_context = ssl._create_unverified_context

In [71]:
device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device = " + device)
if device == 'cpu':
    print("WARNING: Using CPU will cause slower train times")

Using device = mps


# Variables

In [81]:
image_size = 256
content_filename = 'cityscape.png'
style_filename = 'cathedral.png'

# image_save_folder = 'perceptual_cityscape_cathedral'

Neural_Style_Layer_List = ['0', '5', '10', '19', '28']

Normalization_Method = 'None' # 'None' or 'imagenet'
loss_method = 'Perceptual' # 'Base', 'Perceptual', 'Wasserstein'

image_save_folder = 'perceptual_cityscape_cathedral'

total_steps = 3100
save_steps = 100
learning_rate = 0.0001 # loss for base 0.01
alpha = 1 
beta = 1 # beta for base 0.01

In [82]:
# if image_save_folder does not exist, create it
import os
if not os.path.exists('final_project_gen_images/' + image_save_folder):
    os.makedirs('final_project_gen_images/' + image_save_folder)
    

# Neural Style Transfer Network

In [83]:
class VGG19(nn.Module):
    def __init__(self, layer_list):
        super(VGG19, self).__init__()
        self.chosen_features = layer_list
        self.model = models.vgg19(pretrained=True).features[:29]

    def forward(self, x):
        features = []
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.chosen_features:
                features.append(x)
        return features

In [84]:
ns_model = VGG19(Neural_Style_Layer_List).to(device).eval()



# perceptual Neural Network

In [85]:
per_model = models.vgg16(pretrained=True).features.to(device).eval()
# per_model = nn.Sequential(*list(per_model.children())[:-1])  # Remove the classification layers
# per_model.eval()  # Set to evaluation mode



# Load Image Function

In [90]:
def load_image(image_name, image_size=256):
    if Normalization_Method == 'None':
        loader = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
    ])
        
    elif Normalization_Method == 'imagenet':
        loader = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    image = Image.open(image_name).convert('RGB')
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [91]:
# Load pre-trained VGG16 model for perceptual loss
class VGG16Features(nn.Module):
    def __init__(self):
        super(VGG16Features, self).__init__()
        vgg16_model = models.vgg16(pretrained=True).features
        self.layers = nn.Sequential(
            vgg16_model[0], vgg16_model[1], vgg16_model[2], vgg16_model[3],
            vgg16_model[4], vgg16_model[5], vgg16_model[6], vgg16_model[7]
        )
        
    def forward(self, x):
        return self.layers(x)

# Function to extract features using VGG16
def extract_features(img, model, device):
    model = model.to(device)
    img = img.to(device)
    with torch.no_grad():
        features = model(img)
    return features

# Neural Style Transfer Code

In [92]:
content = load_image("final_project_images/content_images/" + content_filename, image_size)
style = load_image("final_project_images/style_images/" + style_filename, image_size)

In [93]:
# Extract features for perceptual loss
per_model = VGG16Features().to(device).eval()
content_features_per = extract_features(content, per_model, device)
style_features_per = extract_features(style, per_model, device)



In [94]:
generated_image = content.clone().requires_grad_(True).to(device)
optimizer = optim.Adam([generated_image], lr=learning_rate, betas=[0.5, 0.999])

In [95]:
loss_values = []
step_list = []
for step in range(total_steps):
    generated_features = ns_model(generated_image)
    content_features = ns_model(content)
    style_features = ns_model(style)
    
    style_loss = 0
    content_loss = 0
    
    for gen_feature, content_feature, style_feature in zip(generated_features,content_features, style_features):
        batch_size, channel, height, width = gen_feature.shape
        
    
        # Compute Gram Matrix
        G_gen = gen_feature.view(channel, height*width).mm(gen_feature.view(channel, height*width).t())
        G_style = style_feature.view(channel, height*width).mm(style_feature.view(channel, height*width).t())
        
        # Compute Loss
        if loss_method == 'Base':
            content_loss += torch.mean((gen_feature - content_feature)**2)
            style_loss += torch.mean((G_gen - G_style)**2)
        elif loss_method == 'Perceptual':
            gen_features_per = per_model(generated_image)
            content_loss += 100 * torch.mean((gen_features_per - content_features_per) ** 2)
            style_loss += 100 * torch.mean((gen_features_per - style_features_per) ** 2)
        elif loss_method == 'Wasserstein':
            content_loss += torch.mean(generated_image) - torch.mean(content)
            style_loss += torch.mean(generated_image) - torch.mean(style)
        elif loss_method == 'total_variation':
            content_loss += torch.mean(torch.abs(generated_image[:, :, :, 1:] - generated_image[:, :, :, :-1])) + torch.mean(torch.abs(generated_image[:, :, 1:, :] - generated_image[:, :, :-1, :]))
            style_loss += torch.mean(torch.abs(generated_image[:, :, :, 1:] - generated_image[:, :, :, :-1])) + torch.mean(torch.abs(generated_image[:, :, 1:, :] - generated_image[:, :, :-1, :]))
            
       
        
    total_loss = alpha*content_loss + beta*style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if step %  save_steps == 0:
        loss_value = total_loss.item()
        loss_values.append(loss_value)
        step_list.append(step)
        print(f'step {step}: {loss_value}')
        image_name = "final_project_gen_images/" + image_save_folder + "/" + str(step) + ".png"
        save_image(generated_image, image_name)
        
# Save loss values
print('saving loss values')
df = pd.DataFrame(list(zip(step_list, loss_values)), columns =['Step', 'Loss'])
df.to_csv("final_project_gen_images/" + image_save_folder + "/loss_values.csv", index=False)

step 0: 5644.9384765625
step 100: 4887.3173828125
step 200: 4427.0869140625
step 300: 4147.41064453125
step 400: 3956.109375
step 500: 3816.333984375
step 600: 3709.689453125
step 700: 3625.794921875
step 800: 3558.0947265625
step 900: 3502.44775390625
step 1000: 3455.856689453125
step 1100: 3416.4052734375
step 1200: 3382.4970703125
step 1300: 3352.92041015625
step 1400: 3327.178955078125
step 1500: 3304.451171875
step 1600: 3284.17578125
step 1700: 3266.0703125
step 1800: 3249.911376953125
step 1900: 3235.38623046875
step 2000: 3222.265625
step 2100: 3210.44189453125
step 2200: 3199.79345703125
step 2300: 3190.123046875
step 2400: 3181.20947265625
step 2500: 3173.14892578125
step 2600: 3165.79296875
step 2700: 3159.01953125
step 2800: 3152.7568359375
step 2900: 3146.99267578125
step 3000: 3141.72607421875
saving loss values
