In [None]:
!pip install x-unet
!pip install scikit-learn

In [14]:
import random
import torch

def set_seed(seed: int = 0):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [15]:
set_seed()

In [None]:
import os

data_path = 'data'
png_folder = os.path.join(data_path, 'images')

In [17]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image

class HighFrequencyDataset(Dataset):
    def __init__(self, images_path):
        self.images_path = images_path
        self.images = self.read_images()

    def read_images(self) -> torch.Tensor:
        image_file_paths = [f for f in os.listdir(self.images_path) if f.lower().endswith('.png')]
        images = [Image.open(os.path.join(self.images_path, img_path)) for img_path in image_file_paths]
        return self.transform_images(images)

    def transform_images(self, images: list[Image]) -> torch.Tensor:
        transform = T.Compose([
            T.Resize((256, 256)),
            T.ToTensor(),
        ])
        return torch.stack([transform(image) for image in images])

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.images[idx]


In [18]:
dataset = HighFrequencyDataset(png_folder)

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, random_state=0)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [8]:
import torch
from x_unet import XUnet

# Create an instance of the XUnet model
unet = XUnet(
    dim = 64,
    channels = 1,
    dim_mults = (1, 2, 4, 8),
    nested_unet_depths = (7, 4, 2, 1),     # nested unet depths, from unet-squared paper
    consolidate_upsample_fmaps = True,     # whether to consolidate outputs from all upsample blocks, used in unet-squared paper
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
unet = unet.to(device)

In [None]:
import random

def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def add_gaussian_noise(x, std):
    # x in [-1, 1]; scale to [-1,1] noise as well
    noise = torch.randn_like(x) * std
    return (x + noise).clamp(-1, 1)

In [None]:

import os, math, argparse, random
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils as vutils

from tqdm import tqdm

# ----- utilities -----


def to_device(batch, device):
    if isinstance(batch, (list, tuple)):
        return [to_device(x, device) for x in batch]
    return batch.to(device, non_blocking=True)

@torch.no_grad()
def denoise_grid(model, noisy, *, nrow=8, clamp=True):
    model.eval()
    pred = model(noisy)
    out = pred
    if clamp:
        out = out.clamp(-1, 1)
    grid = vutils.make_grid(out, nrow=nrow, normalize=True, value_range=(-1, 1))
    return grid

# ----- training -----
def train(args):
    set_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    (out_dir / "samples").mkdir(exist_ok=True)

    # data
    train_set, test_set = get_cifar10(args.data, img_size=args.img_size)
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=True, drop_last=True)
    val_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.workers, pin_memory=True, drop_last=False)

    # model: XUnet
    # Common lucidrains API pattern: dim is base channels, dim_mults define UNet levels
    model = XUnet(
        dim=args.dim,                 # base feature dimension
        channels=3,                   # RGB
        dim_mults=(1, 2, 4, 8),       # 4-level UNet
        # you can tweak additional kwargs if desired; defaults work well for CIFAR-like data
    ).to(device)

    # optimizer
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=1e-4)

    scaler = torch.cuda.amp.GradScaler(enabled=not args.no_amp and device.type == "cuda")
    global_step = 0

    best_val = float("inf")

    for epoch in range(1, args.epochs + 1):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}")
        running = 0.0

        for (x, _) in pbar:
            x = to_device(x, device)          # clean target in [-1,1]
            x_noisy = add_gaussian_noise(x, args.noise_std)

            opt.zero_grad(set_to_none=True)

            with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=scaler.is_enabled()):
                pred = model(x_noisy)          # predict clean image directly
                loss = F.mse_loss(pred, x)     # DAE loss

            if scaler.is_enabled():
                scaler.scale(loss).backward()
                if args.clip_grad is not None:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
                scaler.step(opt)
                scaler.update()
            else:
                loss.backward()
                if args.clip_grad is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
                opt.step()

            running += loss.item()
            global_step += 1
            pbar.set_postfix(loss=f"{loss.item():.4f}")

            # sample preview
            if global_step % args.sample_every == 0:
                with torch.no_grad():
                    grid = denoise_grid(model, x_noisy[:min(16, x_noisy.size(0))])
                vutils.save_image(grid, out_dir / "samples" / f"train_step{global_step:07d}.png")

        train_loss = running / len(train_loader)

        # validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for (xv, _) in val_loader:
                xv = to_device(xv, device)
                xv_noisy = add_gaussian_noise(xv, args.noise_std)
                pv = model(xv_noisy)
                val_loss += F.mse_loss(pv, xv, reduction='mean').item()
        val_loss /= len(val_loader)

        # save checkpoint
        ckpt_path = out_dir / f"epoch{epoch:03d}_val{val_loss:.4f}.pt"
        torch.save({
            "epoch": epoch,
            "model": model.state_dict(),
            "opt": opt.state_dict(),
            "scaler": scaler.state_dict(),
            "args": vars(args),
            "val_loss": val_loss,
        }, ckpt_path)

        # track best
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), out_dir / "best_model.pt")

        print(f"Epoch {epoch} | train {train_loss:.4f} | val {val_loss:.4f} | best {best_val:.4f}")

        # periodic validation samples
        with torch.no_grad():
            # take a small batch for visualization
            xv, _ = next(iter(val_loader))
            xv = to_device(xv, device)[:16]
            xv_noisy = add_gaussian_noise(xv, args.noise_std)
            grid_noisy = vutils.make_grid(xv_noisy, nrow=8, normalize=True, value_range=(-1,1))
            grid_clean = vutils.make_grid(xv, nrow=8, normalize=True, value_range=(-1,1))
            grid_deno  = denoise_grid(model, xv_noisy, nrow=8)

            vutils.save_image(grid_noisy, out_dir / "samples" / f"val_noisy_epoch{epoch:03d}.png")
            vutils.save_image(grid_deno,  out_dir / "samples" / f"val_denoised_epoch{epoch:03d}.png")
            vutils.save_image(grid_clean, out_dir / "samples" / f"val_clean_epoch{epoch:03d}.png")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train X-UNet (lucidrains) for Gaussian denoising")
    parser.add_argument("--data", type=str, default="./data", help="dataset root")
    parser.add_argument("--out-dir", type=str, default="./runs/xunet_denoise", help="output directory")
    parser.add_argument("--img-size", type=int, default=32, help="image side length")
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--workers", type=int, default=4)
    parser.add_argument("--lr", type=float, default=2e-4)
    parser.add_argument("--dim", type=int, default=64, help="X-UNet base channels")
    parser.add_argument("--noise-std", type=float, default=0.2, help="Gaussian noise std (in [-1,1] scale)")
    parser.add_argument("--sample-every", type=int, default=500, help="steps between training previews")
    parser.add_argument("--clip-grad", type=float, default=None, help="e.g., 1.0 for grad clipping")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--cpu", action="store_true", help="force CPU")
    parser.add_argument("--no-amp", action="store_true", help="disable mixed precision")
    args = parser.parse_args()
    train(args)
