## Notebook for training AutoEncoders
We will use this notebook for training our autoencoders. The aim of this notebook is to set up a training procedure as explaineed on the Datasets That Are Not paper.

In [42]:
import torch
import numpy as np
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import random

Make sure the current device is logged

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Torch running on {device}")

Torch running on cuda


### Import and prepare MNIST dataset
We will work with the MNIST dataset for experimenation and setup. Let's download it using the handy `torchvision.MNIST` datasets. We will first prepare our train and validation splits.

In [44]:
# convert data to torch.FloatTensor
transform = transforms.ToTensor()

# load the training and test datasets
mnist_dataset = datasets.MNIST(root='~/.pytorch/MNIST_data/', train=True,
                                   download=True, transform=transform)

# Define the split ratio for validation data
validation_split = 0.2
num_train = len(mnist_dataset)
indices = list(range(num_train))
split = int(validation_split * num_train)

# Shuffle the indices
random.seed(42)  # For reproducibility
random.shuffle(indices)

# Split the indices into training and validation sets
train_indices, val_indices = indices[split:], indices[:split]

### Prepare dataloaders
The dataloaders will be helpful to let us access the dataset in batches during training.

In [48]:
# Create data loaders for training and validation
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

batch_size = 32
train_dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=batch_size, sampler=train_sampler)
validation_dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=batch_size, sampler=val_sampler)

Visualize the data

In [17]:
import matplotlib.pyplot as plt
%matplotlib inline
    
# obtain one batch of training images
images, labels = train_data.__getitem__(0)
images = images.numpy()

# get one image from the batch
img = np.squeeze(images[0])

fig = plt.figure(figsize = (2,2)) 
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')

## Train Autoencoder
Let's train our Autoencoder architecture with a basic training loop to verify outputs and such.

Import the model from our models repository

In [70]:
from models.autoencoder import ConvAutoencoder

In [71]:
model = ConvAutoencoder()
model.to(device)
print(f"Model architecture:\n\n{model}")

Set up basic training hyper parameters and TensorBoard training visualizer

In [72]:
LEARNING_RATE = 0.001
N_EPOCHS = 10
VISUALIZE_EVERY = 1
writer = SummaryWriter('logs')
i=0

Configure basic optimizer with basic loss function.

In [73]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

Train model for `N_EPOCHS`:

In [74]:
print(f"Training network...")
for epoch in tqdm(range(0,N_EPOCHS)):
    train_loss = 0.0
    for i, (images, _) in enumerate(train_loader):
        optimizer.zero_grad()
        images = images.to(device)
        out = model(images)
        loss = criterion(out, images)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*images.size(0)
        writer.add_scalar(f"train_loss", loss.item(), i)
        i+=images.size(0)
        
    if epoch%VISUALIZE_EVERY == 0:
        with torch.no_grad():
            val_mages, _ = next(iter(validation_dataloader))
            val_images = val_images.to(device)
            
            # Forward pass
            output = model(val_images)
            
            # Log images to TensorBoard
            writer.add_images('Input Images', val_images, global_step=epoch)
            writer.add_images('Reconstructed Images', output, global_step=epoch)
        
    # print avg training statistics 
    train_loss = train_loss/len(train_loader)
    writer.add_scalar(f"avg_train_loss", train_loss, epoch)

Training network...


100%|███████████████████████████████████████████| 10/10 [01:29<00:00,  8.91s/it]
