In [15]:
%cd /Users/masha/Documents/GSOC/GSoC-Quantum-Diffusion-Model

from utils.post_training import *
from utils.statistics import *
from utils.plotting import *
from utils.encodings import *
from utils.statistics import calculate_statistics, calculate_fid, ssim
from utils.haar_noising_script import apply_global_haar_scrambling, fast_haar_scramble
from utils.quantum_diffusion import *

import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import scipy.linalg
import random

import pennylane as qml

C:\Users\realc\OneDrive\Documents\GSOC


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [26]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

QG_channel = 1
filename = f"data/QG{QG_channel}_64x64_1k"
num_train_samples = None
num_samples_for_scramble = 100   # how many samples to produce scrambled versions for
num_scrambles = 4   # number of scrambled augmentations per sample
n_qubits = 12
scramble_depth = 8
shared_unitaries = True
batch_size = 16
lr = 2e-4
num_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# output directories
out_dir = "results_qg_cond_fm"
os.makedirs(out_dir, exist_ok=True)

In [None]:
f = h5py.File(filename, "r")
data_X = np.array(f['X'])
print("Raw data shape:", data_X.shape)  # expect (N, H, W) or (N, 1, H, W)

if data_X.ndim == 4:
    # (N, C, H, W)
    data = data_X
else:
    # (N, H, W) -> (N, 1, H, W)
    data = data_X[:, None, :, :]

# optional subsample for faster debug runs
if num_train_samples is not None:
    data = data[:num_train_samples]

print("Prepared data shape:", data.shape)

Raw data shape: (1000, 64, 64)
Prepared data shape: (1000, 1, 64, 64)


In [13]:
# Use the generate_scrambled_dataset you already provided, but adapted to return arrays
from utils.angle_encoding_script import angle_encoding

def pad_to_power_of_two(vec, n_qubits):
    dim = 2 ** n_qubits
    v = np.copy(vec)
    if len(v) < dim:
        v = np.pad(v, (0, dim - len(v)), mode='constant')
    else:
        v = v[:dim]
    # normalize to unit norm (avoid divide-by-zero)
    norm = np.linalg.norm(v)
    if norm == 0:
        return v
    return v / norm

def generate_scrambled_dataset(
    data,                     # numpy array (N, C, H, W) or (N, H, W)
    sample_indices=None,      # list/array of indices to augment (None -> use all)
    num_scrambles=4,
    n_qubits=8,
    depth=8,
    seed=42,
    shared_unitaries=True
):
    rng = np.random.default_rng(seed)
    N = data.shape[0]
    if sample_indices is None:
        sample_indices = np.arange(N)
    encoded_images = []            # original encoded (float arrays)
    scrambled_versions_all = []    # list of lists: for each sample -> [scr1, scr2, ...]
    # prepare shared scramblers if requested
    if shared_unitaries:
        print(f"Generating a set of {num_scrambles} shared unitaries")
        seeds = [int(rng.integers(0, 1e9)) for _ in range(num_scrambles)]
        shared_scramblers = [
            lambda vec, s=s: fast_haar_scramble(vec, n_qubits=n_qubits, depth=depth, seed=int(s))
            for s in seeds
        ]
    else:
        shared_scramblers = None

    for s_idx in sample_indices:
        # pull single-channel image and flatten to vector
        img = data[s_idx]
        if img.ndim == 3 and img.shape[0] == 1:
            img2d = img[0]
        elif img.ndim == 2:
            img2d = img
        else:
            # if multi-channel, take first channel
            img2d = img[0] if img.ndim == 3 else img.squeeze()

        # Option A: angle_encoding expects an image and wires â€” we will generate a numeric vector
        # Many of your encoders transform image -> flattened vector representation. We'll call angle_encoding
        # that returns a numeric vector only if your implementation supports it. Otherwise we will
        # use the flattened pixel vector as fallback.
        try:
            # angle_encoding in your repo appears to be a function that encodes into a qnode; but here we just produce a flattened vector
            encoded_flat = img2d.flatten().astype(np.float32)
        except Exception:
            encoded_flat = img2d.flatten().astype(np.float32)

        # pad to 2^n_qubits dimension
        flat_encoded = pad_to_power_of_two(encoded_flat, n_qubits)
        encoded_images.append(flat_encoded)

        if shared_unitaries:
            scrambles = [scr(flat_encoded) for scr in shared_scramblers]
        else:
            scrambles = [
                fast_haar_scramble(flat_encoded, n_qubits=n_qubits, depth=depth, seed=seed + s_idx * num_scrambles + k)
                for k in range(num_scrambles)
            ]

        scrambled_versions_all.append(scrambles)

    return np.array(encoded_images), scrambled_versions_all

# quick test (small)
encs, scrs = generate_scrambled_dataset(data, sample_indices=np.arange(min(10, data.shape[0])),
                                       num_scrambles=2, n_qubits=n_qubits, depth=scramble_depth,
                                       seed=seed, shared_unitaries=shared_unitaries)
print("encoded shape:", encs.shape, "scrambles samples:", len(scrs))


Generating a set of 2 shared unitaries
encoded shape: (10, 4096) scrambles samples: 10


In [16]:
class QGConditionalDataset(Dataset):
    """
    Dataset that yields:
      - x_input: possibly noised input to the model (e.g., for denoising or forward diffusion)
      - cond: conditioning data (we'll use one scrambled augmentation per sample)
      - target: the original image (supervision)
    For simplicity we keep everything as torch.float32 in [0,1] as channel-first tensors.
    """
    def __init__(self, originals, scrambled_lists, use_concat_condition=False):
        # originals: numpy array (N, C, H, W)
        # scrambled_lists: list length N where each element is list of scrambled vectors (num_scrambles)
        self.orig = originals
        self.scrambled_lists = scrambled_lists
        self.N = len(originals)
        self.use_concat_condition = use_concat_condition

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        img = self.orig[idx].astype(np.float32)  # (C, H, W)
        # choose one scrambled augmentation at random
        scr_list = self.scrambled_lists[idx]
        # scrambled vectors are 1D; we will reshape them back into an image-like shape if desired.
        scr_vec = scr_list[np.random.randint(0, len(scr_list))]
        # re-expand scrambled vector to image shape (C, H, W) via simple tiling or reshape if appropriate
        # Here we use a simple approach: reshape to (H, W) if dim matches, else tile
        C, H, W = img.shape
        scr_len = scr_vec.shape[0]
        if scr_len == H * W:
            scr_img = scr_vec.reshape((H, W))
            scr_img = scr_img[None, :, :]  # add channel
        else:
            # tile/resize the vector into image shape (simple approach)
            scr_img = np.tile(scr_vec[:H*W], int(np.ceil((H*W)/len(scr_vec))))[:H*W].reshape(H, W)[None, :, :]

        # Normalize scrambled image into roughly same scale as original if needed
        # We'll min-max normalize both to [0,1] for training stability (tune if your models expect different)
        def normalize_arr(a):
            amin = a.min()
            amax = a.max()
            if amax - amin <= 1e-8:
                return np.zeros_like(a)
            return (a - amin) / (amax - amin)

        img_n = normalize_arr(img)
        scr_n = normalize_arr(scr_img)

        # Decide input_x (for conditional/diffusion model): many FM approaches add noise during training.
        # Here we give the base (could add gaussian noise if you want).
        x_input = img_n.copy()

        # as torch tensors
        x_input = torch.from_numpy(x_input).float()
        cond = torch.from_numpy(scr_n).float()
        target = torch.from_numpy(img_n).float()

        # optionally prepare a concatenated input (2*C, H, W) if model doesn't accept an explicit condition argument
        if self.use_concat_condition:
            concat_input = torch.cat([cond, x_input], dim=0)  # (2*C, H, W)
            return concat_input, cond, target

        return x_input, cond, target

# Build dataset (choose using first N samples)
num_use = min(data.shape[0], num_samples_for_scramble)
sample_indices = np.arange(num_use)
encoded_images, scrambled_lists = generate_scrambled_dataset(data, sample_indices=sample_indices,
                                                            num_scrambles=num_scrambles,
                                                            n_qubits=n_qubits,
                                                            depth=scramble_depth,
                                                            seed=seed,
                                                            shared_unitaries=shared_unitaries)

# We pass the original images aligned with scrambled augmentations:
orig_subset = data[:num_use]  # (N, C, H, W)
dataset = QGConditionalDataset(orig_subset, scrambled_lists, use_concat_condition=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
print("Dataset size:", len(dataset), "Example shapes:", next(iter(dataloader))[0].shape)


Generating a set of 4 shared unitaries
Dataset size: 100 Example shapes: torch.Size([16, 1, 64, 64])


  cond = torch.from_numpy(scr_n).float()


In [None]:
# MODEL DEFINITION HERE
input_dim = 2 ** n_qubits
output_dim = input_dim

# Number of conditioning unitaries (= num_scrambles)
num_unitaries = num_scrambles     

# -------------------------
# Base quantum model
# -------------------------
hidden_dim = 128
n_layers = 2

base_model = QuantumDiffusionModel(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    output_dim=output_dim,
    n_qubits=n_qubits,
    n_layers=n_layers
).to(device)


# ============================================================
# Conditional wrapper for unitary indices
# ============================================================

class ConditionedQD(nn.Module):
    """
    Adds conditioning by injecting a learned embedding representing 
    which Haar unitary was used to scramble the vector.
    """
    def __init__(self, base_model, num_unitaries, input_dim, embed_dim=64):
        super().__init__()
        self.base = base_model
        self.embed = nn.Embedding(num_unitaries, embed_dim)
        self.project = nn.Linear(embed_dim, input_dim)
        self.act = nn.Tanh()

    def forward(self, x_vec, unitary_idx):
        """
        x_vec:         (batch, input_dim)
        unitary_idx:   (batch,) integer index of which scrambler was used
        """
        emb = self.embed(unitary_idx)     # (B, embed_dim)
        delta = self.project(emb)         # (B, input_dim)
        delta = self.act(delta)           # small bounded perturbation

        x_cond = x_vec + delta            # conditioned input

        out = self.base(x_cond)           # quantum model forward pass
        return out


# Instantiate conditional model
embed_dim = 64
cond_model = ConditionedQD(
    base_model=base_model,
    num_unitaries=num_unitaries,
    input_dim=input_dim,
    embed_dim=embed_dim
).to(device)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(cond_model.parameters(), lr=lr)

In [None]:
class QGVectorDataset(Dataset):
    def __init__(self, encoded_images, scrambled_lists):
        self.encoded = encoded_images                    # shape (N, input_dim)
        self.scrambled = scrambled_lists                 # list of lists
        self.N = len(encoded_images)

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        x_vec = self.encoded[idx]                       # original vector

        # choose random unitary
        unitary_idx = np.random.randint(0, len(self.scrambled[idx]))
        target_vec = x_vec                              # reconstruction target

        return (
            torch.tensor(x_vec, dtype=torch.float32),
            torch.tensor(unitary_idx, dtype=torch.long),
            torch.tensor(target_vec, dtype=torch.float32)
        )

# build the vector dataset
vec_dataset = QGVectorDataset(encoded_images, scrambled_lists)

from torch.utils.data import random_split, DataLoader

val_fraction = 0.1
val_size = int(len(vec_dataset) * val_fraction)
train_size = len(vec_dataset) - val_size

train_ds, val_ds = random_split(vec_dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)

In [27]:
for epoch in range(num_epochs):
    cond_model.train()
    total_loss = 0

    for xb, idxb, yb in train_loader:
        xb, yb, idxb = xb.to(device), yb.to(device), idxb.to(device)
        optimizer.zero_grad()
        out = cond_model(xb, idxb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)

    avg_train_loss = total_loss / len(train_loader.dataset)

    cond_model.eval()
    val_loss = 0
    with torch.no_grad():
        for xb, idxb, yb in val_loader:
            xb, yb, idxb = xb.to(device), yb.to(device), idxb.to(device)
            val_loss += criterion(cond_model(xb, idxb), yb).item() * xb.size(0)

    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Train: {avg_train_loss:.5f} | Val: {avg_val_loss:.5f}")


RuntimeError: shape '[-1, 341, 12]' is invalid for input of size 65536

In [None]:
# ============================================================
#   TRAINING LOOP (clean, with validation)
# ============================================================

# Split into train/val
from torch.utils.data import random_split

val_fraction = 0.1
val_size = int(len(vec_dataset) * val_fraction)
train_size = len(vec_dataset) - val_size

train_ds, val_ds = random_split(vec_dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)

print(f"Training samples: {train_size}, Validation samples: {val_size}")


# Main loop
for epoch in range(num_epochs):
    # ------------------------------------
    # Training
    # ------------------------------------
    cond_model.train()
    total_loss = 0

    for xb, unitary_idxb, yb in train_loader:
        xb = xb.to(device)                  # (batch, input_dim)
        yb = yb.to(device)                  # (batch, input_dim)
        unitary_idxb = unitary_idxb.to(device)

        optimizer.zero_grad()

        out = cond_model(xb, unitary_idxb)  # conditioned forward pass
        loss = criterion(out, yb)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * xb.size(0)

    avg_train_loss = total_loss / len(train_loader.dataset)


    # ------------------------------------
    # Validation
    # ------------------------------------
    cond_model.eval()
    val_loss = 0

    with torch.no_grad():
        for xb, unitary_idxb, yb in val_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            unitary_idxb = unitary_idxb.to(device)

            out = cond_model(xb, unitary_idxb)
            val_loss += criterion(out, yb).item() * xb.size(0)

    avg_val_loss = val_loss / len(val_loader.dataset)


    # ------------------------------------
    # Logging
    # ------------------------------------
    print(f"Epoch {epoch+1}/{num_epochs} "
          f"- Train: {avg_train_loss:.5f} | Val: {avg_val_loss:.5f}")


    # Optional checkpointing
    if (epoch + 1) % 10 == 0:
        torch.save(cond_model.state_dict(),
                   os.path.join(out_dir, f"cond_model_ep{epoch+1}.pt"))
