In [None]:
!pip install comet_ml

In [None]:
import comet_ml  # noqa: F401 (import comet_ml before pytorch)
import torch
import torch.nn as nn
import os
from pathlib import Path
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.loggers import CometLogger
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np

from typing import Optional
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
from torchvision import transforms

In [None]:
os.environ["COMET_API_KEY"] = "QKss0LkkOPjmHkPH4B2koQS16"
os.environ["COMET_PROJECT_NAME"] = "unsupervised-learning"
os.environ["COMET_WORKSPACE"] = "ekipa"

In [None]:
CONFIG = {
    "data": {
        "batch_size": 128,
        "num_workers": 3,
        "image_size": 256,
        "total_samples": 2048,
        "val_split": 0.1,
        "test_split": 0.1,
        "data_dir": "/kaggle/working/data",
    },
    "model": {
        "input_channels": 3,
        "latent_channels": 128,
        "learning_rate": 1e-3,
    },
    "training": {
        "max_epochs": 150,
        "gradient_clip_val": 1.0,
        "early_stopping_patience": 10,
        "lr_scheduler_patience": 5,
        "lr_scheduler_factor": 0.5,
    },
    "experiment": {
        "name": "autoencoder-d-7-5s",
        "seed": 42,
        "visualization_samples": 8,
        "recon_log_every_n_epochs": 5,
    },
}

In [None]:
class AutoEncoder(pl.LightningModule):
    def __init__(
        self,
        input_channels: int = 3,
        latent_channels: int = 128,
        learning_rate: float = 1e-3,
        scheduler_patience: int = 5,
        scheduler_factor: float = 0.5,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.scheduler_patience = scheduler_patience
        self.scheduler_factor = scheduler_factor

        # U-NET https://arxiv.org/pdf/1505.04597

        self.encoder = nn.Sequential(
            # (input_channels x 256 x 256) -> (64 x 128 x 128)
            nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.GELU(),

            # (64 x 128 x 128) -> (128 x 64 x 64)
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.GELU(),

            # (128 x 64 x 64) -> (256 x 32 x 32)
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.GELU(),

            # (256 x 32 x 32) -> (512 x 16 x 16)
            nn.Conv2d(256, 512, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(512),
            nn.GELU(),

            # (512 x 16 x 16) -> (latent_channels x 8 x 8)
            nn.Conv2d(512, latent_channels, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(latent_channels),
            nn.GELU(),
        )

        self.decoder = nn.Sequential(
            # (latent_channels x 8 x 8) -> (512 x 16 x 16)
            nn.ConvTranspose2d(
                latent_channels, 512, kernel_size=4, stride=2, padding=1
            ),
            nn.BatchNorm2d(512),
            nn.GELU(),

            # (512 x 16 x 16) -> (256 x 32 x 32)
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.GELU(),

            # (256 x 32 x 32) -> (128 x 64 x 64)
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.GELU(),

            # (128 x 64 x 64) -> (64 x 128 x 128)
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.GELU(),
            
            # (64 x 128 x 128) -> (3 x 256 x 256)
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        latent_space = self.encoder(x)
        reconstructed_image = self.decoder(latent_space)
        return reconstructed_image

    def training_step(self, batch, batch_idx):
        images = batch["image"]
        reconstructed = self(images)

        loss = nn.functional.mse_loss(reconstructed, images)

        self.log("train_loss", loss, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        images = batch["image"]
        reconstructed = self(images)

        loss = nn.functional.mse_loss(reconstructed, images)

        self.log("val_loss", loss, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode="min",
            factor=self.scheduler_factor,
            patience=self.scheduler_patience,
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "train_loss",
            },
        }

In [None]:
class WikiArtStreamingDataset(IterableDataset):
    """Streaming Dataset wrapper for WikiArt that applies transforms."""

    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __iter__(self):
        for item in self.dataset:
            image = item["image"]

            if self.transform:
                image = self.transform(image)

            yield {"image": image}

In [None]:
class WikiArtDataModule(pl.LightningDataModule):
    """PyTorch Lightning DataModule for WikiArt dataset with local disk caching."""

    def __init__(
        self,
        batch_size: int = 16,
        num_workers: int = 0,
        image_size: int = 256,
        total_samples: Optional[int] = None,
        val_split: float = 0.1,
        test_split: float = 0.1,
        data_dir: str = "/kaggle/working/data",
    ):
        super().__init__()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.image_size = image_size
        self.total_samples = total_samples
        self.val_split = val_split
        self.test_split = test_split
        self.data_dir = Path(data_dir)

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ])

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def prepare_data(self):
        """Download dataset to local directory if not already downloaded."""
        self.data_dir.mkdir(parents=True, exist_ok=True)
        
        dataset_path = self.data_dir / "Artificio___WikiArt_Full"
        if not dataset_path.exists():
            print(f"Downloading WikiArt dataset to {self.data_dir}...")
            load_dataset(
                "Artificio/WikiArt_Full",
                cache_dir=str(self.data_dir),
                keep_in_memory=False,
            )
            print("Dataset download completed!")
        else:
            print(f"Dataset already exists at {dataset_path}")

    def setup(self, stage: Optional[str] = None):
        """Load dataset from local disk using streaming."""
        
        dataset = load_dataset(
            "Artificio/WikiArt_Full",
            cache_dir=str(self.data_dir),
            split="train",
            streaming=True
        )
        
        if self.total_samples is None:
            self.total_samples = 103_250

        self.test_size = int(self.total_samples * self.test_split)
        self.val_size = int(self.total_samples * self.val_split)
        self.train_size = self.total_samples - self.val_size - self.test_size

        train_hf = dataset.take(self.train_size)
        val_hf = dataset.skip(self.train_size).take(self.val_size)
        test_hf = dataset.skip(self.train_size + self.val_size).take(self.test_size)

        self.train_dataset = WikiArtStreamingDataset(train_hf, transform=self.transform)
        self.val_dataset = WikiArtStreamingDataset(val_hf, transform=self.transform)
        self.test_dataset = WikiArtStreamingDataset(test_hf, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

In [None]:
class ReconstructionLogger(Callback):
    """Logs reconstruction examples to Comet every N epochs."""

    def __init__(self, log_every_n_epochs: int = 5, num_samples: int = 8):
        super().__init__()
        self.log_every_n_epochs = log_every_n_epochs
        self.num_samples = num_samples
        self.sample_batch = None

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx == 0 and self.sample_batch is None:
            self.sample_batch = {
                "image": batch["image"][: self.num_samples].detach().cpu()
            }

    def on_validation_epoch_end(self, trainer, pl_module):
        if self.sample_batch is None:
            return

        if (trainer.current_epoch + 1) % self.log_every_n_epochs != 0:
            return

        images = self.sample_batch["image"].to(pl_module.device)
        pl_module.eval()
        with torch.inference_mode():
            reconstructed = pl_module(images)
            if isinstance(reconstructed, tuple):
                reconstructed = reconstructed[0]

        pl_module.train()
        images = images.cpu()
        reconstructed = reconstructed.cpu()
        comparison = torch.stack([images, reconstructed], dim=1)
        comparison = comparison.view(-1, *images.shape[1:])
        grid = vutils.make_grid(
            comparison,
            nrow=2,
            normalize=True,
            value_range=(0, 1),
            padding=2,
        )
        if trainer.logger is not None:
            trainer.logger.experiment.log_image(
                grid.permute(1, 2, 0).numpy(),
                name="reconstructions",
                step=trainer.current_epoch,
            )

In [None]:
def visualize_results(
    model,
    data_module,
    num_samples: int = 8,
    output_path: str = "reconstruction_results.png",
):
    model.eval()
    val_loader = data_module.val_dataloader()
    batch = next(iter(val_loader))
    images = batch["image"][:num_samples]
    device = next(model.parameters()).device
    images = images.to(device)

    with torch.inference_mode():
        reconstructed = model(images)
        if isinstance(reconstructed, tuple):
            reconstructed = reconstructed[0]

    images = images.cpu()
    reconstructed = reconstructed.cpu()
    _, axes = plt.subplots(2, num_samples, figsize=(20, 5))

    for i in range(num_samples):
        img_orig = images[i].permute(1, 2, 0).numpy()
        axes[0, i].imshow(np.clip(img_orig, 0, 1))
        axes[0, i].axis("off")
        if i == 0:
            axes[0, i].set_title("Original", fontsize=12)
        img_recon = reconstructed[i].permute(1, 2, 0).numpy()
        axes[1, i].imshow(np.clip(img_recon, 0, 1))
        axes[1, i].axis("off")
        if i == 0:
            axes[1, i].set_title("Reconstructed", fontsize=12)
            
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    print(f"\nResults saved to '{output_path}'")
    plt.show()

In [None]:
pl.seed_everything(CONFIG["experiment"]["seed"], workers=True)

In [None]:
data_module = WikiArtDataModule(
    batch_size=CONFIG["data"]["batch_size"],
    num_workers=CONFIG["data"]["num_workers"],
    image_size=CONFIG["data"]["image_size"],
    total_samples=CONFIG["data"]["total_samples"],
    val_split=CONFIG["data"]["val_split"],
    test_split=CONFIG["data"]["test_split"],
)

In [None]:
model = AutoEncoder(
    input_channels=CONFIG["model"]["input_channels"],
    latent_channels=CONFIG["model"]["latent_channels"],
    learning_rate=CONFIG["model"]["learning_rate"],
    scheduler_patience=CONFIG["training"]["lr_scheduler_patience"],
    scheduler_factor=CONFIG["training"]["lr_scheduler_factor"],
)

print(model)

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints",
    filename="autoencoder-{epoch:02d}-{val_loss:.4f}",
    monitor="val_loss",
    mode="min",
    save_top_k=3,
    save_last=True,
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=CONFIG["training"]["early_stopping_patience"],
    mode="min",
)

logger = CometLogger(
    api_key=os.environ["COMET_API_KEY"],
    project=os.environ["COMET_PROJECT_NAME"],
    workspace=os.environ["COMET_WORKSPACE"],
    name=CONFIG["experiment"]["name"],
)

recon_logger = ReconstructionLogger(
    log_every_n_epochs=CONFIG["experiment"]["recon_log_every_n_epochs"],
    num_samples=CONFIG["experiment"]["visualization_samples"],
)

logger.log_hyperparams(CONFIG)
logger.experiment.log_parameter("config_source", "notebook_dict")

In [None]:
trainer = pl.Trainer(
    max_epochs=CONFIG["training"]["max_epochs"],
    accelerator="auto",
    devices=1,
    callbacks=[checkpoint_callback, early_stop_callback, recon_logger],
    logger=logger,
    log_every_n_steps=10,
    gradient_clip_val=CONFIG["training"]["gradient_clip_val"],
    deterministic=True,
)

In [None]:
trainer.fit(model, data_module)
print(f"Best model path: {checkpoint_callback.best_model_path}")

In [None]:
visualize_results(
    model, data_module, num_samples=CONFIG["experiment"]["visualization_samples"]
)
logger.experiment.log_image("reconstruction_results.png", name="Final Reconstructions")
logger.experiment.end()