In [81]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

In [82]:
# run params
SECTION = 'vae'
RUN_ID = '0002'
DATA_NAME = 'digits'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #

In [83]:
NUM_CLASSES = 10

transform = transforms.Compose(
    [transforms.ToTensor()])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
                                         shuffle=False)

In [84]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [85]:
class Encoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(64*6*6, 2)
        self.mu = nn.Linear(64*6*6, 2)

        self.log_var = nn.Linear(64*6*6, 2)



        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.Conv2d(32, 64, 3, 2),
            nn.Conv2d(64, 64, 3, 2),
            nn.Conv2d(64, 64, 3, 1),
        )
    

    def forward(self, x):
        for layer in self.encoder:
            x = F.pad(layer(x), (1,0,1,0))
            x = F.leaky_relu(x)
        x = self.flatten(x)

        mu, log_var = self.mu(x), self.log_var(x)

        def sampling(args):
            mu, log_var = args
            epsilon = torch.normal(0., 1., size=mu.shape).to(device)
            return mu + torch.exp(log_var / 2) * epsilon

        x = sampling([mu, log_var])

        return x


In [86]:
Encoder().to(device)(torch.randn((1, 1, 28, 28)).to(device))

tensor([[-0.7162,  0.3403]], device='cuda:0', grad_fn=<AddBackward0>)

In [87]:
class Decoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = nn.Linear(2, 64*6*6)


        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 64, 3, 1, 1),
            nn.ConvTranspose2d(64, 64, 3, 2, 1),
            nn.ConvTranspose2d(64, 32, 3, 2, 1),
            nn.ConvTranspose2d(32, 1, 3, 1, 1),
        )
    def forward(self, x):
        x = self.linear1(x)
        x = x.reshape(-1, 64, 6, 6)
        for i, layer in enumerate(self.decoder):
            if i == len(self.decoder) - 1:
                x = layer(x)
            else:
                x = F.pad(layer(x), (0,1,0,1))
                x = F.leaky_relu(x)
        return F.sigmoid(x)


In [88]:
class AutoEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Encoder().to(device)
        self.decoder = Decoder().to(device)
    def forward(self, x):
        t = self.encoder(x)
        x = self.decoder(t)
        return x, t

In [89]:
model =AutoEncoder().to(device)
model

AutoEncoder(
  (encoder): Encoder(
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (linear1): Linear(in_features=2304, out_features=2, bias=True)
    (mu): Linear(in_features=2304, out_features=2, bias=True)
    (log_var): Linear(in_features=2304, out_features=2, bias=True)
    (encoder): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
      (1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    )
  )
  (decoder): Decoder(
    (linear1): Linear(in_features=2, out_features=2304, bias=True)
    (decoder): Sequential(
      (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (2): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): ConvTranspose2d(32, 1, kernel_size=(3, 3), stride=(1, 1)

In [90]:
optimizer = optim.Adam(model.parameters(), lr=0.0005)

criterion = nn.MSELoss()

In [91]:
for epoch in range(200):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs, _ = model(inputs)
        loss = torch.sqrt(criterion(outputs, inputs))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # print statistics
        running_loss += loss.item()
    print(f'[{epoch + 1}] loss: {running_loss / len(trainloader):.3f}')

print('Finished Training')



[1] loss: 0.471
[2] loss: 0.422
[3] loss: 0.413
[4] loss: 0.408


In [None]:
torch.save(model, RUN_FOLDER + "/weights/weight.pt")