In [None]:
# DATA PREPARATION

from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch

dataset = ImageFolder('data/raw')

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: torch.flatten(x)),
])

dataset.transform = transform
dataset.classes

In [None]:
# DATA SPLIT

from sklearn.model_selection import train_test_split

ds_idx = list(range(len(dataset)))
train_idx, test_idx = train_test_split(ds_idx, test_size=0.15, shuffle=True, stratify=dataset.targets, random_state=42)
train_idx, valid_idx = train_test_split(train_idx, test_size=0.15/0.85, shuffle=True, stratify=[dataset.targets[idx] for idx in train_idx], random_state=42)

len(train_idx), len(valid_idx), len(test_idx)

In [None]:
# GENERATE DATASETS

train_ds = torch.utils.data.Subset(dataset, train_idx)
valid_ds = torch.utils.data.Subset(dataset, valid_idx)
test_ds = torch.utils.data.Subset(dataset, test_idx)

In [None]:
# DATA LOADERS

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=batch_size, shuffle=False)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False)

In [None]:
from torch import nn

encoder = nn.Sequential(
    nn.Conv2d(3, 4, 3, 2, 1), # 64 -> 32
    nn.ReLU(),
    nn.Conv2d(4, 8, 3, 2, 1), # 32 -> 16
    nn.ReLU(),
    nn.Conv2d(8, 16, 3, 2, 1), # 16 -> 8
    nn.ReLU(),
    nn.Conv2d(16, 32, 3, 2, 1), # 8 -> 4
    nn.ReLU(),
)

decoder = nn.Sequential(
    nn.ConvTranspose2d(32, 16, 3, 2, 1), # 4 -> 8
    nn.ReLU(),
    nn.ConvTranspose2d(16, 8, 3, 2, 1), # 8 -> 16
    nn.ReLU(),
    nn.ConvTranspose2d(8, 4, 3, 2, 1), # 16 -> 32
    nn.ReLU(),
    nn.ConvTranspose2d(4, 3, 3, 2, 1), # 32 -> 64
    nn.Sigmoid(),
)

In [None]:
import lightning as L
import torch

class AutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.criterion = nn.MSELoss()

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_hat = self(x)
        loss = self.criterion(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x_hat = self(x)
        loss = self.criterion(x_hat, x)
        self.log('val_loss', loss)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer