In [2]:
import os
_, current_folder_name = os.path.split(os.getcwd())
if current_folder_name == "notebooks":
    os.chdir("..")  
print(os.getcwd())  

/gpfs/data/oermannlab/users/ngok02/ModelCollapse/vae


In [7]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

In [12]:
latent_dims = 20
batch_size = 256
capacity = 64
learning_rate = 1e-3
num_epochs = 30

In [5]:
train_dataset = datasets.MNIST(root='../data', train = True, download = True, transform = ToTensor())
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='../data', train = False, download = True, transform = ToTensor())
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [6]:
# device = 'cpu'
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [36]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        c = capacity
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=4, stride=2, padding=1) # out: c x 14 x 14
        self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1) # out: c x 7 x 7
        self.fc1 = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)
        self.fc2 = nn.Linear(in_features=latent_dims, out_features=10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x

In [37]:
classifier = Classifier()
classifier = classifier.to(device)

optimizer = torch.optim.Adam(params=classifier.parameters(), lr=learning_rate, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
classifier.train()

train_loss_avg = []

print('Training ...')
for epoch in range(num_epochs):
    train_loss_avg.append(0)
    num_batches = 0
    
    for image_batch, y in train_dataloader:
        image_batch = image_batch.to(device)
        target = F.one_hot(y, num_classes=10).float().to(device)
        
        pred = classifier(image_batch)
        loss = criterion(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss_avg[-1] += loss.item()
        num_batches += 1

    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average cross entropy error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))

Training ...
Epoch [1 / 30] average cross entropy error: 1.700136
Epoch [2 / 30] average cross entropy error: 1.591612
Epoch [3 / 30] average cross entropy error: 1.578432
Epoch [4 / 30] average cross entropy error: 1.574034
Epoch [5 / 30] average cross entropy error: 1.570537
Epoch [6 / 30] average cross entropy error: 1.568301
Epoch [7 / 30] average cross entropy error: 1.528654
Epoch [8 / 30] average cross entropy error: 1.478766
Epoch [9 / 30] average cross entropy error: 1.474834
Epoch [10 / 30] average cross entropy error: 1.472910
Epoch [11 / 30] average cross entropy error: 1.471807
Epoch [12 / 30] average cross entropy error: 1.470991
Epoch [13 / 30] average cross entropy error: 1.470055
Epoch [14 / 30] average cross entropy error: 1.469671
Epoch [15 / 30] average cross entropy error: 1.469444
Epoch [16 / 30] average cross entropy error: 1.468882
Epoch [17 / 30] average cross entropy error: 1.468390
Epoch [18 / 30] average cross entropy error: 1.467627
Epoch [19 / 30] average 

In [43]:
classifier.eval()
accuracy = 0
num_samples = 0
for image_batch, y in test_dataloader:
    image_batch = image_batch.to(device)
    target = y.to(device)
    
    pred = torch.argmax(classifier(image_batch), dim=1)
    accuracy += torch.sum(pred == target).item()
    num_samples += len(image_batch)

print(f"Accuracy: {accuracy / num_samples}")

Accuracy: 0.9883
