In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import tqdm
from torchvision.utils import save_image, make_grid
import math
import einops
from typing import List, Tuple

In [2]:
class MLP(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.ln = nn.Sequential(
            nn.ReLU(),
            nn.Linear(input_dim, output_dim),
        )

    def forward(self, x):
        return self.ln(x)
    
class PositionalEmbedding(nn.Module):
    def __init__(self, T: int, output_dim: int) -> None:
        super().__init__()
        self.output_dim = output_dim
        position = torch.arange(T).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, output_dim, 2) * (-math.log(10000.0) / output_dim))
        pe = torch.zeros(T, output_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor):
        return self.pe[x].reshape(x.shape[0], self.output_dim)
    
# print(PositionalEmbedding(2,4)(torch.tensor([[1], [0], [0]])))
# torch.tensor([[1,2,3], [4,5,6], [7,8,9]])[ torch.tensor([[1], [0], [1]]) ].reshape(3, 3)

In [3]:
class MultiheadAttention(nn.Module):
    def __init__(self, n_heads: int, emb_dim: int, input_dim: int) -> None:
        super().__init__()
        assert emb_dim % n_heads == 0
        head_dim = emb_dim // n_heads
        self.K_W = nn.Parameter(torch.rand(n_heads, input_dim, head_dim))
        self.K_b = nn.Parameter(torch.rand(n_heads, head_dim))
        self.Q_W = nn.Parameter(torch.rand(n_heads, input_dim, head_dim))
        self.Q_b = nn.Parameter(torch.rand(n_heads, head_dim))
        self.V_W = nn.Parameter(torch.rand(n_heads, input_dim, head_dim))
        self.V_b = nn.Parameter(torch.rand(n_heads, head_dim))
        self.O_W = nn.Parameter(torch.rand(n_heads, head_dim, input_dim))
        self.O_b = nn.Parameter(torch.rand(input_dim))
        self.norm = nn.LayerNorm([input_dim])
        self.mlp = nn.Sequential(
            nn.LayerNorm([input_dim]),
            nn.Linear(input_dim, input_dim),
            nn.GELU(),
            nn.Linear(input_dim, input_dim)
        )

    def forward(self, x, t, condition, mask):
        _, input_dim, h, w = x.shape
        x = einops.rearrange(x, "b c h w -> b (h w) c")
        res = x
        res = self.norm(x)
        k = einops.einsum(res, self.K_W, "b size ch, n_h ch h_dim -> b size n_h h_dim")
        k = k + self.K_b
        q = einops.einsum(res, self.Q_W, "b size ch, n_h ch h_dim -> b size n_h h_dim")
        q = q + self.Q_b
        qk = einops.einsum(q, k, "batch s1 n_h h_dim, batch s2 n_h h_dim -> batch n_h s1 s2")
        qk = qk / torch.sqrt(torch.tensor(input_dim, dtype=torch.float))
        qk = qk.softmax(-1)
        v = einops.einsum(res, self.V_W, "b size ch, n_h ch h_dim -> b size n_h h_dim")
        v = v + self.V_b
        res = einops.einsum(qk, v, "batch n_h size size, batch size n_h h_dim -> batch size n_h h_dim")
        res = einops.einsum(res, self.O_W, "batch size n_h h_dim, n_h h_dim ch -> batch size ch")
        res = res + self.O_b
        res = res + x
        res = self.mlp(res) + res
        res = einops.rearrange(res, "b (h w) c -> b c h w", h=h, w=w)
        return res

# (MultiheadAttention(16, 64, 128)(torch.rand((6, 128, 7, 7)), torch.rand((6, 128)))).shape

In [36]:
class ResnetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, is_residual=False, is_debug=False):
        super().__init__()
        self.conv_1 = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
        )
        self.conv_2 = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1)
        )
        self.time_emb = MLP(input_dim=time_emb_dim, output_dim=out_channels)
        self.condition_emb = MLP(input_dim=time_emb_dim, output_dim=out_channels)
        if in_channels != out_channels:
            self.conv_3 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
        else:
            self.conv_3 = nn.Identity()
        self.is_debug = is_debug
        self.is_residual = is_residual

    def forward(self, x, t, condition, mask):
        h = self.conv_1(x)
        if t is None:
            return self.conv_3(x) + self.conv_2(h)
        t = self.time_emb(t)
        # print(mask[:, None].shape)
        condition = self.condition_emb(condition) * mask[:, None]
        batch_size, emb_dim = t.shape 
        t = t.view(batch_size, emb_dim, 1, 1)
        condition = condition.view(batch_size, emb_dim, 1, 1)
        conv_res = self.conv_2(h + t + condition)
        if self.is_residual:
            return self.conv_3(x) + conv_res
        else:
            return conv_res

# print(ResnetBlock(8, 16, 4)(torch.rand(2, 8, 32, 32), torch.rand(2, 4), torch.rand(2,4), torch.zeros(2)).shape )

In [5]:

class DownBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.AvgPool2d(kernel_size=2)
    
    def forward(self, x, t, condition, mask):
        return self.pool(x)
    
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upscale = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
    
    def forward(self, x, t, condition, mask):
        return self.upscale(x)


In [6]:
class SequenceWithTimeEmbedding(nn.Module):
    def __init__(self, blocks):
        super().__init__()
        self.models = nn.ModuleList(blocks)
    
    def forward(self, x, t, cond, mask):
        for model in self.models:
            x = model(x, t, cond, mask)
        return x

In [9]:
class UNet(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            T,
            num_classes,
            steps=(1, 2, 4),
            hid_size = 128,
            attn_steps = [2],
            use_time_emb=True,
            has_residuals=True,
            use_self_attention=False,
            num_resolution_blocks=2,
            is_debug = False
        ):
        super().__init__()

        self.use_time_emb = use_time_emb

        time_emb_dim = hid_size * 4
        self.time_embedding = nn.Sequential(
            PositionalEmbedding(T=T, output_dim=hid_size),
            nn.Linear(hid_size, time_emb_dim),
            nn.ReLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )

        self.cond_embedding = nn.Sequential(
            nn.Embedding(num_classes, hid_size),
            nn.Linear(hid_size, time_emb_dim),
            nn.ReLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )

        self.first_conv = nn.Conv2d(in_channels, steps[0] * hid_size, 3, padding=1)
        self.down_blocks = nn.ModuleList()
        prev_hid_size = steps[0] * hid_size
        for (index, step) in enumerate(steps):
            res_blocks = []
            for block in range(num_resolution_blocks):
                res_blocks.append(
                    ResnetBlock(
                        in_channels=prev_hid_size if block == 0 else step * hid_size,
                        out_channels=step * hid_size,
                        time_emb_dim=time_emb_dim,
                        is_residual=has_residuals
                    )
                )
                if step in attn_steps:
                    res_blocks.append(
                        MultiheadAttention(
                            n_heads=4,
                            emb_dim=step * hid_size,
                            input_dim=step * hid_size
                        )
                    )
            self.down_blocks.append(
                SequenceWithTimeEmbedding(res_blocks)
            )
            if index != len(steps) - 1:
                self.down_blocks.append(DownBlock())
            prev_hid_size = step * hid_size
        if len(attn_steps) > 0:
            self.backbone = SequenceWithTimeEmbedding([
                ResnetBlock(steps[-1] * hid_size, steps[-1] * hid_size, time_emb_dim=time_emb_dim),
                MultiheadAttention(n_heads=4, emb_dim=steps[-1] * hid_size, input_dim=steps[-1] * hid_size),
                ResnetBlock(steps[-1] * hid_size, steps[-1] * hid_size, time_emb_dim=time_emb_dim),
            ])
        else:
            self.backbone = SequenceWithTimeEmbedding([
                ResnetBlock(steps[-1] * hid_size, steps[-1] * hid_size, time_emb_dim=time_emb_dim),
                ResnetBlock(steps[-1] * hid_size, steps[-1] * hid_size, time_emb_dim=time_emb_dim),
            ])

        self.up_blocks = nn.ModuleList()
        reverse_steps = list(reversed(steps))
        for (index, step) in enumerate(reverse_steps):
            res_blocks = []
            for block in range(num_resolution_blocks):
                next_hid_size = reverse_steps[index + 1] * hid_size if index != len(steps) - 1 else step * hid_size
                res_blocks.append(
                    ResnetBlock(
                        in_channels=prev_hid_size * 2 if block == 0 else next_hid_size,
                        out_channels=next_hid_size,
                        time_emb_dim=time_emb_dim,
                        is_residual=has_residuals
                    )
                )
                if step in attn_steps:
                    res_blocks.append(
                        MultiheadAttention(
                            n_heads=4,
                            emb_dim=next_hid_size,
                            input_dim=next_hid_size
                        )
                    )
            self.up_blocks.append(
                SequenceWithTimeEmbedding(res_blocks)
            )
            if index != len(steps) - 1:
                self.up_blocks.append(UpBlock(next_hid_size, next_hid_size))
            prev_hid_size = next_hid_size

        self.is_debug = is_debug
        self.out = nn.Sequential(*[
            nn.GroupNorm(8, steps[0] * hid_size),
            nn.ReLU(),
            nn.Conv2d(in_channels=steps[0] * hid_size, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        ])

    def forward(self, x, t, cond, mask):
        time_emb = self.time_embedding(t)
        cond_emb = self.cond_embedding(cond)

        x = self.first_conv(x)
        hx = []
        for down_block in self.down_blocks:
            x = down_block(x, time_emb, cond_emb, mask)
            if not isinstance(down_block, DownBlock):
                hx.append(x)
        x = self.backbone(x, time_emb, cond_emb, mask)

        ind = len(hx) - 1
        for up_block in self.up_blocks:
            if not isinstance(up_block, UpBlock):
                x = up_block(torch.cat([x, hx[ind]], 1), time_emb, cond_emb, mask)
                ind -= 1
            else:
                x = up_block(x, time_emb, cond_emb, mask)
        x = self.out(x)

        return x
    
# test_unet = UNet(in_channels=1, out_channels=1, T=100, num_classes=10, is_debug=True, use_self_attention=True)
# test_unet(torch.rand(3, 1, 28, 28), torch.randint(1, 100, (3,1)), torch.randint(1,10, (3,1)), torch.ones((3,1)))

In [10]:
import torch
import gc

# Collect garbage
gc.collect()

# Empty the PyTorch cache
torch.cuda.empty_cache()

In [11]:
transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5], std=[0.5])
])

class UnNormalize(torch.nn.Module):
    def __init__(self, mean, std) -> None:
        super().__init__()
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)
    
    def forward(self, tensor):
        mean = self.mean.to(tensor.device)
        std = self.std.to(tensor.device)
        
        # Clone the input tensor to avoid modifying the original tensor
        tensor = tensor.clone()
        
        # Apply the un-normalization
        unnormalized_tensor = tensor * std[:, None, None] + mean[:, None, None]
        
        return unnormalized_tensor
    
transform_to_pil = transforms.Compose([
    # UnNormalize(mean=[0.5], std=[0.5]),
    transforms.ToPILImage()
])

# plt.imshow(transform_to_pil(train_dataset[0][0]), cmap='gray')

In [12]:
train_dataset = torchvision.datasets.MNIST(root="./datasets", download=True, transform=transform_to_tensor)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

test_dataset = torchvision.datasets.MNIST(root="./datasets", download=True, transform=transform_to_tensor, train=False)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, num_workers=2)

In [16]:
import torch.backends
import torch.backends.mps

device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available():
    device = "mps"
    print("USE MPS")
    # test_unet = UNet(in_channels=1, out_channels=1, is_debug=True).to(device)
    # test_unet(train_dataset[0][0].unsqueeze(0).to(device), torch.tensor([[1]], dtype=torch.float, device=device))

# import os
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

USE MPS


In [44]:
class DDPM(nn.Module):
    def __init__(self, T: int, nn_model: nn.Module, p_cond=0.2):
        super().__init__()
        self.T = T
        self.nn_model = nn_model.to(device)
        beta_schedule = torch.linspace(1e-4, 0.02, T + 1, device=device)
        self.p_cond = torch.tensor([p_cond]).to(device)
        alpha_t_schedule = 1 - beta_schedule
        bar_alpha_t_schedule = torch.cumprod(alpha_t_schedule.detach().cpu(), 0).to(device)
        sqrt_bar_alpha_t_schedule = torch.sqrt(bar_alpha_t_schedule)
        sqrt_minus_bar_alpha_t_schedule = torch.sqrt(1 - bar_alpha_t_schedule)
        self.register_buffer("beta_schedule", beta_schedule)
        self.register_buffer("alpha_t_schedule", alpha_t_schedule)
        self.register_buffer("bar_alpha_t_schedule", bar_alpha_t_schedule)
        self.register_buffer("sqrt_bar_alpha_t_schedule", sqrt_bar_alpha_t_schedule)
        self.register_buffer("sqrt_minus_bar_alpha_t_schedule", sqrt_minus_bar_alpha_t_schedule)
        self.criterion = nn.MSELoss()

    def forward(self, imgs: torch.Tensor, conds: torch.Tensor):
        t = torch.randint(low=1, high=self.T+1, size=(imgs.shape[0],), device=device)
        noise = torch.randn_like(imgs, device=device)
        batch_size, channels, width, height = imgs.shape
        noise_imgs = self.sqrt_bar_alpha_t_schedule[t].view((batch_size, 1, 1 ,1)) * imgs \
            + self.sqrt_minus_bar_alpha_t_schedule[t].view((batch_size, 1, 1, 1)) * noise
        
        conds = conds.unsqueeze(1)
        mask = torch.rand_like(conds, dtype=torch.float32) > self.p_cond
        
        pred_noise = self.nn_model(noise_imgs, t.unsqueeze(1), conds, mask.long())

        return self.criterion(pred_noise, noise)
    
    def sample(self, n_samples: int, size, classes: List[int], w: float):
        self.eval()
        assert len(classes) == n_samples
        with torch.no_grad():
            x_t = torch.randn(n_samples, *size, device=device)
            cond = torch.tensor(classes).unsqueeze(1).to(device)
            mask_ones = torch.ones_like(cond)
            mask_zeros = torch.zeros_like(cond)
            w = torch.tensor(w).to(device)
            for t in range(self.T, 0, -1):
                z = torch.randn_like(x_t, device=device) if t > 0 else 0
                t_tensor = torch.tensor([t], device=device).repeat(x_t.shape[0], 1)
                pred_noise_cond = self.nn_model(x_t, t_tensor, cond, mask_ones)
                pred_noise_zero = self.nn_model(x_t, t_tensor, cond, mask_zeros)
                pred_noise = (1 + w) * pred_noise_cond - w * pred_noise_zero
                x_t = 1 / torch.sqrt(self.alpha_t_schedule[t]) * \
                    (x_t - pred_noise * (1 - self.alpha_t_schedule[t]) / self.sqrt_minus_bar_alpha_t_schedule[t]) + \
                    torch.sqrt(self.beta_schedule[t]) * z
            x_t = x_t*-1 + 1
            return x_t

In [24]:
def train(model: DDPM, optimizer: torch.optim.Optimizer, epochs: int):
    training_losses = []
    val_losses = []
    for epoch in range(epochs):
        model.train(True)
        training_loss = 0
        val_loss = 0
        pbar = tqdm.tqdm(train_dataloader)
        for index, (imgs, labels) in enumerate(pbar):
            optimizer.zero_grad()
            
            imgs = imgs.to(device)
            labels = labels.to(device)
    
            loss = model(imgs, labels)
    
            loss.backward()
            optimizer.step()
            training_loss += loss.item()
            pbar.set_description(f"loss for epoch {epoch}: {training_loss / (index + 1):.4f}")
        model.eval()
        with torch.no_grad():
            for (imgs, labels) in test_dataloader:
                imgs = imgs.to(device)
                labels = labels.to(device)
                
                loss = model(imgs, labels)
        
                val_loss += loss.item()
        training_losses.append(training_loss / len(train_dataset))
        val_losses.append(val_loss / len(test_dataset))
    return training_losses, val_losses

In [None]:
EPOCHS = 10

T = 400

model_with_time = DDPM(T = T, nn_model=UNet(in_channels=1, out_channels=1, T=T+1, num_classes=10))
_, val_losses = train(
  model_with_time,
  torch.optim.Adam(params=model_with_time.parameters(), lr=2e-4),
  EPOCHS
)

plt.plot(val_losses, label="with time")

plt.legend()

loss for epoch 0: 0.2100:  11%|█         | 51/469 [01:12<09:55,  1.42s/it] 


KeyboardInterrupt: 

In [45]:
n_samples = 10
x_t = model_with_time.sample(n_samples=n_samples, size=train_dataset[0][0].shape, classes=[0,1,2,3,4,5,6,7,8,9], w=2)
grid = make_grid(x_t, nrow=10)
save_image(grid, f"epoch_without.png")

cols = 5
rows = (n_samples // cols) + (0 if n_samples % cols == 0 else 1)
fig, axs = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
for i in range(x_t.shape[0]):
    row = i // cols
    axs[row, i % cols].imshow(x_t[i].permute(1,2,0).detach().cpu().numpy(), cmap='gray')

torch.Size([10, 1])
torch.Size([10, 1])
torch.Size([10, 1])


KeyboardInterrupt: 