**Training + Inference**

In [None]:
"""
Conditional Diffusion (DDPM) for Brightfield -> Red/Green
...
"""

# === INSTALL / SETUP ===
try:
    import diffusers
    import torchvision
    import accelerate
except ImportError:
    print("Installing required packages...")
    !pip install -q diffusers transformers accelerate safetensors torchvision scikit-image tqdm

# === IMPORTS ===
import os, json, random
from glob import glob
from tqdm.auto import tqdm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models # For VGG

from diffusers import UNet2DModel, DDPMScheduler
from torch.optim.lr_scheduler import CosineAnnealingLR # New Scheduler

# === PATHS ===
# Output directory setup
BASE = '/content/drive/MyDrive/mayank/Brightfield vs Fluorescent Staining Dataset'
OUT_DIR = '/content/drive/MyDrive/mayank/conditional_diffusion_outputs_v4'

os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUT_DIR, 'samples'), exist_ok=True)
os.makedirs(os.path.join(OUT_DIR, 'models'), exist_ok=True)

# === 0) PERCEPTUAL LOSS SETUP (VGG) ===
class PerceptualLoss(torch.nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
        self.slice = torch.nn.Sequential()
        # Use first 9 layers of VGG for perceptual similarity
        for x in range(9):
            self.slice.add_module(str(x), vgg[x])
        for param in self.slice.parameters():
            param.requires_grad = False  # Freeze VGG

    def forward(self, pred, target):
        # Convert 1-channel image to 3-channel for VGG
        pred_3ch = pred.repeat(1, 3, 1, 1)
        target_3ch = target.repeat(1, 3, 1, 1)

        # Bring to VGG range [0,1]
        pred_norm = (pred_3ch + 1) / 2
        target_norm = (target_3ch + 1) / 2

        # Compute L1 between VGG feature maps
        return F.l1_loss(self.slice(pred_norm), self.slice(target_norm))

# === 1) LIST SUBFOLDERS & SPLIT ===
# Auto-discover dataset folders and split into train/val/test
all_subfolders = sorted([d for d in os.listdir(BASE) if os.path.isdir(os.path.join(BASE, d))])
test_folders = [f for f in all_subfolders if 'Set_3' in f]
trainval_folders = [f for f in all_subfolders if f not in test_folders]

def discover_triplets_in_folders(base_dir, folders_list):
    folder_items = {}
    for sf in folders_list:
        sfp = os.path.join(base_dir, sf)
        if not os.path.exists(sfp): continue
        files = sorted(glob(os.path.join(sfp, '*')))
        grouping = {}

        # Group BF/Red/Green by filename stem
        for f in files:
            name = os.path.splitext(os.path.basename(f))[0]
            if name[-1].isdigit():
                suffix = name[-1]
                if name[-2] == '_': key = name[:-2]
                else: key = name[:-1]
                grouping.setdefault(key, {})[suffix] = f

        # Collect valid triplets only
        items = []
        for k, v in grouping.items():
            if '0' in v and '1' in v and '2' in v:
                items.append((v['0'], v['1'], v['2'], sf))
        folder_items[sf] = sorted(items)
    return folder_items

folderwise_triplets = discover_triplets_in_folders(BASE, trainval_folders + test_folders)
train_items = []
val_items = []
test_items = []
random.seed(42)

# Per-folder 80/20 split
for folder, items in folderwise_triplets.items():
    if folder in test_folders:
        test_items.extend(items)
    else:
        n_total = len(items)
        n_train = int(0.8 * n_total)
        random.shuffle(items)
        train_items.extend(items[:n_train])
        val_items.extend(items[n_train:])

# Shuffle final lists
random.shuffle(train_items)
random.shuffle(val_items)
random.shuffle(test_items)

# Fallback for empty datasets
if len(train_items) == 0:
    print("WARNING: No train data found. Using dummy data.")
    train_items = [("dummy_bf.png", "dummy_g.png", "dummy_r.png", "folder")] * 10
if len(test_items) == 0:
    test_items = [("dummy_bf.png", "dummy_g.png", "dummy_r.png", "folder")] * 2

print(f"Train: {len(train_items)} | Val: {len(val_items)} | Test: {len(test_items)}")

# === 2) DATASET ===
class BFtoFluoDataset(Dataset):
    def __init__(self, triplets, out_size=256, augment=False):
        self.triplets = triplets
        self.out_size = out_size
        self.augment = augment
        self.is_dummy = "dummy" in triplets[0][0] if triplets else False

        # Normalize BF and Fluorescence to [-1,1]
        self.tf_base = transforms.Compose([
            transforms.Resize((out_size, out_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

        # Simple augmentation for BF + target
        self.aug_transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(15)
        ])

    def __len__(self):
        return len(self.triplets) * 2  # Each triplet gives Red + Green

    def __getitem__(self, idx):
        triplet_idx = idx // 2
        is_red = (idx % 2 == 0)  # Even index = Red, Odd = Green

        # Dummy fallback for empty dataset scenario
        if self.is_dummy:
            return torch.randn(3, 256, 256), torch.randn(1, 256, 256), torch.tensor([1.0, 0.0])

        bf_path, green_path, red_path, _ = self.triplets[triplet_idx]

        try:
            bf = Image.open(bf_path).convert('RGB')
            target = Image.open(red_path if is_red else green_path).convert('L')
        except Exception:
            # Bad/missing image case
            return torch.zeros(3, self.out_size, self.out_size), torch.zeros(1, self.out_size, self.out_size), torch.zeros(2)

        # Apply identical augmentations
        if self.augment:
            seed = np.random.randint(0, 99999)
            random.seed(seed); torch.manual_seed(seed)
            bf = self.aug_transforms(bf)
            random.seed(seed); torch.manual_seed(seed)
            target = self.aug_transforms(target)

        bf_tensor = self.tf_base(bf)
        target_tensor = self.tf_base(target)

        # One-hot condition: [1,0] red, [0,1] green
        cond = torch.tensor([1.0, 0.0]) if is_red else torch.tensor([0.0, 1.0])

        return bf_tensor, target_tensor, cond

# Create Dataloaders
train_dataset = BFtoFluoDataset(train_items, augment=True)
val_dataset   = BFtoFluoDataset(val_items)
test_dataset  = BFtoFluoDataset(test_items)

train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, num_workers=2, pin_memory=True, persistent_workers=True)
val_loader   = DataLoader(val_dataset, batch_size=24, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)

# === 3) CONDITIONAL DIFFUSION MODEL ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# UNet with attention blocks for better structure preservation
unet = UNet2DModel(
    sample_size=256,
    in_channels=6, # noisy(1) + BF(3) + cond(2)
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(96, 192, 384, 768),  # Wider network
    down_block_types=(
        "DownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
    ),
)

unet.to(device)

scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="epsilon")

# AdamW optimizer + Cosine LR schedule
optimizer = torch.optim.AdamW(unet.parameters(), lr=3e-4, weight_decay=1e-5)
T_max = 200 * len(train_loader)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=T_max, eta_min=1e-6)

# Init perceptual loss
perceptual_loss_fn = PerceptualLoss().to(device)

# === 4) EMA HELPER ===
class EMA:
    def __init__(self, model, decay=0.995):
        self.model = model
        self.decay = decay
        # Store shadow weights
        self.shadow = {name: param.clone().detach() for name, param in model.named_parameters() if param.requires_grad}

    def update(self):
        # Update running average of weights
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = self.decay * self.shadow[name] + (1.0 - self.decay) * param.data

    def apply_shadow(self):
        # Load EMA weights into model
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.shadow[name])

ema = EMA(unet)

# ---
## 5. Training Loop
print(f"Starting training on {device}...")

num_epochs = 200
best_val_loss = float('inf')
history = {'train':[], 'val':[]}

L1_WEIGHT = 1.0
PERCEPTUAL_WEIGHT = 0.01

print(f"Loss Weights: L1={L1_WEIGHT}, Perceptual={PERCEPTUAL_WEIGHT}")

history_path = os.path.join(OUT_DIR, 'training_history.json')

for epoch in range(1, num_epochs+1):
    unet.train()
    running_train = 0.0
    n_train_steps = 0

    # === TRAINING ===
    for bf, tgt, cond in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
        bf, tgt, cond = bf.to(device), tgt.to(device), cond.to(device)

        # Sample noise level t
        noise = torch.randn_like(tgt)
        bs = tgt.shape[0]
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (bs,), device=device).long()

        # Forward diffusion: add DDPM noise
        noisy_target = scheduler.add_noise(tgt, noise, timesteps)

        # Build condition map
        cond_map = cond[:,:,None,None].expand(-1, -1, bf.shape[2], bf.shape[3])

        # 6-channel model input
        model_input = torch.cat([noisy_target, bf, cond_map], dim=1)

        # Predict noise
        noise_pred = unet(model_input, timesteps).sample

        # L1 noise prediction loss
        noise_l1_loss = F.l1_loss(noise_pred, noise)

        # Predict x0 for perceptual loss
        alpha_t = scheduler.alphas_cumprod[timesteps].view(-1, 1, 1, 1)
        sqrt_alpha_t = alpha_t.sqrt()
        sqrt_one_minus_alpha_t = (1.0 - alpha_t).sqrt()
        x0_pred = (noisy_target - sqrt_one_minus_alpha_t * noise_pred) / sqrt_alpha_t

        perceptual_loss = perceptual_loss_fn(x0_pred, tgt)

        total_loss = (L1_WEIGHT * noise_l1_loss) + (PERCEPTUAL_WEIGHT * perceptual_loss)

        # Backprop + step
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        lr_scheduler.step()
        ema.update()  # EMA update

        running_train += total_loss.item()
        n_train_steps += 1

    avg_train_loss = running_train / n_train_steps

    # === VALIDATION ===
    unet.eval()
    running_val = 0.0
    n_val_steps = 0
    with torch.no_grad():
        for bf, tgt, cond in val_loader:
            bf, tgt, cond = bf.to(device), tgt.to(device), cond.to(device)

            noise = torch.randn_like(tgt)
            timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (bf.shape[0],), device=device).long()
            noisy_target = scheduler.add_noise(tgt, noise, timesteps)

            cond_map = cond[:,:,None,None].expand(-1, -1, bf.shape[2], bf.shape[3])
            model_input = torch.cat([noisy_target, bf, cond_map], dim=1)

            noise_pred = unet(model_input, timesteps).sample
            noise_l1_loss = F.l1_loss(noise_pred, noise)

            # Reconstruct x0
            alpha_t = scheduler.alphas_cumprod[timesteps].view(-1, 1, 1, 1)
            sqrt_alpha_t = alpha_t.sqrt()
            sqrt_one_minus_alpha_t = (1.0 - alpha_t).sqrt()
            x0_pred = (noisy_target - sqrt_one_minus_alpha_t * noise_pred) / sqrt_alpha_t

            perceptual_loss = perceptual_loss_fn(x0_pred, tgt)
            total_loss = (L1_WEIGHT * noise_l1_loss) + (PERCEPTUAL_WEIGHT * perceptual_loss)

            running_val += total_loss.item()
            n_val_steps += 1

    avg_val_loss = running_val / n_val_steps

    # Save loss history
    history['train'].append(avg_train_loss)
    history['val'].append(avg_val_loss)

    try:
        with open(history_path, 'w') as f:
            json.dump(history, f)
    except Exception as e:
        print(f"⚠️ WARNING: Failed to save loss history file. Error: {e}")

    print(f"Epoch {epoch}/{num_epochs} | LR: {optimizer.param_groups[0]['lr']:.2e} | Train Loss: {avg_train_loss:.5f} | Val Loss: {avg_val_loss:.5f}")

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save({
            'epoch': epoch,
            'model': unet.state_dict(),
            'ema': ema.shadow,
            'optimizer': optimizer.state_dict()
        }, os.path.join(OUT_DIR, 'models', 'best_model.pt'))
        print(">>> Saved New Best Model")

# End Training
print("\nTraining loop finished.")
print(f"History saved at: {history_path}")

# ---
## 6. Inference
print("\n=== Starting Final Inference ===")

def unnorm(x):
    """Convert [-1,1] back to [0,1]"""
    return (x.clamp(-1, 1) + 1) / 2

def inference_loop(bf_img, condition_vec, model, sched, device):
    """Full DDPM reverse process for one image"""
    img = torch.randn((1, 1, 256, 256)).to(device)  # Start from noise
    cond_map = condition_vec[:,:,None,None].expand(-1, -1, 256, 256).to(device)
    bf_img = bf_img.to(device)

    sched.set_timesteps(1000)
    for t in sched.timesteps:
        model_input = torch.cat([img, bf_img, cond_map], dim=1)
        with torch.no_grad():
            noise_pred = model(model_input, t).sample
        img = sched.step(noise_pred, t, img).prev_sample  # DDPM update

    return img[0]

def save_color_panel(bf, red_gt, red_pred, green_gt, green_pred, fname):
    """Save 5-panel visualization (BF, GTs, predictions)"""
    bf_unnorm = unnorm(bf).permute(1,2,0).cpu().numpy()
    r_gt = unnorm(red_gt).permute(1,2,0).cpu().numpy()
    r_pred = unnorm(red_pred).permute(1,2,0).cpu().numpy()
    g_gt = unnorm(green_gt).permute(1,2,0).cpu().numpy()
    g_pred = unnorm(green_pred).permute(1,2,0).cpu().numpy()

    fig, axs = plt.subplots(1, 5, figsize=(20, 4))

    axs[0].imshow(bf_unnorm); axs[0].set_title("BF Input")
    axs[1].imshow(r_gt, cmap='Reds_r'); axs[1].set_title("Red GT")
    axs[2].imshow(r_pred, cmap='Reds_r'); axs[2].set_title("Red Pred")
    axs[3].imshow(g_gt, cmap='Greens_r'); axs[3].set_title("Green GT")
    axs[4].imshow(g_pred, cmap='Greens_r'); axs[4].set_title("Green Pred")

    for ax in axs: ax.axis('off')
    plt.tight_layout()
    plt.savefig(fname)
    plt.close()

# Load best model + EMA weights
model_path = os.path.join(OUT_DIR, 'models', 'best_model.pt')
SAMPLES_OUT = os.path.join(OUT_DIR, 'samples', 'final_5_panel_color_v3')
os.makedirs(SAMPLES_OUT, exist_ok=True)

if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location=device)
    unet.load_state_dict(checkpoint['model'])

    class SimpleEMA:
        def __init__(self, model):
            self.model = model
        def apply_shadow(self, shadow_dict):
            for name, param in self.model.named_parameters():
                if name in shadow_dict:
                    param.data.copy_(shadow_dict[name])

    ema_loader = SimpleEMA(unet)
    ema_loader.apply_shadow(checkpoint['ema'])
    unet.eval()
    print("Loaded EMA model for inference.")

    indices_to_test = range(0, len(test_dataset), 2)
    max_samples = 200

    with torch.no_grad():
        for i in tqdm(indices_to_test, total=len(test_dataset)//2, desc="Generating Test Samples"):

            bf_r, tgt_r, cond_r = test_dataset[i]
            bf_g, tgt_g, cond_g = test_dataset[i+1]

            # Generate red + green separately
            pred_r = inference_loop(bf_r.unsqueeze(0), cond_r.unsqueeze(0), unet, scheduler, device)
            pred_g = inference_loop(bf_g.unsqueeze(0), cond_g.unsqueeze(0), unet, scheduler, device)

            out_path = os.path.join(SAMPLES_OUT, f"final_result_{i//2:04d}.png")
            save_color_panel(bf_r, tgt_r, pred_r, tgt_g, pred_g, out_path)

    print(f"\nInference complete. Saved to: {SAMPLES_OUT}")
else:
    print("Model not found — train first.")

print("\nAll done.")


**Inference Only**

In [None]:
"""
Inference-Only Script (No GT needed)
Generates Red & Green fluorescence from Brightfield images
for any images uploaded to /content/
"""

import os
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torchvision.transforms as transforms
from diffusers import UNet2DModel, DDPMScheduler

# -------------------------
# USER INPUTS
# -------------------------
NEW_FOLDER = "/content"  # folder with BF images
OUT_DIR = "/content/output_fluorescence"
MODEL_PATH = "/content/drive/MyDrive/mayank/conditional_diffusion_outputs_v4/models/best_model.pt"

os.makedirs(OUT_DIR, exist_ok=True)

# -------------------------
# DEVICE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------
# LOAD MODEL
# -------------------------
unet = UNet2DModel(
    sample_size=256,
    in_channels=6,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(96, 192, 384, 768),
    down_block_types=(
        "DownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
    ),
)

scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_schedule="linear",
    prediction_type="epsilon"
)

checkpoint = torch.load(MODEL_PATH, map_location=device)
unet.load_state_dict(checkpoint['model'])
unet.to(device)
unet.eval()
print("Loaded trained UNet model.")

# Apply EMA if present
if 'ema' in checkpoint:
    class SimpleEMA:
        def __init__(self, model):
            self.model = model
        def apply_shadow(self, shadow_dict):
            for name, param in self.model.named_parameters():
                if name in shadow_dict:
                    param.data.copy_(shadow_dict[name])
    ema_loader = SimpleEMA(unet)
    ema_loader.apply_shadow(checkpoint['ema'])
    print("Applied EMA weights.")

# -------------------------
# IMAGE TRANSFORMS
# -------------------------
tf = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])   # BF is RGB
])

# -------------------------
# UTILS
# -------------------------
def unnorm(x):
    return (x.clamp(-1, 1) + 1) / 2

def inference_loop(bf_img, condition_vec, model, sched):
    """Runs full DDPM reverse pass."""
    img = torch.randn((1, 1, 256, 256)).to(device)
    cond_map = condition_vec[:,:,None,None].expand(-1, -1, 256, 256).to(device)
    bf_img = bf_img.to(device)

    sched.set_timesteps(1000)
    for t in sched.timesteps:
        model_input = torch.cat([img, bf_img, cond_map], dim=1)
        with torch.no_grad():
            noise_pred = model(model_input, t).sample
        img = sched.step(noise_pred, t, img).prev_sample
    return img[0]

# -------------------------
# LOAD BF IMAGES
# -------------------------
def load_bf_images(folder):
    IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp")
    bf_files = []
    for ext in IMAGE_EXTS:
        bf_files.extend(glob(os.path.join(folder, f"*{ext}")))
    return sorted(bf_files)

bf_images = load_bf_images(NEW_FOLDER)
print(f"Found {len(bf_images)} brightfield images.")

# -------------------------
# RUN INFERENCE
# -------------------------
SAVE_DIR = os.path.join(OUT_DIR, "generated")
os.makedirs(SAVE_DIR, exist_ok=True)

for idx, bf_path in enumerate(tqdm(bf_images, desc="Generating Fluorescence")):
    # Load BF image
    bf = Image.open(bf_path).convert("RGB")
    bf_t = tf(bf).unsqueeze(0).to(device)

    # One-hot conditions for Red and Green
    cond_r = torch.tensor([[1.0, 0.0]], device=device)
    cond_g = torch.tensor([[0.0, 1.0]], device=device)

    # Predict Red & Green channels
    pred_r = inference_loop(bf_t, cond_r, unet, scheduler)
    pred_g = inference_loop(bf_t, cond_g, unet, scheduler)

    # Save 3-panel image (BF + Red Pred + Green Pred)
    bf_unnorm = unnorm(bf_t[0]).permute(1,2,0).cpu().numpy()
    r_pred = unnorm(pred_r).squeeze().cpu().numpy()
    g_pred = unnorm(pred_g).squeeze().cpu().numpy()

    fig, axs = plt.subplots(1, 3, figsize=(15,5))
    axs[0].imshow(bf_unnorm); axs[0].set_title("BF")
    axs[1].imshow(r_pred, cmap="Reds_r"); axs[1].set_title("Red Pred")
    axs[2].imshow(g_pred, cmap="Greens_r"); axs[2].set_title("Green Pred")
    for a in axs: a.axis("off")
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, f"result_{idx:04d}.png"))
    plt.close()

print(f"\nDONE! All predicted fluorescence saved in:\n{SAVE_DIR}")
#

Loaded trained UNet model.
Applied EMA weights.
Found 3 brightfield images.


Generating Fluorescence:   0%|          | 0/3 [00:00<?, ?it/s]