In [35]:
import torch 
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
from torchvision.models import vgg19,VGG19_Weights
from torchvision.utils import save_image

In [36]:
model = vgg19(weights='IMAGENET1K_V1').features
#note .features returns only the conv layers
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [37]:
class NSTmodel(nn.Module):
  def __init__(self):
    super(NSTmodel,self).__init__()
    self.chosen_features = ["0","5","10","19","28"] #layer numbers of the conv layers for which we will find loss
    self.model = vgg19(weights='IMAGENET1K_V1').features[:29]
  def forward(self,x):
    features = []
    for layer_num, layer in enumerate(self.model):
      x = layer(x)
      if str(layer_num) in self.chosen_features:
        features.append(x)
    return features

In [38]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
image_size = 500

In [39]:
def img_loader(image_name):
  image = Image.open(image_name)
  image = loader(image).unsqueeze(0)
  return image.to(device)

In [40]:
loader = transforms.Compose(
    [
        transforms.Resize((image_size,image_size)),
        transforms.ToTensor(),
    ]
  ) 

In [41]:
original_image = img_loader("mona_lisa.jpg")
style_image =  img_loader("great_wave.jpg")
generated_image = original_image.clone().requires_grad_(True)
model = NSTmodel().to(device).eval()
print(model)

NSTmodel(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), pa

In [42]:
epochs = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01

optimizer = optim.Adam([generated_image],lr = learning_rate)

In [44]:
for step in range(epochs+1):
  generated_image_features = model(generated_image)
  orignal_image_features = model(original_image)
  style_image_features = model(style_image)

  style_loss = 0
  content_loss = 0

  for (gen,ori,sty) in zip(generated_image_features,orignal_image_features,style_image_features):
    batch_size,channel,height,width = gen.shape

    content_loss += torch.mean((gen-ori)**2)
    Gg = gen.view(channel,height*width).mm(gen.view(channel,height*width).t())
    Gs = sty.view(channel,height*width).mm(sty.view(channel,height*width).t())
    style_loss += torch.mean((Gg-Gs)**2)
    
  total_loss = alpha*content_loss + beta*style_loss

  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()

  if step%100 == 0:
    print(f"Epoch = {step}, Loss = {total_loss}")
  if step%1000 == 0:
    save_image(generated_image, f"generated{step}.png")

Epoch = 0, Loss = 2322152.25
Epoch = 100, Loss = 448993.90625
Epoch = 200, Loss = 291006.28125
Epoch = 300, Loss = 228339.3125
Epoch = 400, Loss = 191686.171875
Epoch = 500, Loss = 166911.40625
Epoch = 600, Loss = 148581.234375
Epoch = 700, Loss = 134402.28125
Epoch = 800, Loss = 123061.8359375
Epoch = 900, Loss = 113794.765625
Epoch = 1000, Loss = 106119.5625
Epoch = 1100, Loss = 99763.5546875
Epoch = 1200, Loss = 94446.78125
Epoch = 1300, Loss = 89935.671875
Epoch = 1400, Loss = 86050.96875
Epoch = 1500, Loss = 82675.1875
Epoch = 1600, Loss = 79723.171875
Epoch = 1700, Loss = 77103.4140625
Epoch = 1800, Loss = 74743.5546875
Epoch = 1900, Loss = 72600.6640625
Epoch = 2000, Loss = 70635.4296875
Epoch = 2100, Loss = 68814.5234375
Epoch = 2200, Loss = 67115.4453125
Epoch = 2300, Loss = 65522.453125
Epoch = 2400, Loss = 64012.15234375
Epoch = 2500, Loss = 62565.41796875
Epoch = 2600, Loss = 61175.59375
Epoch = 2700, Loss = 59832.13671875
Epoch = 2800, Loss = 58528.26171875
Epoch = 2900, L