<a href="https://colab.research.google.com/github/oelin/sprite-diffusion/blob/main/Sprite_Diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Sprite Diffusion

In [None]:
#@markdown Install dependencies.

!pip -q install einops datasets timm labml labml_helpers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.8/110.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m388.9/388.9 kB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━

In [None]:
#@markdown Implement the noise schedule and sampler.

from dataclasses import dataclass
import math

import torch
import torch.nn as nn


def explicit_broadcast(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Explicitly broadcast x to y."""

    axes = (1,)*max(len(y.shape) - len(x.shape), 0)

    return x.view(x.shape + axes)


class DDPMSchedule:

    def __init__(self, start: float, end: float, steps: int) -> None:

        self.start = start
        self.end = end
        self.steps = steps
        self.beta = torch.linspace(start, end, steps)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

    def forward(
        self,
        x: torch.Tensor,
        noise: torch.Tensor,
        step: torch.Tensor,
    ) -> torch.Tensor:
        """Sample from the forward process."""

        alpha_bar = self.alpha_bar.to(x.device)[step]
        a = explicit_broadcast(torch.sqrt(alpha_bar), x)
        b = explicit_broadcast(torch.sqrt(1 - alpha_bar), x)

        return a*x + b*noise


@torch.no_grad()
def sample_ddpm(
    model: nn.Module,
    schedule: DDPMSchedule,
    x: torch.Tensor,
) -> torch.Tensor:
    """Sample from the model."""

    for step in reversed(range(schedule.steps)):

        z = torch.randn_like(x).to(x.device) if (step > 0) else 0.

        alpha = schedule.alpha[step].item()
        alpha_bar = schedule.alpha_bar[step].item()
        beta = 1 - alpha
        noise = model(x, torch.full((x.size(0),), step).to(x.device))

        x = 1/math.sqrt(alpha) * (x - beta/math.sqrt(1-alpha_bar)*noise) + math.sqrt(beta)*z

    return x

In [None]:
#@markdown Implement the trainer.

from dataclasses import dataclass
import gc

import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm

from torchvision.utils import save_image


def save_samples(x):

    for i, xi in enumerate(x):
        xi = (xi - xi.min()) / (xi.max() - xi.min())  # Normalize
        transforms.ToPILImage()(xi).resize((256, 256), 0).save(f'./sample-{i:03d}.png')

@dataclass
class Trainer:

    effective_batch_size: int
    sample_steps: int
    checkpoint_steps: int

    def train(
        self,
        model: nn.Module,
        schedule: DDPMSchedule,
        optimizer: torch.optim.Optimizer,
        dataloader: torch.utils.data.DataLoader,
        x_column: str,
        steps: int,
        device: str,
    ) -> None:
        """Train a model."""

        torch.cuda.empty_cache()
        gc.collect()

        model.to(device)
        model.train()
        scaler = torch.cuda.amp.GradScaler()
        # sampler = DDPMSampler()
        # mixed_precision = torch.float16 if device == 'cuda' else torch.bfloat16

        batch_size = dataloader.batch_size
        batches_per_step = max(1, self.effective_batch_size // batch_size)

        # Training loop.

        global_step = 0
        bar = tqdm(total=steps)
        bar.set_description(f'loss: (pending)')

        while global_step < steps:
            for batch, examples in enumerate(dataloader):

                if len(examples[x_column]) != batch_size:
                    continue

                x = examples[x_column].to(device)
                noise = torch.randn_like(x, device=device)
                step = torch.randint(low=0, high=schedule.steps, size=(batch_size,), device=device)
                # x_step = get_noised_example(x, noise, step, device)
                x_step = schedule.forward(x, noise, step)

                # with torch.autocast(device, dtype=mixed_precision):
                loss = F.mse_loss(model(x_step, step), noise) / batches_per_step
                scaler.scale(loss).backward()
                # loss.backward()

                # Events.

                if (batch + 1) % batches_per_step == 0:

                    scaler.step(optimizer)
                    scaler.update()
                    # optimizer.step()
                    optimizer.zero_grad()

                    global_step += 1
                    bar.set_description(f'loss: {loss.detach().item() * batches_per_step:0.6f}')
                    bar.update(1)

                if self.sample_steps and (global_step + 1) % self.sample_steps == 0:

                    x_shape = list(x.shape)
                    x_shape[0] = 4
                    x_T = torch.randn(x_shape).to(device)
                    x_0 = sample_ddpm(model, schedule, x_T)

                    save_samples(x_0)

                    del x_T
                    del x_0

                    torch.cuda.empty_cache()
                    gc.collect()

                if self.checkpoint_steps and (global_step + 1) % self.checkpoint_steps == 0:

                    torch.save(model.state_dict(), './checkpoint.pt')

        # Save model.

        torch.save(model.state_dict(), './model.pt')

        # Clean up.

        bar.close()
        model.eval()
        torch.cuda.empty_cache()
        gc.collect()


In [None]:
#@markdown Implement the model (imported from labml).

import math
from typing import Optional, Tuple, Union, List

import torch
from torch import nn

from labml_helpers.module import Module


class Swish(Module):
    """
    ### Swish activation function

    $$x \cdot \sigma(x)$$
    """

    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    """
    ### Embeddings for $t$
    """

    def __init__(self, n_channels: int):
        """
        * `n_channels` is the number of dimensions in the embedding
        """
        super().__init__()
        self.n_channels = n_channels
        # First linear layer
        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
        # Activation
        self.act = Swish()
        # Second linear layer
        self.lin2 = nn.Linear(self.n_channels, self.n_channels)

    def forward(self, t: torch.Tensor):
        # Create sinusoidal position embeddings
        # [same as those from the transformer](../../transformers/positional_encoding.html)
        #
        # \begin{align}
        # PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
        # PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
        # \end{align}
        #
        # where $d$ is `half_dim`
        half_dim = self.n_channels // 8
        emb = math.log(10_000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)

        # Transform with the MLP
        emb = self.act(self.lin1(emb))
        emb = self.lin2(emb)

        #
        return emb


class ResidualBlock(Module):
    """
    ### Residual block

    A residual block has two convolution layers with group normalization.
    Each resolution is processed with two residual blocks.
    """

    def __init__(self, in_channels: int, out_channels: int, time_channels: int,
                 n_groups: int = 32, dropout: float = 0.1):
        """
        * `in_channels` is the number of input channels
        * `out_channels` is the number of input channels
        * `time_channels` is the number channels in the time step ($t$) embeddings
        * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
        * `dropout` is the dropout rate
        """
        super().__init__()
        # Group normalization and the first convolution layer
        self.norm1 = nn.GroupNorm(n_groups, in_channels)
        self.act1 = Swish()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

        # Group normalization and the second convolution layer
        self.norm2 = nn.GroupNorm(n_groups, out_channels)
        self.act2 = Swish()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

        # If the number of input channels is not equal to the number of output channels we have to
        # project the shortcut connection
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
        else:
            self.shortcut = nn.Identity()

        # Linear layer for time embeddings
        self.time_emb = nn.Linear(time_channels, out_channels)
        self.time_act = Swish()

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        * `x` has shape `[batch_size, in_channels, height, width]`
        * `t` has shape `[batch_size, time_channels]`
        """
        # First convolution layer
        h = self.conv1(self.act1(self.norm1(x)))
        # Add time embeddings
        h += self.time_emb(self.time_act(t))[:, :, None, None]
        # Second convolution layer
        h = self.conv2(self.dropout(self.act2(self.norm2(h))))

        # Add the shortcut connection and return
        return h + self.shortcut(x)


class AttentionBlock(Module):
    """
    ### Attention block

    This is similar to [transformer multi-head attention](../../transformers/mha.html).
    """

    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
        """
        * `n_channels` is the number of channels in the input
        * `n_heads` is the number of heads in multi-head attention
        * `d_k` is the number of dimensions in each head
        * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
        """
        super().__init__()

        # Default `d_k`
        if d_k is None:
            d_k = n_channels
        # Normalization layer
        self.norm = nn.GroupNorm(n_groups, n_channels)
        # Projections for query, key and values
        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
        # Linear layer for final transformation
        self.output = nn.Linear(n_heads * d_k, n_channels)
        # Scale for dot-product attention
        self.scale = d_k ** -0.5
        #
        self.n_heads = n_heads
        self.d_k = d_k

    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
        """
        * `x` has shape `[batch_size, in_channels, height, width]`
        * `t` has shape `[batch_size, time_channels]`
        """
        # `t` is not used, but it's kept in the arguments because for the attention layer function signature
        # to match with `ResidualBlock`.
        _ = t
        # Get shape
        batch_size, n_channels, height, width = x.shape
        # Change `x` to shape `[batch_size, seq, n_channels]`
        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
        # Get query, key, and values (concatenated) and shape it to `[batch_size, seq, n_heads, 3 * d_k]`
        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
        # Split query, key, and values. Each of them will have shape `[batch_size, seq, n_heads, d_k]`
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        # Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
        # Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
        attn = attn.softmax(dim=2)
        # Multiply by values
        res = torch.einsum('bijh,bjhd->bihd', attn, v)
        # Reshape to `[batch_size, seq, n_heads * d_k]`
        res = res.view(batch_size, -1, self.n_heads * self.d_k)
        # Transform to `[batch_size, seq, n_channels]`
        res = self.output(res)

        # Add skip connection
        res += x

        # Change to shape `[batch_size, in_channels, height, width]`
        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)

        #
        return res


class DownBlock(Module):
    """
    ### Down block

    This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution.
    """

    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
        super().__init__()
        self.res = ResidualBlock(in_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res(x, t)
        x = self.attn(x)
        return x


class UpBlock(Module):
    """
    ### Up block

    This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution.
    """

    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
        super().__init__()
        # The input has `in_channels + out_channels` because we concatenate the output of the same resolution
        # from the first half of the U-Net
        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res(x, t)
        x = self.attn(x)
        return x


class MiddleBlock(Module):
    """
    ### Middle block

    It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`.
    This block is applied at the lowest resolution of the U-Net.
    """

    def __init__(self, n_channels: int, time_channels: int):
        super().__init__()
        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
        self.attn = AttentionBlock(n_channels)
        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res1(x, t)
        x = self.attn(x)
        x = self.res2(x, t)
        return x


class Upsample(nn.Module):
    """
    ### Scale up the feature map by $2 \times$
    """

    def __init__(self, n_channels):
        super().__init__()
        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        # `t` is not used, but it's kept in the arguments because for the attention layer function signature
        # to match with `ResidualBlock`.
        _ = t
        return self.conv(x)


class Downsample(nn.Module):
    """
    ### Scale down the feature map by $\frac{1}{2} \times$
    """

    def __init__(self, n_channels):
        super().__init__()
        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        # `t` is not used, but it's kept in the arguments because for the attention layer function signature
        # to match with `ResidualBlock`.
        _ = t
        return self.conv(x)


class UNet(Module):
    """
    ## U-Net
    """

    def __init__(self, image_channels: int = 3, n_channels: int = 64,
                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
                 is_attn: Union[Tuple[bool, ...], List[bool]] = (False, False, True, True),
                 n_blocks: int = 2):
        """
        * `image_channels` is the number of channels in the image. $3$ for RGB.
        * `n_channels` is number of channels in the initial feature map that we transform the image into
        * `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels`
        * `is_attn` is a list of booleans that indicate whether to use attention at each resolution
        * `n_blocks` is the number of `UpDownBlocks` at each resolution
        """
        super().__init__()

        # Number of resolutions
        n_resolutions = len(ch_mults)

        # Project image into feature map
        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))

        # Time embedding layer. Time embedding has `n_channels * 4` channels
        self.time_emb = TimeEmbedding(n_channels * 4)

        # #### First half of U-Net - decreasing resolution
        down = []
        # Number of channels
        out_channels = in_channels = n_channels
        # For each resolution
        for i in range(n_resolutions):
            # Number of output channels at this resolution
            out_channels = in_channels * ch_mults[i]
            # Add `n_blocks`
            for _ in range(n_blocks):
                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
                in_channels = out_channels
            # Down sample at all resolutions except the last
            if i < n_resolutions - 1:
                down.append(Downsample(in_channels))

        # Combine the set of modules
        self.down = nn.ModuleList(down)

        # Middle block
        self.middle = MiddleBlock(out_channels, n_channels * 4, )

        # #### Second half of U-Net - increasing resolution
        up = []
        # Number of channels
        in_channels = out_channels
        # For each resolution
        for i in reversed(range(n_resolutions)):
            # `n_blocks` at the same resolution
            out_channels = in_channels
            for _ in range(n_blocks):
                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
            # Final block to reduce the number of channels
            out_channels = in_channels // ch_mults[i]
            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
            in_channels = out_channels
            # Up sample at all resolutions except last
            if i > 0:
                up.append(Upsample(in_channels))

        # Combine the set of modules
        self.up = nn.ModuleList(up)

        # Final normalization and convolution layer
        self.norm = nn.GroupNorm(8, n_channels)
        self.act = Swish()
        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        * `x` has shape `[batch_size, in_channels, height, width]`
        * `t` has shape `[batch_size]`
        """

        # Get time-step embeddings
        t = self.time_emb(t)

        # Get image projection
        x = self.image_proj(x)

        # `h` will store outputs at each resolution for skip connection
        h = [x]
        # First half of U-Net
        for m in self.down:
            x = m(x, t)
            h.append(x)

        # Middle (bottom)
        x = self.middle(x, t)

        # Second half of U-Net
        for m in self.up:
            if isinstance(m, Upsample):
                x = m(x, t)
            else:
                # Get the skip connection from first half of U-Net and concatenate
                s = h.pop()
                x = torch.cat((x, s), dim=1)
                #
                x = m(x, t)

        # Final normalization and convolution
        return self.final(self.act(self.norm(x)))

In [None]:
#@markdown Load the dataset.

from datasets import load_dataset
import torch
from torchvision import transforms


dataset = load_dataset('mwkldeveloper/sprites_1788_16', split='train')

transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Resize((32, 32), 0)
])


def process(examples):
    return {
        'image': [transform(image) for image in examples['images']]
    }

dataset.set_transform(process)

In [None]:
#@markdown Train the model.

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# schedule = DDPMSchedule(start=1e-4, end=1/20, steps=100)
schedule = DDPMSchedule(start=1e-4, end=0.02, steps=1000) # original steps=1000

# model = DiT(
#     input_size=32,
#     patch_size=2,
#     in_channels=1,  # 3
#     hidden_size=256,
#     depth=16,
#     num_heads=16,
#     mlp_ratio=3,
#     num_classes=10,
#     learn_sigma=False,
# )

model = UNet(
    image_channels=3,
)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=16, shuffle=True, num_workers=4)  # original: bs = 16

trainer = Trainer(
    effective_batch_size=16,#16,
    sample_steps=512, #512,
    checkpoint_steps=1024,
)

trainer.train(
    model=model,
    schedule=schedule,
    optimizer=optimizer,
    dataloader=dataloader,
    x_column='image',
    steps=50_000,
    device=device,
)



  0%|          | 0/50000 [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


KeyboardInterrupt: 

In [None]:
#@markdown Save the model.

drive.mount('/content/drive')
torch.save(model.state_dict(), '/content/drive/MyDrive/sprite-diffusion-40k-steps.pt')