In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pydicom
!pip install einops
!pip install diffusers
!pip install lpips
!pip install kornia

Collecting pydicom
  Downloading pydicom-3.0.1-py3-none-any.whl.metadata (9.4 kB)
Downloading pydicom-3.0.1-py3-none-any.whl (2.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pydicom
Successfully installed pydicom-3.0.1
Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=0.4.0->lpips)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=0.4.0->lpips)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=0.4.0->lpips)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=0.4.0->lpips)
  Downloading 

In [None]:
import os, glob, math, functools, random
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import datasets, transforms
import math
import numpy as np
from PIL import Image
import pydicom
import importlib
from einops import rearrange, repeat
from torchvision import models
from tqdm.auto import tqdm
from diffusers import AutoencoderKL, DDPMScheduler
import lpips
from torch.utils.data import random_split
import kornia

In [None]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # ensure it's single‑channel
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])           # mean/std for one channel
])

# AutoEncoder

# Difusion

## Model

### UNET

In [None]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class ResnetBlockTime(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, groups=32):
        super().__init__()
        self.norm1 = nn.GroupNorm(groups, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.time_mlp = nn.Sequential(Swish(), nn.Linear(time_emb_dim, out_ch * 2))
        self.norm2 = nn.GroupNorm(groups, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.res_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t_emb):
        h = self.norm1(x); h = Swish()(h); h = self.conv1(h)
        scale, shift = self.time_mlp(t_emb).chunk(2, dim=1)
        h = h * (scale[:, :, None, None] + 1) + shift[:, :, None, None]
        h = self.norm2(h); h = Swish()(h); h = self.conv2(h)
        return h + self.res_conv(x)

class ResnetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, groups=32):
        super().__init__()
        self.norm1 = nn.GroupNorm(groups, in_ch)
        self.act1  = Swish()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)

        self.norm2 = nn.GroupNorm(groups, out_ch)
        self.act2  = Swish()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)

        if in_ch != out_ch:
            self.res_conv = nn.Conv2d(in_ch, out_ch, 1)
        else:
            self.res_conv = nn.Identity()

    def forward(self, x):
        h = self.conv1(self.act1(self.norm1(x)))
        h = self.conv2(self.act2(self.norm2(h)))
        return h + self.res_conv(x)

class Downsample(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.op = nn.Conv2d(ch, ch, 4, stride=2, padding=1)
    def forward(self, x): return self.op(x)

class Upsample(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.op = nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1)
    def forward(self, x): return self.op(x)

class Encoder(nn.Module):
    def __init__(
        self,
        in_channels: int = 1,
        hidden_dims: list[int] = [64, 128, 256],
        latent_dim: int = 16,
        attn_resolutions: list[int] = [16],      # apply attention when H=W equals these
        diffusion_embed_dim: int = 512
    ):
        super().__init__()
        self.latent_dim = latent_dim

        # input convolution
        self.conv_in = nn.Conv2d(in_channels, hidden_dims[0], 3, padding=1)

        # build down‑stack
        self.down_blocks = nn.ModuleList()
        ch = hidden_dims[0]
        for h in hidden_dims:
            block = nn.ModuleList([ ResnetBlock(ch, h) ])
            if h in attn_resolutions:
                block.append(AttentionBlock(h))
            block.append(Downsample(h))
            self.down_blocks.append(block)
            ch = h

        # final ResNet (no downsample)
        self.mid_block = nn.ModuleList([
            ResnetBlock(ch, ch),
            AttentionBlock(ch),
            ResnetBlock(ch, ch),
        ])

        # produce μ & log var: double‑z
        self.conv_mu_logvar = nn.Conv2d(ch, 2*latent_dim, 3, padding=1)

        # quantization bridges
        self.quant_conv     = nn.Conv2d(latent_dim, diffusion_embed_dim, 1)
        self.post_quant_conv= nn.Conv2d(diffusion_embed_dim, latent_dim, 1)

    def forward(self, x):
        h = self.conv_in(x)
        for block in self.down_blocks:
            for layer in block:
                h = layer(h)

        for layer in self.mid_block:
            h = layer(h)

        # double‑z
        stats = self.conv_mu_logvar(h)
        mu, logvar = torch.chunk(stats, 2, dim=1)

        # quant bridges (optional, e.g. for your diffusion embedding)
        quant  = self.quant_conv(mu)           # maps μ→embed_dim
        post_q = self.post_quant_conv(quant)   # back to latent_dim

        return mu, logvar, post_q

class Decoder(nn.Module):
    def __init__(
        self,
        out_channels: int = 1,
        hidden_dims: list[int] = [256, 128, 64],
        latent_dim: int = 16,
        attn_resolutions: list[int] = [16], **kwargs
    ):
        super().__init__()

        # initial expand
        self.initial = nn.Conv2d(latent_dim, hidden_dims[0], 3, padding=1)

        # build up‑stack
        self.up_blocks = nn.ModuleList()
        ch = hidden_dims[0]
        for h in hidden_dims[1:]:
            block = nn.ModuleList([ ResnetBlock(ch, h) ])
            if h in attn_resolutions:
                block.append(AttentionBlock(h))
            block.append(Upsample(h))
            self.up_blocks.append(block)
            ch = h

        # final conv to image
        self.conv_out = nn.Sequential(
            ResnetBlock(ch, ch),
            nn.GroupNorm(32, ch),
            Swish(),
            Upsample(ch),
            nn.Conv2d(ch, out_channels, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        h = self.initial(z)
        for block in self.up_blocks:
            for layer in block:
                h = layer(h)
        return self.conv_out(h)

class AttentionBlock(nn.Module):
    def __init__(self, ch, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        head_dim = ch // num_heads
        assert head_dim * num_heads == ch, "ch must be divisible by num_heads"

        self.to_q = nn.Conv2d(ch, ch, 1)
        self.to_k = nn.Conv2d(ch, ch, 1)
        self.to_v = nn.Conv2d(ch, ch, 1)
        self.proj = nn.Conv2d(ch, ch, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        N = H * W
        h = self.num_heads
        d = C // h

        # project
        q = self.to_q(x).view(B, h, d, N)
        k = self.to_k(x).view(B, h, d, N)
        v = self.to_v(x).view(B, h, d, N)

        # scaled dot-product: (B, h, N, N)
        attn = torch.einsum('b h d n, b h d m -> b h n m', q, k)
        attn = attn * (d ** -0.5)
        attn = torch.softmax(attn, dim=-1)

        # attend to v → (B, h, d, N)
        out = torch.einsum('b h n m, b h d m -> b h d n', attn, v)
        out = out.contiguous().view(B, C, H, W)

        return self.proj(out)

class CrossAttentionBlock(nn.Module):
    def __init__(self, ch, cond_ch, num_heads=4):
        super().__init__()
        self.norm      = nn.GroupNorm(32, ch)
        self.num_heads = num_heads
        head_dim       = ch // num_heads
        assert head_dim * num_heads == ch, "ch must divide evenly"
        self.to_q = nn.Conv2d(ch, ch, 1)
        self.to_k = nn.Conv2d(cond_ch, ch, 1)
        self.to_v = nn.Conv2d(cond_ch, ch, 1)
        self.proj = nn.Conv2d(ch, ch, 1)

    def forward(self, x, cond):
        B, C, H, W = x.shape
        h = self.num_heads; d = C // h
        x_norm = self.norm(x)
        # match spatial size
        if cond.shape[-2:] != (H, W):
            cond = F.interpolate(cond, size=(H, W), mode='bilinear', align_corners=False)

        q = self.to_q(x_norm).view(B, h, d, H*W).transpose(-1,-2)  # [B,h,N,d]
        k = self.to_k(cond)   .view(B, h, d, H*W).transpose(-1,-2)
        v = self.to_v(cond)   .view(B, h, d, H*W).transpose(-1,-2)

        attn = (q @ k.transpose(-1,-2)) * (d**-0.5)
        attn = attn.softmax(dim=-1)
        out  = (attn @ v).transpose(-1,-2).contiguous().view(B, C, H, W)
        return x + self.proj(out)

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, t):
        half = self.dim // 2
        freq = torch.exp(
            torch.arange(half, device=t.device) * -(math.log(10000) / (half - 1))
        )
        args = t[:, None] * freq[None]
        return torch.cat([args.sin(), args.cos()], dim=-1)

class ResnetBlockTime(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, groups=32):
        super().__init__()
        self.norm1 = nn.GroupNorm(groups, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.time_mlp = nn.Sequential(Swish(), nn.Linear(time_emb_dim, out_ch * 2))
        self.norm2 = nn.GroupNorm(groups, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.res_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t_emb):
        h = self.norm1(x); h = Swish()(h); h = self.conv1(h)
        scale, shift = self.time_mlp(t_emb).chunk(2, dim=1)
        h = h * (scale[:, :, None, None] + 1) + shift[:, :, None, None]
        h = self.norm2(h); h = Swish()(h); h = self.conv2(h)
        return h + self.res_conv(x)


class LatentDiffusionUNetConditional(nn.Module):
    def __init__(
        self,
        in_ch: int = 16,
        cond_ch: int = 16,
        base_ch: int = 128,
        ch_mults: tuple[int,...] = (1,2,4,8),
        time_dim: int = 512,
        attn_res: tuple[int,...] = (8,16)
    ):
        super().__init__()
        self.cond_ch = cond_ch
        # time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim*4), Swish(),
            nn.Linear(time_dim*4, time_dim)
        )
        # initial conv: noisy + cond
        self.init_conv = nn.Conv2d(in_ch+cond_ch, base_ch, 3, padding=1)
        # channel sizes
        chs = [base_ch*m for m in ch_mults]

        # Down path
        self.downs = nn.ModuleList()
        prev_ch   = base_ch
        for i, out_ch in enumerate(chs):
            layers = []
            layers.append(ResnetBlockTime(prev_ch, out_ch, time_dim))
            if 2**i in attn_res: layers.append(AttentionBlock(out_ch))
            layers.append(CrossAttentionBlock(out_ch, cond_ch))
            if i < len(chs)-1: layers.append(Downsample(out_ch))
            self.downs.append(nn.ModuleList(layers))
            prev_ch = out_ch

        # Middle
        mid_ch = chs[-1]
        self.mid = nn.ModuleList([
            ResnetBlockTime(mid_ch, mid_ch, time_dim),
            AttentionBlock(mid_ch),
            CrossAttentionBlock(mid_ch, cond_ch),
            ResnetBlockTime(mid_ch, mid_ch, time_dim),
        ])

        # Up path
        self.ups = nn.ModuleList()
        for i in range(len(chs)-1):
            in_ch_up = chs[-1-i]
            skip_ch  = chs[-2-i]
            out_ch   = skip_ch
            layers = nn.ModuleDict({
                'upsample': Upsample(in_ch_up),
                'resnet':   ResnetBlockTime(in_ch_up + skip_ch, out_ch, time_dim),
                'cross':    CrossAttentionBlock(out_ch, cond_ch)
            })
            if 2**(len(chs)-2-i) in attn_res:
                layers['attn'] = AttentionBlock(out_ch)
            self.ups.append(layers)

        # Final output
        self.final = nn.Sequential(
            nn.GroupNorm(32, chs[0]),
            Swish(),
            nn.Conv2d(chs[0], in_ch, 3, padding=1)
        )

    def forward(self, x, t, cond):
        # x: noisy target latent [B,in_ch,H,W]
        # cond: low-dose latent [B,cond_ch,H,W]
        t_emb = self.time_mlp(t)
        h     = self.init_conv(torch.cat([x, cond], dim=1))
        skips = []

        # Down
        for block in self.downs:
            for layer in block:
                if isinstance(layer, Downsample):
                    skips.append(h)
                    h = layer(h)
                elif isinstance(layer, CrossAttentionBlock):
                    h = layer(h, cond)
                else:
                    h = layer(h, t_emb) if isinstance(layer, ResnetBlockTime) else layer(h)

        # Middle
        for layer in self.mid:
            if isinstance(layer, CrossAttentionBlock):
                h = layer(h, cond)
            else:
                h = layer(h, t_emb) if isinstance(layer, ResnetBlockTime) else layer(h)

        # Up
        for up in self.ups:
            h = up['upsample'](h)             # 1) upsample
            skip = skips.pop()
            h = torch.cat([h, skip], dim=1)  # 2) concat
            h = up['resnet'](h, t_emb)       # 3a) resnet
            if 'attn' in up: h = up['attn'](h)       # 3b) optional self-attn
            h = up['cross'](h, cond)                # 3c) cross-attn

        return self.final(h)

# Test

In [None]:
@torch.no_grad()
def sample_ddim_guided(z_cond, scheduler, diffusion):
    guidance_scale = 5.0
    B = z_cond.size(0)
    # start from pure noise
    z = torch.randn_like(z_cond)

    for t in scheduler.timesteps:
        t_int   = int(t.item() if isinstance(t, torch.Tensor) else t)
        t_batch = torch.full((B,), t_int, device=device, dtype=torch.long)

        # 1) unconditional prediction (cond dropped → zeros)
        eps_uncond = diffusion(z, t_batch, cond=torch.zeros_like(z_cond))
        # 2) conditional prediction
        eps_cond   = diffusion(z, t_batch, cond=z_cond)
        # 3) blend for classifier-free guidance
        eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

        # 4) take a DDIM step
        out = scheduler.step(eps, t_int, z, return_dict=True)
        z   = out.prev_sample

    return z

def load_and_preprocess(path):
    ds  = pydicom.dcmread(path)
    arr = ds.pixel_array.astype(float)
    arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-5)
    pil = Image.fromarray((arr * 255).astype("uint8"))
    return transform(pil).unsqueeze(0)

In [None]:
def pred_image(low_img_t, diffusion, vae, scheduler):

  with torch.no_grad():
      low_rgb = low_img_t.repeat(1,3,1,1)
      enc     = vae.encode(low_rgb)
      z_cond  = enc.latent_dist.sample() * vae.config.scaling_factor

  with torch.no_grad():
      low_rgb = low_img_t.repeat(1,3,1,1)
      enc     = vae.encode(low_rgb)
      z_pred = sample_ddim_guided(z_cond, scheduler, diffusion) * vae.config.scaling_factor

  with torch.no_grad():
      dec     = vae.decode(z_pred / vae.config.scaling_factor).sample
      pred    = (dec / 2 + 0.5).clamp(0,1)[0,0]  # [H,W] float in [0,1]
  return pred


def predict(vae, diffusion, scheduler, low_path):

    low_img_t  = load_and_preprocess(low_path).to(device)   # [1,1,H,W]
    pred = pred_image(low_img_t, diffusion, vae, scheduler)
    return pred

In [None]:
device = torch.device("cuda")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")

old_out = vae.decoder.conv_out
w2 = old_out.weight.data                             # [3, C, k, k]
w2_gray = w2.mean(dim=0, keepdim=True)                # [1, C, k, k]

new_out = nn.Conv2d(
    in_channels=old_out.in_channels,
    out_channels=1,
    kernel_size=old_out.kernel_size,
    stride=old_out.stride,
    padding=old_out.padding,
    bias=(old_out.bias is not None)
)
new_out.weight.data.copy_(w2_gray)
if old_out.bias is not None:
    new_out.bias.data.fill_(old_out.bias.data.mean())

vae.decoder.conv_out = new_out

# freeze everything
for p in vae.parameters():
    p.requires_grad = False

vae.to(device).eval()

ckpt_path = "/content/drive/MyDrive/CT Models/VAE/pt_vae_2_epoch_100.pth"
state_dict = torch.load(ckpt_path, map_location=device)
vae.load_state_dict(state_dict)
vae.to(device).eval()

latent_ch = vae.config.latent_channels
diffusion_model = LatentDiffusionUNetConditional(
    in_ch   = latent_ch,
    cond_ch = latent_ch,
    base_ch = 128,
    ch_mults= (1,2,4,8),
    time_dim= 512,
    attn_res= (8,16)
).to(device)
ckpt = torch.load("/content/drive/MyDrive/CT Models/Diffusion/pt-diffusion_3_epoch_10.pth", map_location=device)
diffusion_model.load_state_dict(ckpt)
diffusion_model.eval()

scheduler = DDPMScheduler(
        beta_start=1e-4,
        beta_end=0.02,
        beta_schedule="squaredcos_cap_v2"
    )
scheduler.set_timesteps(1000, device=device)

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

In [None]:
import json
image_paths = json.load(open("/content/drive/MyDrive/CT/image_paths.json"))
ssim_module = kornia.metrics.SSIM(
    window_size=11,        # standard 11×11 Gaussian window
    max_val=1.0,           # your images are in [0,1]
    eps=1e-12,
    padding='same'
).to(device)
results = []
for low_path in image_paths[1_000:1_500]:
    # 1) get your H×W prediction in [0,1]
    pred = predict(vae, diffusion_model, scheduler, low_path).to(device)  # [H,W]

    # 2) load & de-normalize your ground‐truth into [1,1,H,W]
    gt_t = load_and_preprocess(low_path).to(device)                      # [1,1,H,W]
    gt   = (gt_t * 0.5 + 0.5).clamp(0,1)                                  # [1,1,H,W]

    # 3) reshape pred to [1,1,H,W]
    pred_b = pred.unsqueeze(0).unsqueeze(0)                              # [1,1,H,W]

    # 4) compute SSIM map & reduce
    with torch.no_grad():
        ssim_map   = ssim_module(pred_b, gt)                             # [1,1,H,W]
        ssim_score = ssim_map.mean().item()                             # scalar

    results.append({"path": low_path, "ssim": ssim_score})
    import json
    with open("/content/drive/MyDrive/CT/results_2000.json", "w") as f:
        json.dump(results, f)

# now `results` is a list of {"path":…, "ssim":…} for each image