In [1]:
import torch
from torchvision import datasets, transforms
from pathlib import Path
import os
import sys

In [2]:
# utils_path = os.path.join(os.getcwd(), "winnet")
# utils_path

In [3]:
# sys.path.append(utils_path)

In [4]:
from models.unet import UNet
from models.lifting_denoiser import LiftingDenoiser

In [5]:
init_params = {
    "UNet": {
        "base_channels": 16,
        "channel_multipliers": (1, 2, 4, 8),
        "in_channels": 3 # 1 for Fashion-MNIST, 3 for CIFAR-10
    },
    "LiftingDenoiser": {
        "input_channels": 3,
        "coarse_channels": 3,
        "hidden_channels": 128,
        "num_lifting_steps": 4,
        "lifting_type": "revnet",
        "detail_denoiser": "clista",
        "split_merge_type": "haar_redundant",
        "do_convert_t_to_sigma": True,
        # "split_merge_patch_size": 4,
    },
    "GaussianDiffusion": {
        "image_size": 32, # 28 for raw Fashion-MNIST, 32 after padding/resize
        "timesteps": 1000, # forward diffusion steps
        "sampling_timesteps": 10, # DDIM steps at inference
        "objective": "pred_x0" # or 'pred_noise' or 'pred_v'
    },
    "Trainer": {
        "folder": './cifar_images',
        "results_folder": './results',
        "train_batch_size": 128,
        "train_lr": 1e-6,
        "train_num_steps": 2,
        "ema_update_every": 10,
        "num_samples": 4, 
        "calculate_fid": True, 
        "amp": True,
        "save_and_sample_every": 1,
        "save_best_and_latest_only": True,
        "lr_scheduler_type":"cosine",        # or "step", "exponential", "lambda", None
        "lr_scheduler_kwargs": {"eta_min": 1e-7},
    },
}

In [6]:
use_unet = False

In [7]:
if use_unet:
    model = UNet(**init_params["UNet"])
else:
    model = LiftingDenoiser(**init_params["LiftingDenoiser"])

In [8]:
param_count = sum(p.numel() for p in model.parameters())
param_count

1481513

In [9]:
from diffusion.diffusion import GaussianDiffusion

In [10]:
diffusion_model = GaussianDiffusion(
    model,
    **init_params["GaussianDiffusion"]
)

Updated `sigmas` tensor successfully.


The cell below only needs to be run once; subsequent runs will use the downloaded images.

In [11]:
root      = './data'
save_root = Path('./fashion_images')
save_root.mkdir(parents=True, exist_ok=True)

fm_train = datasets.FashionMNIST(
    root, 
    train=True, 
    download=True, 
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    ])
        
)
for idx, (tensor, _) in enumerate(fm_train):
    # convert 0-1 tensor → PIL L and (optionally) pad to 32×32
    img = transforms.ToPILImage(mode='RGB')(tensor)
    img = transforms.Resize(32, antialias=True)(img)   # keep square, upsample
    img.save(save_root / f'{idx:06}.png')

In [12]:
# root      = './data'
# save_root = Path('./cifar_images')
# save_root.mkdir(parents=True, exist_ok=True)

# cifar_train = datasets.CIFAR10(root, train=True, download=True)
# for idx, (img, _) in enumerate(cifar_train):
#     img.save(save_root / f'{idx:06}.png')   # already 32×32 RGB

In [13]:
from diffusion.trainer import DiffusionTrainer

In [14]:
diffusion_trainer = DiffusionTrainer(
    diffusion_model,
    **init_params["Trainer"]
)

In [15]:
diffusion_trainer.train()

  0%|          | 0/2 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/10 [00:00<?, ?it/s]

Stacking Inception features for 50000 samples from the real dataset.


  0%|          | 0/391 [00:00<?, ?it/s]


KeyboardInterrupt

