## Notebook for training AutoEncoders
We will use this notebook for training our autoencoders.

In [None]:
import torch
import numpy as np
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

Make sure the current device is logged

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

### Import and prepare MNIST dataset

In [None]:
# convert data to torch.FloatTensor
transform = transforms.ToTensor()

# load the training and test datasets
train_data = datasets.MNIST(root='~/.pytorch/MNIST_data/', train=True,
                                   download=True, transform=transform)
test_data = datasets.MNIST(root='~/.pytorch/MNIST_data/', train=False,
                                  download=True, transform=transform)

### Prepare dataloaders

In [None]:
# Create training and test dataloaders
num_workers = 0
# how many samples per batch to load
batch_size = 20

# prepare data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)

Visualize the data

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
    
# obtain one batch of training images
images, labels = train_data.__getitem__(0)
images = images.numpy()

# get one image from the batch
img = np.squeeze(images[0])

fig = plt.figure(figsize = (5,5)) 
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')

## Train Autoencoder

Completed:
- Training converges.

TODO:
- Visualize reconstruction every `VISUALIZE_EVERY` epochs.

In [None]:
from models.autoencoder import ConvAutoencoder
import torch.nn as nn
import torch.nn.functional as F

In [None]:
model = ConvAutoencoder()
model.to(device)
print(model)

In [None]:
LEARNING_RATE = 0.001
N_EPOCHS = 30
VISUALIZE_EVERY = 10
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
writer = SummaryWriter('logs')
i=0

In [None]:
for epoch in range(0,N_EPOCHS):
    train_loss = 0.0
    for i, (images, _) in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        images = images.to(device)
        out = model(images)
        loss = criterion(out, images)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*images.size(0)
        writer.add_scalar(f"train_loss", loss.item(), i)
        i+=images.size(0)
        
    # if epoch%VISUALIZE_EVERY == 0:
    #     print("Visualizing images")
    #     grid = torchvision.utils.make_grid(out)
    #     writer.add_image('images', out, epoch)
    #     out_grid = torchvision.utils.make_grid(out)
    #     writer.add_image('reconstructed images', out_grid, epoch)
        
    # print avg training statistics 
    train_loss = train_loss/len(train_loader)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(
        epoch, 
        train_loss
        ))

In [None]:
loss