In [1]:
import time
from load_data import load_images
from losses import content_loss, style_loss
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from transforms import prep, post
from torch.autograd import Variable
from macros import *

In [2]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Current Device:', device)

Current Device: cpu


## Image Preparations

In [3]:
imgs = load_images(1, 2)
imgs = [prep(img) for img in imgs]
imgs = [Variable(img.unsqueeze(0).to(device)) for img in imgs]
img_con, img_sty = imgs
opt_img = Variable(img_con.data.clone(), requires_grad=True)

## Model Preparations

In [4]:
print('Optimizing from content image... Using pre-trained model VGG19_bn')
model = models.vgg19_bn(pretrained=True).to(device)
for param in model.parameters():
    param.requires_grad = False

Optimizing from content image... Using pre-trained model VGG19_bn




## Feature Extraction

In [5]:
class FeatureSaver(nn.Module):
    feature = None
    def __init__(self, layer):
        self.hook = layer.register_forward_hook(self.hook_func)
    def hook_func(self, module, input, output):
        self.feature = output
    def close(self):
        self.hook.remove()

In [6]:
content_feature_savers = [FeatureSaver(model.features[layer]) for layer in content_layers]
model(Variable(img_con))
content_features = [saver.feature.clone() for saver in content_feature_savers]
print(f'Saved content features from layer {content_layers[0]} of the model')
style_feature_savers  = [FeatureSaver(model.features[layer]) for layer in style_layers]
model(Variable(img_sty))
style_features = [saver.feature.clone() for saver in style_feature_savers]
print(f'Saved style features from layer {layer} of the model' for layer in style_layers)

Saved content features from layer 37 of the model
<generator object <genexpr> at 0x16282d0e0>


## Optimizer Setup

In [7]:
optimizer = optim.LBFGS([opt_img])

def closure():
    global i
    model(opt_img)
    gen_content_feats = [saver.feature.clone() for saver in content_feature_savers]
    gen_style_feats = [saver.feature.clone() for saver in style_feature_savers]

    contentloss = WEIGHT_CONTENT * content_loss(gen_content_feats, content_features)
    styleloss = style_loss(gen_style_feats, style_features, WEIGHTS_STYLE)
    loss = contentloss + styleloss

    optimizer.zero_grad()
    loss.backward()

    if i % show_iter == 0:
        print(f"Epoch: {i}, Content loss: {contentloss}, Style loss: {styleloss}, Total loss: {loss}")
    i += 1
    return loss

In [8]:
start_time = time.time()
i = 0
print('Start Training...')
while i < max_iter:
    optimizer.step(closure)

end_time = time.time()
print(f"Training completed in {end_time - start_time:.2f} seconds.")

out_img = post(opt_img.data[0].cpu().squeeze())
out_img.save('Result/kanagawa_neckarfront.png', format='png')

Start Training...
Epoch: 0, Content loss: 0.0, Style loss: 188437.03125, Total loss: 188437.03125
Epoch: 50, Content loss: 0.0058092097751796246, Style loss: 2942.35205078125, Total loss: 2942.35791015625
Epoch: 100, Content loss: 0.006294267252087593, Style loss: 638.6104125976562, Total loss: 638.61669921875
Epoch: 150, Content loss: 0.006473978981375694, Style loss: 263.77642822265625, Total loss: 263.78289794921875
Epoch: 200, Content loss: 0.006575554143637419, Style loss: 141.4736328125, Total loss: 141.48020935058594
Epoch: 250, Content loss: 0.006645968183875084, Style loss: 93.17888641357422, Total loss: 93.18553161621094
Epoch: 300, Content loss: 0.006698772311210632, Style loss: 66.3231201171875, Total loss: 66.32981872558594
Epoch: 350, Content loss: 0.006747280713170767, Style loss: 51.66287612915039, Total loss: 51.66962432861328
Epoch: 400, Content loss: 0.0067830318585038185, Style loss: 42.37866973876953, Total loss: 42.38545227050781
Epoch: 450, Content loss: 0.006819