# Transformer Autoencoder

In [1]:
#@markdown Dependencies.

%%shell

pip -q install \
    transformers \
    tokenizers \
    datasets \
    accelerate \
    evaluate \
    bitsandbytes \
    wandb \
    einops

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/536.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━[0m [32m501.8/536.6 kB[0m [31m14.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.6/536.6 kB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/280.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.0/280.0 kB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 MB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m86.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m



In [2]:
!pip install pyarrow==11.0.0

Collecting pyarrow==11.0.0
  Downloading pyarrow-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.9/34.9 MB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow
  Attempting uninstall: pyarrow
    Found existing installation: pyarrow 15.0.0
    Uninstalling pyarrow-15.0.0:
      Successfully uninstalled pyarrow-15.0.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 2.17.0 requires pyarrow>=12.0.0, but you have pyarrow 11.0.0 which is incompatible.[0m[31m
[0mSuccessfully installed pyarrow-11.0.0


In [133]:
#@markdown Model.

from typing import List, Tuple
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat, reduce


class LinearUpsample(nn.Module):
    """Linear upsample.

    Example
    -------
    >>> module = LinearUpsample(embedding_dimension=256)
    >>> x = torch.randn((1, 5, 256))
    >>> x = module(x)  # Shape: (1, 10, 256).
    """

    def __init__(self, *, embedding_dimension: int) -> None:
        """Initialize the module.

        Parameters
        ----------
        embedding_dimension : int
            The embedding dimension.
        """

        super().__init__()

        self.linear = nn.Linear(
            in_features=embedding_dimension,
            out_features=embedding_dimension * 2,
            bias=False,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward the module.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        """

        x = self.linear(x)
        x = rearrange(x, '... t (n e) -> ... (n t) e', n=2)

        return x


class LinearDownsample(nn.Module):
    """Linear downsample.

    Example
    -------
    >>> module = LinearDownsample(embedding_dimension=256)
    >>> x = torch.randn((1, 10, 256))
    >>> x = module(x)  # Shape: (1, 5, 256).
    """

    def __init__(self, *, embedding_dimension: int) -> None:
        """Initialize the module.

        Parameters
        ----------
        embedding_dimension : int
            The embedding dimension.
        """

        super().__init__()

        self.linear = nn.Linear(
            in_features=embedding_dimension * 2,
            out_features=embedding_dimension,
            bias=False,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward the module.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        """

        x = rearrange(x, '... (n t) e -> ... t (n e)', n=2)
        x = self.linear(x)

        return x


class RoPE(nn.Module):
    """Rotary positional embedding (RoPE).

    Rotary positional embedding (Su et al., 2023) rotates keys and queries by
    their absolute position such that their dot product depends only on their
    content and *relative position*. Generalized to arbitrary dimensions, RoPE
    divides a D-dimensional space into D//2 subspaces.

    Example
    -------
    >>> module = RoPE(embedding_dimension=256, base=10_000)
    >>> q = torch.randn((1, 10, 256))
    >>> k = torch.randn((1, 10, 256))
    >>> alignment = torch.einsum('bte,bse->bts', module(q), module(k))
    """

    def __init__(self, *, embedding_dimension: int, base: int) -> None:
        """Initialize the module.

        Parameters
        ----------
        embedding_dimension : int
            The embedding dimension.
        base : int
            The base to use for absolute positional encodings.
        """

        super().__init__()

        self.embedding_dimension = embedding_dimension
        self.base = base

        # Precompute theta.

        exponent = torch.arange(
            start=0,
            end=embedding_dimension,
            step=2,
            dtype=torch.float,
        ) / embedding_dimension

        theta = 1. / torch.pow(base, exponent)

        self.theta = theta

    def absolute_positional_encoding(self, x: torch.Tensor) -> torch.Tensor:
        """Perform absolute positional encoding.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        encoding : torch.Tensor
            The absolute positional encoding.
        """

        if self.theta.device != x.device:
            self.theta = self.theta.to(x.device)

        encoding = torch.einsum(
            't,e->te',
            torch.arange(x.size(-2), dtype=torch.float, device=x.device),
            self.theta,
        )

        encoding = repeat(encoding, '... e -> ... (e n)', n=2)

        return encoding

    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        """Rotate each subspace by -90 degrees."""

        x = rearrange(x, '... (e n) -> ... e n', n=2)
        x1, x2 = x.unbind(dim=-1)
        x = torch.stack((-x2, x1), dim=-1)
        x = rearrange(x, '... e n -> ... (e n)')

        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Foward pass."""

        encoding = self.absolute_positional_encoding(x)
        x = x * encoding.cos() + (self.rotate_half(x) * encoding.sin())

        return x


def linear_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
) -> torch.Tensor:

    score = F.softmax(torch.einsum('bhnk,bhnc->bhkc', k/4, v), dim=-1)
    x = F.softmax(torch.einsum('bhnk,bhkc->bhnk', q/4, score), dim=-1)

    return x


class Attention(nn.Module):
    """Attention.

    Example
    -------
    >>> module = Attention(
    ...    embedding_dimension=256,
    ...    heads=16,
    ... )
    >>> x = torch.randn((1, 10, 256))
    >>> x = module(x)
    """

    def __init__(
        self,
        *,
        embedding_dimension: int,
        heads: int,
    ) -> None:
        """Initialize the module.

        Parameters
        ----------
        embedding_dimension : int
            The embedding dimension.
        heads : int
            The number of heads.
        """

        super().__init__()

        self.heads = heads

        self.linear_1 = nn.Linear(
            in_features=embedding_dimension,
            out_features=embedding_dimension * 3,
            bias=False,
        )

        self.linear_2 = nn.Linear(
            in_features=embedding_dimension,
            out_features=embedding_dimension,
            bias=False,
        )

        self.rope = RoPE(
            embedding_dimension=embedding_dimension // heads,
            base=10_000,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        q, k, v = rearrange(self.linear_1(x), 'b s (n h e) -> n b h s e', n=3, h=self.heads)
        q, k = self.rope(q), self.rope(k)
        x = linear_attention(q, k, v) #F.scaled_dot_product_attention(q, k, v)
        x = self.linear_2(rearrange(x, 'b h s e -> b s (h e)'))

        return x


# class Attention(nn.Module):
#     """Attention.

#     Example
#     -------
#     >>> module = Attention(
#     ...     embedding_dimension=256,
#     ...     heads=16,
#     ... )
#     >>> x = torch.randn((1, 10, 256))
#     >>> x = module(x)  # Shape: (1, 10, 256).
#     """

#     def __init__(self, *, embedding_dimension: int, heads: int) -> None:
#         """Initialize the module.

#         Parameters
#         ----------
#         embedding_dimension : int
#             The embedding dimension.
#         heads : int
#             The number of heads.
#         """

#         super().__init__()

#         self.heads = heads

#         self.linear_1 = nn.Linear(
#             in_features=embedding_dimension,
#             out_features=embedding_dimension * 3,
#             bias=False,
#         )

#         self.linear_2 = nn.Linear(
#             in_features=embedding_dimension,
#             out_features=embedding_dimension,
#             bias=False,
#         )

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         """Forward the module.

#         Parameters
#         ----------
#         x : torch.Tensor
#             The input tensor.

#         Returns
#         -------
#         x : torch.Tensor
#             The output tensor.
#         """

#         x = self.linear_1(x)
#         q, k, v = rearrange(x, 'b t (n h e) -> n b h t e', n=3, h=self.heads)
#         x = F.scaled_dot_product_attention(q, k, v)
#         x = self.linear_2(rearrange(x, 'b h t e -> b t (h e)'))

#         return x


class ResidualBlock(nn.Module):
    """Residual block.

    Example
    -------
    >>> module = Attention(
    ...     embedding_dimension=256,
    ...     heads=16,
    ... )
    >>> x = torch.randn((1, 10, 256))
    >>> x = module(x)  # Shape: (1, 10, 256).
    """

    def __init__(self, *, embedding_dimension: int, heads: int) -> None:
        """Initialize the module.

        Parameters
        ----------
        embedding_dimension : int
            The embedding dimension.
        heads : int
            The number of heads.
        """

        super().__init__()

        self.attention = Attention(
            embedding_dimension=embedding_dimension,
            heads=heads,
        )

        self.mlp = nn.Sequential(
            nn.Linear(
                in_features=embedding_dimension,
                out_features=embedding_dimension * 3,
            ),
            nn.SiLU(),
            nn.Linear(
                in_features=embedding_dimension * 3,
                out_features=embedding_dimension,
            ),
        )

        self.layer_norm_1 = nn.LayerNorm(normalized_shape=embedding_dimension)
        self.layer_norm_2 = nn.LayerNorm(normalized_shape=embedding_dimension)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward the module.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        """

        x = x + self.attention(self.layer_norm_1(x))
        x = x + self.mlp(self.layer_norm_1(x))

        return x


class UpBlock(nn.Module):
    """Up block.

    Example
    -------
    >>> module = UpBlock(
    ...     embedding_dimension=256,
    ...     heads=16,
    ... )
    >>> x = torch.randn((1, 5, 256))
    >>> x = module(x)  # Shape: (1, 10, 256).
    """

    def __init__(self, *, embedding_dimension: int, heads: int) -> None:
        """Initialize the module.

        Parameters
        ----------
        embedding_dimension : int
            The embedding dimension.
        heads : int
            The number of heads.
        """

        super().__init__()

        self.linear_upsample = LinearUpsample(
            embedding_dimension=embedding_dimension,
        )

        self.residual_block_1 = ResidualBlock(
            embedding_dimension=embedding_dimension,
            heads=heads,
        )

        self.residual_block_2 = ResidualBlock(
            embedding_dimension=embedding_dimension,
            heads=heads,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward the module.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        """

        x = self.linear_upsample(x)
        x = self.residual_block_1(x)
        x = self.residual_block_2(x)

        return x


class DownBlock(nn.Module):
    """Down block.

    Example
    -------
    >>> module = DownBlock(
    ...     embedding_dimension=256,
    ...     heads=16,
    ... )
    >>> x = torch.randn((1, 10, 256))
    >>> x = module(x)  # Shape: (1, 5, 256).
    """

    def __init__(self, *, embedding_dimension: int, heads: int) -> None:
        """Initialize the module.

        Parameters
        ----------
        embedding_dimension : int
            The embedding dimension.
        heads : int
            The number of heads.
        """

        super().__init__()

        self.linear_downsample = LinearDownsample(
            embedding_dimension=embedding_dimension,
        )

        self.residual_block_1 = ResidualBlock(
            embedding_dimension=embedding_dimension,
            heads=heads,
        )

        self.residual_block_2 = ResidualBlock(
            embedding_dimension=embedding_dimension,
            heads=heads,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward the module.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        """

        x = self.linear_downsample(x)
        x = self.residual_block_1(x)
        x = self.residual_block_2(x)

        return x


class Quantizer(nn.Module):
    """Quantizer.

    Example
    -------
    >>> module = Quantizer(
    ...     embedding_dimension=256,
    ...     quantizer_dimension=4,
    ...     quantizer_bits=5,
    ... )
    >>> x = torch.randn((1, 256, 10))
    >>> x = module.encode(x)  # Shape: (1, 4, 10).
    >>> x = module.decode(x)  # Shape: (1, 256, 10).
    """

    def __init__(
        self,
        *,
        embedding_dimension: int,
        quantizer_dimension: int,
        quantizer_bits: int,
    ) -> None:
        """Initialize the module.

        Parameters
        ----------
        embedding_dimension : int
            The embedding dimension.
        latent_dimension : int
            The latent dimension.
        vocabulary_size : int
            The vocabulary size.
        """

        super().__init__()

        self.scale = (2 ** quantizer_bits) // 2

        self.encoder = nn.Sequential(
            nn.Linear(
                in_features=embedding_dimension,
                out_features=quantizer_dimension,
            ),
            nn.Tanh(),
        )

        self.decoder = nn.Sequential(
            nn.Linear(
                in_features=quantizer_dimension,
                out_features=embedding_dimension,
            ),
            nn.LeakyReLU(),
            nn.LayerNorm(normalized_shape=embedding_dimension),
        )

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode a tensor.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        """

        x = self.scale * self.encoder(x)
        x = x + (x.floor() - x).detach()

        return x

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        """Decode a tensor.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        """

        x = self.decoder(x)

        return x

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward the module.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        z : torch.Tensor
            The output latent tensor.
        """

        z = self.encode(x)
        x = self.decode(z)

        return x, z


@dataclass(frozen=True)
class AutoencoderConfiguration:
    input_dimension: int
    embedding_dimension: int
    quantizer_dimension: int
    quantizer_bits: int
    heads: int
    layers: int


class Autoencoder(nn.Module):
    """Autoencoder.

    Example
    -------
    >>> configuration = AutoencoderConfiguration(
    ...     input_dimension=3,
    ...     embedding_dimension=256,
    ...     quantizer_dimension=4,
    ...     quantizer_bits=5,
    ...     heads=16,
    ...     layers=3,
    ... )
    >>> module = Autoencoder(configuration=configuration)
    >>> x = torch.randn((1, 1024, 3))
    >>> z = module.encode(x)  # Shape: (1, 128, 4).
    >>> x = module.decode(z)  # Shape: (1, 1024, 3).
    """

    def __init__(self, *, configuration: AutoencoderConfiguration) -> None:
        """Initialize the module.

        Parameters
        ----------
        configuration : AutoencoderConfiguration
            The module configuration.
        """

        super().__init__()

        self.embedding = nn.Linear(
            in_features=configuration.input_dimension,
            out_features=configuration.embedding_dimension,
            bias=False,
        )

        self.unembedding = nn.Linear(
            in_features=configuration.embedding_dimension,
            out_features=configuration.input_dimension,
            bias=False,
        )

        self.encoder = nn.Sequential(*[
            DownBlock(
                embedding_dimension=configuration.embedding_dimension,
                heads=configuration.heads,
            ) for _ in range(configuration.layers)
        ])

        self.decoder = nn.Sequential(*[
            UpBlock(
                embedding_dimension=configuration.embedding_dimension,
                heads=configuration.heads,
            ) for _ in range(configuration.layers)
        ])

        self.quantizer = Quantizer(
            embedding_dimension=configuration.embedding_dimension,
            quantizer_dimension=configuration.quantizer_dimension,
            quantizer_bits=configuration.quantizer_bits,
        )

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode a tensor.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        """

        x = self.embedding(x)
        x = self.encoder(x)
        x = self.quantizer.encode(x)

        return x

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        """Decode a tensor.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        """

        x = self.quantizer.decode(x)
        x = self.decoder(x)
        x = self.unembedding(x)
        x = torch.sigmoid(x)

        return x

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward the module.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        """

        z = self.encode(x)
        x = self.decode(z)

        return x, z

In [5]:
#@markdown Dataset.

from typing import Dict
from datasets import load_dataset
from torchvision import transforms

resolution = 32

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((resolution, resolution), 0),
])


def preprocess(examples: Dict) -> Dict:

    return {
        'image': [transform(image) for image in examples['image']]
    }


dataset = load_dataset('mnist', split='train')
dataset.set_transform(preprocess)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading data:   0%|          | 0.00/15.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.60M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [118]:
#@markdown Trainer.

from IPython.display import display, clear_output, HTML
from torchvision.utils import save_image


@dataclass(frozen=True)
class TrainerConfiguration:
    steps: int
    epochs: int
    batch_size: int
    batches_per_step: int
    batches_per_log: int
    batches_per_sample: int
    batches_per_checkpoint: int
    dataloader_workers: int
    input_column: str
    sample_path: str
    checkpoint_path: str


@dataclass(frozen=True)
class Trainer:
    configuration: TrainerConfiguration

    def log(self, epoch: int, batch: int, step: int, loss: float) -> None:
        """Perform logging."""

        percent = int(round(100 * step / self.configuration.steps))

        clear_output(wait=True)
        display(HTML(f'<code>({percent:03d}%) epoch: {epoch:06d}, batch: {batch:06d}, step: {step:06d} - loss: {loss:0.6f}</code>'))

    def train(
        self,
        model: Autoencoder,
        optimizer: torch.optim.Optimizer,
        dataset: torch.utils.data.Dataset,
        device: str,
    ) -> None:
        """Train a model."""

        model = model.to(device)

        dataloader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=self.configuration.batch_size,
            shuffle=True,
            num_workers=self.configuration.dataloader_workers,
        )

        step = 0

        for epoch in range(self.configuration.epochs):
            for batch, examples in enumerate(dataloader):

                input = examples[self.configuration.input_column].to(device)
                B, C, H, W = input.shape
                input = rearrange(input, 'b c h w -> b (h w) c')
                reconstruction, latent = model(input)
                reconstruction = rearrange(reconstruction, 'b (h w) c -> b c h w', h=H)
                input = rearrange(input, 'b (h w) c -> b c h w', h=H)

                loss = F.binary_cross_entropy(reconstruction, input)
                loss = loss / self.configuration.batches_per_step
                loss.backward()

                if (batch + 1) % self.configuration.batches_per_step == 0:
                    optimizer.step()
                    optimizer.zero_grad()

                    step += 1

                    if step == self.configuration.steps:
                        return

                if (batch + 1) % self.configuration.batches_per_log == 0:
                    self.log(
                        epoch=epoch,
                        batch=batch,
                        step=step,
                        loss=loss.detach().item(),
                    )

                if (batch + 1) % self.configuration.batches_per_sample == 0:
                    save_image(
                        reconstruction,
                        f'{self.configuration.sample_path}/reconstruction.png',
                    )

                    save_image(
                        input,
                        f'{self.configuration.sample_path}/input.png',
                    )

In [134]:
trainer = Trainer(
    configuration=TrainerConfiguration(
        steps=500,
        epochs=5,
        batch_size=1,
        batches_per_step=8,
        batches_per_log=8,
        batches_per_sample=64,
        batches_per_checkpoint=256,
        dataloader_workers=2,
        input_column='image',
        sample_path='./samples',
        checkpoint_path='./checkpoints',
    ),
)

In [135]:
model = Autoencoder(
    configuration=AutoencoderConfiguration(
        input_dimension=1,
        embedding_dimension=32,
        quantizer_dimension=4,
        quantizer_bits=5,
        heads=16,
        layers=3,
    ),
)

In [136]:
import bitsandbytes as bnb

optimizer = bnb.optim.adam.Adam8bit(model.parameters(), lr=1e-3)

In [138]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

!mkdir ./samples
!mkdir ./checkpoints

trainer.train(
    model=model,
    optimizer=optimizer,
    dataset=dataset,
    device=device,
)