In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torchvision import models
from torchvision.utils import save_image
from PIL import Image

In [None]:
# initialize our model (VGG with the fully connected removed and the weights frozen)
class VggBackbone(nn.Module):
  def __init__(self):
    super(VggBackbone, self).__init__()
    self.chosen_layer_outputs = ["0", "5", "10", "19", "28"]
    self.model = models.vgg19(pretrained=True).features[:29]
  
  def forward(self, x):
    features = []
    for layer_idx, layer in enumerate(self.model):
      x = layer(x)
      if str(layer_idx) in self.chosen_layer_outputs:
        features.append(x)
    
    return features

In [None]:
def load_image_tensor(image_path, transforms, device):
  image = Image.open(image_path)
  image_tensor = transforms(image).unsqueeze(0)
  return image_tensor.to(device)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
image_dim = 356

transforms = T.Compose([T.Resize((image_dim, image_dim)), T.ToTensor()])

content = load_image_tensor("content2.jpg", transforms, DEVICE)
style = load_image_tensor("style.jpg", transforms, DEVICE)
generated = content.clone().requires_grad_(True)


model = VggBackbone().to(DEVICE).eval()

# Hyperparameters
EPOCHS = 6001
LR = 0.001
alpha = 1
beta = 0.1
optimizer = optim.Adam([generated], lr=LR)



In [None]:
# Training Loop
for epoch in range(EPOCHS):
  content_features = model(content)
  style_features = model(style)
  gen_features = model(generated)
  SG_loss = 0
  CG_loss = 0

  for (content_feat, style_feat, gen_feat) in zip(content_features, style_features, gen_features):
    N, C, H, W = gen_feat.shape
    CG_loss += torch.mean((gen_feat - content_feat)**2)

    # Make Gram Matrices:
    style_gram = style_feat.view(C, H*W).mm(style_feat.view(C, H*W).t())
    gen_gram = gen_feat.view(C, H*W).mm(gen_feat.view(C, H*W).t())
    SG_loss = torch.mean((gen_gram - style_gram)**2)
    
  loss = alpha*CG_loss + beta*SG_loss
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  if epoch % 1000 == 0:
    print("Loss : ", loss)
    save_image(generated, f"generated{epoch//1000}.png")

Loss :  tensor(471415.4062, device='cuda:0', grad_fn=<AddBackward0>)
Loss :  tensor(558.2444, device='cuda:0', grad_fn=<AddBackward0>)
Loss :  tensor(395.1545, device='cuda:0', grad_fn=<AddBackward0>)
Loss :  tensor(322.5838, device='cuda:0', grad_fn=<AddBackward0>)
Loss :  tensor(306.1647, device='cuda:0', grad_fn=<AddBackward0>)
Loss :  tensor(251.8336, device='cuda:0', grad_fn=<AddBackward0>)
Loss :  tensor(239.7029, device='cuda:0', grad_fn=<AddBackward0>)
