In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install pydicom
!pip install einops
!pip install pytorch-lightning
!pip install lpips

In [None]:
import os, math, random
from pathlib import Path
import matplotlib.pyplot as plt


import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pydicom
from diffusers import AutoencoderKL
from lpips import LPIPS
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import random_split

# Dataset

In [None]:
class CTDataset(Dataset):
    """
    Custom Dataset for loading .IMA CT images from a single folder.
    """
    def __init__(self, root_dir, transform=None):
        root = Path(root_dir)

        # collect all your low/high folders however you like
        self.pairs = self.collect_pairs_by_position(root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.pairs)

    def collect_pairs_by_position(self, root: str, sort: bool = True):
        """
        Walk quarter_3mm ↔ full_3mm and quarter_1mm ↔ full_1mm in parallel,
        and pair the i-th image in each patient directory by position.
        """
        root = Path(root)
        mapping = {
            "quarter_3mm": "full_3mm",
            "quarter_1mm": "full_1mm",
        }

        pairs = []
        for small_name, full_name in mapping.items():
            small_root = root / small_name
            full_root  = root / full_name

            # each subfolder under small_root is a patient ID
            for patient_dir in sorted(small_root.iterdir()):
                if not patient_dir.is_dir():
                    continue

                # match the same patient under the full folder
                full_patient_dir = full_root / patient_dir.name
                if not full_patient_dir.exists():
                    continue

                # grab all images anywhere under that patient (rglob)
                small_imgs = list(patient_dir.rglob("*.IMA"))
                full_imgs  = list(full_patient_dir.rglob("*.IMA"))

                if sort:
                    small_imgs.sort()
                    full_imgs.sort()

                # pair by index
                for small_img, full_img in zip(small_imgs, full_imgs):
                    pairs.append(small_img)
                    pairs.append(full_img)

        return pairs

    def __getitem__(self, idx):
        # Read DICOM file
        q_path=  self.pairs[idx]
        ds_q = pydicom.dcmread(q_path)
        img_q = ds_q.pixel_array.astype(np.float32)

        # Normalize pixel values to [0, 1]
        img_q = (img_q - img_q.min()) / (img_q.max() - img_q.min() + 1e-5)

        # Convert to 8-bit and PIL Image
        img_q = (img_q * 255).astype(np.uint8)
        img_q = Image.fromarray(img_q)

        # Apply transforms (e.g. ToTensor, Normalize, etc.)
        if self.transform:
            img_q = self.transform(img_q)

        # Return image and dummy label
        return img_q


In [None]:
transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((256,256)),            # or your desired size
            transforms.ToTensor(),                   # → [0,1]
            transforms.Normalize([0.5], [0.5])       # → [−1, +1]
        ])

In [None]:
# ds = CTDataset("/content/drive/MyDrive/CT", transform)
ds = CTDataset("/content/drive/MyDrive/CT/", transform=transform)
n_val   = int(len(ds) * 0.1)
n_train = len(ds) - n_val

train_ds, val_ds = random_split(ds, [n_train, n_val])

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  num_workers=4)
val_loader   = DataLoader(val_ds,   batch_size=8, shuffle=False, num_workers=4)

# AutoEncoder

In [None]:
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=1, ch=64, n_layers=3):
        super().__init__()
        layers = [nn.Conv2d(in_channels, ch, 4, stride=2, padding=1),
                  nn.LeakyReLU(0.2, inplace=True)]
        nf = ch
        for _ in range(1, n_layers):
            layers += [
                nn.Conv2d(nf, nf*2, 4, stride=2, padding=1),
                nn.GroupNorm(32, nf*2),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            nf *= 2
        layers += [nn.Conv2d(nf, 1, 4, padding=1)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

## Train

In [None]:
from torch.optim import Adam

In [None]:
# 1) Define an inverse normalization to go from [-1,+1] back to [0,1]
inv_normalize = transforms.Normalize(
    mean=[-1.0],  # x' = (x - mean) / std  →  to invert: x = x' * std + mean
    std=[2.0]
)

def load_and_preprocess(path):
    ds  = pydicom.dcmread(path)
    arr = ds.pixel_array.astype(float)
    arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-5)
    pil = Image.fromarray((arr * 255).astype("uint8"))
    return transform(pil).unsqueeze(0)

# 2) Function to reconstruct and display
def show_reconstruction(image_path: str, vae: torch.nn.Module, transform, device):
    # Load & preprocess
    x = load_and_preprocess(image_path)             # grayscale
    x = x.to(device)            # [1,1,H,W], normalized to [-1,+1]
    x_rgb = x.repeat(1, 3, 1, 1)                          # VAE expects 3-channel input

    # VAE forward
    vae.eval()
    with torch.no_grad():
        enc_out     = vae.encode(x_rgb)
        latent_dist = enc_out.latent_dist
        z           = latent_dist.sample() * vae.config.scaling_factor
        dec_out     = vae.decode(z)
        recon_rgb   = dec_out.sample                     # [1,3,H',W']
        recon       = recon_rgb[:, :1, :, :]             # take channel 0 → [1,1,H',W']

    # Postprocess: undo normalization, clamp to [0,1]
    recon_denorm = inv_normalize(recon.cpu().squeeze(0))
    orig_denorm  = inv_normalize(x.cpu().squeeze(0))

    recon_img = torch.clamp(recon_denorm, 0.0, 1.0)
    orig_img  = torch.clamp(orig_denorm,  0.0, 1.0)

    # 3) Plot
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    axes[0].imshow(orig_img.squeeze(), cmap="gray")
    axes[0].set_title("Original")
    axes[0].axis("off")
    axes[1].imshow(recon_img.squeeze(), cmap="gray")
    axes[1].set_title("Reconstruction")
    axes[1].axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
def train_vae_accum(
    model: torch.nn.Module,
    train_loader: DataLoader,
    val_loader:   DataLoader,
    epochs: int = 50,
    lr: float = 1e-4,
    beta: float = 1e-3,
    lambda_perc: float = 1.0,
    lambda_adv: float = 1e-2,
    save_dir: str = "/content/drive/MyDrive/CT Models/VAE",
    accumulate_steps: int = 4,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
):
    model.to(device).train()
    disc = PatchDiscriminator(in_channels=1).to(device).train()
    perceptual = LPIPS(net='vgg').to(device).train()

    opt_G = Adam(model.parameters(), lr=lr)
    opt_D = Adam(disc.parameters(),  lr=lr)
    sched_G = CosineAnnealingLR(opt_G, T_max=epochs)
    sched_D = CosineAnnealingLR(opt_D, T_max=epochs)
    os.makedirs(save_dir, exist_ok=True)

    for ep in range(1, epochs + 1):
        metrics = {'loss':0., 'rec':0., 'kl':0., 'perc':0., 'adv':0.}
        opt_G.zero_grad(); opt_D.zero_grad()

        for batch_i, x in enumerate(train_loader, start=1):
            x = x.to(device)
            x_rgb = x.repeat(1, 3, 1, 1)

            # Encode
            enc_out     = model.encode(x_rgb)
            latent_dist = enc_out.latent_dist
            mu, logvar  = latent_dist.mean, latent_dist.logvar
            z           = latent_dist.sample() * model.config.scaling_factor

            # Decode
            dec_out   = model.decode(z)
            recon_rgb = dec_out.sample
            recon     = recon_rgb[:, :1, :, :]

            # Discriminator
            pred_real = disc(x)
            pred_fake = disc(recon.detach())
            loss_D = 0.5 * (
                F.binary_cross_entropy_with_logits(pred_real, torch.ones_like(pred_real)) +
                F.binary_cross_entropy_with_logits(pred_fake, torch.zeros_like(pred_fake))
            )
            (loss_D / accumulate_steps).backward()

            # Generator
            loss_rec    = F.mse_loss(recon, x)
            loss_kl     = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            loss_perc   = perceptual(recon, x).mean()
            pred_fake_G = disc(recon)
            loss_adv_G  = F.binary_cross_entropy_with_logits(pred_fake_G, torch.ones_like(pred_fake_G))

            loss_G = (
                loss_rec
                + beta        * loss_kl
                + lambda_perc * loss_perc
                + lambda_adv  * loss_adv_G
            )
            (loss_G / accumulate_steps).backward()

            if batch_i % accumulate_steps == 0:
                opt_D.step(); opt_G.step()
                opt_D.zero_grad(); opt_G.zero_grad()

            for k,v in [('loss', loss_G), ('rec', loss_rec), ('kl', loss_kl), ('perc', loss_perc), ('adv', loss_adv_G)]:
                metrics[k] += v.item()

        # final step
        if batch_i % accumulate_steps != 0:
            opt_D.step(); opt_G.step()

        sched_G.step(); sched_D.step()

        n = len(train_loader)
        print(f"Epoch {ep:2d} | Loss {metrics['loss']/n:.4f} | Rec {metrics['rec']/n:.4f} | "
              f"KL {metrics['kl']/n:.4f} | Perc {metrics['perc']/n:.4f} | Adv {metrics['adv']/n:.4f}")

        # save
        torch.save(model.state_dict(), os.path.join(save_dir, f"pt_vae_2_epoch_{ep}.pth"))
        torch.save(disc.state_dict(),  os.path.join(save_dir, f"pt_disc_2_epoch_{ep}.pth"))
        print(f"Saved checkpoints for epoch {ep}")
        show_reconstruction(
            "/content/drive/MyDrive/CT/full_3mm/L067/full_3mm/L067_FD_3_1.CT.0002.0001.2015.12.22.18.12.07.5968.358090401.IMA",
            model,
            transform,
            device
        )

In [None]:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
for p in vae.parameters():
    p.requires_grad = True
vae.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [None]:
old_out = vae.decoder.conv_out
w2 = old_out.weight.data
w2_gray = w2.mean(dim=0, keepdim=True)
new_out = nn.Conv2d(
    in_channels=old_out.in_channels,
    out_channels=1,
    kernel_size=old_out.kernel_size,
    stride=old_out.stride,
    padding=old_out.padding,
    bias=(old_out.bias is not None)
)
new_out.weight.data.copy_(w2_gray)
if old_out.bias is not None:
    new_out.bias.data.fill_(old_out.bias.data.mean())
vae.decoder.conv_out = new_out

In [None]:
opt = torch.optim.Adam(
    filter(lambda p: p.requires_grad, vae.parameters()),
    lr=1e-5
)

In [None]:
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_path = "/content/drive/MyDrive/CT Models/VAE/pt_vae_2_epoch_100.pth"
state_dict = torch.load(ckpt_path, map_location=device)
vae.load_state_dict(state_dict)
vae.to(device).eval()

In [None]:
train_vae_accum(
    vae,
    train_loader,
    val_loader,
    epochs=50,
    lr=1e-4,
    beta=1e-3,
    lambda_perc=1.0,
    lambda_adv=1e-2,
    accumulate_steps=4,
    save_dir="/content/drive/MyDrive/CT Models/VAE"
)

# Inference

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
# ... swap conv_out to 1-channel, load your fine-tuned weights, freeze, etc. ...
vae.to(device).eval()

old_out = vae.decoder.conv_out            # this is your Conv2d(in_ch, 3, k, padding)
w2 = old_out.weight.data                  # shape [3, in_ch, k, k]
# average the RGB filters to make a grayscale output
w2_gray = w2.mean(dim=0, keepdim=True)    # shape [1, in_ch, k, k]

new_out = nn.Conv2d(
    in_channels=old_out.in_channels,
    out_channels=1,
    kernel_size=old_out.kernel_size,
    stride=old_out.stride,
    padding=old_out.padding,
    bias=(old_out.bias is not None)
)
new_out.weight.data.copy_(w2_gray)
if old_out.bias is not None:
    # average the biases too
    new_out.bias.data.fill_(old_out.bias.data.mean())

# replace the single conv
vae.decoder.conv_out = new_out
ckpt_path = "/content/drive/MyDrive/CT Models/VAE/pt_vae_epoch_11.pth"  # or whichever epoch
state_dict = torch.load(ckpt_path, map_location=device)
vae.load_state_dict(state_dict)
vae.to(device).eval()

In [None]:
transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

In [None]:
import matplotlib.pyplot as plt
from torchvision import transforms

# 1) Define an inverse normalization to go from [-1,+1] back to [0,1]
inv_normalize = transforms.Normalize(
    mean=[-1.0],  # x' = (x - mean) / std  →  to invert: x = x' * std + mean
    std=[2.0]
)

def load_and_preprocess(path):
    ds  = pydicom.dcmread(path)
    arr = ds.pixel_array.astype(float)
    arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-5)
    pil = Image.fromarray((arr * 255).astype("uint8"))
    return transform(pil).unsqueeze(0)

# 2) Function to reconstruct and display
def show_reconstruction(image_path: str, vae: torch.nn.Module, transform, device):
    # Load & preprocess
    x = load_and_preprocess(image_path)             # grayscale
    x = x.to(device)            # [1,1,H,W], normalized to [-1,+1]
    x_rgb = x.repeat(1, 3, 1, 1)                          # VAE expects 3-channel input

    # VAE forward
    vae.eval()
    with torch.no_grad():
        enc_out     = vae.encode(x_rgb)
        latent_dist = enc_out.latent_dist
        z           = latent_dist.sample() * vae.config.scaling_factor
        dec_out     = vae.decode(z)
        recon_rgb   = dec_out.sample                     # [1,3,H',W']
        recon       = recon_rgb[:, :1, :, :]             # take channel 0 → [1,1,H',W']

    # Postprocess: undo normalization, clamp to [0,1]
    recon_denorm = inv_normalize(recon.cpu().squeeze(0))
    orig_denorm  = inv_normalize(x.cpu().squeeze(0))

    recon_img = torch.clamp(recon_denorm, 0.0, 1.0)
    orig_img  = torch.clamp(orig_denorm,  0.0, 1.0)

    # 3) Plot
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    axes[0].imshow(orig_img.squeeze(), cmap="gray")
    axes[0].set_title("Original")
    axes[0].axis("off")
    axes[1].imshow(recon_img.squeeze(), cmap="gray")
    axes[1].set_title("Reconstruction")
    axes[1].axis("off")
    plt.tight_layout()
    plt.show()

# 4) Usage example:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# assume `vae` and `transform` are already defined and on the right device
show_reconstruction(
    "/content/drive/MyDrive/CT/full_3mm/L067/full_3mm/L067_FD_3_1.CT.0002.0001.2015.12.22.18.12.07.5968.358090401.IMA",
    vae,
    transform,
    device
)