# Wielomodalny Autoenkoder

In [35]:
from typing import Dict, List, Type

import pandas as pd
import torch
from sklearn.decomposition import PCA
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl
from abc import abstractmethod
import os

## Dataset

In [31]:
class MyDataset(Dataset):

    def __init__(self, df: pd.DataFrame):
        super().__init__()
        self._df = df

    def __len__(self):
        return self._df.shape[0]

    def __getitem__(
        self,
        index: int,
    ) -> Dict[str, torch.Tensor]:
        return dict(self._df.iloc[index])


class DataModule(pl.LightningDataModule):

    def __init__(
        self,
        train_path: str,
        test_path: str,
        batch_size: int = 64,
        seed: int = 42,
    ):
        super().__init__()

        train_df = pd.read_pickle(train_path)
        test_df = pd.read_pickle(test_path)

        self.df = {
            "train": train_df,
            "test": test_df,
            "all": pd.concat([train_df, test_df]),
        }
        self.batch_size = batch_size

    def train_dataloader(self) -> DataLoader:
        return self._dataloader("train")

    def val_dataloader(self) -> DataLoader:
        return self._dataloader("test")

    def all_dataloader(self) -> DataLoader:
        return self._dataloader("all")

    def _dataloader(self, split: str) -> DataLoader:
        return DataLoader(
            MyDataset(self.df[split]),
            batch_size=self.batch_size,
            shuffle=split == "train",
            num_workers=int(os.environ.get("NUM_WORKERS", 0)),
        )


## Model

In [14]:
class MultimodalEncoder(nn.Module):

    def __init__(
        self,
        modality_names: List[str],
        in_dims: Dict[str, int],
        hidden_dims: List[int],
        out_dim: int,
        last_activation: Type[nn.Module],
    ):
        super().__init__()

        self.modality_names = modality_names
        
        self.modality_to_encoder_map = nn.ModuleDict()
        for modality_name in modality_names:
            self.modality_to_encoder_map[modality_name] = nn.Sequential(
                nn.Linear(in_dims[modality_name], hidden_dims[0]),
                nn.ReLU(inplace=True),
                *[
                    layer
                    for idx in range(len(hidden_dims) - 1)
                    for layer in (nn.Linear(hidden_dims[idx], hidden_dims[idx + 1]), nn.ReLU(inplace=True))
                ],
                nn.Linear(hidden_dims[-1], out_dim),
                last_activation(),
            )
            

    def forward(self, x: Dict[str, torch.Tensor]) -> List[torch.Tensor]:
        return [
            self.modality_to_encoder_map[modality_name](x[modality_name])
            for modality_name in self.modality_names
        ]
        
    @staticmethod
    def from_hparams(hparams):
        return MultimodalEncoder(
            modality_names=hparams["modality_names"],
            in_dims=hparams["data_dims"],
            hidden_dims=hparams["hidden_dims"],
            out_dim=hparams["emb_dim"],
            last_activation=nn.Tanh,
        )


In [15]:
class AvgFusion(nn.Module):
    
    def forward(self, h: List[torch.Tensor]) -> torch.Tensor:
        return sum(h) / len(h)
    
    
class MLPFusion(nn.Module):
    
    def __init__(
        self,
        modality_dim: int,
        num_modalities: int,
        hidden_dims: List[int],
        out_dim: int,
        last_activation: Type[nn.Module],
    ):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(modality_dim * num_modalities, hidden_dims[0]),
            nn.ReLU(inplace=True),
            *[
                layer
                for idx in range(len(hidden_dims) - 1)
                for layer in (nn.Linear(hidden_dims[idx], hidden_dims[idx + 1]), nn.ReLU(inplace=True))
            ],
            nn.Linear(hidden_dims[-1], out_dim),
            last_activation(),
        )
        
        
    def forward(self, h: List[torch.Tensor]) -> torch.Tensor:
        mlp_input = torch.cat(h, dim=1)
        return self.mlp(mlp_input)
        

In [22]:
class MultimodalDecoder(nn.Module):

    def __init__(
        self,
        modality_names: List[str],
        in_dim: int,
        hidden_dims: List[int],
        out_dims: Dict[str, int],
        last_activation: Type[nn.Module],
    ):
        super().__init__()

        self.modality_names = modality_names
        
        self.modality_to_encoder_map = nn.ModuleDict()
        for modality_name in modality_names:
            self.modality_to_encoder_map[modality_name] = nn.Sequential(
                nn.Linear(in_dim, hidden_dims[0]),
                nn.ReLU(inplace=True),
                *[
                    layer
                    for idx in range(len(hidden_dims) - 1)
                    for layer in (nn.Linear(hidden_dims[idx], hidden_dims[idx + 1]), nn.ReLU(inplace=True))
                ],
                nn.Linear(hidden_dims[-1], out_dims[modality_name]),
                last_activation(),
            )

    def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
        return {
            modality_name: self.modality_to_encoder_map[modality_name](z)
            for modality_name in self.modality_names
        }

In [26]:
class BaseAE(pl.LightningModule):

    def __init__(self, hparams, encoder: nn.Module, decoder: nn.Module):
        super().__init__()

        self.save_hyperparameters(hparams)

        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx: int):
        return {"loss": self._common_step(batch)}

    def training_epoch_end(self, outputs):
        avg_loss = self._summarize_outputs(outputs)

        self.log("step", self.trainer.current_epoch)
        self.log("train/loss", avg_loss, on_epoch=True, on_step=False)

    def validation_step(self, batch, batch_idx: int):
        return {"loss": self._common_step(batch)}

    def validation_epoch_end(self, outputs):
        avg_loss = self._summarize_outputs(outputs)

        self.log("step", self.trainer.current_epoch)
        self.log("val/loss", avg_loss, on_epoch=True, on_step=False)

    @abstractmethod
    def _common_step(self, batch) -> torch.Tensor:
        pass

    @staticmethod
    def _summarize_outputs(outputs):
        losses = [out["loss"] for out in outputs]

        avg_loss = np.mean([loss.cpu() for loss in losses])

        return avg_loss

    def configure_optimizers(self):
        return torch.optim.AdamW(
            params=self.parameters(),
            lr=self.hparams["lr"],
            weight_decay=self.hparams["weight_decay"],
        )

In [27]:
from torch.nn.functional import mse_loss


class MultimodalAE(BaseAE):

    def __init__(self, hparams):
        encoder_cls = hparams["encoder_cls"]
        
        super().__init__(
            hparams=hparams,
            encoder=encoder_cls.from_hparams(hparams),
            decoder=MultimodalDecoder(
                modality_names=hparams["modality_names"],
                in_dim=hparams["emb_dim"],
                hidden_dims=hparams["hidden_dims"][::-1],
                out_dims=hparams["data_dims"],
                last_activation=nn.Identity,
            ),
        )
        
        if hparams["fusion"] == "Avg":
            self.fusion = AvgFusion()
        elif hparams["fusion"] == "MLP":
            self.fusion = MLPFusion(
                modality_dim=hparams["emb_dim"],
                num_modalities=len(hparams["modality_names"]),
                hidden_dims=[hparams["emb_dim"], hparams["emb_dim"]],
                out_dim=hparams["emb_dim"],
                last_activation=nn.Tanh,
            )
        else:
            raise ValueError(f"Unknown fusion module: \"{hparams['fusion']}\"")

    def forward(self, batch) -> torch.Tensor:
        encoded = self.encoder(batch)
        return self.fusion(encoded)

    def _common_step(self, batch) -> torch.Tensor:
        z = self.forward(batch)
        x_rec = self.decoder(z)
        mse = 0
        for modality_name in x_rec:
            mse += mse_loss(batch[modality_name], x_rec[modality_name])
        return mse / len(batch)

## Training loop

In [20]:
def train_model(
    model_cls: Type[pl.LightningModule],
    hparams,
    datamodule: DataModule,
    accelerator="gpu",
):
    pl.seed_everything(42)

    model = model_cls(hparams)

    model_chkpt = ModelCheckpoint(
        dirpath=f"./data/checkpoints/{hparams['name']}/",
        filename="model",
        monitor="val/loss",
        mode="min",
        verbose=True,
    )
    trainer = pl.Trainer(
        logger=TensorBoardLogger(
            save_dir="./data/logs",
            name=hparams["name"],
            default_hp_metric=False,
        ),
        callbacks=[model_chkpt],
        num_sanity_val_steps=0,
        log_every_n_steps=1,
        max_epochs=hparams["num_epochs"],
        accelerator=accelerator,
    )

    trainer.fit(model=model, datamodule=datamodule)


@torch.no_grad()
def extract_embeddings(
    model_cls: Type[pl.LightningModule],
    name: str,
    datamodule: DataModule,
):
    best_model = model_cls.load_from_checkpoint(
        checkpoint_path=f"./data/checkpoints/{name}/model.ckpt"
    )
    best_model.eval()

    z = []

    for batch in datamodule.all_dataloader():
        z.append(best_model(batch))

    return torch.cat(z, dim=0)

In [17]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [18]:
default_hparams = {
    "encoder_cls": MultimodalEncoder,
    "modality_names": ["img_emb", "text_emb"],
    "data_dims": {"img_emb": 2048, "text_emb": 384}, 
    "batch_size": 64,
    "num_epochs": 30,
    "hidden_dims": [256, 256, 256],
    "emb_dim": 128,
    "lr": 1e-3,
    "weight_decay": 5e-4,
}

In [33]:
datamodule = DataModule(
    train_path="data/cub/preprocessed_train.pkl",
    test_path="data/cub/preprocessed_test.pkl",
    batch_size=default_hparams["batch_size"],
)

In [36]:
train_model(
    model_cls=MultimodalAE,
    hparams={
        "name": "ImageTextAvgAE",
        "fusion": "Avg",
        **default_hparams,
    },
    datamodule=datamodule,
)

Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: ./data/logs/ImageTextAvgAE
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params
----------------------------------------------
0 | encoder | MultimodalEncoder | 952 K 
1 | decoder | MultimodalDecoder | 954 K 
2 | fusion  | AvgFusion         | 0     
----------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.625     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
train_model(
    model_cls=MultimodalAE,
    hparams={
        "name": "ImageTextMLPAE",
        "fusion": "MLP",
        **default_hparams,
    },
    datamodule=datamodule,
)