In [None]:
!pip install x-unet
!pip install scikit-learn

In [1]:
import random
import torch

def set_seed(seed: int = 0):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [2]:
set_seed()

In [None]:
import os

data_path = 'data'
png_folder = os.path.join(data_path, 'images')

checkpoints_path = 'checkpoints'
os.makedirs(checkpoints_path, exist_ok=True)

In [23]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image

class HighFrequencyDataset(Dataset):
    def __init__(self, images_path):
        self.images_path = images_path
        self.images = self.read_images()

    def read_images(self) -> torch.Tensor:
        image_file_paths = [f for f in os.listdir(self.images_path) if f.lower().endswith('.png')]
        images = [Image.open(os.path.join(self.images_path, img_path)) for img_path in image_file_paths]
        return self.transform_images(images)

    def transform_images(self, images: list[Image]) -> torch.Tensor:
        transform = T.Compose([
            T.Resize((256, 256)),
            T.ToTensor(),
        ])
        return torch.stack([transform(image) for image in images])

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx) -> torch.Tensor:
        return self.images[idx]


In [24]:
dataset = HighFrequencyDataset(png_folder)

In [25]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, random_state=0)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [26]:
import torch
from x_unet import XUnet

# Create an instance of the XUnet model
unet = XUnet(
    dim = 64,
    channels = 1,
    dim_mults = (1, 2, 4, 8),
    nested_unet_depths = (7, 4, 2, 1),     # nested unet depths, from unet-squared paper
    consolidate_upsample_fmaps = True,     # whether to consolidate outputs from all upsample blocks, used in unet-squared paper
)

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = unet.to(device)

In [28]:
def add_gaussian_noise(x: torch.Tensor, std: float):
    # x in [-1, 1]; scale to [-1,1] noise as well
    noise = torch.randn_like(x) * std
    return (x + noise).clamp(-1, 1)

In [29]:
LEARNING_RATE = 1e-4
BETAS = (0.9, 0.99)
WEIGHT_DECAY = 1e-4
MIXED_PRECISION = True

# Define optimizer
opt = torch.optim.AdamW(unet.parameters(), lr=LEARNING_RATE, betas=BETAS, weight_decay=WEIGHT_DECAY)

# Define scaler (so we can use mixed precision), otherwise, we'll experience vanishing gradients
scaler = torch.cuda.amp.GradScaler(enabled=MIXED_PRECISION and device.type == "cuda")

  scaler = torch.cuda.amp.GradScaler(enabled=MIXED_PRECISION and device.type == "cuda")


In [40]:
from tqdm import tqdm
import torch.nn.functional as F

NOISE_STD = 0.1
CLIP_GRAD = 1
SAMPLE_EVERY = 1000
EPOCHS = 10
MODEL_OUT_DIR = os.path.join(checkpoints_path, 'vanilla')

os.makedirs(MODEL_OUT_DIR, exist_ok=True)

global_step = 0

best_val = float("inf")

for epoch in range(1, EPOCHS + 1):
    unet.train()
    pbar = tqdm(iter(train_loader), desc=f"Epoch {epoch}/{EPOCHS}")
    running = 0.0

    for x in pbar:
        x = x.to(device)          # clean target in [-1,1]
        x_noisy = add_gaussian_noise(x, NOISE_STD)

        opt.zero_grad(set_to_none=True)

        with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=scaler.is_enabled()):
            pred = unet(x_noisy)          # predict clean image directly
            loss = F.mse_loss(pred, x)     # DAE loss

        if scaler.is_enabled():
            scaler.scale(loss).backward()
            if CLIP_GRAD is not None:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(unet.parameters(), CLIP_GRAD)
            scaler.step(opt)
            scaler.update()
        else:
            loss.backward()
            if CLIP_GRAD is not None:
                torch.nn.utils.clip_grad_norm_(unet.parameters(), CLIP_GRAD)
            opt.step()

        running += loss.item()
        global_step += 1
        pbar.set_postfix(loss=f"{loss.item():.4f}")

        # sample preview
        # if global_step % args.sample_every == 0:
        #     with torch.no_grad():
        #         grid = denoise_grid(model, x_noisy[:min(16, x_noisy.size(0))])
        #     vutils.save_image(grid, out_dir / "samples" / f"train_step{global_step:07d}.png")

    train_loss = running / len(train_loader)

    # validation
    unet.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xv in iter(val_loader):
            xv = xv.to(device)
            xv_noisy = add_gaussian_noise(xv, NOISE_STD)
            pv = unet(xv_noisy)
            val_loss += F.mse_loss(pv, xv, reduction='mean').item()
    val_loss /= len(val_loader)

    # track best
    if val_loss < best_val:
        best_val = val_loss
        torch.save(unet.state_dict(), os.path.join(MODEL_OUT_DIR, "best_model.pt"))

    print(f"Epoch {epoch} | train {train_loss:.4f} | val {val_loss:.4f} | best {best_val:.4f}")

Epoch 1/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:20<00:00,  2.29it/s, loss=0.0008]


Epoch 1 | train 0.0011 | val 0.0011 | best 0.0011


Epoch 2/10:  39%|███████████████████████████████████████████▍                                                                   | 18/46 [00:08<00:13,  2.15it/s, loss=0.0008]


KeyboardInterrupt: 