In [None]:
!pip install torch-fidelity

Collecting torch-fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->torch-fidelity)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->torch-fidelity)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->torch-fidelity)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->torch-fidelity)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->torch-fidelity)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->torch-fidel

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch_fidelity import calculate_metrics

In [None]:
!pip install gdown
import gdown
import zipfile

file_id = "1cJyPQzVOzsCZQctNBuHCqxHnOY7v7UiA"
output = "dataset.zip"
gdown.download(f"https://drive.google.com/uc?id={file_id}", output, quiet = False)

with zipfile.ZipFile("dataset.zip", "r") as zip_ref:
  zip_ref.extractall("/content/dataset")



Downloading...
From (original): https://drive.google.com/uc?id=1cJyPQzVOzsCZQctNBuHCqxHnOY7v7UiA
From (redirected): https://drive.google.com/uc?id=1cJyPQzVOzsCZQctNBuHCqxHnOY7v7UiA&confirm=t&uuid=c3609420-c9d5-48b7-8078-ac84e41c92cb
To: /content/dataset.zip
100%|██████████| 304M/304M [00:08<00:00, 34.5MB/s]


In [None]:
data_dir = "/content/dataset/Samples"
npy_files = [f for f in os.listdir(data_dir) if f.endswith('.npy')]

In [None]:
images = []
for file in npy_files:
    img = np.load(os.path.join(data_dir, file))
    images.append(torch.from_numpy(img).float())

In [None]:
images_tensor = torch.stack(images)
print(f"Dataset loaded. Shape: {images_tensor.shape}, Min: {images_tensor.min()}, Max: {images_tensor.max()}")

Dataset loaded. Shape: torch.Size([10000, 1, 150, 150]), Min: 0.0, Max: 1.0


In [None]:
dataset = TensorDataset(images_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class Diffusion(nn.Module):
  def __init__(self, T = 1000, beta_starts = 1e-4, beta_end = 0.02, device = device):
    super().__init__()
    self.T = T
    self.device = device
    betas = torch.linspace(beta_starts, beta_end, T).to(device)
    alphas = 1 - betas
    alpha_bars = torch.cumprod(alphas, dim = 0)

    self.register_buffer('sqrt_alpha_bars', torch.sqrt(alpha_bars))
    self.register_buffer('sqrt_one_minus_alpha_bars', torch.sqrt(1 - alpha_bars))

  def q_sample(self, x0, t, noise = None):
    if noise is None:
      noise = torch.randn_like(x0)
    batch_size = x0.shape[0]
    sqrt_alpha_bar = self.sqrt_alpha_bars[t].view(batch_size, 1, 1, 1)
    sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alpha_bars[t].view(batch_size, 1, 1, 1)
    return sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise

  def ddim_sample(self, model, n_samples, image_size, channels, steps=50):
    model.eval()
    with torch.no_grad():
      total_steps = self.T
      step_indices = torch.linspace(0, total_steps-1, steps, dtype=torch.long, device=device)
      timesteps = step_indices
      x = torch.randn(n_samples, channels, image_size, image_size, device=device)

      for i in range(len(timesteps) - 1, -1, -1):
        t = timesteps[i].repeat(n_samples)
        t_next = timesteps[i-1] if i > 0 else -1
        predicted_noise = model(x, t)
        alpha_bar_t = self.sqrt_alpha_bars[t]**2
        alpha_bar_t_next = self.sqrt_alpha_bars[t_next]**2 if t_next >= 0 else 1.0
        sigma = 0
        x0_pred = (x - torch.sqrt(1 - alpha_bar_t) * predicted_noise) / torch.sqrt(alpha_bar_t)
        if t_next >= 0:
          noise_dir = torch.sqrt(1 - alpha_bar_t_next - sigma**2) * predicted_noise
          x = torch.sqrt(alpha_bar_t_next) * x0_pred + noise_dir
        else:
          x = x0_pred
    return x

In [None]:
class SinusoidalEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = torch.arange(half_dim, device=device).float()
        emb = torch.exp(-emb * (np.log(10000) / half_dim))
        emb = t[:, None] * emb[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)

In [None]:
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.time_mlp = nn.Linear(time_dim, out_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        h = self.conv1(x)
        time_emb = self.relu(self.time_mlp(t))
        h = h + time_emb[:, :, None, None]
        h = self.norm1(h)
        h = self.relu(h)
        h = self.conv2(h)
        h = self.norm2(h)
        return self.relu(h)

In [None]:
class UNet(nn.Module):
    def __init__(self, time_dim=128):
        super().__init__()
        self.time_emb = SinusoidalEmbedding(time_dim)
        self.time_mlp = nn.Sequential(nn.Linear(time_dim, time_dim), nn.ReLU())

        # Downsampling
        self.down1 = UNetBlock(1, 64, time_dim)
        self.pool1 = nn.MaxPool2d(2)  # 150 -> 75
        self.down2 = UNetBlock(64, 128, time_dim)
        self.pool2 = nn.MaxPool2d(2)  # 75 -> 37

        # Bottleneck
        self.bottleneck = UNetBlock(128, 256, time_dim)

        # Upsampling
        self.upconv1 = nn.ConvTranspose2d(256, 128, 2, stride=2)  # 37 -> 74
        self.up1 = UNetBlock(256, 128, time_dim)
        self.upconv2 = nn.ConvTranspose2d(128, 64, 2, stride=2)  # 74 -> 148
        self.up2 = UNetBlock(128, 64, time_dim)

        # Output
        self.out = nn.Conv2d(64, 1, 1)

    def forward(self, x, t):
        t_emb = self.time_emb(t)
        t_emb = self.time_mlp(t_emb)

        d1 = self.down1(x, t_emb)
        d2 = self.pool1(d1)
        d3 = self.down2(d2, t_emb)
        d4 = self.pool2(d3)

        b = self.bottleneck(d4, t_emb)

        u1 = self.upconv1(b)
        u1 = torch.cat([F.interpolate(u1, size=d3.shape[2:], mode='bilinear', align_corners=False), d3], dim=1)
        u2 = self.up1(u1, t_emb)
        u3 = self.upconv2(u2)
        u3 = torch.cat([F.interpolate(u3, size=d1.shape[2:], mode='bilinear', align_corners=False), d1], dim=1)
        u4 = self.up2(u3, t_emb)

        return self.out(u4)

In [None]:
diffusion = Diffusion(device=device).to(device)
model = UNet().to(device)
optimizer = Adam(model.parameters(), lr=1e-4)
scaler = GradScaler()
num_epochs = 50

  scaler = GradScaler()


In [None]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        x0 = batch[0].to(device)
        t = torch.randint(0, diffusion.T, (x0.shape[0],), device=device)
        noise = torch.randn_like(x0)
        xt = diffusion.q_sample(x0, t, noise)
        with autocast():
            predicted_noise = model(xt, t)
            loss = F.mse_loss(predicted_noise, noise)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.6f}")
    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f'diffusion_model_epoch_{epoch+1}.pth')

  with autocast():
Epoch 1/50: 100%|██████████| 313/313 [01:24<00:00,  3.70it/s]


Epoch 1/50, Avg Loss: 0.061173


Epoch 2/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 2/50, Avg Loss: 0.008365


Epoch 3/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 3/50, Avg Loss: 0.006109


Epoch 4/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 4/50, Avg Loss: 0.004949


Epoch 5/50: 100%|██████████| 313/313 [01:22<00:00,  3.78it/s]


Epoch 5/50, Avg Loss: 0.004373


Epoch 6/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 6/50, Avg Loss: 0.004251


Epoch 7/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 7/50, Avg Loss: 0.003748


Epoch 8/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 8/50, Avg Loss: 0.003838


Epoch 9/50: 100%|██████████| 313/313 [01:22<00:00,  3.80it/s]


Epoch 9/50, Avg Loss: 0.003494


Epoch 10/50: 100%|██████████| 313/313 [01:22<00:00,  3.80it/s]


Epoch 10/50, Avg Loss: 0.003297


Epoch 11/50: 100%|██████████| 313/313 [01:22<00:00,  3.80it/s]


Epoch 11/50, Avg Loss: 0.003168


Epoch 12/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 12/50, Avg Loss: 0.003616


Epoch 13/50: 100%|██████████| 313/313 [01:22<00:00,  3.80it/s]


Epoch 13/50, Avg Loss: 0.003116


Epoch 14/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 14/50, Avg Loss: 0.002822


Epoch 15/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 15/50, Avg Loss: 0.002839


Epoch 16/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 16/50, Avg Loss: 0.002829


Epoch 17/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 17/50, Avg Loss: 0.002805


Epoch 18/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 18/50, Avg Loss: 0.002809


Epoch 19/50: 100%|██████████| 313/313 [01:22<00:00,  3.80it/s]


Epoch 19/50, Avg Loss: 0.002782


Epoch 20/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 20/50, Avg Loss: 0.002889


Epoch 21/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 21/50, Avg Loss: 0.002983


Epoch 22/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 22/50, Avg Loss: 0.002647


Epoch 23/50: 100%|██████████| 313/313 [01:22<00:00,  3.80it/s]


Epoch 23/50, Avg Loss: 0.002530


Epoch 24/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 24/50, Avg Loss: 0.002551


Epoch 25/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 25/50, Avg Loss: 0.002328


Epoch 26/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 26/50, Avg Loss: 0.002647


Epoch 27/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 27/50, Avg Loss: 0.002692


Epoch 28/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 28/50, Avg Loss: 0.002683


Epoch 29/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 29/50, Avg Loss: 0.002508


Epoch 30/50: 100%|██████████| 313/313 [01:22<00:00,  3.78it/s]


Epoch 30/50, Avg Loss: 0.002139


Epoch 31/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 31/50, Avg Loss: 0.002213


Epoch 32/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 32/50, Avg Loss: 0.002151


Epoch 33/50: 100%|██████████| 313/313 [01:22<00:00,  3.80it/s]


Epoch 33/50, Avg Loss: 0.002427


Epoch 34/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 34/50, Avg Loss: 0.002314


Epoch 35/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 35/50, Avg Loss: 0.002134


Epoch 36/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 36/50, Avg Loss: 0.002307


Epoch 37/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 37/50, Avg Loss: 0.001954


Epoch 38/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 38/50, Avg Loss: 0.002340


Epoch 39/50: 100%|██████████| 313/313 [01:22<00:00,  3.78it/s]


Epoch 39/50, Avg Loss: 0.002319


Epoch 40/50: 100%|██████████| 313/313 [01:22<00:00,  3.78it/s]


Epoch 40/50, Avg Loss: 0.001903


Epoch 41/50: 100%|██████████| 313/313 [01:22<00:00,  3.78it/s]


Epoch 41/50, Avg Loss: 0.002443


Epoch 42/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 42/50, Avg Loss: 0.001773


Epoch 43/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 43/50, Avg Loss: 0.002127


Epoch 44/50: 100%|██████████| 313/313 [01:22<00:00,  3.78it/s]


Epoch 44/50, Avg Loss: 0.002165


Epoch 45/50: 100%|██████████| 313/313 [01:22<00:00,  3.78it/s]


Epoch 45/50, Avg Loss: 0.002119


Epoch 46/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 46/50, Avg Loss: 0.001924


Epoch 47/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 47/50, Avg Loss: 0.002269


Epoch 48/50: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Epoch 48/50, Avg Loss: 0.002248


Epoch 49/50: 100%|██████████| 313/313 [01:22<00:00,  3.80it/s]


Epoch 49/50, Avg Loss: 0.002209


Epoch 50/50: 100%|██████████| 313/313 [01:22<00:00,  3.78it/s]

Epoch 50/50, Avg Loss: 0.002010





In [None]:
generated_images = diffusion.ddim_sample(model, n_samples=1000, image_size=150, channels=1, steps=50)
generated_images = torch.clamp(generated_images, 0, 1)

OutOfMemoryError: CUDA out of memory. Tried to allocate 5.37 GiB. GPU 0 has a total capacity of 14.74 GiB of which 916.12 MiB is free. Process 5293 has 13.84 GiB memory in use. Of the allocated memory 11.57 GiB is allocated by PyTorch, and 2.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
generated_images_rgb = generated_images.repeat(1, 3, 1, 1).cpu()
real_images_rgb = images_tensor[:1000].repeat(1, 3, 1, 1)
metrics = calculate_metrics(
    input1=generated_images_rgb,
    input2=real_images_rgb,
    fid=True,
    cuda=True,
    isc=False,
    kid=False
)
fid = metrics['frechet_inception_distance']
print(f"FID Score: {fid}")

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i in range(5):
    axes[i].imshow(generated_images[i].squeeze().cpu().numpy(), cmap='gray')
    axes[i].axis('off')
plt.suptitle("Generated Strong Lensing Images")
plt.show()