In [2]:
import yaml
import torch
import os

from torch.utils.data import DataLoader
from torch import optim

from torchvision import transforms
from torchvision.datasets import CelebA
from torchvision.utils import save_image

# project modules
from vae import VAE


In [25]:
model = VAE(64, 1)

In [14]:
from torch.utils.tensorboard import SummaryWriter


In [32]:
writer_test = SummaryWriter()
writer_train = SummaryWriter()

x = torch.arange(-5, 5, 0.1).view(-1, 1)
y = -5 * x + 0.1 * torch.randn(x.size())

model = torch.nn.Linear(1, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1)


for epoch in range(20):
    y1 = model(x)
    loss = criterion(y1, y)
    writer_train.add_scalar("Loss", loss, epoch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch % 2 == 0:
        writer_test.add_scalar("Loss", loss*2, epoch)


writer_train.flush()
writer_test.flush()



In [22]:
writer_train.close()
writer_test.close()

In [34]:
torch.save(model, writer_test.log_dir + '\\vae_model.pth')

In [28]:
writer_test.log_dir.replace('\\', '/')

'runs/Jan05_16-21-37_larschen'

In [7]:
print(model)

VAE(
  (prior_nn): Sequential(
    (0): Linear(in_features=6, out_features=100, bias=True)
    (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (prior_mu): Linear(in_features=100, out_features=50, bias=True)
  (prior_var): Linear(in_features=100, out_features=50, bias=True)
  (cnn): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_r

In [41]:
model.eval()
sample = model.prior_distribution(torch.zeros((1,6))).sample((64,1)).squeeze()


In [42]:
sample = model.decode(sample, torch.zeros((64,1)))

In [43]:
save_image(
    sample.view(64, 3, model.image_size, model.image_size),
    f"models/sample0_{str(10)}.png",
)