In [None]:
import torch
import functools
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

from models import Latent_UNet_Tranformer
from utils import marginal_prob_std, diffusion_coeff, train_diffusion_model

def visualize_perturbations(image, sigma, timesteps):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image = image.to(device)
    
    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
    
    fig, axes = plt.subplots(1, len(timesteps) + 1, figsize=(3 * (len(timesteps) + 1), 3))
    
    # Plot original image
    axes[0].imshow(image.cpu().squeeze(), cmap='gray')
    axes[0].set_title(f"Original (t=0)")
    axes[0].axis('off')
    
    # Plot perturbed images
    for i, t in enumerate(timesteps):
        t_tensor = torch.tensor([t]).to(device)
        std = marginal_prob_std_fn(t_tensor)
        z = torch.randn_like(image)
        perturbed_image = image + z * std
        
        axes[i+1].imshow(perturbed_image.cpu().squeeze(), cmap='gray')
        axes[i+1].set_title(f"t={t:.2f}")
        axes[i+1].axis('off')
    
    plt.tight_layout()
    plt.savefig('perturbed_images.png')
    plt.close()

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define noise fns, params
    sigma = 25.0
    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)

    # Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor()])
    mnist_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
    
    # Select a random image
    random_idx = torch.randint(0, len(mnist_dataset), (1,)).item()
    image, _ = mnist_dataset[random_idx]
    
    # Define timesteps for visualization
    timesteps = [0.2, 0.4, 0.6, 0.8, 1.0]
    
    # Visualize perturbations
    visualize_perturbations(image, sigma, timesteps)

    # ... (rest of the existing code)