In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from PIL import Image
import torchvision.transforms as tranforms
from torchvision.utils import save_image
import cv2 as cv
import numpy as np

In [2]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
torch.__version__

'2.1.2+cu118'

In [4]:
class vggfeatureextrator(nn.Module):
  def __init__(self):
    super(vggfeatureextrator,self).__init__()
    self.model=models.vgg19(weights='VGG19_Weights.DEFAULT').features
    self.selectedlayers=['0','5','10','19','28']
  def forward(self,x):
    outputfeatures=[]
    for layernum,layer in enumerate(self.model):
      x=layer(x)
      if str(layernum) in self.selectedlayers:
        outputfeatures.append(x)
    return outputfeatures


In [5]:
modelvgg=vggfeatureextrator().to(device).eval()

In [6]:
imgsize=720
loader=tranforms.Compose([
    tranforms.Resize((imgsize,imgsize)),
    tranforms.ToTensor()
    ])

In [7]:
def imageloader(imgpath):
  img = Image.open(imgpath).convert("RGB")
  img=loader(img).unsqueeze(0)
  return img.to(device)

In [8]:
contentimg=imageloader(imgpath="/home/tarek/projects/cameras-simulation-tool/src/uuv_simulator/uuv_gazebo_worlds/Media/materials/textures/Rusty-mat.jpg")

In [9]:
styleimg=imageloader(imgpath='/home/tarek/projects/cameras-simulation-tool/src/uuv_simulator/uuv_gazebo_worlds/Media/materials/textures/top-view-dark-soil-background.jpg')

In [11]:
print(contentimg.shape)

torch.Size([1, 3, 720, 720])


In [None]:
styleimg.shape

In [12]:
generatedimg=contentimg.clone().requires_grad_(True)

In [19]:
totalsteps=1000
lr=1e-3
alpha=0.025
beta=1
optimizer=optim.Adam([generatedimg],lr=lr)

In [20]:
def computecontentloss(g,c):
  return torch.mean((g-c)**2)

In [21]:
def computestyleloss(g,s):
  b,c,h,w=g.shape
  gen=g.view(c,h*w).mm(g.view(c,h*w).t())
  style=s.view(c,h*w).mm(s.view(c,h*w).t())
  return torch.mean((gen-style)**2)

In [22]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()


In [23]:
torch.cuda.memory_summary(device=device, abbreviated=False)



In [24]:
for step in range(totalsteps):
  contentimgfeatures=modelvgg(contentimg)
  styleimgfeatures=modelvgg(styleimg)
  generatedimgfeatures=modelvgg(generatedimg)
  contentloss=0
  styleloss=0
  for genfeature,contentfeature,stylefeature in zip(generatedimgfeatures
                                                    ,contentimgfeatures,styleimgfeatures):
    contentloss+=computecontentloss(genfeature,contentfeature)
    styleloss+=computestyleloss(genfeature,stylefeature)
  totalloss=alpha*contentloss+beta*styleloss
  optimizer.zero_grad()
  totalloss.backward()
  optimizer.step()
  if step%100==0:
    print('current losses :','total loss is',totalloss.item()," style loss:",styleloss.item(),' content loss:',contentloss.item(),'\n')
    save_image(generatedimg,'results_sand_3.png')

current losses : total loss is 6509369.0  style loss: 6509369.0  content loss: 2.877631664276123 

current losses : total loss is 5290246.5  style loss: 5290246.5  content loss: 3.027958631515503 

current losses : total loss is 4649101.5  style loss: 4649101.5  content loss: 3.1356170177459717 

current losses : total loss is 4119426.5  style loss: 4119426.5  content loss: 3.222456455230713 

current losses : total loss is 3662617.75  style loss: 3662617.75  content loss: 3.2945945262908936 

current losses : total loss is 3281167.5  style loss: 3281167.5  content loss: 3.3551249504089355 

current losses : total loss is 2941830.5  style loss: 2941830.5  content loss: 3.402787446975708 

current losses : total loss is 2643784.75  style loss: 2643784.75  content loss: 3.4420876502990723 

current losses : total loss is 2378257.5  style loss: 2378257.5  content loss: 3.47892427444458 

current losses : total loss is 2134676.75  style loss: 2134676.75  content loss: 3.5129282474517822 

