<a href="https://colab.research.google.com/github/changeden/289A-Unsupervised-Learning/blob/main/LithoDiffusion_MemoryAnalysis_MetalSet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LithoDiffusion with Shape-Focused Training & Memory Analysis

This notebook includes:
1. Training a DDPM model on MetalSet dataset with deeper U-Net, L1 loss, and no entropy regularization (5 epochs).
2. Generating samples using the trained model.
3. SSIM & Entropy-based memorization analysis.

In [None]:
# Install dependencies
!pip install diffusers accelerate transformers einops scikit-image --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m82.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m63.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m34.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import torch, os, zipfile, shutil
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from diffusers import UNet2DModel, DDPMScheduler
from diffusers.models import AutoencoderKL
from skimage.metrics import structural_similarity as ssim
from skimage.color import rgb2gray
from skimage.transform import resize

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('✅ Using device:', device)

✅ Using device: cuda


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Unzip MetalSet ZIP files
with zipfile.ZipFile('/content/drive/MyDrive/Litho_dataset/MetalSet/target.zip','r') as z:
    z.extractall('/content/target')
with zipfile.ZipFile('/content/drive/MyDrive/Litho_dataset/MetalSet/printed.zip','r') as z:
    z.extractall('/content/printed')

In [None]:
# Prepare data folders and copy images
!mkdir -p /content/data/target /content/data/printed
!cp /content/target/target/*.png /content/data/target/
!cp /content/printed/printed/*.png /content/data/printed/

In [None]:
# Load dataset
transform = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()])
dataset = ImageFolder(root='/content/data', transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)
print('✅ Dataset loaded:', len(dataset), 'images')

✅ Dataset loaded: 32894 images


In [None]:
# Load VQ-VAE & deeper UNet
vqvae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse').to(device).eval()
model = UNet2DModel(
    sample_size=32,
    in_channels=4,
    out_channels=4,
    layers_per_block=2,
    block_out_channels=(256, 256, 512, 512),
    down_block_types=('DownBlock2D','AttnDownBlock2D','AttnDownBlock2D','AttnDownBlock2D'),
    up_block_types=('AttnUpBlock2D','AttnUpBlock2D','AttnUpBlock2D','UpBlock2D')
).to(device)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

In [None]:
# Training loop with L1 reconstruction loss (5 epochs)
import torch.nn.functional as F

def compute_l1_reconstruction_loss(latents_noisy, pred_noise, original_images):
    with torch.no_grad():
        sigma = noise_scheduler.init_noise_sigma
        latents_approx = (latents_noisy - pred_noise * sigma).detach()
        recon = vqvae.decode(latents_approx).sample
    return F.l1_loss(recon, original_images)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
epochs = 10
for epoch in range(epochs):
    total_loss = 0.0
    for images, _ in dataloader:
        images = images.to(device)
        with torch.no_grad():
            latents = vqvae.encode(images).latent_dist.mode()
        noise = torch.randn_like(latents).to(device)
        timesteps = torch.randint(0, 1000, (latents.size(0),), device=device).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        pred_noise = model(noisy_latents, timesteps).sample
        loss_mse = F.mse_loss(pred_noise, noise)
        loss_l1 = compute_l1_reconstruction_loss(noisy_latents, pred_noise, images)
        loss = loss_mse + 0.5 * loss_l1
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs} — Avg Loss: {total_loss/len(dataloader):.4f}")
print("✅ Training finished.")

Epoch 1/10 — Avg Loss: 0.1493
Epoch 2/10 — Avg Loss: 0.1180


In [None]:
# Generate samples (5 images)
generated_images = []
for _ in range(20):
    latents = torch.randn((1, 4, 32, 32)).to(device)
    for t in reversed(range(1000)):
        timestep = torch.tensor([t], device=device).long()
        with torch.no_grad():
            noise_pred = model(latents, timestep).sample
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
    with torch.no_grad():
        recon = vqvae.decode(latents).sample
    img = recon.squeeze().cpu().clamp(0,1).numpy()
    generated_images.append(img)
print(f"✅ Generated {len(generated_images)} images.")

In [None]:
# Load and preprocess printed training images (128x128)
import glob
from PIL import Image
preprocess = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((128,128)),
    transforms.ToTensor()
])
printed_paths = glob.glob('/content/data/printed/*.png')
train_images = []
for p in printed_paths:
    img = Image.open(p)
    t = preprocess(img)
    train_images.append(t.squeeze(0).numpy())
print(f"✅ Loaded {len(train_images)} training images.")

In [None]:
# Define preprocess + memory analysis (SSIM + Entropy)
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.color import rgb2gray
from skimage.transform import resize

def preprocess_image(img, shape=(128,128)):
    if img.ndim == 3:
        img = np.transpose(img, (1,2,0))
        img = rgb2gray(img)
    return resize(img, shape, anti_aliasing=True)

ssim_scores, entropies, matches = [], [], []
for gen in generated_images:
    g = preprocess_image(gen)
    best_ssim, best_match = 0.0, None
    for tr in train_images:
        t = tr if tr.shape == g.shape else resize(tr, g.shape, anti_aliasing=True)
        sc = ssim(g, t, data_range=1.0)
        if sc > best_ssim:
            best_ssim, best_match = sc, t
    ssim_scores.append(best_ssim)
    matches.append(best_match)
    h, _ = np.histogram(g, bins=256, range=(0,1), density=False)
    p = h / np.sum(h)  # Normalize to get proper PMF
    entropies.append(-np.sum(p * np.log2(p + 1e-10)))
labels = ['memorized' if s > 0.98 else 'novel' for s in ssim_scores]

In [None]:
# Plot SSIM vs Entropy
plt.figure(figsize=(6,4))
colors = ['red' if l=='memorized' else 'blue' for l in labels]
plt.scatter(ssim_scores, entropies, c=colors, alpha=0.7)
plt.axvline(0.98, color='gray', linestyle='--', label='SSIM=0.98')
plt.xlabel('SSIM vs Closest Training')
plt.ylabel('Entropy')
plt.title('Memory Analysis: SSIM vs Entropy')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Visualize pairs (first 5)
def show_pair(gen, match, score, label):
    fig, (ax1, ax2) = plt.subplots(1,2,figsize=(4,2))
    ax1.imshow(preprocess_image(gen), cmap='gray')
    ax1.set_title('Generated')
    ax2.imshow(match, cmap='gray')
    ax2.set_title(f"{label}\nSSIM={score:.3f}")
    for ax in (ax1, ax2): ax.axis('off')
    plt.tight_layout()
    plt.show()

for i in range(len(generated_images) if len(generated_images)<20 else 20):
    show_pair(generated_images[i], matches[i], ssim_scores[i], labels[i])