In [1]:
# https://google.github.io/mediapy/mediapy.html
# https://einops.rocks/
!pip install mediapy einops --quiet

In [2]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.transforms import ToTensor
import mediapy as media
from tqdm import tqdm
import matplotlib.pyplot as plt

In [3]:
## utils

def inspect(label, im):
    """ Print some basic image stats."""
    if im is None:
      return
    print()
    print(label + ':')
    print('shape:', im.shape)
    print('dtype:', im.dtype)
    print('max:', torch.max(im))
    print('min:', torch.min(im))
    if im.dtype == torch.float32:
      print('mean:', torch.mean(im))
      print('std:', torch.std(im))
    print()

def ctime_as_fname():
    """ Return time.ctime() formatted well for a file name."""
    return  time.ctime().replace(' ', '_').replace(':', '.')

def torch2np(im):
    print('type im', type(im))
    if len(im.shape) == 4:
        return im.detach().permute(0, 2, 3, 1).cpu().numpy()
    if len(im.shape) == 3:
        return im.detach().permute(1, 2, 0).cpu().numpy()
    if len(im.shape) == 1:
        return im.detach().cpu().numpy()
    else:
        inspect('im', im)
        raise Exception

def show_tensor(x, max_n_horiz=9, height=75):
    """ Expect x.shape = (B, C, H, W) or (C, H, W) """
    if isinstance(x, torch.Tensor):
        x = torch2np(x)
    if len(x.shape) == 4:
        if max_n_horiz is None:           
            media.show_images(x, border=True, height=height)
        else:
            B = x.shape[0]
            idx = 0
            while idx < B:
                media.show_images(x[idx:idx+max_n_horiz], height=height)
                idx += max_n_horiz
    else:
        media.show_image(x, height=height)

def get_video_frames(traj, nrows, ncols, height=200, fps=100):
    N, T, H, W = traj.shape
    assert (N % nrows) == 0, 'N must be divisible by nrows'
    trj = (torch.clip(traj, 0, 1) * 255).byte().cpu().numpy()
    # trj = trj.transpose(1, 2, 0, 3).reshape(T, H, -1)
    trj = trj.reshape(nrows, ncols, T, H, W).transpose(2, 0, 3, 1, 4).reshape(T, nrows * H, ncols * W)
    
    frames = [frame for frame in trj]
    
    return frames

In [4]:
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

mps


In [5]:
train_set = MNIST(
    root='./data',
    train=True,
    download=True,
    transform=ToTensor(),
)

print(train_set)

x, y = train_set[0]
inspect('x', x)
media.show_image(x.squeeze(0), height=150)

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

x:
shape: torch.Size([1, 28, 28])
dtype: torch.float32
max: tensor(1.)
min: tensor(0.)
mean: tensor(0.1377)
std: tensor(0.3125)



In [6]:
""" nn.Modules """

class Conv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, activation=nn.GELU):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.activation = activation()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
      x = self.conv(x)
      x = self.bn(x)
      x = self.activation(x)
      return x


class DownConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, activation=nn.GELU):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=2,
            padding=1,
        )
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.activation = activation()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x


class UpConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, activation=nn.GELU):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv = nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.activation = activation()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x


class Flatten(nn.Module):
    def __init__(self, activation=nn.GELU, kernel_size=7):
        super().__init__()
        self.pool = nn.AvgPool2d(kernel_size=kernel_size)
        self.activation = activation()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(x)
        x = self.activation(x)
        return x


class Unflatten(nn.Module):
    def __init__(self, in_channels: int, activation=nn.GELU, kernel_size=7):
        super().__init__()
        self.in_channels = in_channels
        self.conv = nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=kernel_size,
            stride=kernel_size,
            padding=0,
        )
        self.bn = nn.BatchNorm2d(num_features=in_channels)
        self.activation = activation()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # inspect('unflatten.forward initial x:', x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x


class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, activation=nn.GELU):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv1 = Conv(in_channels=in_channels, out_channels=out_channels, activation=activation)
        self.conv2 = Conv(in_channels=out_channels, out_channels=out_channels, activation=activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = x + self.conv2(x)
        return x


class DownBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, activation=nn.GELU):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv1 = Conv(in_channels=in_channels, out_channels=out_channels, activation=activation)
        self.conv2 = DownConv(in_channels=out_channels, out_channels=out_channels, activation=activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class UpBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, activation=nn.GELU):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv1 = Conv(in_channels=in_channels, out_channels=out_channels, activation=activation)
        self.conv2 = UpConv(in_channels=out_channels, out_channels=out_channels, activation=activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class FCBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, activation=nn.GELU):
        super().__init__()
        self.fc1 = nn.Linear(in_features=in_channels, out_features=out_channels)
        self.activation = activation()
        self.fc2 = nn.Linear(in_features=out_channels, out_features=out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x





class TimeConditionalUNet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        H: int,
        W: int,
        out_channels: int = None,
        num_hiddens: int = 64,
        activation=nn.GELU,
    ):
        super().__init__()
        self.D = D = num_hiddens
        self.C = C = in_channels
        self.C_out = C_out = out_channels if out_channels is not None else in_channels
        self.H = H
        self.W = W
        
        self.cb1 = ConvBlock(in_channels=C, out_channels=D, activation=activation)
        self.db1 = DownBlock(in_channels=D, out_channels=D, activation=activation)
        self.db2 = DownBlock(in_channels=D, out_channels=2*D, activation=activation)
        self.flatten = Flatten(activation=activation, kernel_size=(H//4, W//4))
        self.unflatten = Unflatten(in_channels=2*D, activation=activation, kernel_size=(H//4, W//4))
        self.ub2 = UpBlock(in_channels=4*D, out_channels=D, activation=activation)
        self.ub1 = UpBlock(in_channels=2*D, out_channels=D, activation=activation)
        self.cb2 = ConvBlock(in_channels=2*D, out_channels=D, activation=activation)
        self.conv_out = nn.Conv2d(in_channels=D, out_channels=C_out, kernel_size=3, stride=1, padding=1)


        # Time conditioning layers
        self.fc_unflat_t = FCBlock(in_channels=1, out_channels=2*D, activation=activation)
        self.fc_ub2_t = FCBlock(in_channels=1, out_channels=D, activation=activation)

    def forward(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            x: (B, C, H, W) input tensor.
            t: (B,) normalized time tensor.

        Returns:
            (B, C_out, H, W) output tensor.
        """
        # inspect('x', x)
        B, C, H, W = x.shape
        assert (C, H, W) == (self.C, self.H, self.W)
        t = t.unsqueeze(-1)

        unflat_t = self.fc_unflat_t(t).reshape(B, 2*self.D, 1, 1)
        ub2_t = self.fc_ub2_t(t).reshape(B, self.D, 1, 1)

        x0 = self.cb1(x)
        x1 = self.db1(x0)
        x2 = self.db2(x1)
        flat = self.flatten(x2)
        lat2 = self.unflatten(flat)
        lat2 = lat2 + unflat_t
        lat2 = torch.concat((x2, lat2), dim=1)
        lat1 = self.ub2(lat2)

        lat1 = lat1 + ub2_t
        lat1 = torch.concat((x1, lat1), dim=1)
        lat0 = self.ub1(lat1)
        lat0 = torch.concat((x0, lat0), dim=1)
        lat0 = self.cb2(lat0)
        out = self.conv_out(lat0)
        return out







# class TimeConditionalUNet(nn.Module):
#     def __init__(
#         self,
#         C: int,
#         H: int,
#         W: int,
#         num_hiddens: int = 64,
#         activation=nn.GELU,
#     ):
#         super().__init__()
#         self.D = D = num_hiddens
#         self.C = C
#         self.H = H
#         self.W = W
        
#         self.cb1 = ConvBlock(in_channels=C, out_channels=D, activation=activation)
#         self.db1 = DownBlock(in_channels=D, out_channels=D, activation=activation)
#         self.db2 = DownBlock(in_channels=D, out_channels=2*D, activation=activation)
#         self.flatten = Flatten(activation=activation, kernel_size=(H//4, W//4))
#         self.unflatten = Unflatten(in_channels=2*D, activation=activation, kernel_size=(H//4, W//4))
#         self.ub2 = UpBlock(in_channels=4*D, out_channels=D, activation=activation)
#         self.ub1 = UpBlock(in_channels=2*D, out_channels=D, activation=activation)
#         self.cb2 = ConvBlock(in_channels=2*D, out_channels=D, activation=activation)
#         # self.conv_out = Conv(in_channels=D, out_channels=1, activation=activation) ## this line is wrong, line below is correct (think: we dont want to apply an activation to the model outputs, we would like the model to be able to predict arbitrary outputs, not just outpures > -.5 or whatever min_x GELU(x) is)
#         self.conv_out = nn.Conv2d(in_channels=D, out_channels=C, kernel_size=3, stride=1, padding=1)

#         ## add the time conditional layers
#         self.fc_unflat_t = FCBlock(in_channels=1, out_channels=2*D, activation=activation)
#         self.fc_ub2_t = FCBlock(in_channels=1, out_channels=D, activation=activation)


#     def forward(
#         self,
#         x: torch.Tensor,
#         t: torch.Tensor,
#     ) -> torch.Tensor:
#         """
#         Args:
#             x: (N, C, H, W) input tensor.
#             t: (N,) normalized time tensor.

#         Returns:
#             (N, C, H, W) output tensor.
#         """
#         # inspect('x', x)
#         # assert x.shape[-2:] == (28, 28), "Expect input shape to be (28, 28)."
#         N, C, H, W = x.shape
#         t = t.unsqueeze(-1)

#         unflat_t = self.fc_unflat_t(t).reshape(N, 2*self.D, 1, 1)
#         ub2_t = self.fc_ub2_t(t).reshape(N, self.D, 1, 1)

#         x0 = self.cb1(x)
#         x1 = self.db1(x0)
#         x2 = self.db2(x1)
#         flat = self.flatten(x2)
#         lat2 = self.unflatten(flat)
#         lat2 = lat2 + unflat_t
#         lat2 = torch.concat((x2, lat2), dim=1)
#         lat1 = self.ub2(lat2)
#         lat1 = lat1 + ub2_t
#         lat1 = torch.concat((x1, lat1), dim=1)
#         lat0 = self.ub1(lat1)
#         lat0 = torch.concat((x0, lat0), dim=1)
#         lat0 = self.cb2(lat0)
#         out = self.conv_out(lat0)
#         return out





# class ImageConditionalUNet(nn.Module):
#     def __init__(
#         self,
#         in_channels: int,
#         cond_channels: int,
#         num_hiddens: int = 64,
#         activation=nn.GELU,
#     ):
#         super().__init__()
#         D = num_hiddens
#         self.D = D
        
#         # Main input path
#         self.cb1 = ConvBlock(in_channels=in_channels, out_channels=D, activation=activation)
#         self.db1 = DownBlock(in_channels=D, out_channels=D, activation=activation)
#         self.db2 = DownBlock(in_channels=D, out_channels=2*D, activation=activation)
#         self.flatten = Flatten(activation=activation)
#         self.unflatten = Unflatten(in_channels=2*D, activation=activation)
#         self.ub2 = UpBlock(in_channels=4*D, out_channels=D, activation=activation)
#         self.ub1 = UpBlock(in_channels=2*D, out_channels=D, activation=activation)
#         self.cb2 = ConvBlock(in_channels=2*D, out_channels=D, activation=activation)
#         self.conv_out = nn.Conv2d(in_channels=D, out_channels=1, kernel_size=3, stride=1, padding=1)

#         # Conditioning path for image conditioning
#         self.cond_cb1 = ConvBlock(in_channels=cond_channels, out_channels=D, activation=activation)
#         self.cond_db1 = DownBlock(in_channels=D, out_channels=D, activation=activation)
#         self.cond_db2 = DownBlock(in_channels=D, out_channels=2*D, activation=activation)
#         self.cond_flatten = Flatten(activation=activation)

#         # Time conditioning layers
#         self.fc_unflat_t = FCBlock(in_channels=1, out_channels=2*D, activation=activation)
#         self.fc_ub2_t = FCBlock(in_channels=1, out_channels=D, activation=activation)

#     def forward(
#         self,
#         x: torch.Tensor,
#         t: torch.Tensor,
#         cond: torch.Tensor,
#     ) -> torch.Tensor:
#         """
#         Args:
#             x: (N, C, H, W) input tensor.
#             t: (N,) normalized time tensor.
#             cond: (N, C_cond, H, W) conditioning tensor.

#         Returns:
#             (N, C, H, W) output tensor.
#         """
#         # Process main input
#         x0 = self.cb1(x)
#         x1 = self.db1(x0)
#         x2 = self.db2(x1)
#         flat_x = self.flatten(x2)

#         # Process conditioning input
#         cond0 = self.cond_cb1(cond)
#         cond1 = self.cond_db1(cond0)
#         cond2 = self.cond_db2(cond1)
#         flat_cond = self.cond_flatten(cond2)

#         # Combine input and conditioning paths
#         fused_flat = flat_x + flat_cond  # Alternatively, try concatenation

#         # Apply time conditioning
#         unflat_t = self.fc_unflat_t(t).reshape(-1, 2*self.D, 1, 1)
#         lat2 = self.unflatten(fused_flat) + unflat_t
#         lat2 = torch.concat((x2, lat2), dim=1)

#         lat1 = self.ub2(lat2)
#         ub2_t = self.fc_ub2_t(t).reshape(-1, self.D, 1, 1)
#         lat1 = lat1 + ub2_t
#         lat1 = torch.concat((x1, lat1), dim=1)

#         lat0 = self.ub1(lat1)
#         lat0 = torch.concat((x0, lat0), dim=1)
#         lat0 = self.cb2(lat0)

#         out = self.conv_out(lat0)
#         return out


In [None]:
N = 5
C = 3
H = 64
W = 64

x = torch.randn(N, C, H, W).to(device)
unet = TimeConditionalUNet(in_channels=C, H=H, W=W).to(device)
# out = unet(x, torch.zeros(N).expand(-1))
t = torch.ones(N, dtype=int).to(device)
out = unet(x=x, t=t.to(torch.float32))
inspect('x', x)
inspect('out', out)

In [None]:
# N = 5
# C = 5
# H = 64
# W = 64




# # x = torch.randn(N, C, H, W).to(device)
# # unet = ImageConditionalUNet(in_channels=C, cond_channels=C).to(device)
# # # out = unet(x, torch.zeros(N).expand(-1))
# # t = torch.ones(N, dtype=int).to(device)
# # out = unet(x=x, cond=x, t=t.to(torch.float32))
# # inspect('x', x)
# # inspect('out', out)


# x = torch.randn(N, 1, 28, 28).to(device)  # Main input
# cond = torch.randn(N, 1, 28, 28).to(device)  # Conditioning image
# # t = torch.rand(N)  # Normalized time step
# t = torch.ones(N, dtype=torch.float32).to(device)

# model = ImageConditionalUNet(in_channels=1, cond_channels=1, num_hiddens=64)
# model = model.to(device)
# output = model(x, t, cond)

In [None]:
### DDPM ###
    
# @torch.inference_mode()
# def tc_ddpm_sample(
#     unet: TimeConditionalUNet,
#     ddpm_schedule: dict,
#     img_wh: tuple[int, int],
#     num_ts: int,
#     seed: int = 0,
#     num_samples: int = 1,
# ) -> torch.Tensor:
#     """Algorithm 2 of the DDPM paper with classifier-free guidance.

#     Args:
#         unet: TimeConditionalUNet
#         ddpm_schedule: dict
#         img_wh: (H, W) output image width and height.
#         num_ts: int, number of timesteps.
#         seed: int, random seed.

#     Returns:
#         (N, C, H, W) final sample.
#     """
#     unet.eval()
#     torch.manual_seed(seed)
#     N, C, H, W = num_samples, 1, img_wh[0], img_wh[1]
#     x_t = torch.randn(N, C, H, W).to(device)
#     traj = [x_t]
#     for t_scalar in torch.arange(num_ts - 1, 0, -1).to(device):
#         torch.manual_seed(seed)
#         if t_scalar > 1:
#             z = torch.randn_like(x_t).to(device)
#         else:
#             z = torch.zeros_like(x_t).to(device)
#         t = torch.ones(N, dtype=int).to(device) * t_scalar

#         noise_pred = unet(x=x_t, t=t.to(torch.float32)/num_ts)

#         a = ddpm_schedule['alphas'].to(device)
#         ab = ddpm_schedule['alpha_bars'].to(device)
#         b = ddpm_schedule['betas'].to(device)

#         clean_est = (x_t - torch.sqrt(1 - ab[t]) * noise_pred) / torch.sqrt(ab[t])

#         a0 = torch.sqrt(ab[t-1]) * b[t]
#         a1 = torch.sqrt(a[t]) * (1 - ab[t-1])
#         x_t = (a0 * clean_est + a1 * x_t) / (1 - ab[t])
#         x_t = x_t + torch.sqrt(b[t]) * z
#         traj.append(x_t)

#         seed += 1
#     return x_t, torch.cat(traj, dim=1)



class DDPM(nn.Module):
    
    def __init__(
            self,
            unet: TimeConditionalUNet,
            betas: tuple[float, float] = (1e-4, 0.02),
            num_ts: int = 300,
        ):
        super().__init__()
        
        self.C, self.H, self.W = unet.C, unet.H, unet.W
        self.unet = unet
        self.num_ts = num_ts
        self.schedule = DDPM.get_schedule(betas[0], betas[1], num_ts)

        for k, v in self.schedule.items():
            self.register_buffer(k, v, persistent=False)

    
    def forward(self, x_0: torch.Tensor) -> torch.Tensor:
        """ Algorithm 1 of the DDPM paper.
        Args:
            x: (N, C, H, W) input tensor.

        Returns:
            (,) diffusion loss.
        """
        self.unet.train()
        
        N, C, H, W = x_0.shape
        t = torch.randint(low=0, high=self.num_ts, size=(N,), device=device)

        noise = torch.randn_like(x_0, device=device)
        ab = self.schedule['alpha_bars']
        x_t = (
            torch.sqrt(ab[t]) * x_0
            + torch.sqrt(1 - ab[t]) * noise
        )
        t = t.to(torch.float32) / self.num_ts  ## note t.unsqueeze(-1) for the FCBlocks
        noise_pred = self.unet(x=x_t, t=t)
        loss = F.mse_loss(noise, noise_pred)
        return loss
    

    @torch.inference_mode()
    def sample(
        self,
        seed: int = 0,
        num_samples: int = 1,
    ) -> torch.Tensor:
        """Algorithm 2 of the DDPM paper with classifier-free guidance.
    
        Args:
            seed: int, random seed.
    
        Returns:
            (N, C, H, W) final sample.
        """
        self.unet.eval()
        torch.manual_seed(seed)
        
        N, C, H, W = num_samples, self.C, self.H, self.W
        
        x_t = torch.randn(N, C, H, W).to(device)
        traj = [x_t]
        for t_scalar in torch.arange(self.num_ts - 1, 0, -1).to(device):
            torch.manual_seed(seed)
            if t_scalar > 1:
                z = torch.randn_like(x_t).to(device)
            else:
                z = torch.zeros_like(x_t).to(device)
            t = torch.ones(N, dtype=int).to(device) * t_scalar
    
            noise_pred = self.unet(x=x_t, t=t.to(torch.float32)/self.num_ts)
    
            a = self.schedule['alphas'].to(device)
            ab = self.schedule['alpha_bars'].to(device)
            b = self.schedule['betas'].to(device)
    
            clean_est = (x_t - torch.sqrt(1 - ab[t]) * noise_pred) / torch.sqrt(ab[t])
    
            a0 = torch.sqrt(ab[t-1]) * b[t]
            a1 = torch.sqrt(a[t]) * (1 - ab[t-1])
            x_t = (a0 * clean_est + a1 * x_t) / (1 - ab[t])
            x_t = x_t + torch.sqrt(b[t]) * z
            traj.append(x_t)
    
            seed += 1
        return x_t, torch.cat(traj, dim=1)
    

    @classmethod
    def get_schedule(cls, beta1: float, beta2: float, num_ts: int) -> dict:
        """Constants for DDPM training and sampling.
    
        Arguments:
            beta1: float, starting beta value.
            beta2: float, ending beta value.
            num_ts: int, number of timesteps.
    
        Returns:
            dict with keys:
                betas: linear schedule of betas from beta1 to beta2.
                alphas: 1 - betas.
                alpha_bars: cumulative product of alphas.
        """
        assert beta1 < beta2 < 1.0, "Expect beta1 < beta2 < 1.0."
        betas = torch.linspace(beta1, beta2, num_ts).reshape(num_ts, 1, 1, 1).to(torch.float32).to(device)
        alphas = 1 - betas
        alpha_bars = torch.cumprod(alphas, dim=0)
        return {
            "betas": betas.to(device),
            "alphas": alphas.to(device),
            "alpha_bars": alpha_bars.to(device),
        }

In [None]:
""" Create TimeConditionalUNet DDPM model and train it, or load weights from saved checkpoint. """

num_epochs = 1
num_hiddens = 16
batch_size = 128

load_from_saved = False
epoch_number_str = str(num_epochs)
checkpoint_fname_stem = f'data/tcond_unet_optim_epoch' # {epoch_number_str}.pth'
checkpoint_fname = f'{checkpoint_fname_stem}{epoch_number_str}.pth'

model = TimeConditionalUNet(
    in_channels=1,
    H=28,
    W=28,
    num_hiddens=num_hiddens,
).to(device)

ddpm = DDPM(
    unet=model,
    betas=(1e-4, 0.2),
    num_ts=30,
).to(device)

gamma = 0.1 ** (1.0 / num_epochs)

subset = torch.utils.data.Subset(train_set, list(range(10_000)))

dataloader = DataLoader(
    dataset=subset,
    batch_size=batch_size,
    shuffle=True,
)

optimizer = torch.optim.Adam(ddpm.parameters(), lr=1e-2)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=gamma)

cur_epoch = 0
tcond_ddpm_train_losses = []

if load_from_saved:

    checkpoint = torch.load(checkpoint_fname, weights_only=True)
    cur_epoch = checkpoint['epoch'] + 1
    tcond_ddpm_train_losses = checkpoint['train_losses']
    ddpm.load_state_dict(checkpoint['ddpm_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])

for epoch in range(cur_epoch, num_epochs):
    epoch_loss = 0
    batch_count = 0
    for x, y in tqdm(dataloader):
        optimizer.zero_grad() ## .. i initially forgot this line and ended up with train curves that have a distincly oscillitory (underdamped) motion.. interesting..
        x = x.to(device)
        # x = x*2 - 1 ### affine scaling
        loss = ddpm(x)
        loss.backward()
        optimizer.step()
        tcond_ddpm_train_losses.append(loss.item())
        epoch_loss += loss.item()
        # batch_count += 1
        # if batch_count > 100:
        #     break
    average_loss = epoch_loss / len(dataloader)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}')

    lr_scheduler.step()

    if (epoch % 5 == 0) or (epoch == num_epochs - 1):
        checkpoint = {
            'epoch': epoch,
            'ddpm_state_dict': ddpm.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler_state_dict': lr_scheduler.state_dict(),
            'train_losses': tcond_ddpm_train_losses,
        }
        path = f'{checkpoint_fname_stem}{epoch}.pth'
        # path = f'/content/drive/My Drive/cs180_project5/tcond_unet_optim_epoch{epoch}.pth'
        if not os.path.exists('data'):
            os.mkdir('data')

        torch.save(checkpoint, path)


%matplotlib inline
plt.plot(tcond_ddpm_train_losses, label='Training Loss')
plt.yscale('log')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.grid('true', which='both')
plt.title('Training Loss Over Time')
plt.legend()
# plt.savefig('/content/drive/My Drive/cs180_project5/tcond_ddpm_train_curve.png')
plt.show()

In [None]:
### load time conditional ddpm from checkpoint. create sample generation figures and animation
# epoch_number_str = '2'
# checkpoint_fname = f'data/tcond_unet_optim_epoch{epoch_number_str}.pth'

# model = TimeConditionalUNet(
#     C=1,
#     H=28,
#     W=28,
#     num_hiddens=num_hiddens,
# ).to(device)

# ddpm = DDPM(
#     unet=model,
#     betas=(1e-4, 0.02),
#     num_ts=200,
# ).to(device)

# checkpoint = torch.load(checkpoint_fname, weights_only=True)
# ddpm.load_state_dict(checkpoint['ddpm_state_dict'])

nrows = 10
ncols = 10
num_samples = nrows * ncols

x_0, traj = ddpm.sample(seed=101, num_samples=num_samples)
x_0 = torch.clip(x_0, 0, 1).detach().cpu()
traj = torch.clip(traj, 0, 1).detach().cpu()

inspect('x0', x_0)
inspect('traj', traj)


show_tensor(x_0, max_n_horiz=ncols, height=75)
media.show_video(
    get_video_frames(traj, nrows=nrows, ncols=ncols),
    height=75 * nrows,
    fps=ddpm.num_ts / 2,
)
    
# media.write_video(f'/content/drive/My Drive/cs180_project5/tcond_ddpm_sample_trajectories_epoch{epoch_number_str}.mp4', frames)

In [None]:
torch.randn(3, 3, device=None).device

In [None]:
x.to(None) is x

In [None]:
(1, 2) == (1, 2)