In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython import embed
from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(669)

### Hooks + VGG19

In [None]:
def show_params(model):
    for name, params in model.named_modules():
        print(name)
        print(params)
        print()   

# https://discuss.pytorch.org/t/how-can-l-load-my-best-model-as-a-feature-extractor-evaluator/17254/6
acts = {}
handles = []
def get_activation(name, detach):
    def hook(model, inputs, outputs):
        if detach:
            acts[name] = outputs.detach()
        else:
            acts[name] = outputs
    return hook

In [None]:
import torchvision.models as models
vgg = models.vgg19(pretrained=True)
# Replace maxpool with avgpool as per: https://stackoverflow.com/a/65429290
for i,layer in vgg.features.named_children():
    if isinstance(layer, nn.MaxPool2d):
        vgg.features[int(i)] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False)
vgg = vgg.to('cuda')
# Select layers
layers = [
          (0,'conv1_1','style'),
          (5,'conv2_1','style'),
          (10,'conv3_1','style'),
          (19,'conv4_1','style'),
          (21,'conv4_2','content'),
          (28,'conv5_1','style')
]
# Register forward hooks on layers
for layer,layer_name,_ in layers:
    h = vgg.features[layer].register_forward_hook(get_activation(layer_name, detach=True))
    handles.append(h)
    # DOESN'T WORK: Also set vgg's learnable parameters to be non-leaf, aka. not learnable:
#     if isinstance(vgg.features[layer],nn.Conv2d):
#         embed()
#         weight = vgg.features[layer].weight.data.detach()
#         weight.requires_grad = True
#         weight.is_leaf = False
#         del vgg.features[layer].weight
#         vgg.features[layer].weight.data = weight

### Load content and style targets

In [None]:
# Load style and content images
a = Image.open("./webb1.png")
a = a.convert('RGB')
# a = Image.open("./sower.png")
p = Image.open("./eee3.png")
p = p.convert('RGB')
# Resize

import PIL.ImageOps
w,h = a.size
a = a.crop((750,0,w/2,h/2))
print(a.size)
print(p.size)
a = a.resize((1000,1000))
p = p.resize((1000,1000))
display(a)
display(p)
# '''
# PIL image to numpy
a = np.asarray(a)
p = np.asarray(p)[:,:,:3]
# Convert to Torch tensor, scale...
a = torch.tensor(a, dtype=torch.float32, requires_grad=False, device='cuda') / 255
p = torch.tensor(p, dtype=torch.float32, requires_grad=False, device='cuda') / 255
# ... and normalize as per: # https://pytorch.org/vision/0.12/models.html
mean = torch.tensor([0.485, 0.456, 0.406], device='cuda')
std = torch.tensor([0.229, 0.224, 0.225], device='cuda')
a = (a - mean) / std
p = (p - mean) / std

# Transpose and get activations
from einops import rearrange
a = rearrange(a, "W H C -> 1 C H W")
p = rearrange(p, "W H C -> 1 C H W")

# Get activation values
vgg(a)
acts_a = acts
acts = {}
vgg(p)
acts_p = acts
acts = {}
for i,(layer,name,layer_type) in enumerate(layers):
    print(f"layer #{layer}, {layer_type}")
    print(f"{name} - {acts_a[name].shape}")
    print(f"{name} - {acts_a[name].requires_grad}")
    assert torch.equal(acts_a[name],acts_p[name])==False
    print()

for h in handles:
    h.remove()
# '''

### Training

In [None]:
def style_loss(acts_x, acts_a, layers):
    losses = []
    for i,(layer,name,layer_type) in enumerate(layers):
        if layer_type=='style':
            x_features = torch.squeeze(acts_x[name], dim=0)
            a_features = torch.squeeze(acts_a[name], dim=0)
            # Flatten spatial dimensions
            x_features = rearrange(x_features, "C H W -> C (H W)")
            a_features = rearrange(a_features, "C H W -> C (H W)")
            # Gram matrix, einsum("C H W, D H W -> C D", x, x)?
            gram_x = torch.einsum("C S, S D -> C D", x_features, x_features.T) # S is for spatial over W*H
            gram_a = torch.einsum("C S, S D -> C D", a_features, a_features.T)
            loss = F.mse_loss(gram_x, gram_a)
            losses.append(loss)
    loss_final = sum(losses) / len(losses)
    return loss_final
            
            
def content_loss(acts_x, acts_p, layers):
    losses = []
    for i,(layer,name,layer_type) in enumerate(layers):
        if layer_type=='content':
            x_features = torch.squeeze(acts_x[name], dim=0)
            p_features = torch.squeeze(acts_p[name], dim=0)
            loss = F.mse_loss(x_features, p_features)
            losses.append(loss)
    loss_final = sum(losses) / len(losses)
    return loss_final

def var_loss(x):
    return torch.mean(torch.abs(x))
#     width_diff = x[:,:,:,1:] - x[:,:,:,:-1] # ie. col_2 - col_1
#     height_diff = x[:,:,1:,:] - x[:,:,:-1,:] # ie. row_2 - row_1
#     diag_diff1 = x[:,:,1:,1:] - x[:,:,:-1,:-1]
#     diag_diff2 = x[:,:,:-1,1:] - x[:,:,1:,:-1]
#     diff_types = [width_diff, height_diff, diag_diff1, diag_diff2]
#     loss = 0.0
#     for t in diff_types:
#         loss = loss + torch.mean(torch.square(t))
#     return loss

In [None]:
def unnormalize(x, std, mean):
    std = std.to('cpu')
    mean = mean.to('cpu')
    x = rearrange(x, "B C H W -> W H (C B)")
    x = x * std[None,None,:] + mean[None,None,:]
    x = np.clip(x, 0.0, 1.0)
    x = x.detach().numpy()*255
    x = x.astype(np.uint8)
#     x = Image.fromarray(x,'RGB')
    return x

def pixel_delta(x_old, x_new):
    x_old = np.asarray(x_old)
    x_new = np.asarray(x_new)
    x_delta = np.abs(x_new - x_old)
    delta_mean, delta_std = np.mean(x_delta), np.std(x_delta)
    print(f"{delta_mean=},\n{delta_std=}")
    return 

In [None]:
trained_x = []

In [None]:
# Initialize x, normalize, make into trainable tensor.
x = p.clone().detach()
x.requires_grad = True  
x = x.to('cuda')

### Set loss weights

In [None]:
optimizer = torch.optim.Adam(params=[x], lr=6e-3, weight_decay=0.)

In [None]:
alpha = 1e-8 # Content
beta = 1e0   # Style
gamma = 1e3  # Pixel variation
# Note: set high content weight to initialize x, then switch to emphasize style.
# This seems to make colors more coherent (ie. blue on sweater, yellow on skin, etc.)
# Finish up with higher gamma to fix artifacts.
print(f"texture var_loss baseline = {var_loss(a)*gamma}")

In [None]:
from tqdm import tqdm
render_old = None
render_new = unnormalize(x.detach().to('cpu'), std, mean)
iter_k = 100
for i in tqdm(range(5000+1)):
    if i%iter_k==0 and i>0:
        render_old = render_new # ie. from (i-100)
        render_new = unnormalize(x.detach().to('cpu'), std, mean) # ie. current value i
    # Register forward hooks on layers
    handles = []
    for layer,layer_name,_ in layers:
        h = vgg.features[layer].register_forward_hook(get_activation(layer_name, detach=False))
        handles.append(h)
        
    # Run x through vgg, extract activations, get style and content losses
    vgg(x)
    content_loss_ = alpha*content_loss(acts, acts_p, layers)
    style_loss_ = beta*style_loss(acts, acts_a, layers)
    var_loss_ = gamma*var_loss(x)
    loss = content_loss_ + style_loss_ + var_loss_
    # Run backprop, etc.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if i%iter_k==0 and i>0:
        # Losses
        print(f"style loss = {style_loss_}")
        print(f"content loss = {content_loss_}")
        print(f"var loss = {var_loss_}")
        # Render before, after, and delta
        render_new = unnormalize(x.detach().to('cpu'), std, mean)
        pixel_delta(render_old, render_new)
        fused = np.concatenate((render_new, render_new-render_old),axis=1)
        trained_x.append(Image.fromarray(fused,'RGB')) # Store progress
        # Display renders
#         render_delta_pil = Image.fromarray(render_new - render_old,'RGB')
        render_new_pil = Image.fromarray(render_new,'RGB')
#         render_old_pil = Image.fromarray(render_old,'RGB')
#         display(render_delta_pil.resize(tuple(int(dim/2) for dim in render_delta_pil.size))) # Delta
        display(render_new_pil.resize(tuple(int(dim/1) for dim in render_new_pil.size)))     # New 

    # Remove handles
    for h in handles:
        h.remove()
    


In [None]:
for render in trained_x:
    display(render)

In [None]:
print(x)