In [35]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import tqdm
from torchvision.utils import save_image, make_grid
import math

In [48]:
IGNORE_TIME_EMB = True

class MLP(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.ln = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(input_dim, output_dim),
        )

    def forward(self, x):
        return self.ln(x)

class DoubleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim, is_debug=False):
        super().__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(8, out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(8, out_channels),
            nn.ReLU(inplace=True)
        )
        self.time_emb = MLP(input_dim=emb_dim, output_dim=out_channels)
        self.is_debug = is_debug

    def forward(self, x, t, has_attn = False):
        x = self.conv_1(x)
        if IGNORE_TIME_EMB:
            return self.conv_2(x)
        t = self.time_emb(t)
        batch_size, emb_dim = t.shape 
        t = t.view(batch_size, emb_dim, 1, 1)
        return self.conv_2(x + t)
    
class PositionalEmbedding(nn.Module):
    def __init__(self, T: int, output_dim: int) -> None:
        super().__init__()
        self.output_dim = output_dim
        position = torch.arange(T).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, output_dim, 2) * (-math.log(10000.0) / output_dim))
        pe = torch.zeros(T, output_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor):
        return self.pe[x].reshape(x.shape[0], self.output_dim)

# print(PositionalEmbedding(2,4)(torch.tensor([[1], [0], [0]])))
# torch.tensor([[1,2,3], [4,5,6], [7,8,9]])[ torch.tensor([[1], [0], [1]]) ].reshape(3, 3)

# print( DoubleConvBlock(1, 16, 4)(torch.rand(2, 1, 28, 28), torch.rand(2, 4)) )

In [37]:

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim, is_debug=False):
        super().__init__()
        self.conv = DoubleConvBlock(
            in_channels=in_channels,
            out_channels=out_channels,
            emb_dim=emb_dim,
            is_debug=is_debug
        )
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.is_debug = is_debug
    def forward(self, x, t):
        x = self.conv(x, t)
        return self.pool(x)
    
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim, is_debug=False):
        super().__init__()
        self.upscale = nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels, kernel_size=2, stride=2)
        self.conv = DoubleConvBlock(
            in_channels=in_channels,
            out_channels=out_channels,
            emb_dim=emb_dim,
            is_debug=is_debug
        )
        self.is_debug = is_debug
    
    def forward(self, x, skip, t):
        x = self.upscale(torch.cat([x, skip], 1))
        return self.conv(x, t)


In [38]:
print(0.)

0.0


In [39]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, T, hid_size = 256, is_debug = False):
        super().__init__()

        time_emb_dim = hid_size * 4
        self.time_embedding = nn.Sequential(
            PositionalEmbedding(T=T, output_dim=hid_size),
            nn.Linear(hid_size, time_emb_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hid_size * 4, time_emb_dim)
        )

        self.is_debug = is_debug
        self.down_1 = DownBlock(in_channels=in_channels, out_channels=hid_size, emb_dim=time_emb_dim)
        self.down_2 = DownBlock(in_channels=hid_size, out_channels=hid_size * 2, emb_dim=time_emb_dim)
        # self.resnet_left_3 = ResnetBlock(in_channels=hid_size * 2, out_channels=hid_size * 4)
        # self.down_3 = nn.MaxPool2d(kernel_size=2, return_indices=True)

        self.backbone = DoubleConvBlock(in_channels=hid_size * 2, out_channels=hid_size * 2, emb_dim=time_emb_dim)

        # self.up_1 = nn.ConvTranspose2d(in_channels=hid_size * 4, out_channels=hid_size * 4, kernel_size=2, stride=2)
        # self.resnet_right_1 = ResnetBlock(in_channels=hid_size * 4, out_channels=hid_size * 2)
        self.up_2 = UpBlock(in_channels=hid_size * 4, out_channels=hid_size, emb_dim=time_emb_dim)
        self.up_3 = UpBlock(in_channels=hid_size * 2, out_channels=hid_size, emb_dim=time_emb_dim)
        self.out = nn.Conv2d(in_channels=hid_size, out_channels=out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x, t):
        time_emb = self.time_embedding(t)

        down_1 = self.down_1(x, time_emb)
        if self.is_debug:
            print("Down 1 shape:", down_1.shape)
        down_2 = self.down_2(down_1, time_emb)
        if self.is_debug:
            print("Down 2 shape:", down_2.shape)
        # x_3 = self.resnet_left_3(x)
        # if self.is_debug:
        #     print("Resnet left 3 shape:", x_3.shape)
        # x, ind_3 = self.down_3(x_3)
        # if self.is_debug:
        #     print("Down 3 shape:", x.shape)
        
        # if self.is_debug:
        #     print("Time:", t.shape)
        # batch_size, dim = time_emb.shape
        # time_emb = time_emb.view(batch_size, dim, 1, 1)
        # if self.is_debug:
        #     print("Time embedding:", time_emb.shape)
        # x = self.backbone(x + time_emb)
        x = self.backbone(down_2, time_emb)
        
        # x = self.up_1(x, indices = ind_3)
        # if self.is_debug:
        #     print("Up 1 shape:", x.shape)
        # x = self.resnet_right_1(x + x_3)
        # if self.is_debug:
        #     print("Resnet right 1 shape:", x.shape)
        x = self.up_2(x, down_2, time_emb)
        if self.is_debug:
            print("Up 2 shape:", x.shape)
        x = self.up_3(x, down_1, time_emb)
        if self.is_debug:
            print("Up 1 shape:", x.shape)
        x = self.out(x)

        return x
    
# test_unet = UNet(in_channels=1, out_channels=1, T=100, is_debug=True)
# test_unet(torch.rand(3, 1, 28, 28), torch.randint(1, 100, (3,1)))

In [40]:
import torch
import gc

# Collect garbage
gc.collect()

# Empty the PyTorch cache
torch.cuda.empty_cache()

In [41]:
transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5], std=[0.5])
])

class UnNormalize(torch.nn.Module):
    def __init__(self, mean, std) -> None:
        super().__init__()
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)
    
    def forward(self, tensor):
        mean = self.mean.to(tensor.device)
        std = self.std.to(tensor.device)
        
        # Clone the input tensor to avoid modifying the original tensor
        tensor = tensor.clone()
        
        # Apply the un-normalization
        unnormalized_tensor = tensor * std[:, None, None] + mean[:, None, None]
        
        return unnormalized_tensor
    
transform_to_pil = transforms.Compose([
    # UnNormalize(mean=[0.5], std=[0.5]),
    transforms.ToPILImage()
])

# plt.imshow(transform_to_pil(train_dataset[0][0]), cmap='gray')

In [2]:
train_dataset = torchvision.datasets.MNIST(root="./datasets", download=True, transform=transform_to_tensor)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

test_dataset = torchvision.datasets.MNIST(root="./datasets", download=True, transform=transform_to_tensor, train=False)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, num_workers=2)

NameError: name 'torchvision' is not defined

In [43]:
import torch.backends
import torch.backends.mps

device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available():
    device = "mps"
    print("USE MPS")
    # test_unet = UNet(in_channels=1, out_channels=1, is_debug=True).to(device)
    # test_unet(train_dataset[0][0].unsqueeze(0).to(device), torch.tensor([[1]], dtype=torch.float, device=device))

# import os
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

USE MPS


In [44]:
class DDPM(nn.Module):
    def __init__(self, T: int, nn_model: nn.Module):
        super().__init__()
        self.T = T
        self.nn_model = nn_model.to(device)
        beta_schedule = torch.linspace(1e-4, 0.02, T + 1, device=device)
        alpha_t_schedule = 1 - beta_schedule
        bar_alpha_t_schedule = torch.cumprod(alpha_t_schedule.detach().cpu(), 0).to(device)
        sqrt_bar_alpha_t_schedule = torch.sqrt(bar_alpha_t_schedule)
        sqrt_minus_bar_alpha_t_schedule = torch.sqrt(1 - bar_alpha_t_schedule)
        self.register_buffer("beta_schedule", beta_schedule)
        self.register_buffer("alpha_t_schedule", alpha_t_schedule)
        self.register_buffer("bar_alpha_t_schedule", bar_alpha_t_schedule)
        self.register_buffer("sqrt_bar_alpha_t_schedule", sqrt_bar_alpha_t_schedule)
        self.register_buffer("sqrt_minus_bar_alpha_t_schedule", sqrt_minus_bar_alpha_t_schedule)
        self.criterion = nn.MSELoss()

    def forward(self, imgs):
        t = torch.randint(low=1, high=self.T+1, size=(imgs.shape[0],), device=device)
        noise = torch.randn_like(imgs, device=device)
        batch_size, channels, width, height = imgs.shape
        noise_imgs = self.sqrt_bar_alpha_t_schedule[t].view((batch_size, 1, 1 ,1)) * imgs \
            + self.sqrt_minus_bar_alpha_t_schedule[t].view((batch_size, 1, 1, 1)) * noise
        
        pred_noise = self.nn_model(noise_imgs, t.unsqueeze(1))

        return self.criterion(pred_noise, noise)
    
    def sample(self, n_samples, size):
        self.eval()
        with torch.no_grad():
            x_t = torch.randn(n_samples, *size, device=device)
            for t in range(self.T, 0, -1):
                z = torch.randn_like(x_t, device=device) if t > 0 else 0
                t_tensor = torch.tensor([t], device=device).repeat(x_t.shape[0], 1)
                pred_noise = self.nn_model(x_t, t_tensor)
                x_t = 1 / torch.sqrt(self.alpha_t_schedule[t]) * \
                    (x_t - pred_noise * (1 - self.alpha_t_schedule[t]) / self.sqrt_minus_bar_alpha_t_schedule[t]) + \
                    torch.sqrt(self.beta_schedule[t]) * z
            x_t = x_t*-1 + 1
            return x_t

In [45]:
def train(model: DDPM, optimizer: torch.optim.Optimizer, epochs: int):
    training_losses = []
    val_losses = []
    for epoch in range(epochs):
        model.train(True)
        training_loss = 0
        val_loss = 0
        pbar = tqdm.tqdm(train_dataloader)
        for index, (imgs, labels) in enumerate(pbar):
            optimizer.zero_grad()
            
            imgs = imgs.to(device)
    
            loss = model(imgs)
    
            loss.backward()
            optimizer.step()
            training_loss += loss.item()
            pbar.set_description(f"loss for epoch {epoch}: {training_loss / (index + 1):.4f}")
        model.eval()
        with torch.no_grad():
            for (imgs, labels) in test_dataloader:
                imgs = imgs.to(device)
                
                loss = model(imgs)
        
                val_loss += loss.item()
        training_losses.append(training_loss / len(train_dataset))
        val_losses.append(val_loss / len(test_dataset))
    return training_losses, val_losses

In [49]:
EPOCHS = 10

T = 400

IGNORE_TIME_EMB = True
model_without_time = DDPM(T = T, nn_model=UNet(in_channels=1, out_channels=1, T=T+1))
_, val_losses_without = train(
  model_without_time,
  torch.optim.Adam(params=model_without_time.parameters(), lr=2e-4),
  EPOCHS
)

IGNORE_TIME_EMB = False
model_with_time = DDPM(T = T, nn_model=UNet(in_channels=1, out_channels=1, T=T+1))
_, val_losses = train(
  model_with_time,
  torch.optim.Adam(params=model_with_time.parameters(), lr=2e-4),
  EPOCHS
)

plt.plot(val_losses_without, label="without time")
plt.plot(val_losses, label="with time")

plt.legend()

loss for epoch 0: 1.4783:   4%|▎         | 17/469 [00:20<09:15,  1.23s/it]


KeyboardInterrupt: 

In [None]:
IGNORE_TIME_EMB = True

n_samples = 10
x_t = model_without_time.sample(n_samples=n_samples, size=train_dataset[0][0].shape)
grid = make_grid(x_t, nrow=10)
save_image(grid, f"epoch_without.png")

cols = 5
rows = (n_samples // cols) + (0 if n_samples % cols == 0 else 1)
fig, axs = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
for i in range(x_t.shape[0]):
    row = i // cols
    axs[row, i % cols].imshow(x_t[i].permute(1,2,0).detach().cpu().numpy(), cmap='gray')

In [None]:
IGNORE_TIME_EMB = False

n_samples = 10
x_t = model_with_time.sample(n_samples=n_samples, size=train_dataset[0][0].shape)
grid = make_grid(x_t, nrow=10)
save_image(grid, f"epoch_without.png")

cols = 5
rows = (n_samples // cols) + (0 if n_samples % cols == 0 else 1)
fig, axs = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
for i in range(x_t.shape[0]):
    row = i // cols
    axs[row, i % cols].imshow(x_t[i].permute(1,2,0).detach().cpu().numpy(), cmap='gray')