In [1]:
import datetime
import itertools
from pathlib import Path
from typing import Tuple

import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import WandbLogger
from moviepy.editor import ImageSequenceClip
from torch.utils.data import DataLoader, TensorDataset

import wandb


class PlotEmbeddingsCallback(Callback):
    def __init__(self, log_dir: Path):
        super().__init__()
        self.log_dir = log_dir
        self.log_dir.mkdir(parents=True, exist_ok=True)

    def on_train_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule):
        embeddings = pl_module.embedding.weight.detach().cpu().numpy()

        plt.figure(figsize=(10, 10))
        plt.scatter(embeddings[:, 0], embeddings[:, 1])

        for i in range(len(embeddings)):
            plt.annotate(str(i), (embeddings[i, 0], embeddings[i, 1]))

        plt.title(f"Embedding Space at Epoch {trainer.current_epoch}")
        plt.xlabel("Dimension 1")
        plt.ylabel("Dimension 2")

        zero_pad = len(str(trainer.max_epochs))
        filename = str(
            self.log_dir / f"embeddings_epoch_{trainer.current_epoch:0{zero_pad}}.png"
        )

        plt.savefig(filename)
        plt.close()

        if isinstance(trainer.logger, WandbLogger):
            trainer.logger.log_image(key="embeddings", images=[wandb.Image(filename)])

    def on_train_end(self, trainer: L.Trainer, pl_module: L.LightningModule):
        image_files = sorted(list(self.log_dir.glob("embeddings_epoch_*.png")))

        clip = ImageSequenceClip([str(img) for img in image_files], fps=5)

        clip.write_videofile(
            str(self.log_dir / "embedding_evolution.mp4"),
            fps=5,
            codec="libx264",
            audio=False,
        )

        # wandbにも動画をログ
        if isinstance(trainer.logger, WandbLogger):
            trainer.logger.experiment.log(
                {
                    "embedding_evolution": wandb.Video(
                        str(self.log_dir / "embedding_evolution.mp4"),
                        fps=5,
                        format="mp4",
                    )
                }
            )


class TupleDataModule(L.LightningDataModule):
    def __init__(self, tuple_size: int, range_size: int, batch_size: int):
        super().__init__()
        self.tuple_size = tuple_size
        self.range_size = range_size
        self.batch_size = batch_size

        self.train_data = None
        self.val_data = None
        self.test_data = None

    def setup(self, stage: str | None = None):
        if stage == "fit" or stage is None:
            all_combinations = torch.tensor(
                list(itertools.product(range(self.range_size), repeat=self.tuple_size))
            )
            indices = torch.randperm(len(all_combinations))

            train_size = int(len(all_combinations) * 0.8)
            train_indices = indices[:train_size]
            val_indices = indices[train_size:]

            train_tensor = all_combinations[train_indices]
            val_tensor = all_combinations[val_indices]

            self.train_data = TensorDataset(train_tensor)
            self.val_data = TensorDataset(val_tensor)

    def train_dataloader(self) -> DataLoader[Tuple[torch.Tensor, ...]]:
        assert self.train_data is not None
        return DataLoader(
            self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=4
        )

    def val_dataloader(self) -> DataLoader[Tuple[torch.Tensor, ...]]:
        assert self.val_data is not None
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self) -> DataLoader[Tuple[torch.Tensor, ...]]:
        assert self.test_data is not None
        return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=4)


class TupleAutoencoder(L.LightningModule):
    def __init__(self, tuple_length: int, range_size: int, embedding_dim: int):
        super().__init__()
        self.tuple_length = tuple_length
        self.range_size = range_size

        self.embedding = nn.Embedding(range_size, embedding_dim)
        self.linear = nn.Linear(tuple_length * embedding_dim, tuple_length * range_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.size(0)

        embedded = self.embedding(x)
        embedded_flat = embedded.view(batch_size, -1)
        output = self.linear(embedded_flat)
        output = output.view(batch_size, self.tuple_length, self.range_size)
        output = F.softmax(output, dim=-1)

        return output

    def _compute_loss(self, batch: torch.Tensor) -> torch.Tensor:
        x = batch[0]
        output = self(x)

        loss = F.nll_loss(
            output.log().view(-1, self.range_size), x.view(-1), reduction="sum"
        )
        return loss / x.size(0)

    def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        loss = self._compute_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch: torch.Tensor, batch_idx: int):
        loss = self._compute_loss(batch)
        self.log("val_loss", loss)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.parameters(), lr=0.001)


torch.autograd.set_detect_anomaly(True)

TUPLE_SIZE = 3
RANGE_SIZE = 50
EMBEDDING_DIM = 2
BATCH_SIZE = 2048
NUM_EPOCHS = 20

datamodule = TupleDataModule(
    tuple_size=TUPLE_SIZE, range_size=RANGE_SIZE, batch_size=BATCH_SIZE
)

model = TupleAutoencoder(TUPLE_SIZE, RANGE_SIZE, EMBEDDING_DIM)

run_name = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
wandb_logger = WandbLogger(run_name, project="tuple-autoencoder")

log_dir = Path("logs") / run_name
plot_callback = PlotEmbeddingsCallback(log_dir)

trainer = L.Trainer(
    max_epochs=NUM_EPOCHS,
    accelerator="auto",
    devices=1,
    logger=wandb_logger,
    callbacks=[plot_callback],
)

trainer.fit(model=model, datamodule=datamodule)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mgizmrkv[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params | Mode 
------------------------------------------------
0 | embedding | Embedding | 100    | train
1 | linear    | Linear    | 1.1 K  | train
------------------------------------------------
1.2 K     Trainable params
0         Non-trainable params
1.2 K     Total params
0.005     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Epoch 19: 100%|██████████| 49/49 [00:00<00:00, 51.09it/s, v_num=vol3]      

`Trainer.fit` stopped: `max_epochs=20` reached.


Moviepy - Building video logs/2024-10-31_23-54-18/embedding_evolution.mp4.
Moviepy - Writing video logs/2024-10-31_23-54-18/embedding_evolution.mp4

Epoch 19: 100%|██████████| 49/49 [00:01<00:00, 36.72it/s, v_num=vol3]



Moviepy - Done !                                                     
Moviepy - video ready logs/2024-10-31_23-54-18/embedding_evolution.mp4
Epoch 19: 100%|██████████| 49/49 [00:01<00:00, 25.11it/s, v_num=vol3]



Epoch 19: 100%|██████████| 49/49 [00:01<00:00, 25.07it/s, v_num=vol3]
