# AUTOENCODER

In [None]:
import torch.nn as nn
import torch 
import torchvision
import torchvision.transforms as T
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

from torch.nn import Module, ModuleList, Conv2d, ConvTranspose2d
from torch.nn import ReLU, BCELoss, Sigmoid
from torch.utils.data import Subset
from torch.optim import Adam

from tqdm import tqdm
import os

In [None]:
# CONFIG
batch_size = 64
epoch = 50
lr = 0.01
num_workers = 4
device = "cuda" if torch.cuda.is_available() else "cpu"
pin_memory = True if device == "cuda" else False

base_output = "output"

## DATASET

In [None]:
transforms = T.Compose(
    [T.ToTensor()]
    )

train = Subset(datasets.MNIST(root="./data", train = True, download = True, transform = transforms), range(2000))
test = Subset(datasets.MNIST(root="./data", train = False, download = True, transform = transforms), range(2000))

if __name__ == "__main__":
    train_dataloader = DataLoader(train, shuffle = True, batch_size=batch_size, num_workers = num_workers if torch.cuda.is_available() else 0, pin_memory = pin_memory)
    test_dataloader = DataLoader(test, shuffle = True, batch_size=batch_size, num_workers = num_workers if torch.cuda.is_available() else 0, pin_memory = pin_memory)

In [None]:
class MODEL(Module):
    def __init__(self, channels = [1, 16, 32, 64], bottleneck_dim = 484):
        super().__init__()
        self.channels = channels
        self.bottleneck_dim = bottleneck_dim

        self.encoder = ModuleList([
            Conv2d(channels[i], channels[i+1], 3) for i in range (len(channels)-1)
        ])


        self.flatten = nn.Flatten()
        self.bottleneck = nn.Sequential(
            nn.Linear(channels[-1] * 22 * 22, bottleneck_dim * 3),
            nn.ReLU(),
            nn.Linear(bottleneck_dim * 3, bottleneck_dim * 2),
            nn.ReLU(),
            nn.Linear(bottleneck_dim * 2, bottleneck_dim * 1),
            nn.ReLU(),
            nn.Linear(
                bottleneck_dim * 1, channels[-1] * 22 * 22
            ),  # Output correct shape
            nn.ReLU(),
        )

        self.unflatten = nn.Unflatten(1, (channels[-1], 22, 22))
        
        
        self.decoder = ModuleList([
            ConvTranspose2d(channels[len(channels)-1-i], channels[len(channels)-2-i], 3) for i in range (len(channels)-1)
        ])
        self.relu  =ReLU()
        self.sigmoid = Sigmoid()


    def forward(self, x):
        encoderList = []
        decoderList = []

        for i in range(len(self.channels) -1):
            x = self.encoder[i](x)
            encoderList.append(x)
            x = self.relu(x)

        x = self.flatten(x)
        x = self.bottleneck(x)
        x = self.unflatten(x)
        
        for i in range(len(self.channels) -1):
            x = self.decoder[i](x)
            if i < len(self.decoder) - 1:  # Apply ReLU to all but last layer
                x = self.relu(x)
            decoderList.append(x) 
        
        x = self.sigmoid(x)
        return (x, encoderList, decoderList)

In [None]:
sample = torch.randn(1,1, 28, 28)
model = MODEL()
output = model.forward(sample)
print(model)

In [None]:
encoderlist= output[1]
decoderlist= output[2]

for i in range(len(encoderlist)):
    print(f"Encoder List : {encoderlist[i].shape}")
print()
for i in range(len(encoderlist)):
    print(f"Decoder List : {decoderlist[i].shape}") 

print("\n", torch.flatten(encoderlist[-1], start_dim = 1).shape) 
# dim = 1 preserves the batches or in this case filters
# [64, 22, 22] becomes [64, 484]

# Before training
plt.imshow(decoderlist[-1].detach().numpy().squeeze())

## Training

In [None]:
model = MODEL().to(device=device)
lossFunc = BCELoss()
opt = Adam(model.parameters(), lr = lr)

trainsteps = len(train) // batch_size
teststeps = len(test) // batch_size

h = {"train_loss": [], "test_loss": []}

for e in tqdm(range(epoch)):
    model.train()

    totaltrainloss = 0
    totaltestloss = 0

    for x, _ in train_dataloader:
        x = x.to(device)

        pred = model(x)[0]
        loss = lossFunc(pred, x)

        opt.zero_grad()
        loss.backward()
        opt.step()

        totaltrainloss += loss

    with torch.no_grad():
        model.eval()

        for x, _ in test_dataloader:
            x = x.to(device)

            pred = model(x)[0]
            loss = lossFunc(pred, x)

            totaltestloss += loss

    avgtrainloss = totaltrainloss / trainsteps
    avgtestloss = totaltestloss / teststeps

    h["train_loss"].append(avgtrainloss.detach().numpy())
    h["test_loss"].append(avgtestloss.detach().numpy())

    print(f"[INFO] EPOCH : {e+1}/{epoch}")
    print(f"[INFO] TRAIN LOSS : {avgtrainloss} --- TEST LOSS : {avgtestloss}")

    if e+1 % 10 == 0:
        print("INFO MODEL SAVED")
        torch.save(model, os.path.join(base_output, f"model_{e}.pth") )

In [None]:
trained_model = torch.load(os.path.join(base_output, f"model_{40}.pth"), weights_only=False)
trained_model.eval()

plt.imshow(test[0][0].squeeze())
plt.show()
plt.imshow(trained_model(test[0][0])[0].cpu().detach().numpy().squeeze())
plt.show()