Imports

In [None]:
import os
import zipfile
import glob
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
from unet_color import UNet

Import data

In [None]:
ZIP_PATH = "/content/drive/MyDrive/Fall_25/Machine_Learning/Assignments/fake-v2.zip"
EXTRACT_DIR = "/content/my_images"   # Extract in Colab workspace

# Extract zip if not already extracted
if not os.path.exists(EXTRACT_DIR):
    print("Extracting image ZIP from Google Drive...")
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(EXTRACT_DIR)
    print("Extraction complete.")
else:
    print("Images already extracted.")

# Load images
def load_images_from_folder(folder):
    image_paths = glob.glob(os.path.join(folder, "**", "*.*"), recursive=True)
    allowed = (".png", ".jpg", ".jpeg")
    image_paths = [p for p in image_paths if p.lower().endswith(allowed)]

    images = []

    for path in image_paths:
        try:
            img = Image.open(path).convert("RGB")
            img = img.resize((64, 64), Image.LANCZOS)
            arr = np.array(img, dtype=np.float32) / 255.0
            images.append(arr)
        except Exception as e:
            print(f"Skipping {path}: {e}")

    return np.stack(images)

print("Loading images...")
trainX = load_images_from_folder(EXTRACT_DIR)
print(f"Loaded {trainX.shape[0]} images.")

Extracting chunk_1.zip ...
Extracting chunk_3_a.zip ...
Extracting chunk_3_b.zip ...
Extracting chunk_4_a.zip ...
Extracting chunk_4_b.zip ...
Extracting chunk_5.zip ...
All zip files extracted.

Loading images...
Skipping images/12479.jpg: Image size (232748750 pixels) exceeds limit of 178956970 pixels, could be decompression bomb DOS attack.
Loaded 11999 images total.


Sampling

In [None]:
def sample_batch(batch_size, device):
    indices = torch.randperm(trainX.shape[0])[:batch_size]
    data = torch.from_numpy(trainX[indices]).permute(0,3,1,2).to(device)  # permute instead of squeeze for color
    return data

Diffusion

In [None]:
class DiffusionModel:

    def __init__(self, T: int, model: nn.Module, device: str):
        self.T = T
        self.function_approximator = model.to(device)
        self.device = device

        self.beta = torch.linspace(1e-4, 0.02, T).to(device)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

    def training(self, batch_size, optimizer):
        """
        Algorithm 1 in Denoising Diffusion Probabilistic Models
        """

        x0 = sample_batch(batch_size, self.device)
        t = torch.randint(1, self.T + 1, (batch_size,), device=self.device,
                          dtype=torch.long)
        eps = torch.randn_like(x0)

        # Take one gradient descent step
        alpha_bar_t = self.alpha_bar[t - 1].unsqueeze(-1).unsqueeze(
            -1).unsqueeze(-1)
        eps_predicted = self.function_approximator(torch.sqrt(
            alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * eps, t - 1)
        loss = nn.functional.mse_loss(eps, eps_predicted)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

    @torch.no_grad()
    def sampling(self, n_samples=1, image_channels=3, img_size=(64, 64),  # change channels and resolution here
                 use_tqdm=True):
        """
        Algorithm 2 in Denoising Diffusion Probabilistic Models
        """

        x = torch.randn((n_samples, image_channels, img_size[0], img_size[1]),
                        device=self.device)
        progress_bar = tqdm if use_tqdm else lambda x: x
        for t in progress_bar(range(self.T, 0, -1)):
            z = torch.randn_like(x) if t > 1 else torch.zeros_like(x)
            t = torch.ones(n_samples, dtype=torch.long, device=self.device) * t

            beta_t = self.beta[t - 1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            alpha_t = self.alpha[t - 1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            alpha_bar_t = self.alpha_bar[t - 1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

            mean = 1 / torch.sqrt(alpha_t) * (x - ((1 - alpha_t) / torch.sqrt(
                1 - alpha_bar_t)) * self.function_approximator(x, t - 1))
            sigma = torch.sqrt(beta_t)
            x = mean + sigma * z
        return x

Training

In [None]:
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size = 64
    model = UNet()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
    diffusion = DiffusionModel(T=1000, model=model, device=device)

    # training
    print("Starting training...")
    for step in tqdm(range(100_000)):
        loss = diffusion.training(batch_size, optimizer)

    # results
    nb_images = 81
    samples = diffusion.sampling(n_samples=nb_images, use_tqdm=False)

    plt.figure(figsize=(17, 17))
    for i in range(nb_images):
        plt.subplot(9, 9, i + 1)
        plt.axis('off')
        plt.imshow(samples[i].permute(1,2,0).cpu().numpy().clip(0, 1))  # permute, delete cmap = 'gray'

    os.makedirs("Imgs", exist_ok=True)
    plt.savefig("Imgs/samples.png")
    print("Generated images saved to Imgs/samples.png")

Starting training...


  0%|          | 0/150000 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 54.12 MiB is free. Process 9843 has 14.69 GiB memory in use. Of the allocated memory 13.86 GiB is allocated by PyTorch, and 720.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Save Model

In [None]:
torch.save(model.state_dict(), "unet64_rgb.pth")
print("Model saved to unet64_rgb.pth")

Generate One Image

In [None]:
# ---- Generate and plot ONE image after training ----
print("Generating a single sample...")
one_sample = diffusion.sampling(n_samples=1, use_tqdm=False)

# Convert from (1, 3, 64, 64) â†’ (64, 64, 3) for plotting
img = one_sample[0].permute(1, 2, 0).cpu().numpy().clip(0, 1)

plt.figure(figsize=(4, 4))
plt.axis('off')
plt.imshow(img)
plt.show()