In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Sampler:
    def __init__(self, num_steps=1000, beta_start=0.0001, beta_end=0.02):
        self.num_steps = num_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.beta_schedule = self.linear_beta_schedule()
        self.alpha = 1 - self.beta_schedule
        self.alpha_cummulative_prod = torch.cumprod(self.alpha, dim=-1)

    def linear_beta_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.num_steps)

    def _repeated_unsqueeze(self, target, tensor):
        while target.dim() > tensor.dim():
            tensor = tensor.unsqueeze(-1)
        return tensor

    def add_noise(self, image, timesteps):
        batch_size, c, h, w = image.shape
        device = image.device
        alpha_cummulative_prod_timesteps = self.alpha_cummulative_prod[timesteps].to(
            device
        )
        mean_coeff = alpha_cummulative_prod_timesteps**0.5
        var_coeff = (1 - alpha_cummulative_prod_timesteps) ** 0.5
        mean_coeff = self._repeated_unsqueeze(image, mean_coeff)
        var_coeff = self._repeated_unsqueeze(image, var_coeff)
        noise = torch.randn_like(image)
        """print(mean_coeff.shape)
        print(image.shape)"""
        noisy_image = mean_coeff * image + var_coeff * noise
        return noisy_image, noise

    def remove_noise(self, image, timesteps, predicted_noise):
        b, c, h, w = image.shape
        device = image.device
        equal_to_zero_mask = timesteps == 0
        beta_t = self.beta_schedule[timesteps].to(device)
        alpha_t = self.alpha[timesteps].to(device)
        alpha_cummulative_prod_t = self.alpha_cummulative_prod[timesteps].to(device)
        alpha_cummulative_prod_t_prev = self.alpha_cummulative_prod[timesteps - 1].to(
            device
        )
        alpha_cummulative_prod_t_prev[equal_to_zero_mask] = (
            1.0  # @QUESTION: this line of code looks weird
        )
        noise = torch.randn_like(
            image
        )  # This is element z in line 4 in Algorithm 2 Sampling
        variance = (
            beta_t
            * (1 - alpha_cummulative_prod_t_prev)
            / (1 - alpha_cummulative_prod_t)
        )  # This is element beta_t_hat in formula (7)
        variance = self._repeated_unsqueeze(image, variance)
        sigma_t_z = (
            variance**0.5
        ) * noise  # This is element sigma * z in line 4 in Algorithm 2 Sampling
        noise_coff = (
            beta_t / (1 - alpha_cummulative_prod_t) ** 0.5
        )  # This is an element in line 4 in Algorithm 2 Sampling, in the paper, they write beta_t in form of (1 - alpha_t)
        noise_coff = self._repeated_unsqueeze(image, noise_coff)
        reciprocal_root_alpha_t = alpha_t ** (
            -0.5
        )  # This is the first element in Algorithm 2 Sampling
        reciprocal_root_alpha_t = self._repeated_unsqueeze(
            image, reciprocal_root_alpha_t
        )

        # Final formula in Algorithm 2 Sampling
        mean = reciprocal_root_alpha_t * (image - noise_coff * predicted_noise)
        denoised = mean + sigma_t_z

        return denoised


sampler = Sampler()
"""rand = torch.randn(4, 3, 64, 64)
pred_noise = torch.randn_like(rand)
randtime = torch.randint(0, 1000, (4,))
sampler.remove_noise(image=rand, timesteps=randtime, predicted_noise=pred_noise)"""

'rand = torch.randn(4, 3, 64, 64)\npred_noise = torch.randn_like(rand)\nrandtime = torch.randint(0, 1000, (4,))\nsampler.remove_noise(image=rand, timesteps=randtime, predicted_noise=pred_noise)'

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels, num_heads=12, attn_p=0, proj_p=0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = in_channels // num_heads
        self.scale = self.head_dim ** (-0.5)  # 1 / sqrt(d)
        self.query = nn.Linear(in_channels, in_channels)
        self.key = nn.Linear(in_channels, in_channels)
        self.value = nn.Linear(in_channels, in_channels)
        self.attn_p = attn_p
        self.proj = nn.Linear(in_channels, in_channels)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        q = (
            self.query(x)
            .reshape(batch_size, seq_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )
        k = (
            self.key(x)
            .reshape(batch_size, seq_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )
        v = (
            self.value(x)
            .reshape(batch_size, seq_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )
        x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_p)
        x = x.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
        x = self.proj(x)
        x = self.proj_drop(x)
        """print(x.shape)"""
        return x


class MLP(nn.Module):
    def __init__(self, in_channels, mlp_ratio=4, mlp_p=0):
        super().__init__()
        self.fc_1 = nn.Linear(in_channels, in_channels * mlp_ratio)
        self.act = nn.GELU()
        self.drop_1 = nn.Dropout(mlp_p)
        self.fc_2 = nn.Linear(in_channels * mlp_ratio, in_channels)
        self.drop_2 = nn.Dropout(mlp_p)

    def forward(self, x):
        x = self.fc_1(x)
        x = self.act(x)
        x = self.drop_1(x)
        x = self.fc_2(x)
        x = self.drop_2(x)
        return x


class TransformerBlock(nn.Module):
    def __init__(
        self, in_channels, num_heads=4, mlp_ratio=2, proj_p=0, attn_p=0, mlp_p=0
    ):
        super().__init__()
        self.norm_1 = nn.LayerNorm(
            in_channels, eps=1e-6
        )  # @QUESTION: what does eps mean?
        self.attn = SelfAttention(
            in_channels=in_channels, num_heads=num_heads, attn_p=attn_p, proj_p=proj_p
        )
        self.norm_2 = nn.LayerNorm(in_channels, eps=1e-6)
        self.mlp = MLP(in_channels=in_channels, mlp_ratio=mlp_ratio, mlp_p=mlp_p)

    def forward(self, x):
        b, c, h, w = x.shape  # batch_size, channels, height, weight
        x = x.reshape(b, c, h * w).permute(0, 2, 1)  # Swap dim 1 anf dim 2
        x = x + self.attn(self.norm_1(x))
        x = x + self.mlp(self.norm_2(x))
        x = x.permute(0, 2, 1).reshape(b, c, h, w)
        return x


rand = torch.randn(4, 64, 14, 14)
t = TransformerBlock(in_channels=64, num_heads=4)
"""t(rand).shape"""
output = t(rand)

In [4]:
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, time_embed_dim, scaled_time_embed_dim):
        super().__init__()

        # This one is untrainable
        self.inv_freq = nn.Parameter(
            1.0
            / (10000 ** (torch.arange(0, time_embed_dim, 2).float() / time_embed_dim)),
            requires_grad=False,
        )

        # This one is trainable
        self.time_mlp = nn.Sequential(
            nn.Linear(time_embed_dim, scaled_time_embed_dim),
            nn.SiLU(),
            nn.Linear(scaled_time_embed_dim, scaled_time_embed_dim),
            nn.SiLU(),
        )

    def forward(self, timesteps: torch.Tensor):
        timestep_freqs = timesteps.unsqueeze(1) * self.inv_freq.unsqueeze(0)
        embeddings = torch.cat(
            [torch.sin(timestep_freqs), torch.cos(timestep_freqs)], dim=-1
        )
        embeddings = self.time_mlp(embeddings)
        return embeddings


s = SinusoidalTimeEmbedding(time_embed_dim=128, scaled_time_embed_dim=256)
timesteps = torch.tensor([1, 2, 3])
output = s(timesteps)

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, groupnorm_num_groups, time_embed_dim):
        super().__init__()
        self.time_expand = nn.Linear(time_embed_dim, out_channels)
        self.groupnorm_1 = nn.GroupNorm(groupnorm_num_groups, in_channels)
        self.conv_1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            padding="same",
        )
        self.groupnorm_2 = nn.GroupNorm(groupnorm_num_groups, out_channels)
        self.conv_2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            padding="same",
        )
        self.resize_channels = (
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
            if in_channels != out_channels
            else nn.Identity()
        )

    def forward(self, x, time_embeddings):
        residual_connection = x
        time_embeddings = self.time_expand(time_embeddings)
        x = self.groupnorm_1(x)
        x = F.silu(x)  # @QUESTION: why do we use silu?
        x = self.conv_1(x)
        x = x + time_embeddings.unsqueeze(-1).unsqueeze(-1)
        x = self.groupnorm_2(x)
        x = F.silu(x)
        x = self.conv_2(x)
        x = x + self.resize_channels(residual_connection)
        return x


iamges = torch.randn(4, 64, 128, 128)
time_embeddings = torch.randn(4, 256)
rb = ResidualBlock(
    in_channels=64, out_channels=512, groupnorm_num_groups=16, time_embed_dim=256
)
output = rb(iamges, time_embeddings)

In [6]:
class UpSampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding="same",
            ),
        )

    def forward(self, x):
        return self.upsample(x)