In [1]:
import numpy as np
import torch
import os
import time
import pickle

from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

from utils import *
from networks import *

device = 'cuda:0'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
COCO_PATH = './coco'
STYLE_PATH = './style_imgs/rain.jpg'

FEATURENET_TYPE = Vgg16
CONTENT_FMID = 3

IMAGE_SIZE = 256
BATCH_SIZE = 1
LEARNING_RATE = 1e-3
NUM_EPOCHS = 1
STYLE_WEIGHT = 1e5
CONTENT_WEIGHT = 1e0
TV_WEIGHT = 1e-7

In [3]:
dataset_transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),          # scale shortest side to image_size
        transforms.CenterCrop(IMAGE_SIZE),      # crop center image_size out
        transforms.ToTensor(),                  # turn image from [0-255] to [0-1]
        imagenet_normalize()                    # normalize with ImageNet values
    ])
train_dataset = datasets.ImageFolder(COCO_PATH, dataset_transform)
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True)

In [4]:
img_net = ImageTransformNet().to(device)
optimizer = Adam(img_net.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()
feature_net = FEATURENET_TYPE().to(device)




In [5]:
def train(img_net, feature_net, STYLE_PATH, VAL_PATH=None):
    print('training')

    style_transform = transforms.Compose([
        transforms.ToTensor(),                  # turn image from [0-255] to [0-1]
        imagenet_normalize()                    # normalize with ImageNet values
    ])
    style_img = load_image(STYLE_PATH)
    style_tensor = Variable(style_transform(style_img).repeat(BATCH_SIZE, 1, 1, 1)).to(device)
    style_name = os.path.split(STYLE_PATH)[-1].split('.')[0]
    print(f'style: {style_name}')
    style_features = feature_net(style_tensor)
    style_gram = [gram(fmap) for fmap in style_features]

    if VAL_PATH:
        val_img = load_image(VAL_PATH)
        val_tensor = Variable(dataset_transform(val_img).repeat(1, 1, 1, 1), requires_grad=False).to(device)
    else:
        val_tensor = None

    style_loss_list = []
    content_loss_list = []
    tv_loss_list = []

    min_loss = float('inf')

    for epoch_id in range(NUM_EPOCHS):
        print(f'epoch {epoch_id}')

        img_cnt = 0
        s_style_loss = 0.0
        s_content_loss = 0.0
        s_tv_loss = 0.0
        
        for batch_id, (x, _) in enumerate(train_loader):
            batch_cnt = len(x)
            img_cnt += batch_cnt
            
            img_net.train()
            optimizer.zero_grad()
            
            x = Variable(x).to(device)
            out = img_net(x)
            in_features = feature_net(x)
            out_features = feature_net(out)
            out_gram = [gram(fmap) for fmap in out_features]

            b_style_loss = 0.0
            for j in range(len(out_gram)):
                b_style_loss += loss_fn(out_gram[j], style_gram[j][:batch_cnt])
            b_style_loss *= STYLE_WEIGHT
            s_style_loss += b_style_loss.item()
            style_loss_list.append(b_style_loss.item())

            b_content_loss = CONTENT_WEIGHT * loss_fn(in_features[CONTENT_FMID], out_features[CONTENT_FMID])
            s_content_loss += b_content_loss.item()
            content_loss_list.append(b_content_loss.item())

            diff_i = torch.sum(torch.abs(out[:, :, :, 1:] - out[:, :, :, :-1]))
            diff_j = torch.sum(torch.abs(out[:, :, 1:, :] - out[:, :, :-1, :]))
            b_tv_loss = TV_WEIGHT * (diff_i + diff_j)
            s_tv_loss += b_tv_loss.item()
            tv_loss_list.append(b_tv_loss.item())

            total_loss = b_style_loss + b_content_loss + b_tv_loss
            if total_loss.item() < min_loss:
                min_loss = total_loss.item()
                torch.save(img_net.state_dict(), f'./weight/{style_name}.pt')
            total_loss.backward()
            optimizer.step()

            if (batch_id + 1) % 10 == 0:
                print(f'epoch: [{epoch_id + 1}]/[{NUM_EPOCHS}]; batch: [{img_cnt}]/[{len(train_dataset)}]')
                print(f'style loss:   {b_style_loss:6f}, avg = {s_style_loss / (batch_id + 1.0):6f}')
                print(f'content loss: {b_content_loss:6f}, avg = {s_content_loss / (batch_id + 1.0):6f}')
                print(f'tv loss:      {b_tv_loss:6f}, avg = {s_tv_loss / (batch_id + 1.0):6f}')

                if val_tensor is not None:
                    img_net.eval()
                    out = img_net(val_tensor).detach().cpu()[0]
                    save_image(f'./val/{style_name}_{epoch_id:02d}_{img_cnt:05d}.jpg', out)
    
    return style_loss_list, content_loss_list, tv_loss_list


In [6]:
style_loss, content_loss, tv_loss = train(img_net, feature_net, './style_imgs/rain.jpg', VAL_PATH='./content_imgs/cute.jpg')


training
style: rain
epoch 0
epoch: [1]/[1]; batch: [10]/[5000]
style loss:   54.397243, avg = 160.240894
content loss: 3.970539, avg = 4.042556
tv loss:      0.001860, avg = 0.002452
epoch: [1]/[1]; batch: [20]/[5000]
style loss:   39.811367, avg = 103.062949
content loss: 3.244239, avg = 4.043823
tv loss:      0.001777, avg = 0.002095
epoch: [1]/[1]; batch: [30]/[5000]
style loss:   46.341198, avg = 83.148483
content loss: 3.112395, avg = 3.871325
tv loss:      0.001240, avg = 0.001985
epoch: [1]/[1]; batch: [40]/[5000]
style loss:   38.701042, avg = 72.497311
content loss: 6.773934, avg = 4.032267
tv loss:      0.002073, avg = 0.001920
epoch: [1]/[1]; batch: [50]/[5000]
style loss:   35.608440, avg = 65.289855
content loss: 2.077895, avg = 4.070799
tv loss:      0.001973, avg = 0.001935
epoch: [1]/[1]; batch: [60]/[5000]
style loss:   34.129147, avg = 60.739168
content loss: 7.137985, avg = 4.065157
tv loss:      0.002565, avg = 0.001982
epoch: [1]/[1]; batch: [70]/[5000]
style loss

In [7]:
def dump_loss(name, sl, cl, tl):
    with open(f'pickle_objs/{name}_style_loss', 'wb') as f:
        pickle.dump(sl, f)
    with open(f'pickle_objs/{name}_content_loss', 'wb') as f:
        pickle.dump(cl, f)
    with open(f'pickle_objs/{name}_tv_loss', 'wb') as f:
        pickle.dump(tl, f)

In [8]:
dump_loss('mosaic', style_loss, content_loss, tv_loss)

In [9]:
def transfer(in_path, out_path):
    test_img = load_image(in_path)
    test_tensor = Variable(dataset_transform(test_img).repeat(1, 1, 1, 1), requires_grad=False).to(device)
    out_tensor = img_net(test_tensor).detach().cpu()[0]
    save_image(out_path, out_tensor)

In [10]:
# transfer('./content_imgs/flower.jpg', 'test.jpg')