In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Paths
DATA_ROOT = "/content/drive/MyDrive/400x400_paired_data"
SAVE_DIR = "/content/drive/MyDrive/vae_finetuning/models"
!mkdir -p "$SAVE_DIR"

# Dependencies
!pip install -q diffusers==0.30.0 transformers accelerate torch torchvision tqdm pillow termcolor

# Imports
import os
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from diffusers import AutoencoderKL
from termcolor import colored

# config
EPOCHS = 10
BATCH_SIZE = 2              # lowered for diagnostics (paper uses 16)
LR = 1e-5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
GRAD_CLIP = 1.0


# Dataset

class WatermarkDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.tf = transform
        all_folders = sorted([d for d in os.listdir(root) if not d.startswith(".")])
        self.pairs = []
        for folder in all_folders:
            fpath = os.path.join(root, folder)
            if not os.path.isdir(fpath):
                continue
            xw = os.path.join(fpath, f"{folder}_hidden.png")
            xi = os.path.join(fpath, f"{folder}_hidden_inv.png")
            if os.path.exists(xw) and os.path.exists(xi):
                self.pairs.append((xw, xi))
        print(f"Loaded {len(self.pairs)} paired samples from {len(all_folders)} folders.")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        xw_path, xi_path = self.pairs[idx]
        xw = Image.open(xw_path).convert("RGB")
        xi = Image.open(xi_path).convert("RGB")
        if self.tf:
            xw = self.tf(xw)
            xi = self.tf(xi)
        return xw, xi


# --- Transform ---
transform = transforms.Compose([
    transforms.Resize((400, 400)),
    transforms.ToTensor(),   # range [0, 1]
])

dataset = WatermarkDataset(DATA_ROOT, transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

# Model
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(DEVICE)
vae.train()
optimizer = torch.optim.Adam(vae.parameters(), lr=LR)
criterion = nn.MSELoss()

# Finetuning Loop (Algorithm 1 + diagnostics)

for epoch in range(EPOCHS):
    vae.train()
    total_loss = 0.0

    for step, (xw, xi) in enumerate(tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}"), start=1):
        xw, xi = xw.to(DEVICE, non_blocking=True), xi.to(DEVICE, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)

        # --- Input stats ---
        if step == 1:

        # --- Encode + Decode ---
        try:
            posterior = vae.encode(xw)
            z = posterior.latent_dist.mean

            # Decode (paper step 6)
            x_inv_hat = vae.decode(z / vae.config.scaling_factor).sample
        except Exception as e:
            print(colored(f"Forward crash: {e}", "red"))
            break

        # --- Compute loss (paper Eq. 1) ---
        try:
            loss = criterion(x_inv_hat, xi)
        except Exception as e:
            print(colored(f"Loss computation crash: {e}", "red"))
            break

        # --- Backpropagation (Algorithm 1, step 9) ---
        if torch.isnan(loss):
            print(colored("NaN loss detected – skipping batch", "red"))
            continue

        loss.backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(), GRAD_CLIP)
        optimizer.step()

        total_loss += loss.item() * xw.size(0)

        if step % 100 == 0:
            print(f"[Epoch {epoch+1} | Step {step}] Loss = {loss.item():.6f}")

        if torch.isnan(z).any() or torch.isnan(x_inv_hat).any():
            print(colored(f"NaN detected in latent or output at step {step}", "red"))
            raise SystemExit

    avg_loss = total_loss / len(dataset)
    print(colored(f"Epoch [{epoch+1}/{EPOCHS}] - Avg MSE Loss: {avg_loss:.6f}", "cyan"))

    # --- Save sample reconstructions ---
    with torch.no_grad():
        xw_vis, xi_vis = dataset[0]
        xw_vis = xw_vis.unsqueeze(0).to(DEVICE)
        posterior_vis = vae.encode(xw_vis)
        z_vis = posterior_vis.latent_dist.mean
        x_hat_vis = vae.decode(z_vis / vae.config.scaling_factor).sample

        def denorm(x): return x.clamp(0, 1)
        save_image(torch.cat([
          denorm(xw_vis),
          denorm(x_hat_vis),
          denorm(xi_vis.unsqueeze(0).to(DEVICE))
        ], dim=0),
        os.path.join(SAVE_DIR, f"epoch{epoch+1}_recon.png"), nrow=3)
        print(f"Saved reconstruction preview → epoch{epoch+1}_recon.png")


# Save final model

vae.save_pretrained(SAVE_DIR)
print(colored(f"Finetuning complete! Model saved to {SAVE_DIR}", "green"))


#experiment to check

In [18]:
import torch
from torchvision import transforms
from PIL import Image
import os

def load_secret(path):
    with open(path, "r") as f:
        return [int(c) for c in f.read().strip()]

def directional_decode(xw_tensor, x_hat_tensor, num_bits=100):
    """
    Rough heuristic to get directional Hamming bits.
    Splits image into `num_bits` regions and thresholds mean differences.
    """
    diff = (x_hat_tensor - xw_tensor).flatten()  # 1D tensor
    region_size = diff.shape[0] // num_bits
    bits = []
    for i in range(num_bits):
        start = i * region_size
        end = start + region_size
        region_mean = diff[start:end].mean().item()  # <-- only 1 index needed
        bits.append(1 if region_mean > 0 else 0)
    return bits


# --- Example usage ---
DATA_ROOT = "/content/drive/MyDrive/400x400_paired_data"
MODEL_PATH = "/content/drive/MyDrive/vae_finetuning/models"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained(MODEL_PATH).to(DEVICE)
vae.eval()

transform = transforms.Compose([
    transforms.Resize((400, 400)),
    transforms.ToTensor(),
])

folders = sorted([d for d in os.listdir(DATA_ROOT) if not d.startswith(".")])[:5]

for folder in folders:
    fpath = os.path.join(DATA_ROOT, folder)
    xw_path = os.path.join(fpath, f"{folder}_hidden.png")
    secret_path = os.path.join(fpath, f"{folder}_secret.txt")
    secret_inv_path = os.path.join(fpath, f"{folder}_secret_inv.txt")

    xw = transform(Image.open(xw_path).convert("RGB")).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        z = vae.encode(xw).latent_dist.mean
        x_hat = vae.decode(z / vae.config.scaling_factor).sample

    decoded_bits = directional_decode(xw, x_hat, num_bits=100)
    secret = load_secret(secret_path)
    secret_inv = load_secret(secret_inv_path)

    hamming_to_secret = sum(a != b for a, b in zip(decoded_bits, secret))
    hamming_to_inv = sum(a != b for a, b in zip(decoded_bits, secret_inv))

    print(f"{folder} | Hamming to secret: {hamming_to_secret}/100 | Hamming to inverse: {hamming_to_inv}/100")

000000 | Hamming to secret: 49/100 | Hamming to inverse: 51/100
000001 | Hamming to secret: 44/100 | Hamming to inverse: 56/100
000002 | Hamming to secret: 58/100 | Hamming to inverse: 42/100
000003 | Hamming to secret: 50/100 | Hamming to inverse: 50/100
000004 | Hamming to secret: 43/100 | Hamming to inverse: 57/100
