In [51]:
import numpy as np
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.onnx
import matplotlib.pyplot as plt

In [52]:
import utils
from transformer_net import TransformerNet
from vgg import Vgg16

In [53]:
# PARAMS

SEED = 128
IMAGE_SIZE = 256
LR = 0.001
EPOCHS = 2
CONTENT_WEIGHT = 1e3
STYLE_WEIGHT = 1e10
BATCH_SIZE = 4
MODEL_PATH = 'data/style1/model'
DATASET_PATH = 'data/style1/dataset'
STYLE_IMAGE = 'data/style1/style1.jpg'

In [54]:
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(DATASET_PATH, transform)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
    
    # TransformerNet
    transformer = TransformerNet().to(device)
    optimizer = torch.optim.Adam(transformer.parameters(), LR)
    mse_loss = torch.nn.MSELoss()

    # VGG
    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(STYLE_IMAGE)
    style = style_transform(style)
    style = style.repeat(BATCH_SIZE, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]


    # Training
    for epoch in range(EPOCHS):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            x = utils.normalize_batch(x)
            y = utils.normalize_batch(y)

            features_x = vgg(x)
            features_y = vgg(y)

            content_loss = CONTENT_WEIGHT * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for feat_y, gram_s in zip(features_y, gram_style):
                gram_y = utils.gram_matrix(feat_y)
                style_loss += mse_loss(gram_y, gram_s[:n_batch, :, :])
            style_loss *= STYLE_WEIGHT

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % LOG_INTERVAL == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), epoch + 1, count, len(loader.dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)
            
    
    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(EPOCHS) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        STYLE_IMAGE.split('/')[-1].split('.')[0]) + "_" + str(CONTENT_WEIGHT) + "_" + str(STYLE_WEIGHT) + ".model"
    save_model_path = os.path.join(MODEL_PATH, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)



In [55]:
content_image = 'data/style1/original_images/person_0000.jpg'
model = 'data/style1/model/*.model'

In [56]:
def stylize():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(model)
        
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        style_model = torch.nn.DataParallel(style_model)
        output = style_model(content_image).cpu()
    display(output)


    

In [57]:
train()



OutOfMemoryError: CUDA out of memory. Tried to allocate 724.00 MiB (GPU 0; 3.81 GiB total capacity; 1.62 GiB already allocated; 629.62 MiB free; 2.24 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF