In [9]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms

In [10]:
transform = transforms.ToTensor()

In [11]:
train_dataset = torchvision.datasets.MNIST(root = "./data", train=True, download = True, transform=transform)
valid_dataset = torchvision.datasets.MNIST(root = "./data", train=False, download = True, transform=transform)

In [12]:
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=100)

In [14]:
class Encoder(nn.Module):
    def __init__(self, input_size=28*28, hidden_size1=128, hidden_size2=16, z_dim=2):
        super().__init__()
        self.fc1=nn.Linear(input_size, hidden_size1)
        self.fc2=nn.Linear(hidden_size1, hidden_size2)
        self.fc3=nn.Linear(hidden_size2, z_dim)
        self.relu = nn.ReLU()
    def forward(self, x):
        x=self.relu(self.fc1(x))
        x=self.relu(self.fc2(x))
        x=self.fc3(x)
        return x

In [15]:
class Decoder(nn.Module):
    def __init__(self, output_size=28*28, hidden_size1=128, hidden_size2=16, z_dim=2):
        super().__init__()
        self.fc1=nn.Linear(z_dim, hidden_size2)
        self.fc2=nn.Linear(hidden_size2, hidden_size1)
        self.fc3=nn.Linear(hidden_size1, output_size)
        self.relu = nn.ReLU()
    def forward(self, x):
        x=self.relu(self.fc1(x))
        x=self.relu(self.fc2(x))
        x=torch.sigmoid(self.fc3(x))
        return x

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [17]:
enc = Encoder().to(device)
dec = Decoder().to(device)

In [18]:
loss_fn = nn.MSELoss()
optimizer_enc = torch.optim.Adam(enc.parameters())
optimizer_dec = torch.optim.Adam(dec.parameters())

In [20]:
train_loss = []
num_epochs = 200
for epoch in range(num_epochs):
    train_epoch_loss = 0
    for (imgs, _) in train_dl:
        imgs = imgs.to(device)
        imgs = imgs.flatten(1)
        latents = enc(imgs)
        output = dec(latents)
        loss = loss_fn(output, imgs)
        train_epoch_loss += loss.cpu().detach().numpy()
        optimizer_enc.zero_grad()
        optimizer_dec.zero_grad()
        loss.backward()
        optimizer_enc.step()
        optimizer_dec.step()
    train_loss.append(train_epoch_loss)