<a href="https://colab.research.google.com/github/azhgh22/Comparative-analysis-of-Generative-models-on-CIFAR-10/blob/main/experiments/train_ddpm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Set Env**

In [None]:
%%capture
from google.colab import drive
drive.mount('/content/drive')

from google.colab import userdata
token = userdata.get('GITHUB_TOKEN')
user_name = userdata.get('GITHUB_USERNAME')
mail = userdata.get('GITHUB_MAIL')

!git config --global user.name "{user_name}"
!git config --global user.email "{mail}"
!git clone https://{token}@github.com/azhgh22/Comparative-analysis-of-Generative-models-on-CIFAR-10.git

# **Imports**

In [None]:
# Imports
import sys
import os
import torch
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from torchvision.utils import make_grid

# Add the root directory of the cloned repository to the Python path
sys.path.append('/content/Comparative-analysis-of-Generative-models-on-CIFAR-10')

import importlib
import data.cifar10 as cifar10_module
importlib.reload(cifar10_module)
from data.cifar10 import load_cifar10



In [None]:
train_loader, _ = load_cifar10(batch_size=1024, normalize_inputs=True, pin_memory=True, num_workers=2)

In [None]:
print(f"Batch Size: {train_loader.batch_size}")
print(f"Num Workers: {train_loader.num_workers}")
print(f"Pin Memory: {train_loader.pin_memory}")

# Tip: If num_workers is 0, try increasing it to 2 or 4 to parallelize data loading.
# Tip: pin_memory=True speeds up transfer to GPU.

In [None]:
import matplotlib.pyplot as plt

# Convert to HWC for plotting
def show_img(img):
  img = (img + 1) / 2
  img = torch.clamp(img, 0, 1)

  if img.dim() == 4:
      img = img[0]
  img = img.detach().cpu()
  img = img.clamp(0,1)
  img = img.permute(1,2,0)  # CHW -> HWC
  plt.figure(figsize=(4,4))
  plt.imshow(img, interpolation='nearest')
  plt.axis('off')
  plt.show()

In [None]:
def show_images(images, title="Images", n_row=4):
    # Display a grid of images
    images = images.cpu()
    # Denormalize from [-1, 1] to [0, 1]
    images = (images + 1) / 2
    images = torch.clamp(images, 0, 1)

    n = len(images)
    n_col = n_row
    n_row = (n + n_col - 1) // n_col

    fig, axes = plt.subplots(n_row, n_col, figsize=(n_col * 2, n_row * 2))
    axes = axes.flatten() if n > 1 else [axes]

    for i, ax in enumerate(axes):
        if i < n:
            img = images[i].permute(1, 2, 0).numpy()
            ax.imshow(img)
            ax.axis('off')
        else:
            ax.axis('off')

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

In [None]:
train_dataset = train_loader.dataset
img = train_dataset[0][0]
show_img(img)

In [None]:
from models.scorebased_models.ddpm import create_ddpm
from utils.get_device import get_device
from train.train import Train
from utils.checkpointer import Checkpointer

device = get_device()

model = create_ddpm(image_size=32, image_channels=3, timesteps=1000).to(device)

In [None]:
checkpoint_dir = "/content/drive/MyDrive/checkpoints_final/ddpm"
checkpointer = Checkpointer(checkpoint_dir, "ddpm", 10, False)
train = Train(model, 200, train_loader, checkpointer, device)
train.load_checkpoint()

In [None]:
train.train()

In [None]:
print(model.model.conv_in.weight.grad.mean())
print(model.model.conv_in.weight.grad.std())
print(model.model.conv_in.weight.grad.abs().mean())

In [None]:
gen_img = model.sample(16)

In [None]:
show_images(gen_img)

### Check System RAM and GPU RAM Usage

In [None]:
print('--- System RAM Usage ---')
!free -h

print('\n--- GPU RAM Usage ---')
!nvidia-smi

--- System RAM Usage ---
               total        used        free      shared  buff/cache   available
Mem:            12Gi       2.3Gi       1.6Gi       161Mi       8.8Gi       9.9Gi
Swap:             0B          0B          0B

--- GPU RAM Usage ---
Fri Jan 30 23:49:04 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   70C    P0             30W /   70W |   11360MiB /  15360MiB | 

Understanding the output from the above cells will help pinpoint whether it's CPU or GPU memory that is constrained. Here are some common solutions:

1.  **Reduce Batch Size**: If your batch size is large (e.g., 128 as currently set for your `train_loader`), try reducing it. A smaller batch size uses less memory per iteration.

2.  **Delete Unused Variables**: Explicitly delete variables or tensors that are no longer needed, especially large ones. Python's garbage collector might not always immediately free memory, so forcing it can help.

3.  **Clear CUDA Cache**: If your GPU memory is full, you can try clearing the CUDA cache. This often helps if there are fragmented or lingering allocations.

4.  **Restart Runtime**: This is the most straightforward way to free up all allocated memory. If you restart the runtime, you'll need to re-run your setup cells.

5.  **Optimize Data Loading**: Ensure your `load_cifar10` function or any data transformations aren't loading the entire dataset into CPU memory unnecessarily or creating many copies.

6.  **Gradient Accumulation (for GPU memory)**: If you need a larger 'effective' batch size but are constrained by GPU memory, you can use gradient accumulation by performing forward/backward passes on smaller batches and only updating weights after several steps.