In [25]:
import time
from load_data import load_images
from losses import content_loss, style_loss
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from transforms import prep, post
from torch.autograd import Variable
from macros import *

In [26]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Current Device:', device)

Current Device: cpu


## Image Preparations

In [27]:
imgs = load_images(2, 2)
imgs = [prep(img) for img in imgs]
imgs = [Variable(img.unsqueeze(0).to(device)) for img in imgs]
img_con, img_sty = imgs
opt_img = Variable(img_con.data.clone(), requires_grad=True)

## Model Preparations

In [28]:
print('Optimizing from content image... Using pre-trained model VGG19_bn')
model = models.vgg19_bn(pretrained=True).to(device)
for param in model.parameters():
    param.requires_grad = False

Optimizing from content image... Using pre-trained model VGG19_bn


## Feature Extraction

In [29]:
class FeatureSaver(nn.Module):
    feature = None
    def __init__(self, layer):
        self.hook = layer.register_forward_hook(self.hook_func)
    def hook_func(self, module, input, output):
        self.feature = output
    def close(self):
        self.hook.remove()

In [30]:
content_feature_savers = [FeatureSaver(model.features[layer]) for layer in content_layers]
model(Variable(img_con))
content_features = [saver.feature.clone() for saver in content_feature_savers]
print(f'Saved content features from layer {content_layers[0]} of the model')
style_feature_savers  = [FeatureSaver(model.features[layer]) for layer in style_layers]
model(Variable(img_sty))
style_features = [saver.feature.clone() for saver in style_feature_savers]
print(f'Saved style features from layer {layer} of the model' for layer in style_layers)

Saved content features from layer 37 of the model
<generator object <genexpr> at 0x13d4dda80>


## Optimizer Setup

In [31]:
optimizer = optim.LBFGS([opt_img])

def closure():
    global i
    model(opt_img)
    gen_content_feats = [saver.feature.clone() for saver in content_feature_savers]
    gen_style_feats = [saver.feature.clone() for saver in style_feature_savers]

    contentloss = WEIGHT_CONTENT * content_loss(gen_content_feats, content_features)
    styleloss = style_loss(gen_style_feats, style_features, WEIGHTS_STYLE)
    loss = contentloss + styleloss

    optimizer.zero_grad()
    loss.backward()

    if i % show_iter == 0:
        print(f"Epoch: {i}, Content loss: {contentloss}, Style loss: {styleloss}, Total loss: {loss}")
    i += 1
    return loss

In [32]:
start_time = time.time()
i = 0
print('Start Training...')
while i < max_iter:
    optimizer.step(closure)

end_time = time.time()
print(f"Training completed in {end_time - start_time:.2f} seconds.")

out_img = post(opt_img.data[0].cpu().squeeze())
out_img.save('Result/kanagawa_ghibli.png', format='png')

Start Training...
Epoch: 0, Content loss: 0.0, Style loss: 301846.03125, Total loss: 301846.03125
Epoch: 50, Content loss: 0.005699015222489834, Style loss: 6062.5048828125, Total loss: 6062.5107421875
Epoch: 100, Content loss: 0.006193954031914473, Style loss: 1069.41064453125, Total loss: 1069.4168701171875
Epoch: 150, Content loss: 0.006436120253056288, Style loss: 355.30615234375, Total loss: 355.3125915527344
Epoch: 200, Content loss: 0.006631428375840187, Style loss: 171.02200317382812, Total loss: 171.0286407470703
Epoch: 250, Content loss: 0.0067511689849197865, Style loss: 101.37615966796875, Total loss: 101.3829116821289
Epoch: 300, Content loss: 0.00684228865429759, Style loss: 68.37683868408203, Total loss: 68.38368225097656
Epoch: 350, Content loss: 0.006915047764778137, Style loss: 50.1961669921875, Total loss: 50.20308303833008
Epoch: 400, Content loss: 0.006958798039704561, Style loss: 39.56109619140625, Total loss: 39.56805419921875
Epoch: 450, Content loss: 0.00699027