In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    !git clone https://github.com/dmhd1/mcproject.git
    %cd mcproject/
    %mkdir data/
    !pip install torch torchvision
else:
    print('Not running on CoLab')

In [2]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
import yaml
import matplotlib.pyplot as plt
import numpy as np
from lib.models import vae as vae

In [3]:
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

In [4]:
class DictAsMember(dict):
    def __getattr__(self, name):
        value = self[name]
        if isinstance(value, dict):
            value = DictAsMember(value)
        return value

In [5]:
options = """
batch_size: 128
epochs: 10
cuda: False
seed: 1
log_interval: 100
result_folder: './runs'
data_folder: '../data'
"""

In [6]:
args = DictAsMember(yaml.safe_load(options))
args.cuda = args.cuda and torch.cuda.is_available()

In [7]:
writer = SummaryWriter(f'{args.result_folder}/tb')

In [8]:
torch.manual_seed(args.seed);

Set up device:

In [9]:
device = torch.device("cuda:0" if args.cuda else 'cpu' )
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

Load data:

In [10]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(args.data_folder, train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(args.data_folder, train=False, transform=transforms.ToTensor()),
    batch_size=args.batch_size, shuffle=True, **kwargs)

NameError: name 'data_folder' is not defined

In [None]:
dataiter = iter(train_loader)
images, labels = dataiter.next()
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
writer.add_image('mnist_images', img_grid)

In [None]:
model = vae.VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_function = vae.loss_function

In [None]:
torch.jit.check_trace = False
writer.add_graph(model, images.to(device))

In [None]:
if 'google.colab' in str(get_ipython()):
    %load_ext
    !tensorboard --logdir=./runs/
else:
    print('Run tensorboard on your local machine.')

### Training

In [None]:
for epoch in range(1, args.epochs + 1):
    vae.train(model, epoch, optimizer, train_loader, loss_function, device, args, writer=writer)
    vae.test(model, epoch, test_loader, loss_function, device, args, writer=writer)
    with torch.no_grad():
        sample = torch.randn(64, 20).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28), f'{args.result_folder}/{epoch}.png')