In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class SimpleUNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.down2 = nn.Conv2d(64, 128, 3, padding=1)
        self.up1 = nn.ConvTranspose2d(128, 64, 3, padding=1)
        self.up2 = nn.ConvTranspose2d(64, out_channels, 3, padding=1)

    def forward(self, x, t):
        x1 = F.relu(self.down1(x))
        x2 = F.relu(self.down2(x1))
        x = F.relu(self.up1(x2))
        x = self.up2(x + x1)
        return x

class DDPM(nn.Module):
    def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02):
        super().__init__()
        self.num_timesteps = num_timesteps
        self.beta = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        
        self.model = SimpleUNet(1, 1)

    def forward(self, x, t):
        return self.model(x, t)

    def get_loss(self, x_0):
        t = torch.randint(0, self.num_timesteps, (x_0.shape[0],))
        noise = torch.randn_like(x_0)
        x_t = self.q_sample(x_0, t, noise)
        predicted_noise = self(x_t, t)
        loss = F.mse_loss(noise, predicted_noise)
        return loss

    def q_sample(self, x_0, t, noise):
        alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
        return torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise

    @torch.no_grad()
    def p_sample(self, x, t):
        beta_t = self.beta[t]
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - self.alpha_bar[t])
        sqrt_alpha_bar_t = torch.sqrt(self.alpha_bar[t])
        
        model_mean = (1 / torch.sqrt(self.alpha[t])) * (x - (beta_t / sqrt_one_minus_alpha_bar_t) * self(x, t))
        
        if t == 0:
            return model_mean
        else:
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(beta_t) * noise

    @torch.no_grad()
    def sample(self, num_samples, size, device):
        x = torch.randn(num_samples, 1, *size).to(device)
        for t in reversed(range(self.num_timesteps)):
            x = self.p_sample(x, torch.full((num_samples,), t, device=device, dtype=torch.long))
        return x

# Training loop
def train(model, dataloader, num_epochs, device):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(num_epochs):
        for batch in dataloader:
            optimizer.zero_grad()
            x = batch[0].to(device)
            loss = model.get_loss(x)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

# Visualization
def visualize_samples(model, num_samples, size, device):
    samples = model.sample(num_samples, size, device)
    samples = samples.cpu().numpy()
    
    fig, axes = plt.subplots(1, num_samples, figsize=(num_samples*3, 3))
    for i, ax in enumerate(axes):
        ax.imshow(samples[i, 0], cmap='gray')
        ax.axis('off')
    plt.show()

# Main execution
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = DDPM().to(device)

    # Load MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

    # Train the model
    num_epochs = 10
    train(model, dataloader, num_epochs, device)

    # Visualize results
    visualize_samples(model, num_samples=5, size=(28, 28), device=device)

In [None]:
from tqdm import tqdm

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

class ImprovedUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=32):
        super().__init__()

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        # Initial projection
        self.conv0 = nn.Conv2d(in_channels, 64, 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([
            Block(64, 128, time_emb_dim),
            Block(128, 256, time_emb_dim),
            Block(256, 256, time_emb_dim),
        ])
        # Upsample
        self.ups = nn.ModuleList([
            Block(256, 256, time_emb_dim, up=True),
            Block(384, 128, time_emb_dim, up=True),
            Block(192, 64, time_emb_dim, up=True),
        ])

        self.output = nn.Conv2d(64, out_channels, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

class ImprovedDDPM(nn.Module):
    def __init__(self, num_timesteps=1000):
        super().__init__()
        self.num_timesteps = num_timesteps
        self.model = ImprovedUNet()

        # Cosine noise schedule
        betas = cosine_beta_schedule(num_timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        
        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        
        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )

        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        predicted_noise = self.model(x_noisy, t)

        loss = F.mse_loss(noise, predicted_noise)

        return loss

    @torch.no_grad()
    def p_sample(self, x, t, t_index):
        betas_t = extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas_cumprod, t, x.shape)
        
        # Equation 11 in the paper
        # Use our model (noise predictor) to predict the mean
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * self.model(x, t) / sqrt_one_minus_alphas_cumprod_t
        )

        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(self.posterior_variance, t, x.shape)
            noise = torch.randn_like(x)
            # Algorithm 2 line 4:
            return model_mean + torch.sqrt(posterior_variance_t) * noise 

    @torch.no_grad()
    def p_sample_loop(self, shape):
        device = next(self.model.parameters()).device

        b = shape[0]
        # start from pure noise (for each example in the batch)
        img = torch.randn(shape, device=device)
        imgs = []

        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), i)
            imgs.append(img.cpu().numpy())
        return imgs

    @torch.no_grad()
    def sample(self, image_size, batch_size=16, channels=3):
        return self.p_sample_loop(shape=(batch_size, channels, image_size, image_size))

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# Training loop
def train(model, dataloader, num_epochs, device, lr=2e-4):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            optimizer.zero_grad()
            x = batch[0].to(device)
            t = torch.randint(0, model.num_timesteps, (x.shape[0],), device=device).long()
            loss = model.p_losses(x, t)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
        scheduler.step()

# Visualization
def visualize_samples(model, num_samples, image_size, channels, device):
    model.eval()
    samples = model.sample(image_size=image_size, batch_size=num_samples, channels=channels)
    samples = torch.from_numpy(samples[-1])  # Get the last step of the sampling process
    
    fig, axes = plt.subplots(1, num_samples, figsize=(num_samples*3, 3))
    for i, ax in enumerate(axes):
        img = samples[i].permute(1, 2, 0).cpu().numpy()
        img = (img + 1) / 2  # Rescale from [-1, 1] to [0, 1]
        ax.imshow(img)
        ax.axis('off')
    plt.show()

# Main execution
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ImprovedDDPM().to(device)

    # Load CIFAR-10 dataset (or any other dataset you prefer)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

    # Train the model
    num_epochs = 100  # You might want to increase this for better results
    train(model, dataloader, num_epochs, device)

    # Visualize results
    visualize_samples(model, num_samples=5, image_size=32, channels=3, device=device)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from functools import partial
from tqdm import tqdm

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()
        
    def forward(self, x, t):
        h = self.bnorm1(self.relu(self.conv1(x)))
        time_emb = self.relu(self.time_mlp(t))
        time_emb = time_emb[(..., ) + (None, ) * 2]
        h = h + time_emb
        h = self.bnorm2(self.relu(self.conv2(h)))
        return self.transform(h)

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=32):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        self.conv0 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.downs = nn.ModuleList([
            Block(64, 128, time_emb_dim),
            Block(128, 256, time_emb_dim),
            Block(256, 256, time_emb_dim),
        ])
        self.ups = nn.ModuleList([
            Block(256, 256, time_emb_dim, up=True),
            Block(384, 128, time_emb_dim, up=True),
            Block(192, 64, time_emb_dim, up=True),
        ])
        self.output = nn.Conv2d(64, out_channels, 1)

    def forward(self, x, timestep):
        t = self.time_mlp(timestep)
        x = self.conv0(x)
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.ReLU(),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        return self.decode(self.encode(x))

class LatentDiffusion(nn.Module):
    def __init__(self, autoencoder, unet, num_timesteps=1000):
        super().__init__()
        self.autoencoder = autoencoder
        self.unet = unet
        self.num_timesteps = num_timesteps

        betas = cosine_beta_schedule(num_timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        
        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )

        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        predicted_noise = self.unet(x_noisy, t)

        loss = F.mse_loss(noise, predicted_noise)

        return loss

    @torch.no_grad()
    def p_sample(self, x, t, t_index):
        betas_t = extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas_cumprod, t, x.shape)
        
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * self.unet(x, t) / sqrt_one_minus_alphas_cumprod_t
        )

        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(self.betas, t, x.shape)
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise 

    @torch.no_grad()
    def p_sample_loop(self, shape):
        device = next(self.unet.parameters()).device

        b = shape[0]
        img = torch.randn(shape, device=device)
        imgs = []

        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), i)
            imgs.append(img.cpu().numpy())
        return imgs

    @torch.no_grad()
    def sample(self, image_size, batch_size=16, channels=3):
        return self.p_sample_loop(shape=(batch_size, channels, image_size // 8, image_size // 8))

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# Training loop
def train(ldm, autoencoder, dataloader, num_epochs, device, lr=1e-4):
    optimizer = torch.optim.AdamW(ldm.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        ldm.train()
        total_loss = 0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            optimizer.zero_grad()
            x = batch[0].to(device)
            
            # Encode the input to latent space
            with torch.no_grad():
                x_latent = autoencoder.encode(x)
            
            t = torch.randint(0, ldm.num_timesteps, (x.shape[0],), device=device).long()
            loss = ldm.p_losses(x_latent, t)
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

# Visualization
def visualize_samples(ldm, autoencoder, num_samples, image_size, channels, device):
    ldm.eval()
    autoencoder.eval()
    
    latent_samples = ldm.sample(image_size=image_size, batch_size=num_samples, channels=channels)
    latent_samples = torch.from_numpy(latent_samples[-1]).to(device)  # Get the last step of the sampling process
    
    # Decode the latent samples
    with torch.no_grad():
        samples = autoencoder.decode(latent_samples)
    
    fig, axes = plt.subplots(1, num_samples, figsize=(num_samples*3, 3))
    for i, ax in enumerate(axes):
        img = samples[i].permute(1, 2, 0).cpu().numpy()
        img = (img + 1) / 2  # Rescale from [-1, 1] to [0, 1]
        ax.imshow(img)
        ax.axis('off')
    plt.show()

# Main execution
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize models
    autoencoder = Autoencoder().to(device)
    unet = UNet(in_channels=256, out_channels=256, time_emb_dim=32).to(device)
    ldm = LatentDiffusion(autoencoder, unet).to(device)

    # Load CIFAR-10 dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

    # Train the autoencoder
    autoencoder_optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-3)
    num_epochs_ae = 10
    for epoch in range(num_epochs_ae):
        autoencoder.train()
        total_loss = 0
        for batch in tqdm(dataloader, desc=f"Autoencoder Epoch {epoch+1}/{num_epochs_ae}"):
            x = batch[0].to(device)
            autoencoder_optimizer.zero_grad()
            x_recon = autoencoder(x)
            loss = F.mse_loss(x_recon, x)
            loss.backward()
            autoencoder_optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"Autoencoder Epoch {epoch+1}/{num_epochs_ae}, Average Loss: {avg_loss:.4f}")

    # Train the latent diffusion model
    num_epochs_ldm = 50
    train(ldm, autoencoder, dataloader, num_epochs_ldm, device)

    # Visualize results
    visualize_samples(ldm, autoencoder, num_samples=5, image_size=32, channels=256, device=device)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from functools import partial
from tqdm import tqdm

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, class_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.class_mlp = nn.Linear(class_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        
    def forward(self, x, t, c):
        h = self.bnorm1(self.relu(self.conv1(x)))
        time_emb = self.relu(self.time_mlp(t))
        class_emb = self.relu(self.class_mlp(c))
        h = h + time_emb[(..., ) + (None, ) * 2] + class_emb[(..., ) + (None, ) * 2]
        h = self.bnorm2(self.relu(self.conv2(h)))
        return self.transform(h)

class ConditionalUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=32, num_classes=10):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        self.class_emb = nn.Embedding(num_classes, time_emb_dim)
        
        self.conv0 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.downs = nn.ModuleList([
            Block(64, 128, time_emb_dim, time_emb_dim),
            Block(128, 256, time_emb_dim, time_emb_dim),
            Block(256, 256, time_emb_dim, time_emb_dim),
        ])
        self.ups = nn.ModuleList([
            Block(256, 256, time_emb_dim, time_emb_dim, up=True),
            Block(384, 128, time_emb_dim, time_emb_dim, up=True),
            Block(192, 64, time_emb_dim, time_emb_dim, up=True),
        ])
        self.output = nn.Conv2d(64, out_channels, 1)

    def forward(self, x, timestep, class_label):
        t = self.time_mlp(timestep)
        c = self.class_emb(class_label)
        x = self.conv0(x)
        residual_inputs = []
        for down in self.downs:
            x = down(x, t, c)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t, c)
        return self.output(x)

class ConditionalDDPM(nn.Module):
    def __init__(self, model, num_classes, num_timesteps=1000):
        super().__init__()
        self.model = model
        self.num_classes = num_classes
        self.num_timesteps = num_timesteps

        betas = cosine_beta_schedule(num_timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        
        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )

        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, x_start, t, class_labels, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        predicted_noise = self.model(x_noisy, t, class_labels)

        loss = F.mse_loss(noise, predicted_noise)

        return loss

    @torch.no_grad()
    def p_sample(self, x, t, t_index, class_labels):
        betas_t = extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas_cumprod, t, x.shape)
        
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * self.model(x, t, class_labels) / sqrt_one_minus_alphas_cumprod_t
        )

        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(self.betas, t, x.shape)
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise 

    @torch.no_grad()
    def p_sample_loop(self, shape, class_labels):
        device = next(self.model.parameters()).device

        b = shape[0]
        img = torch.randn(shape, device=device)
        imgs = []

        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), i, class_labels)
            imgs.append(img.cpu().numpy())
        return imgs

    @torch.no_grad()
    def sample(self, image_size, batch_size=16, channels=3, class_labels=None):
        if class_labels is None:
            class_labels = torch.randint(0, self.num_classes, (batch_size,), device=next(self.model.parameters()).device)
        return self.p_sample_loop(shape=(batch_size, channels, image_size, image_size), class_labels=class_labels)

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# Training loop
def train(cddpm, dataloader, num_epochs, device, lr=2e-4):
    optimizer = torch.optim.AdamW(cddpm.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        cddpm.train()
        total_loss = 0
        for batch, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            optimizer.zero_grad()
            x = batch.to(device)
            labels = labels.to(device)
            t = torch.randint(0, cddpm.num_timesteps, (x.shape[0],), device=device).long()
            loss = cddpm.p_losses(x, t, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

# Visualization
def visualize_samples(cddpm, num_samples, image_size, channels, device, class_labels=None):
    cddpm.eval()
    if class_labels is None:
        class_labels = torch.randint(0, cddpm.num_classes, (num_samples,), device=device)
    samples = cddpm.sample(image_size=image_size, batch_size=num_samples, channels=channels, class_labels=class_labels)
    samples = torch.from_numpy(samples[-1])  # Get the last step of the sampling process
    
    fig, axes = plt.subplots(1, num_samples, figsize=(num_samples*3, 3))
    for i, ax in enumerate(axes):
        img = samples[i].permute(1, 2, 0).cpu().numpy()
        img = (img + 1) / 2  # Rescale from [-1, 1] to [0, 1]
        ax.imshow(img)
        ax.set_title(f"Class: {class_labels[i].item()}")
        ax.axis('off')
    plt.show()

# Main execution
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 10  # CIFAR-10 has 10 classes
    model = ConditionalUNet(num_classes=num_classes).to(device)
    cddpm = ConditionalDDPM(model, num_classes=num_classes).to(device)

    # Load CIFAR-10 dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

    # Train the model
    num_epochs = 100  # You might want to increase this for better results
    train(cddpm, dataloader, num_epochs, device)

    # Visualize results
    visualize_samples(cddpm, num_samples=5, image_size=32, channels=3, device=device)

    # Generate samples for specific classes
    class_labels = torch.tensor([0, 1, 2, 3, 4], device=device)  # Generate one sample for each of the first 5 classes
    visualize_samples(cddpm, num_samples=5, image_size=32, channels=3, device=device, class_labels=class_labels)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from functools import partial
from tqdm import tqdm

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, class_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.class_mlp = nn.Linear(class_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        
    def forward(self, x, t, c):
        h = self.bnorm1(self.relu(self.conv1(x)))
        time_emb = self.relu(self.time_mlp(t))
        class_emb = self.relu(self.class_mlp(c))
        h = h + time_emb[(..., ) + (None, ) * 2] + class_emb[(..., ) + (None, ) * 2]
        h = self.bnorm2(self.relu(self.conv2(h)))
        return self.transform(h)

class ConditionalImprovedUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=32, num_classes=10):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        self.class_emb = nn.Embedding(num_classes, time_emb_dim)
        
        self.conv0 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.downs = nn.ModuleList([
            Block(64, 128, time_emb_dim, time_emb_dim),
            Block(128, 256, time_emb_dim, time_emb_dim),
            Block(256, 256, time_emb_dim, time_emb_dim),
        ])
        self.ups = nn.ModuleList([
            Block(256, 256, time_emb_dim, time_emb_dim, up=True),
            Block(384, 128, time_emb_dim, time_emb_dim, up=True),
            Block(192, 64, time_emb_dim, time_emb_dim, up=True),
        ])
        self.output = nn.Conv2d(64, out_channels, 1)

    def forward(self, x, timestep, class_label):
        t = self.time_mlp(timestep)
        c = self.class_emb(class_label)
        x = self.conv0(x)
        residual_inputs = []
        for down in self.downs:
            x = down(x, t, c)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t, c)
        return self.output(x)

class ConditionalImprovedDDPM(nn.Module):
    def __init__(self, model, num_classes, num_timesteps=1000):
        super().__init__()
        self.model = model
        self.num_classes = num_classes
        self.num_timesteps = num_timesteps

        betas = cosine_beta_schedule(num_timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        
        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
        
        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)
        
        # Log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))
        self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )

        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, x_start, t, class_labels, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        predicted_noise = self.model(x_noisy, t, class_labels)

        loss = F.mse_loss(noise, predicted_noise)

        return loss

    @torch.no_grad()
    def p_sample(self, x, t, t_index, class_labels):
        betas_t = extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas_cumprod, t, x.shape)
        
        # Equation 11 in the paper
        # Use our model (noise predictor) to predict the mean
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * self.model(x, t, class_labels) / sqrt_one_minus_alphas_cumprod_t
        )

        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(self.posterior_variance, t, x.shape)
            noise = torch.randn_like(x)
            # Algorithm 2 line 4:
            return model_mean + torch.sqrt(posterior_variance_t) * noise 

    @torch.no_grad()
    def p_sample_loop(self, shape, class_labels):
        device = next(self.model.parameters()).device

        b = shape[0]
        img = torch.randn(shape, device=device)
        imgs = []

        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), i, class_labels)
            imgs.append(img.cpu().numpy())
        return imgs

    @torch.no_grad()
    def sample(self, image_size, batch_size=16, channels=3, class_labels=None):
        if class_labels is None:
            class_labels = torch.randint(0, self.num_classes, (batch_size,), device=next(self.model.parameters()).device)
        return self.p_sample_loop(shape=(batch_size, channels, image_size, image_size), class_labels=class_labels)

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# Training loop
def train(ciddpm, dataloader, num_epochs, device, lr=2e-4):
    optimizer = torch.optim.AdamW(ciddpm.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    for epoch in range(num_epochs):
        ciddpm.train()
        total_loss = 0
        for batch, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            optimizer.zero_grad()
            x = batch.to(device)
            labels = labels.to(device)
            t = torch.randint(0, ciddpm.num_timesteps, (x.shape[0],), device=device).long()
            loss = ciddpm.p_losses(x, t, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
        scheduler.step()

# Visualization
def visualize_samples(ciddpm, num_samples, image_size, channels, device, class_labels=None):
    ciddpm.eval()
    if class_labels is None:
        class_labels = torch.randint(0, ciddpm.num_classes, (num_samples,), device=device)
    samples = ciddpm.sample(image_size=image_size, batch_size=num_samples, channels=channels, class_labels=class_labels)
    samples = torch.from_numpy(samples[-1])  # Get the last step of the sampling process
    
    fig, axes = plt.subplots(1, num_samples, figsize=(num_samples*3, 3))
    for i, ax in enumerate(axes):
        img = samples[i].permute(1, 2, 0).cpu().numpy()
        img = (img + 1) / 2  # Rescale from [-1, 1] to [0, 1]
        ax.imshow(img)
        ax.set_title(f"Class: {class_labels[i].item()}")
        ax.axis('off')
    plt.show()

# Main execution
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 10  # CIFAR-10 has 10 classes
    model = ConditionalImprovedUNet(num_classes=num_classes).to(device)
    ciddpm = ConditionalImprovedDDPM(model, num_classes=num_classes).to(device)

    # Load CIFAR-10 dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

    # Train the model
    num_epochs = 100  # You might want to increase this for better results
    train(ciddpm, dataloader, num_epochs, device)

    # Visualize results
    visualize_samples(ciddpm, num_samples=5, image_size=32, channels=3, device=device)

    # Generate samples for specific classes
    class_labels = torch.tensor([0, 1, 2, 3, 4], device=device)  # Generate one sample for each of the first 5 classes
    visualize_samples(ciddpm, num_samples=5, image_size=32, channels=3, device=device, class_labels=class_labels)

    # Function to generate and save samples
    def generate_and_save_samples(ciddpm, num_samples, image_size, channels, device, save_path):
        ciddpm.eval()
        class_labels = torch.arange(ciddpm.num_classes).repeat(num_samples).to(device)
        samples = ciddpm.sample(image_size=image_size, batch_size=num_samples * ciddpm.num_classes, channels=channels, class_labels=class_labels)
        samples = torch.from_numpy(samples[-1])  # Get the last step of the sampling process
        
        fig, axes = plt.subplots(ciddpm.num_classes, num_samples, figsize=(num_samples*3, ciddpm.num_classes*3))
        for i in range(ciddpm.num_classes):
            for j in range(num_samples):
                img = samples[i*num_samples + j].permute(1, 2, 0).cpu().numpy()
                img = (img + 1) / 2  # Rescale from [-1, 1] to [0, 1]
                axes[i, j].imshow(img)
                axes[i, j].axis('off')
            axes[i, 0].set_title(f"Class {i}")
        
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()

    # Generate and save a grid of samples
    generate_and_save_samples(ciddpm, num_samples=5, image_size=32, channels=3, device=device, save_path='ciddpm_samples.png')

    print("Training and sampling completed. Check 'ciddpm_samples.png' for generated samples.")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from functools import partial
from tqdm import tqdm

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, class_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.class_mlp = nn.Linear(class_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        
    def forward(self, x, t, c):
        h = self.bnorm1(self.relu(self.conv1(x)))
        time_emb = self.relu(self.time_mlp(t))
        class_emb = self.relu(self.class_mlp(c))
        h = h + time_emb[(..., ) + (None, ) * 2] + class_emb[(..., ) + (None, ) * 2]
        h = self.bnorm2(self.relu(self.conv2(h)))
        return self.transform(h)

class ConditionalUNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=4, time_emb_dim=32, num_classes=10):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        self.class_emb = nn.Embedding(num_classes, time_emb_dim)
        
        self.conv0 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.downs = nn.ModuleList([
            Block(64, 128, time_emb_dim, time_emb_dim),
            Block(128, 256, time_emb_dim, time_emb_dim),
            Block(256, 256, time_emb_dim, time_emb_dim),
        ])
        self.ups = nn.ModuleList([
            Block(256, 256, time_emb_dim, time_emb_dim, up=True),
            Block(384, 128, time_emb_dim, time_emb_dim, up=True),
            Block(192, 64, time_emb_dim, time_emb_dim, up=True),
        ])
        self.output = nn.Conv2d(64, out_channels, 1)

    def forward(self, x, timestep, class_label):
        t = self.time_mlp(timestep)
        c = self.class_emb(class_label)
        x = self.conv0(x)
        residual_inputs = []
        for down in self.downs:
            x = down(x, t, c)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t, c)
        return self.output(x)

class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.ReLU(),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        return self.decode(self.encode(x))

class ConditionalLatentDiffusion(nn.Module):
    def __init__(self, autoencoder, unet, num_classes, num_timesteps=1000):
        super().__init__()
        self.autoencoder = autoencoder
        self.unet = unet
        self.num_classes = num_classes
        self.num_timesteps = num_timesteps

        betas = cosine_beta_schedule(num_timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        
        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )

        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, x_start, t, class_labels, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        predicted_noise = self.unet(x_noisy, t, class_labels)

        loss = F.mse_loss(noise, predicted_noise)

        return loss

    @torch.no_grad()
    def p_sample(self, x, t, t_index, class_labels):
        betas_t = extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas_cumprod, t, x.shape)
        
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * self.unet(x, t, class_labels) / sqrt_one_minus_alphas_cumprod_t
        )

        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(self.betas, t, x.shape)
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise 

    @torch.no_grad()
    def p_sample_loop(self, shape, class_labels):
        device = next(self.unet.parameters()).device

        b = shape[0]
        img = torch.randn(shape, device=device)
        imgs = []

        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), i, class_labels)
            imgs.append(img.cpu().numpy())
        return imgs

    @torch.no_grad()
    def sample(self, image_size, batch_size=16, channels=256, class_labels=None):
        if class_labels is None:
            class_labels = torch.randint(0, self.num_classes, (batch_size,), device=next(self.unet.parameters()).device)
        return self.p_sample_loop(shape=(batch_size, channels, image_size // 8, image_size // 8), class_labels=class_labels)

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# Training loop
def train(cldm, autoencoder, dataloader, num_epochs, device, lr=1e-4):
    optimizer = torch.optim.AdamW(cldm.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        cldm.train()
        total_loss = 0
        for batch, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            optimizer.zero_grad()
            x = batch.to(device)
            labels = labels.to(device)
            
            # Encode the input to latent space
            with torch.no_grad():
                x_latent = autoencoder.encode(x)
            
            t = torch.randint(0, cldm.num_timesteps, (x.shape[0],), device=device).long()
            loss = cldm.p_losses(x_latent, t, labels)
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

# Visualization
def visualize_samples(cldm, autoencoder, num_samples, image_size, channels, device, class_labels=None):
    cldm.eval()
    autoencoder.eval()
    
    if class_labels is None:
        class_labels = torch.randint(0, cldm.num_classes, (num_samples,), device=device)
    latent_samples = cldm.sample(image_size=image_size, batch_size=num_samples, channels=channels, class_labels=class_labels)
    latent_samples = torch.from_numpy(latent_samples[-1]).to(device)  # Get the last step of the sampling process
    
    # Decode the latent samples
    with torch.no_grad():
        samples = autoencoder.decode(latent_samples)
    
    fig, axes = plt.subplots(1, num_samples, figsize=(num_samples*3, 3))
    for i, ax in enumerate(axes):
        img = samples[i].permute(1, 2, 0).cpu().numpy()
        img = (img + 1) / 2  # Rescale from [-1, 1] to [0, 1]
        ax.imshow(img)
        ax.set_title(f"Class: {class_labels[i].item()}")
        ax.axis('off')
    plt.show()

# Main execution
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 10  # CIFAR-10 has 10 classes

    # Initialize models
    autoencoder = Autoencoder().to(device)
    unet = ConditionalUNet(in_channels=256, out_channels=256, num_classes=num_classes).to(device)
    cldm = ConditionalLatentDiffusion(autoencoder, unet, num_classes=num_classes).to(device)

    # Load CIFAR-10 dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

    # Train the autoencoder
    autoencoder_optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-3)
    num_epochs_ae = 10
    for epoch in range(num_epochs_ae):
        autoencoder.train()
        total_loss = 0
        for batch, _ in tqdm(dataloader, desc=f"Autoencoder Epoch {epoch+1}/{num_epochs_ae}"):
            x = batch.to(device)
            autoencoder_optimizer.zero_grad()
            x_recon = autoencoder(x)
            loss = F.mse_loss(x_recon, x)
            loss.backward()
            autoencoder_optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"Autoencoder Epoch {epoch+1}/{num_epochs_ae}, Average Loss: {avg_loss:.4f}")

    # Train the conditional latent diffusion model
    num_epochs_cldm = 100  # You might want to increase this for better results
    train(cldm, autoencoder, dataloader, num_epochs_cldm, device)

    # Visualize results
    visualize_samples(cldm, autoencoder, num_samples=5, image_size=32, channels=256, device=device)

    # Generate samples for specific classes
    class_labels = torch.tensor([0, 1, 2, 3, 4], device=device)  # Generate one sample for each of the first 5 classes
    visualize_samples(cldm, autoencoder, num_samples=5, image_size=32, channels=256, device=device, class_labels=class_labels)

    # Function to generate and save samples
    def generate_and_save_samples(cldm, autoencoder, num_samples, image_size, channels, device, save_path):
        cldm.eval()
        autoencoder.eval()
        class_labels = torch.arange(cldm.num_classes).repeat(num_samples).to(device)
        latent_samples = cldm.sample(image_size=image_size, batch_size=num_samples * cldm.num_classes, channels=channels, class_labels=class_labels)
        latent_samples = torch.from_numpy(latent_samples[-1]).to(device)  # Get the last step of the sampling process
        
        with torch.no_grad():
            samples = autoencoder.decode(latent_samples)
        
        fig, axes = plt.subplots(cldm.num_classes, num_samples, figsize=(num_samples*3, cldm.num_classes*3))
        for i in range(cldm.num_classes):
            for j in range(num_samples):
                img = samples[i*num_samples + j].permute(1, 2, 0).cpu().numpy()
                img = (img + 1) / 2  # Rescale from [-1, 1] to [0, 1]
                axes[i, j].imshow(img)
                axes[i, j].axis('off')
            axes[i, 0].set_title(f"Class {i}")
        
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()

    # Generate and save a grid of samples
    generate_and_save_samples(cldm, autoencoder, num_samples=5, image_size=32, channels=256, device=device, save_path='cldm_samples.png')

    print("Training and sampling completed. Check 'cldm_samples.png' for generated samples.")