In [2]:
import random
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data

In [3]:
np.random.seed(1234)
torch.manual_seed(1234)
random.seed(1234)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

## Data

In [5]:
train_data = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    ]),
    download=True
)

test_data = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    ]),
    download=True
)

In [6]:
print(train_data.data.size())
print(test_data.data.size())

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])


## DataLoader

In [7]:
batch_size = 128
train_dataloader = data.DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True)
test_dataloader = data.DataLoader(dataset = test_data, batch_size=batch_size, shuffle=False)

dataloaders_dict = {'train': train_dataloader, 'val': test_dataloader}

## Medel

In [23]:
class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(28*28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 8),
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(8, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28*28),
            torch.nn.Sigmoid()
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [24]:
net = AutoEncoder()

## Loss Function

In [13]:
criterion = torch.nn.MSELoss()

## Optimizer

In [25]:
optim = torch.optim.SGD(params=net.parameters(), lr=0.01, momentum=0.9)

## Training

In [26]:
def train_model(net, dataloader_dict, criterion, optimizer, epochs):
    for epoch in range(epochs):
        print("Epoch {}/{}".format(epoch+1, epochs))
        print("------------")

        for phase in ['train', 'val']:
            if phase == 'train': net.train()
            else: net.eval()

            epoch_loss = 0

            if (epoch==0) and (phase=='train'): continue

            for inputs, labels in tqdm(dataloader_dict[phase]):
                inputs = inputs.reshape(-1, 28*28)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase=='train'):
                    outputs = net(inputs)
                    loss = criterion(outputs, inputs)
                    #_, preds = torch.max(outputs, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    epoch_loss += loss.item() * inputs.size(0)
                
            epoch_loss = epoch_loss / len(dataloader_dict[phase].dataset)

            print('{} Loss: {:.4f}'.format(epoch, epoch_loss))


In [28]:
train_model(net, dataloaders_dict, criterion, optim, epochs=5)

Epoch 1/5
------------


100%|██████████| 79/79 [00:01<00:00, 58.61it/s]


0 Loss: 1.2403
Epoch 2/5
------------


100%|██████████| 469/469 [00:09<00:00, 51.55it/s]


1 Loss: 1.1819


100%|██████████| 79/79 [00:01<00:00, 62.84it/s]


1 Loss: 1.0357
Epoch 3/5
------------


100%|██████████| 469/469 [00:09<00:00, 51.05it/s]


2 Loss: 0.8392


100%|██████████| 79/79 [00:01<00:00, 62.37it/s]


2 Loss: 0.8076
Epoch 4/5
------------


100%|██████████| 469/469 [00:09<00:00, 52.09it/s]


3 Loss: 0.8027


100%|██████████| 79/79 [00:01<00:00, 56.91it/s]


3 Loss: 0.8052
Epoch 5/5
------------


100%|██████████| 469/469 [00:09<00:00, 50.56it/s]


4 Loss: 0.8011


100%|██████████| 79/79 [00:01<00:00, 56.50it/s]

4 Loss: 0.8040



