In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

In [None]:
model = models.vgg19(pretrained=True).features
print(model)


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))


Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPo

In [None]:
class VGG(nn.Module):
  def __init__(self):
    super(VGG, self).__init__()
    
    self.choosen_features=['0', '5', '10', '19', '28']
    self.model= models.vgg19(pretrained= True).features[:29]

  def forward(self, x):
    features= []

    for layer_num, layer in enumerate(self.model):
      x= layer(x)

      if str(layer_num) in self.choosen_features:
        features.append(x)
    
    return features

In [None]:
def load_image(image_name):
  image = Image.open(image_name)
  image = loader(image).unsqueeze(0)
  return image.to(device)

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
image_size = 356

loader = transforms.Compose(
    {
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        #transforms.Normalize(mean=[], std=[])
    }
)

In [None]:
orignal_img= load_image("GoldenGate.jpg")
style_img= load_image("starry_night.jpg")


In [None]:
# print(orignal_image)
# print(style_image)
model = VGG().to(device).eval()

generated= orignal_image.clone().requires_grad_(True)

In [None]:
#Hyperparameters
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer= optim.Adam([generated], lr= learning_rate)

for step in range(total_steps):
  generated_features= model(generated)
  orignal_img_features = model(orignal_img)
  style_features = model(style_img)

  style_loss = orignal_loss = 0

  for gen_feature, orig_feature, style_feature in zip(
      generated_features, orignal_img_features, style_features
  ):
    batch_size, channel, height, width = gen_feature.shape
    orignal_loss += torch.mean((gen_feature - orig_feature) **2)

    #Gram Matrix
    G = gen_feature.view(channel, height*width).mm(
        gen_feature.view(channel, height*width).t()
    ) 

    A= style_feature.view(channel, height*width).mm(
       style_feature.view(channel, height*width).t() 
    )

    style_loss += torch.mean((G - A)**2)

  total_loss = alpha*orignal_loss + beta*style_loss
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()

  if step % 200 == 0:
    print(total_loss)
    save_image(generated, "generated.jpg")

tensor(912881.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22698.1094, device='cuda:0', grad_fn=<AddBackward0>)
tensor(10493.7217, device='cuda:0', grad_fn=<AddBackward0>)
tensor(6883.6309, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5185.8511, device='cuda:0', grad_fn=<AddBackward0>)
tensor(4170.8501, device='cuda:0', grad_fn=<AddBackward0>)
tensor(3485.0869, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2992.1936, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2619.2512, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2324.3472, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2083.4709, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1884.6556, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1718.4493, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1575.6732, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1452.0513, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1345.0640, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1251.6470, device='cuda:0', grad_fn=<AddBackw