In [4]:
%cd /Users/masha/Documents/GSOC/GSoC-Quantum-Diffusion-Model

from utils.post_training import *
from utils.statistics import *
from utils.plotting import *
from utils.angle_encoding_script import angle_encoding
from utils.haar_noising_script import apply_haar_scrambling

import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import scipy.linalg

import pennylane as qml

/Users/masha/Documents/GSOC/GSoC-Quantum-Diffusion-Model


In [12]:
QG_channel = 1
num_samples = 1000

encoded_path = f"data/QG{QG_channel}_64x64_{num_samples}_encoded.pt"
scrambled_path = f"data/QG{QG_channel}_64x64_{num_samples}_scrambled.pt"

encoded_data = torch.load(encoded_path)
scrambled_states = torch.load(scrambled_path)

print(encoded_data.shape)
print(scrambled_states.shape)

train_encoded_data, val_encoded_data, train_scrambled_states, val_scrambled_states = train_test_split(
    encoded_data[:num_samples], scrambled_states, test_size=0.2, random_state=42, shuffle=True
)

torch.Size([1000, 32, 32, 4])
torch.Size([1000, 32, 32, 4])


In [None]:
# Simple DDPM-style diffusion model (UNet) training on the preprocessed quantum data
import math
from torch.utils.data import DataLoader, TensorDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Prepare data: ensure shape is (N, C, H, W)
data = encoded_data[:num_samples]
if isinstance(data, np.ndarray):
    data = torch.tensor(data)
if data.ndim == 4:
    # (N, H, W, C) -> (N, C, H, W)
    data = data.permute(0, 3, 1, 2)
elif data.ndim == 3:
    data = data.unsqueeze(1)
data = data.float().to(device)

dataset = TensorDataset(data)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Small UNet-like model (keeps things simple)
class SmallUNet(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.down1 = nn.Sequential(nn.Conv2d(in_channels, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64,64,3,padding=1), nn.ReLU())
        self.down2 = nn.Sequential(nn.MaxPool2d(2), nn.Conv2d(64,128,3,padding=1), nn.ReLU())
        self.up1 = nn.Sequential(nn.ConvTranspose2d(128,64,2,stride=2), nn.ReLU())
        self.out = nn.Conv2d(64, in_channels, 1)
        self.time_mlp = nn.Sequential(nn.Linear(1,64), nn.ReLU(), nn.Linear(64,64))
    def forward(self,x,t):
        h1 = self.down1(x)
        h2 = self.down2(h1)
        h = self.up1(h2)
        te = self.time_mlp(t.unsqueeze(-1).float()).unsqueeze(-1).unsqueeze(-1)
        h = h + te
        out = self.out(h)
        return out

# Diffusion schedule and helpers
def get_beta_schedule(T, beta_start=1e-4, beta_end=2e-2):
    return torch.linspace(beta_start, beta_end, T)

T = 1000
betas = get_beta_schedule(T).to(device)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

def q_sample(x0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x0)
    sqrt_acp = sqrt_alphas_cumprod[t].view(-1,1,1,1)
    sqrt_om = sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
    return sqrt_acp * x0 + sqrt_om * noise

def p_losses(model, x0, t):
    noise = torch.randn_like(x0)
    xt = q_sample(x0, t, noise)
    pred_noise = model(xt, t.float())
    return nn.functional.mse_loss(pred_noise, noise)

# Instantiate model and optimizer
in_channels = data.shape[1]
model_ddpm = SmallUNet(in_channels).to(device)
optim = torch.optim.Adam(model_ddpm.parameters(), lr=1e-3)

num_epochs = 10
for epoch in range(num_epochs):
    model_ddpm.train()
    running = 0.0
    for batch in dataloader:
        x = batch[0].to(device)
        b = x.size(0)
        t = torch.randint(0, T, (b,), device=device)
        loss = p_losses(model_ddpm, x, t)
        optim.zero_grad(); loss.backward(); optim.step()
        running += loss.item() * b
    print(f'Epoch {epoch+1}/{num_epochs}, loss {running/len(dataset):.6f}')

# Simple ancestral sampling function
@torch.no_grad()
def sample(model, n_samples):
    model.eval()
    shape = data.shape[1:]
    x = torch.randn(n_samples, *shape, device=device)
    for i in reversed(range(T)):
        t = torch.full((n_samples,), i, device=device, dtype=torch.long)
        pred_noise = model(x, t.float())
        beta = betas[i]
        alpha = alphas[i]
        acp = alphas_cumprod[i]
        if i > 0:
            noise = torch.randn_like(x)
        else:
            noise = torch.zeros_like(x)
        coef1 = 1 / torch.sqrt(alpha)
        coef2 = (beta / torch.sqrt(1 - acp))
        x = coef1 * (x - coef2 * pred_noise) + torch.sqrt(beta) * noise
    return x

# Generate a few samples and save the model
generated = sample(model_ddpm, 4).cpu()
# convert to (N, H, W, C) for visualization if needed
generated = generated.permute(0,2,3,1).numpy()
print('generated samples shape:', generated.shape)
torch.save(model_ddpm.state_dict(), 'saved_models/ddpm_unet_simple.pth')

Epoch 1/10, loss 322.610773
Epoch 2/10, loss 8.248180
Epoch 3/10, loss 1.240490
Epoch 4/10, loss 1.017753
Epoch 5/10, loss 1.009096
Epoch 6/10, loss 1.010510
Epoch 7/10, loss 1.007079
Epoch 8/10, loss 1.000044
Epoch 9/10, loss 0.994814
Epoch 10/10, loss 0.983745
generated samples shape: (4, 32, 32, 4)
