In [229]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

In [230]:
# run params
SECTION = 'vae'
RUN_ID = '0001'
DATA_NAME = 'digits'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

MODE =  'build' #'load' #

In [231]:
NUM_CLASSES = 10

transform = transforms.Compose(
    [transforms.ToTensor()])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
                                         shuffle=False)

In [232]:
encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            
            nn.Conv2d(32, 64, 3, 2),
            
            nn.Conv2d(64, 64, 3, 2),
            
            nn.Conv2d(64, 64, 3, 1),
            
        )
for i in encoder:
    print(i)

Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))


In [233]:
class AutoEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.leaky_relu = nn.LeakyReLU()
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(64*6*6, 2)
        self.linear2 = nn.Linear(2, 64*6*6)
        self.sigmoid = nn.Sigmoid()


        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.Conv2d(32, 64, 3, 2),
            nn.Conv2d(64, 64, 3, 2),
            nn.Conv2d(64, 64, 3, 1),
        )

        

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 64, 3, 1, 1),
            nn.ConvTranspose2d(64, 64, 3, 2, 1),
            nn.ConvTranspose2d(64, 32, 3, 2, 1),
            nn.ConvTranspose2d(32, 1, 3, 1, 1),
        )
    def forward(self, x):
        for layer in self.encoder:
            x = F.pad(layer(x), (1,0,1,0))
            x = self.leaky_relu(x)
        x = self.linear1(self.flatten(x))
        x = self.linear2(x)
        x = x.reshape(-1, 64, 6, 6)
        for i, layer in enumerate(self.decoder):
            if i == len(self.decoder) - 1:
                x = layer(x)
            else:
                x = F.pad(layer(x), (0,1,0,1))
                x = self.leaky_relu(x)
        return self.sigmoid(x)

In [234]:
model =AutoEncoder()
model

AutoEncoder(
  (leaky_relu): LeakyReLU(negative_slope=0.01)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=2304, out_features=2, bias=True)
  (linear2): Linear(in_features=2, out_features=2304, bias=True)
  (sigmoid): Sigmoid()
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (2): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ConvTranspose2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [235]:
optimizer = optim.Adam(model.parameters(), lr=0.0005)

criterion = nn.MSELoss()

In [236]:
for epoch in range(10):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = torch.sqrt(criterion(outputs, inputs))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # print statistics
        running_loss += loss.item()
    print(f'[{epoch + 1}] loss: {running_loss / len(trainloader):.3f}')

print('Finished Training')

[1] loss: 0.452
[2] loss: 0.414
[3] loss: 0.405
[4] loss: 0.400
[5] loss: 0.397
[6] loss: 0.395
[7] loss: 0.392
[8] loss: 0.391
[9] loss: 0.389
[10] loss: 0.388
Finished Training


In [237]:
torch.save(model, "run/vae/0001_digits/weights/weight.pt")