In [1]:
import torchvision
from torchvision import transforms, datasets
import models
from torch import optim
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.distributions.multivariate_normal import MultivariateNormal
import numpy as np
from dataloader import MNISTIndexed
from losses import LossG, NaiveLoss
import yaml
import random

from matplotlib import pyplot as plt



In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATASET_SIZE = 60000
IMAGES_TO_USE = 1600
CONTENT_CODE_LEN = 20
BATCH_SIZE = 16

In [4]:
with open("./conf.yaml", "r") as f:
    cfg = yaml.safe_load(f)
seed = cfg['seed']
if seed == -1:
    seed = np.random.randint(2 ** 32 - 1, dtype=np.int64)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
print(f'running with seed: {seed}.')

running with seed: 1591025859.


since Python 3.9 and will be removed in a subsequent version. The only 
supported seed types are: None, int, float, str, bytes, and bytearray.
  random.seed(seed)


In [5]:
data_path = "./datasets/MNIST/"

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                              ])

# Download and load the data
mnist_data = MNISTIndexed(data_path, download=True, train=True, transform=transform)
used_indices = np.random.randint(0, DATASET_SIZE, IMAGES_TO_USE)
mnist_subsample = torch.utils.data.Subset(mnist_data, used_indices)
mnist_dataloader = torch.utils.data.DataLoader(mnist_subsample, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
criterion = NaiveLoss(cfg)

In [7]:
model = models.GeneratorBasic(CONTENT_CODE_LEN, 4, 10, (BATCH_SIZE, 1, 28, 28))

In [8]:
def train_model(model, tboard_name, loss_func, train_loader, epochs=50, lr=1e-3, noise_std=0.5, reg_factor=1e-6):
    writer = SummaryWriter(log_dir='logs/' + tboard_name)
    optimizer = optim.Adam(model.parameters(), lr=lr)
 
    # prepare the data
    # TODO play with initizalization and refactor?
    class_codes = torch.normal(0.5, noise_std, (10, 10)).to(device)
    content_codes = torch.normal(0.5, noise_std, (DATASET_SIZE, CONTENT_CODE_LEN)).to(device)
    
    # set up some variables for the visualizations
    display_contents = used_indices[:4]
    display_classes = [0, 1, 2, 3]

    for epoch in range(epochs):
        model.train()

        losses = []
        for data_row in train_loader:
            # get the inputs; data is a list of [inputs, labels]
            images, labels, indices = data_row
            images = images.to(device)

            # create input for network
            cur_content, cur_class = content_codes[indices], class_codes[labels]
            cur_content.requires_grad_(True)
            cur_class.requires_grad_(True)
            noisy_code = cur_content + torch.rand(CONTENT_CODE_LEN) * noise_std
            inputs = torch.cat((cur_class, noisy_code), 1)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            
            loss = loss_func(torch.cat([outputs, outputs, outputs], dim=1), torch.cat([images, images, images], dim=1), cur_content)
            loss.backward()
            optimizer.step()

            # statistics
            losses.append(loss.item())
        
        model.eval()
        inputs = []
        for disp_classes in display_classes:
            for disp_contents in display_contents:
                inputs.append(torch.cat((class_codes[disp_classes], content_codes[disp_contents])).unsqueeze(0))
        outputs = model(torch.cat(inputs, 0))
        img_grid = torchvision.utils.make_grid(outputs,nrow=4)
        writer.add_image('Epoch ' + str(epoch+1),img_grid)

        writer.add_scalar('loss', np.mean(losses), epoch)
        print("Epoch: {}, loss: {}\n".format(epoch, np.mean(losses)))

    writer.close()

In [9]:
train_model(model, "checkNaiveGen4", criterion, mnist_dataloader, epochs=50, lr=1e-3, noise_std=0.3, reg_factor=1e-6)

Epoch: 0, loss: 0.37437031596899034

Epoch: 1, loss: 0.30358454495668413

Epoch: 2, loss: 0.28559818655252456

Epoch: 3, loss: 0.27821382507681847

Epoch: 4, loss: 0.2758671416342258

Epoch: 5, loss: 0.2744828520715237

Epoch: 6, loss: 0.2737030744552612

Epoch: 7, loss: 0.27254353150725363

Epoch: 8, loss: 0.27210637167096136

Epoch: 9, loss: 0.27127748414874076

Epoch: 10, loss: 0.2702812401950359

Epoch: 11, loss: 0.26972040578722956

Epoch: 12, loss: 0.2687340289354324

Epoch: 13, loss: 0.2681415903568268

Epoch: 14, loss: 0.26724010914564134

Epoch: 15, loss: 0.26597691252827643

Epoch: 16, loss: 0.26532404124736786

Epoch: 17, loss: 0.2648063203692436

Epoch: 18, loss: 0.26331633076071737

Epoch: 19, loss: 0.26277846455574033

Epoch: 20, loss: 0.2606666797399521

Epoch: 21, loss: 0.26011467203497884

Epoch: 22, loss: 0.25858143880963325

Epoch: 23, loss: 0.25717959985136984

Epoch: 24, loss: 0.2569903887808323

Epoch: 25, loss: 0.2552080246806145

Epoch: 26, loss: 0.2542532651126