In [1]:
import torch
import data, model, loss
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.tensorboard as tb
import torchvision

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
VAL_PORTION = 0.2
ITERATIONS = 5000
VAL_ITERATIONS = 10
RESOLUTION = 32
CONTENT_DIM = 512
STYLE_DIM = 512
BATCH_SIZE = 32

CONTENT_LOSS_WEIGHTS = {
    'relu_4_2' : 1e0,
}

In [13]:
data_style = data.load_debug_dataset('../dataset/debug/style', resolution=RESOLUTION, number_instances=200)
data_style_train, data_style_val = torch.utils.data.random_split(data_style, [len(data_style) - int(VAL_PORTION * len(data_style)), int(VAL_PORTION * len(data_style))])
data_loader_style_train = DataLoader(data_style_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_style_val = DataLoader(data_style_val, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_content = data.load_debug_dataset('../dataset/debug/content', resolution=RESOLUTION, number_instances=200)
data_content_train, data_content_val = torch.utils.data.random_split(data_content, [len(data_content) - int(VAL_PORTION * len(data_content)), int(VAL_PORTION * len(data_content))])
data_loader_content_train = DataLoader(data_content_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_content_val = DataLoader(data_content_val, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_loader_train = data.DatasetPairIterator(data_loader_content_train, data_loader_style_train)
data_loader_val = data.DatasetPairIterator(data_loader_content_val, data_loader_style_val)

data_style = data.load_dataset('../dataset/style', resolution=RESOLUTION)
data_style_train, data_style_val = torch.utils.data.random_split(data_style, [len(data_style) - int(VAL_PORTION * len(data_style)), int(VAL_PORTION * len(data_style))])
data_loader_style_train = DataLoader(data_style_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_style_val = DataLoader(data_style_val, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_content = data.load_dataset('../dataset/content', resolution=RESOLUTION)
data_content_train, data_content_val = torch.utils.data.random_split(data_content, [len(data_content) - int(VAL_PORTION * len(data_content)), int(VAL_PORTION * len(data_content))])
data_loader_content_train = DataLoader(data_content_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_content_val = DataLoader(data_content_val, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_loader_train = data.DatasetPairIterator(data_loader_content_train, data_loader_style_train)
data_loader_val = data.DatasetPairIterator(data_loader_content_val, data_loader_style_val)

In [18]:
content_encoder = model.Encoder(512, normalization=True, residual=True, num_down_convolutions=4)
decoder = model.Decoder(512, None, (32, 32), out_channels=3, residual=True, normalization='in', num_up_convolutions=4)

#content_encoder = model.ResNetEncoder(CONTENT_DIM, architecture=torchvision.models.resnet18, pretrained=True)
#content_encoder = model.VGGEncoder((3, RESOLUTION, RESOLUTION), flattened_output_dim=CONTENT_DIM, architecture=torchvision.models.vgg19, n_layers=19)
#decoder = model.Decoder(None, None, (RESOLUTION, RESOLUTION))
#decoder = model.VGGDecoder((3, RESOLUTION, RESOLUTION), None, CONTENT_DIM, architecture=torchvision.models.vgg19, n_layers=19)
loss_net = loss.LossNet()
_ = loss_net.eval()
criterion = torch.nn.MSELoss()

In [19]:
print(f'Content encoder has {sum(p.numel() for p in content_encoder.parameters() if p.requires_grad)} parameters.')
print(f'Decoder has {sum(p.numel() for p in decoder.parameters() if p.requires_grad)} parameters.')

Content encoder has 6500928 parameters.
Decoder has 4926234 parameters.


In [20]:
if torch.cuda.is_available(): 
    content_encoder = content_encoder.cuda()
    decoder = decoder.cuda()
    loss_net = loss_net.cuda()

trainable_parameters = []
#linear = torch.nn.Sequential(
#    torch.nn.Linear(Cd * Hd * Wd, CONTENT_DIM),
#).cuda()
for parameter in content_encoder.parameters():
    trainable_parameters.append(parameter)
for parameter in decoder.parameters():
    trainable_parameters.append(parameter)
#for parameter in linear.parameters():
#    trainable_parameters.append(parameter)

In [21]:
optimizer = torch.optim.Adam(trainable_parameters, lr=1e-4)

In [22]:
tb_writer = tb.SummaryWriter('log/autoencoder')

In [23]:
with torch.autograd.set_detect_anomaly(True):
    iteration = 0
    val_step = 0
    for (content_image, content_path), _ in data_loader_train:
        if iteration >= ITERATIONS: 
            break
        if torch.cuda.is_available():
            content_image = content_image.to('cuda')

        content_encoder.train(), decoder.train()
        optimizer.zero_grad()

        z = content_encoder(content_image)
        #z = linear(z.view(-1, Cd * Hd * Wd))
        decoded = decoder(z, None)

        features_content = loss_net(content_image)
        features_decoded = loss_net(decoded)
        total_loss = loss.perceptual_loss(features_content, features_decoded, CONTENT_LOSS_WEIGHTS)
        total_loss.backward()
        optimizer.step()
        
        z = z.detach().cpu().numpy()

        tb_writer.add_scalar('train loss', total_loss.item(), iteration)
        print(f'\r{iteration:5d} / {ITERATIONS}: loss : {total_loss.item():.4f} -- z_mean : {z.mean():.5f} -- z_std : {z.std(axis=0).mean():.5f}', end='\r')

        if iteration % 100 == 0:
            tb_writer.add_images('train images', torch.from_numpy(np.concatenate([
                data.vgg_normalization_undo(img.detach().cpu().numpy()) for img in [content_image, decoded]
            ])), iteration)
            # Validate
            print('\nValidation...')
            with torch.no_grad():
                val_iteration = 0
                content_encoder.eval(), decoder.eval()
                for (content_image, content_path), _ in data_loader_val:
                    if val_iteration >= VAL_ITERATIONS:
                        print('\Training...')
                        break

                    if torch.cuda.is_available():
                        content_image = content_image.to('cuda')


                    z = content_encoder(content_image)
                    #z = linear(z.view(-1, Cd * Hd * Wd))
                    decoded = decoder(z, None)

                    features_content = loss_net(content_image)
                    features_decoded = loss_net(decoded)
                    total_loss = loss.perceptual_loss(features_content, features_decoded, CONTENT_LOSS_WEIGHTS)
                    #total_loss = criterion(content_image, decoded)
                    
                    z = z.detach().cpu().numpy()

                    print(f'\r{val_iteration:5d} / {VAL_ITERATIONS}: loss : {total_loss.item():.4f} -- z_mean : {z.mean():.5f} -- z_std : {z.std(axis=0).mean():.5f}', end='\r')


                    tb_writer.add_scalar('validation loss', total_loss.item(), val_step)
                    tb_writer.add_images('validation images', torch.from_numpy(np.concatenate([
                        data.vgg_normalization_undo(img.detach().cpu().numpy()) for img in [content_image, decoded]
                    ])), val_step)
                    val_iteration += 1
                    val_step += 1

        iteration += 1


    0 / 5000: loss : 36.5201 -- z_mean : 0.06726 -- z_std : 0.14791
Validation...
\Training...loss : 34.0379 -- z_mean : 0.07034 -- z_std : 0.14780
  100 / 5000: loss : 18.9468 -- z_mean : 0.07868 -- z_std : 0.14941
Validation...
\Training...loss : 24.7783 -- z_mean : 0.07850 -- z_std : 0.14218
  200 / 5000: loss : 14.2019 -- z_mean : 0.06630 -- z_std : 0.16028
Validation...
\Training...loss : 24.1431 -- z_mean : 0.06696 -- z_std : 0.14383
  300 / 5000: loss : 10.1863 -- z_mean : 0.05668 -- z_std : 0.17147
Validation...
\Training...loss : 27.2567 -- z_mean : 0.05536 -- z_std : 0.14547
  400 / 5000: loss : 7.3516 -- z_mean : 0.04749 -- z_std : 0.177941
Validation...
\Training...loss : 26.4079 -- z_mean : 0.04538 -- z_std : 0.14708
  500 / 5000: loss : 5.6895 -- z_mean : 0.03992 -- z_std : 0.18639
Validation...
\Training...loss : 27.1184 -- z_mean : 0.03935 -- z_std : 0.14811
  600 / 5000: loss : 4.7365 -- z_mean : 0.04106 -- z_std : 0.18961
Validation...
\Training...loss : 27.8698 -- z_

KeyboardInterrupt: 

In [None]:
print(decoder)

In [None]:
print(content_encoder)