In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
%cd /content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti



In [None]:
!pip install nibabel tqdm
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118  # For GPU


In [None]:
from zipfile import ZipFile
import os

uploaded_path = '/content/neuroimage_2021_calamiti.zip'
with ZipFile(uploaded_path, 'r') as zip_ref:
    zip_ref.extractall('/content/calamiti')
os.chdir('/content/calamiti')


In [None]:
import os
os.chdir("/content/calamiti")


In [None]:
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118

In [None]:
!pip uninstall -y torchvision


In [None]:
!pip install torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118


In [None]:
import sys
sys.path.append('/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules')


In [None]:
import os
os.kill(os.getpid(), 9)


In [None]:
!pip install numpy==1.24.4


In [None]:
from multiorientation import MultiOrientationImages

dataset_path = '/content/drive/MyDrive/CALAMITI_Project/sample_dataset'
dataset = MultiOrientationImages(dataset_path, data_name='T1w', mode='train')

sample = dataset[0]
print(sample['input'].shape)   # Expecting [1, 3, H, W]
print(sample['target'].shape)  # Expecting [1, H, W]


In [None]:
from multiorientation import MultiOrientationImages

dataset_path = '/content/drive/MyDrive/CALAMITI_Project/sample_dataset'
dataset = MultiOrientationImages(dataset_path, data_name='T1w', mode='train')

sample = dataset[0]
print(sample['input'].shape)   # Should be [1, 3, H, W]
print(sample['target'].shape)  # Should be [1, H, W]


In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
import sys
sys.path.append('/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules')


In [None]:
!pip install nibabel tqdm
!pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
!pip install numpy==1.24.4


In [None]:
!python /content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py


In [None]:
%cd /content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/


In [None]:
!python fusion.py


In [None]:
!rm -f /content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/__init__.py


In [None]:
import sys
import importlib.util

fusion_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py"

spec = importlib.util.spec_from_file_location("fusion", fusion_path)
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)

# Now run the training
fusion.run_training()


In [None]:
if "fusion" in sys.modules:
    del sys.modules["fusion"]


In [None]:
!head -n 20 /content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py


In [None]:
!python /content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py


In [None]:
%cd /content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules
!ls


In [None]:
# Read contents
file_path = '/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py'
with open(file_path, 'r') as file:
    content = file.read()

# Example: Add a print at the end
content += '\nprint("Fusion script updated!")\n'

# Write it back
with open(file_path, 'w') as file:
    file.write(content)


In [None]:
new_code = """
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from tqdm import tqdm
import numpy as np

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
import nibabel as nib

from multiorientation import MultiOrientationImages, custom_collate
from network import FusionNet
from utils import mkdir_p


class FusionNetwork:
    def __init__(self, pretrained_model=None, gpu=-1, data_path="", data_name="", batch_size=1):
        self.data_path = data_path
        self.data_name = data_name
        self.batch_size = batch_size
        self.pretrained_model = pretrained_model

        self.gpu = gpu
        self.device = torch.device("cuda" if gpu >= 0 and torch.cuda.is_available() else "cpu")

        self.fusion_net = FusionNet(in_ch=1, out_ch=1).to(self.device)
        self.checkpoint = None
        self.start_epoch = 0

        if self.pretrained_model is not None:
            self.checkpoint = torch.load(self.pretrained_model, map_location=self.device)
            self.fusion_net.load_state_dict(self.checkpoint['fusion_net'])

    def load_dataset(self, dataset_dir, data_name, batch_size):
        self.train_dataset = MultiOrientationImages(dataset_dir=dataset_dir, data_name=data_name, mode="train")
        self.valid_dataset = MultiOrientationImages(dataset_dir=dataset_dir, data_name=data_name, mode="valid")

        if len(self.train_dataset) == 0:
            print("[WARNING] No training samples found. Skipping training.")
            self.train_loader = None
        else:
            self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=custom_collate)

        if len(self.valid_dataset) == 0:
            print("[WARNING] No validation samples found. Skipping validation.")
            self.valid_loader = None
        else:
            self.valid_loader = DataLoader(self.valid_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=custom_collate)

        print("Train dataset path:", self.train_dataset.data_path)
        print("Valid dataset path:", self.valid_dataset.data_path)

    def initialize_training(self, out_dir, lr):
        self.out_dir = out_dir
        mkdir_p(self.out_dir)
        mkdir_p(os.path.join(out_dir, 'results'))
        mkdir_p(os.path.join(out_dir, 'models'))

        self.l1_loss = nn.L1Loss(reduction='none')
        self.optim_fusion_net = torch.optim.AdamW(self.fusion_net.parameters(), lr=lr)

        if self.checkpoint is not None:
            self.start_epoch = self.checkpoint['epoch']
            self.optim_fusion_net.load_state_dict(self.checkpoint['optim_fusion_net'])

        self.start_epoch += 1

    def train(self, epochs):
        if self.train_loader is None:
            print("[SKIP] Training skipped due to empty dataset.")
            return

        log_file = os.path.join(self.out_dir, 'training_log.txt')

        for epoch in range(self.start_epoch, epochs + 1):
            self.fusion_net.train()
            train_loss_sum = 0.0
            num_train_imgs = 0

            for batch_id, batch in tqdm(enumerate(self.train_loader), total=len(self.train_loader), desc=f"Epoch {epoch} [Train]"):
                if batch is None or not all(k in batch for k in ['input', 'target']):
                    print(f"[SKIP] Missing data in batch {batch_id}")
                    continue

                input_tensor = batch["input"].to(self.device)  # [B, 1, 3, H, W]
                ori_img = batch["target"].to(self.device)     # [B, 1, H, W]

                curr_batch_size = ori_img.size(0)

                self.optim_fusion_net.zero_grad()
                syn_img = self.fusion_net(input_tensor)
                loss = self.cal_loss(syn_img, ori_img)
                loss.backward()
                self.optim_fusion_net.step()

                train_loss_sum += loss.item() * curr_batch_size
                num_train_imgs += curr_batch_size

            avg_train_loss = train_loss_sum / num_train_imgs
            print(f"[Epoch {epoch}] Avg Train Loss: {avg_train_loss:.4f}")

            if self.valid_loader is None:
                print("[SKIP] Validation skipped due to empty dataset.")
                avg_valid_loss = 0.0
            else:
                self.fusion_net.eval()
                valid_loss_sum = 0.0
                num_valid_imgs = 0

                with torch.no_grad():
                    for batch_id, batch in tqdm(enumerate(self.valid_loader), total=len(self.valid_loader), desc=f"Epoch {epoch} [Valid]"):
                        if batch is None or not all(k in batch for k in ['input', 'target']):
                            print(f"[SKIP] Missing data in batch {batch_id}")
                            continue

                        input_tensor = batch["input"].to(self.device)
                        ori_img = batch["target"].to(self.device)
                        curr_batch_size = ori_img.size(0)

                        syn_img = self.fusion_net(input_tensor)
                        loss = self.cal_loss(syn_img, ori_img)

                        valid_loss_sum += loss.item() * curr_batch_size
                        num_valid_imgs += curr_batch_size

                if num_valid_imgs > 0:
                    avg_valid_loss = valid_loss_sum / num_valid_imgs
                else:
                    avg_valid_loss = 0.0

                print(f"[Epoch {epoch}] Avg Valid Loss: {avg_valid_loss:.4f}")

            checkpoint_path = os.path.join(self.out_dir, 'models', f"fusion_epoch{epoch:03d}.pth")
            self.save_model(checkpoint_path, epoch)

            with open(log_file, 'a') as f:
                f.write(f"{epoch},{avg_train_loss:.6f},{avg_valid_loss:.6f}\n")

    def save_model(self, file_name, epoch):
        state = {
            'epoch': epoch,
            'fusion_net': self.fusion_net.state_dict(),
            'optim_fusion_net': self.optim_fusion_net.state_dict()
        }
        torch.save(state, file_name)
        print(f"[SAVED] Model checkpoint saved at {file_name}")

    def test(self, imgs, out_dir, prefix, img_affine, img_hdr, norm=1000):
        self.fusion_net.eval()
        with torch.no_grad():
            imgs = tuple([img.to(self.device) for img in imgs])
            imgs = torch.cat(imgs, dim=1)
            fuse_img = self.fusion_net(imgs)

            img_save = np.array(fuse_img.cpu().squeeze().permute(1, 2, 0).permute(1, 0, 2))
            img_save = img_save * ( norm / 0.25)
            img_save = nib.Nifti1Image(img_save, img_affine, img_hdr)
            file_name = os.path.join(out_dir, f'{prefix}_fusion.nii.gz')
            nib.save(img_save, file_name)

    def cal_loss(self, syn_img, ori_img):
        loss = self.l1_loss(syn_img, ori_img)
        return loss.mean()

def custom_collate(batch):
    batch = [sample for sample in batch if sample is not None]
    if len(batch) == 0:
        return None
    return default_collate(batch)

def run_training():
    class Args:
        dataset_dir = "/content/drive/MyDrive/CALAMITI_Project/sample_dataset/slices/default"
        modality = "T1w"
        output_dir = "./output"
        batch_size = 1
        epochs = 20
        lr = 0.0001
        gpu = 0 if torch.cuda.is_available() else -1
        mode = 'train'
        checkpoint = None

    args = Args()

    network = FusionNetwork(
        pretrained_model=args.checkpoint,
        gpu=args.gpu,
        data_path=args.dataset_dir,
        data_name=args.modality,
        batch_size=args.batch_size
    )

    network.load_dataset(args.dataset_dir, args.modality, args.batch_size)
    network.initialize_training(args.output_dir, lr=args.lr)

    if args.mode == 'train':
        network.train(epochs=args.epochs)
    else:
        print("[SKIP] Only training mode is implemented.")

#no entry point, we'll import and call run_training() manually
"""

with open(file_path, 'w') as f:
    f.write(new_code)


In [None]:
!head -n 15 /content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py


In [None]:
import sys
import importlib.util

# Remove cached version
if "fusion" in sys.modules:
    del sys.modules["fusion"]

fusion_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py"

spec = importlib.util.spec_from_file_location("fusion", fusion_path)
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)

# Run main
fusion.main()


In [None]:
# Clean import (no cache)
if "fusion" in sys.modules:
    del sys.modules["fusion"]

import importlib.util

fusion_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py"
spec = importlib.util.spec_from_file_location("fusion", fusion_path)
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)

# ✅ Run the training function
fusion.run_training()


In [None]:
import sys, importlib.util

# ✅ Clear any old modules (even if not loaded yet)
for mod in ["fusion", "network", "multiorientation", "utils"]:
    if mod in sys.modules:
        del sys.modules[mod]


In [None]:
import sys
module_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules"
if module_path not in sys.path:
    sys.path.append(module_path)



In [None]:
import importlib.util

fusion_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py"

spec = importlib.util.spec_from_file_location("fusion", fusion_path)
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)


In [None]:
import inspect
print(inspect.getsource(fusion.FusionNet))



In [None]:
fusion.run_training()


In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
import importlib.util
import sys

spec = importlib.util.spec_from_file_location("fusion", "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py")
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)

# 🔁 Reload the model from checkpoint
network = fusion.FusionNetwork(
    pretrained_model="./output/models/fusion_epoch020.pth",
    gpu=0 if torch.cuda.is_available() else -1,
    data_path="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    data_name="T1w",
    batch_size=1
)

network.load_dataset(network.data_path, network.data_name, batch_size=1)


Trained output over T1W

In [None]:
# Get one batch
val_batch = next(iter(network.valid_loader))

# Move tensors to GPU/CPU
device = network.device
axial = val_batch['axial'].to(device)
sagittal = val_batch['sagittal'].to(device)
coronal = val_batch['coronal'].to(device)
target = val_batch['target'].to(device)

# Create input tensor and run inference
input_tensor = torch.cat([axial, sagittal, coronal], dim=1)
network.fusion_net.eval()
with torch.no_grad():
    fused_output = network.fusion_net(input_tensor)

# 🖼️ Visualize the result
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.imshow(fused_output[0, 0].cpu().numpy(), cmap='gray')
plt.title("Fused Output")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(target[0, 0].cpu().numpy(), cmap='gray')
plt.title("Ground Truth")
plt.axis("off")

plt.show()


In [None]:
import matplotlib.pyplot as plt
import torch

# Set model to evaluation mode
network.fusion_net.eval()

# Get one batch
val_batch = next(iter(network.valid_loader))
device = network.device

# Extract and move inputs to device
axial = val_batch['axial'].to(device)
sagittal = val_batch['sagittal'].to(device)
coronal = val_batch['coronal'].to(device)
target = val_batch['target'].to(device)

# Fuse inputs
input_tensor = torch.cat([axial, sagittal, coronal], dim=1)
with torch.no_grad():
    fused_output = network.fusion_net(input_tensor)

# Prepare 2D numpy arrays
axial_np = axial[0, 0].cpu().numpy()
sagittal_np = sagittal[0, 0].cpu().numpy()
coronal_np = coronal[0, 0].cpu().numpy()
fused_np = fused_output[0, 0].cpu().numpy()
target_np = target[0, 0].cpu().numpy()

# Plot
plt.figure(figsize=(15, 4))

plt.subplot(1, 5, 1)
plt.imshow(axial_np, cmap='gray')
plt.title("Axial View")
plt.axis('off')

plt.subplot(1, 5, 2)
plt.imshow(sagittal_np, cmap='gray')
plt.title("Sagittal View")
plt.axis('off')

plt.subplot(1, 5, 3)
plt.imshow(coronal_np, cmap='gray')
plt.title("Coronal View")
plt.axis('off')

plt.subplot(1, 5, 4)
plt.imshow(fused_np, cmap='gray')
plt.title("Fused Output")
plt.axis('off')

plt.subplot(1, 5, 5)
plt.imshow(target_np, cmap='gray')
plt.title("Ground Truth")
plt.axis('off')

plt.tight_layout()
plt.show()


In [None]:
fusion.__file__  # May not work directly; alternative:
print(fusion_path)


In [None]:
import sys
sys.path.append("/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules")



In [None]:
from multiorientation import MultiOrientationImages, custom_collate
from network import FusionNet
from utils import mkdir_p


In [None]:
import importlib.util

spec = importlib.util.spec_from_file_location("fusion", "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py")
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)

fusion.run_training()


In [None]:
import matplotlib.pyplot as plt
import torch

# Set your model to evaluation mode
fusion.fusion_net.eval()

# Get one batch from validation set
sample_batch = next(iter(fusion.valid_loader))

# Move inputs to the same device as the model
axial = sample_batch['axial'].to(fusion.device)
sagittal = sample_batch['sagittal'].to(fusion.device)
coronal = sample_batch['coronal'].to(fusion.device)
ori_img = sample_batch['target'].to(fusion.device)

# Concatenate multi-orientation input
input_tensor = torch.cat([axial, sagittal, coronal], dim=1)

# Run model
with torch.no_grad():
    syn_img = fusion.fusion_net(input_tensor)

# Plot fused vs. original for the first sample in the batch
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(syn_img[0, 0].cpu().numpy(), cmap='gray')
axes[0].set_title("Fused Output")
axes[1].imshow(ori_img[0, 0].cpu().numpy(), cmap='gray')
axes[1].set_title("Ground Truth")
plt.tight_layout()
plt.show()


In [None]:
import importlib.util
import sys

# Load the fusion module
spec = importlib.util.spec_from_file_location("fusion", "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py")
fusion_module = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion_module
spec.loader.exec_module(fusion_module)

# 🔁 Train and capture the model instance
fusion_model = fusion_module.run_training()


In [None]:
import matplotlib.pyplot as plt
import torch

fusion_model.fusion_net.eval()
sample_batch = next(iter(fusion_model.valid_loader))

axial = sample_batch['axial'].to(fusion_model.device)
sagittal = sample_batch['sagittal'].to(fusion_model.device)
coronal = sample_batch['coronal'].to(fusion_model.device)
ori_img = sample_batch['target'].to(fusion_model.device)

input_tensor = torch.cat([axial, sagittal, coronal], dim=1)

with torch.no_grad():
    syn_img = fusion_model.fusion_net(input_tensor)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(syn_img[0, 0].cpu().numpy(), cmap='gray')
axes[0].set_title("Fused Output")
axes[1].imshow(ori_img[0, 0].cpu().numpy(), cmap='gray')
axes[1].set_title("Ground Truth")
plt.tight_layout()
plt.show()


In [None]:
!cp ./output/models/fusion_epoch020.pth /content/drive/MyDrive/fusion_epoch020.pth


**FINE TUNING**

In [None]:
import sys
import importlib.util

# ✅ Add module directory to Python path
sys.path.append('/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules')

# 🔁 Load fusion module
spec = importlib.util.spec_from_file_location("fusion", "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py")
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)


In [None]:
from multiorientation import MultiOrientationImages, custom_collate
from network import FusionNet
from utils import mkdir_p


In [None]:
from tqdm.notebook import tqdm


In [None]:
import importlib.util
import sys

sys.path.append('/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules')

spec = importlib.util.spec_from_file_location("fusion", "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py")
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)

# Run fine-tuning
trained_network = fusion.run_finetuning(
    checkpoint_path="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/output/models/fusion_epoch020.pth",
    fine_tune=True
)


In [None]:
import importlib.util
import sys

sys.path.append('/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules')

spec = importlib.util.spec_from_file_location("fusion", "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py")
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)

# Run fine-tuning
trained_network = fusion.run_finetuning(
    checkpoint_path="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/output/models/fusion_epoch020.pth",
    fine_tune=True
)


In [None]:
print("Train batches:", len(trained_network.train_loader))
print("Valid batches:", len(trained_network.valid_loader))


In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

loader = DataLoader(
    trained_network.train_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=0,
    collate_fn=custom_collate
)

print("Checking manual batches:")
for i, batch in enumerate(tqdm(loader)):
    if batch is None:
        print(f"❌ Batch {i} is None")
    else:
        print(f"✅ Batch {i}: keys = {batch.keys()}")
    if i > 4:
        break


In [None]:
sample_loader = next(iter(trained_network.train_loader))

if sample_loader is None:
    print("❌ Batch is None")
else:
    print("✅ Keys in batch:", sample_loader.keys())
    for k, v in sample_loader.items():
        print(f"{k}: shape = {v.shape}")


In [None]:
ds = trained_network.train_dataset
sample = ds[0]

if sample is None:
    print("❌ Sample is None — data likely missing or failed to load")
else:
    print("✅ Sample loaded:", sample.keys())
    for k, v in sample.items():
        print(f"{k}: shape = {v.shape}")


In [None]:
import importlib.util
import sys

# Add the module directory to path if not already
sys.path.append('/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules')

# Reload fusion
spec = importlib.util.spec_from_file_location("fusion", "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py")
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)


In [None]:
trained_network = fusion.run_finetuning(
    checkpoint_path="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/output/models/fusion_epoch020.pth",
    fine_tune=True
)


In [None]:
fusion.run_finetuning(
    checkpoint_path="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/output/models/fusion_epoch020.pth",
    fine_tune=True
)


In [None]:
import importlib.util
import sys

spec = importlib.util.spec_from_file_location("fusion", "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py")
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)

# 🚀 Now this should work
fusion.run_finetuning(fine_tune=True)


In [None]:
import matplotlib.pyplot as plt

# Load loss logs
original_log = "./output/training_log.txt"
finetune_log = "./output_finetune/training_log.txt"

def load_loss(path):
    epochs, train_losses, valid_losses = [], [], []
    with open(path, 'r') as f:
        for line in f:
            e, t, v = line.strip().split(',')
            epochs.append(int(e))
            train_losses.append(float(t))
            valid_losses.append(float(v))
    return epochs, train_losses, valid_losses

e1, tr1, val1 = load_loss(original_log)
e2, tr2, val2 = load_loss(finetune_log)

plt.plot(e1, val1, label="Before Fine-Tuning")
plt.plot(e2, val2, label="After Fine-Tuning")
plt.xlabel("Epoch")
plt.ylabel("Validation Loss")
plt.title("Validation Loss Comparison")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
import torch
import matplotlib.pyplot as plt

# 🔄 Get one batch from validation loader
val_batch = next(iter(network.valid_loader))

# 🎯 Prepare input and target tensors
axial = val_batch['axial'].to(network.device)
sagittal = val_batch['sagittal'].to(network.device)
coronal = val_batch['coronal'].to(network.device)
target = val_batch['target'].to(network.device)

input_tensor = torch.cat([axial, sagittal, coronal], dim=1)

# 🔍 Run inference
network.fusion_net.eval()
with torch.no_grad():
    output = network.fusion_net(input_tensor)

# 📊 Plot the results
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].imshow(output[0, 0].cpu().numpy(), cmap='gray')
axs[0].set_title("Fused Output")
axs[1].imshow(target[0, 0].cpu().numpy(), cmap='gray')
axs[1].set_title("Ground Truth")
plt.tight_layout()
plt.show()


In [None]:
import importlib.util
import sys

spec = importlib.util.spec_from_file_location("fusion", "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py")
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)


In [None]:
import fusion
fusion.run_finetuning()


In [None]:
%cd /content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules

**CYCLEGAN**

In [None]:
import os

folder_path = '/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules'  # change this to your folder
file_path = os.path.join(folder_path, 'cyclegan_fusion.py')

code = '''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import tqdm
import os

from network import FusionNet  # This will act as Generator G
from network import Unet       # This can be used for Generator F
from utils import mkdir_p
from multiorientation import MultiOrientationImages

class Discriminator(nn.Module):
    def __init__(self, in_ch):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_ch, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, 4, 1, 1),
        )

    def forward(self, x):
        return self.model(x)

class CycleGANFusionNetwork:
    def __init__(self, dataset_dir, modality, batch_size=1, lr=2e-4, lambda_cycle=10.0, gpu=0):
        self.device = torch.device("cuda" if torch.cuda.is_available() and gpu >= 0 else "cpu")

        # Generators
        self.G = FusionNet(in_ch=3, out_ch=1).to(self.device)      # Axial/Sagittal/Coronal → Fused
        self.F = Unet(in_ch=1, out_ch=3).to(self.device)            # Fused → Axial/Sagittal/Coronal

        # Discriminators
        self.D_Y = Discriminator(1).to(self.device)
        self.D_X = Discriminator(3).to(self.device)

        # Losses
        self.adv_loss = nn.MSELoss()
        self.cycle_loss = nn.L1Loss()

        # Optimizers
        self.opt_G = torch.optim.Adam(list(self.G.parameters()) + list(self.F.parameters()), lr=lr, betas=(0.5, 0.999))
        self.opt_D_Y = torch.optim.Adam(self.D_Y.parameters(), lr=lr, betas=(0.5, 0.999))
        self.opt_D_X = torch.optim.Adam(self.D_X.parameters(), lr=lr, betas=(0.5, 0.999))

        # Dataset
        self.train_dataset = MultiOrientationImages(dataset_dir=dataset_dir, data_name=modality, mode="train")
        self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=default_collate)
        self.lambda_cycle = lambda_cycle

    def train(self, epochs):
        for epoch in range(1, epochs + 1):
            print(f"\\nEpoch {epoch}/{epochs}")
            for batch in tqdm(self.train_loader, desc="Training"):
                if batch is None or not all(k in batch for k in ['axial', 'sagittal', 'coronal', 'target']):
                    continue

                real_X = torch.cat([batch['axial'], batch['sagittal'], batch['coronal']], dim=1).to(self.device)
                real_Y = batch['target'].to(self.device)

                # ----------------------
                # Train Generators
                # ----------------------
                self.opt_G.zero_grad()

                fake_Y = self.G(real_X)
                rec_X = self.F(fake_Y)

                fake_X = self.F(real_Y)
                rec_Y = self.G(fake_X)

                loss_G_adv_Y = self.adv_loss(self.D_Y(fake_Y), torch.ones_like(self.D_Y(fake_Y)))
                loss_G_adv_X = self.adv_loss(self.D_X(fake_X), torch.ones_like(self.D_X(fake_X)))

                loss_cycle_X = self.cycle_loss(rec_X, real_X)
                loss_cycle_Y = self.cycle_loss(rec_Y, real_Y)

                total_loss_G = loss_G_adv_Y + loss_G_adv_X + self.lambda_cycle * (loss_cycle_X + loss_cycle_Y)
                total_loss_G.backward()
                self.opt_G.step()

                # ----------------------
                # Train Discriminators
                # ----------------------
                self.opt_D_Y.zero_grad()
                loss_D_Y_real = self.adv_loss(self.D_Y(real_Y), torch.ones_like(self.D_Y(real_Y)))
                loss_D_Y_fake = self.adv_loss(self.D_Y(fake_Y.detach()), torch.zeros_like(self.D_Y(fake_Y)))
                loss_D_Y = 0.5 * (loss_D_Y_real + loss_D_Y_fake)
                loss_D_Y.backward()
                self.opt_D_Y.step()

                self.opt_D_X.zero_grad()
                loss_D_X_real = self.adv_loss(self.D_X(real_X), torch.ones_like(self.D_X(real_X)))
                loss_D_X_fake = self.adv_loss(self.D_X(fake_X.detach()), torch.zeros_like(self.D_X(fake_X)))
                loss_D_X = 0.5 * (loss_D_X_real + loss_D_X_fake)
                loss_D_X.backward()
                self.opt_D_X.step()

            print(f"[Epoch {epoch}] Generator Loss: {total_loss_G.item():.4f}, D_Y Loss: {loss_D_Y.item():.4f}, D_X Loss: {loss_D_X.item():.4f}")
'''

with open(file_path, 'w') as f:
    f.write(code)


In [None]:
import sys
import importlib.util

# Step 1: Add the folder path to sys.path
module_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules"
sys.path.append(module_path)

# Step 2: Load the cyclegan_fusion module
spec = importlib.util.spec_from_file_location("cyclegan_fusion", f"{module_path}/cyclegan_fusion.py")
cyclegan_fusion = importlib.util.module_from_spec(spec)
sys.modules["cyclegan_fusion"] = cyclegan_fusion
spec.loader.exec_module(cyclegan_fusion)


**Reload the file**

In [None]:
import importlib.util
import sys

spec = importlib.util.spec_from_file_location("cyclegan_fusion", "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/cyclegan_fusion.py")
cyclegan_fusion = importlib.util.module_from_spec(spec)
sys.modules["cyclegan_fusion"] = cyclegan_fusion
spec.loader.exec_module(cyclegan_fusion)


**Initialize and Train**

In [None]:
model = cyclegan_fusion.CycleGANFusionNetwork(
    dataset_dir='/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default',
    modality='T1w',
    batch_size=1,
    lr=2e-4,
    lambda_cycle=10.0,
    gpu=0  # Set to 0 to use GPU if available
)


In [None]:
model = cyclegan_fusion.CycleGANFusionNetwork(
    dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    modality="T1w",
    batch_size=1,
    lr=0.0002,
    gpu=0
)

model.train(epochs=10)


In [None]:
import sys
import importlib.util

# ✅ Add your module directory to sys.path
module_dir = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules"
if module_dir not in sys.path:
    sys.path.append(module_dir)

# ✅ Now load cyclegan_fusion.py
cyclegan_path = f"{module_dir}/cyclegan_fusion.py"
spec = importlib.util.spec_from_file_location("cyclegan_fusion", cyclegan_path)
cyclegan_fusion = importlib.util.module_from_spec(spec)
sys.modules["cyclegan_fusion"] = cyclegan_fusion
spec.loader.exec_module(cyclegan_fusion)


In [None]:
import os
os.makedirs("/content/drive/MyDrive/CALAMITI_Project/output/models", exist_ok=True)

torch.save(model.G.state_dict(), "/content/drive/MyDrive/CALAMITI_Project/output/models/generator_G.pth")
torch.save(model.F.state_dict(), "/content/drive/MyDrive/CALAMITI_Project/output/models/generator_F.pth")


In [None]:
import shutil
import os

src_folder = "/content/drive/MyDrive/CALAMITI_Project/generated_reconstructions"
dst_folder = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default/train/T1w"

for file in os.listdir(src_folder):
    if file.endswith(".nii.gz"):
        shutil.copy(os.path.join(src_folder, file), os.path.join(dst_folder, file))


In [None]:
import sys

# Add the folder where network.py and multiorientation.py are stored
sys.path.append("/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules")


In [None]:
from network import FusionNet, Unet
from multiorientation import MultiOrientationImages


In [None]:
import os
import torch
import nibabel as nib
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate

from network import FusionNet, Unet
from multiorientation import MultiOrientationImages

# === CONFIGURATION ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_dir = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default"
modality = "T1w"
output_dir = "/content/drive/MyDrive/CALAMITI_Project/generated_reconstructions"
os.makedirs(output_dir, exist_ok=True)

# === LOAD MODELS ===
G = FusionNet(in_ch=3, out_ch=1).to(device)
F = Unet(in_ch=1, out_ch=3).to(device)

G.load_state_dict(torch.load("/content/drive/MyDrive/CALAMITI_Project/output/models/generator_G.pth"))
F.load_state_dict(torch.load("/content/drive/MyDrive/CALAMITI_Project/output/models/generator_F.pth"))

G.eval()
F.eval()

# === HELPER TO SAVE .nii.gz ===
def tensor_to_nii(tensor, reference_nii_path, save_path):
    tensor = tensor.detach().cpu().squeeze().numpy()
    ref = nib.load(reference_nii_path)
    img = nib.Nifti1Image(tensor, affine=ref.affine, header=ref.header)
    nib.save(img, save_path)

# === LOAD DATA ===
dataset = MultiOrientationImages(dataset_dir=dataset_dir, data_name=modality, mode="train")
loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=default_collate)

# === RUN INFERENCE ===
with torch.no_grad():
    for batch in tqdm(loader, desc="Generating Cycle Reconstructions"):
        if batch is None or not all(k in batch for k in ['axial', 'sagittal', 'coronal', 'target']):
            continue

        real_X = torch.cat([batch['axial'], batch['sagittal'], batch['coronal']], dim=1).to(device)
        real_X_path = batch['axial_path'][0]  # assuming you track axial path in dataset

        # A → B → A
        fake_Y = G(real_X)
        cycle_X = F(fake_Y)

        filename = os.path.basename(real_X_path).replace('.nii.gz', '')
        save_path = os.path.join(output_dir, f"{filename}_axial_recon_axial_recon.nii.gz")
        tensor_to_nii(cycle_X, real_X_path, save_path)



In [None]:
import os
import numpy as np
import nibabel as nib
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

# === CONFIGURATION ===
ground_truth_dir = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default/train/T1w"
reconstructed_dir = "/content/drive/MyDrive/CALAMITI_Project/generated_reconstructions"

threshold = 0.5  # Pixel threshold to binarize
all_gt = []
all_pred = []

# === LIST ORIGINAL FILES ===
gt_files = sorted([f for f in os.listdir(ground_truth_dir) if f.endswith("_axial_recon.nii.gz")])

print(f"Found {len(gt_files)} ground truth axial slices.")

# === LOOP THROUGH FILES ===
for gt_file in tqdm(gt_files, desc="Generating confusion data"):
    try:
        gt_path = os.path.join(ground_truth_dir, gt_file)

        # Corresponding cycle-reconstructed file name
        base_name = gt_file.replace("_axial_recon.nii.gz", "")
        recon_name = f"{base_name}_axial_recon_axial_recon.nii.gz"
        recon_path = os.path.join(reconstructed_dir, recon_name)

        if not os.path.exists(recon_path):
            print(f"[WARNING] Missing reconstruction for {gt_file}")
            continue

        # Load images
        gt_img = nib.load(gt_path).get_fdata()
        recon_img = nib.load(recon_path).get_fdata()

        # Normalize if necessary
        gt_img = (gt_img - gt_img.min()) / (gt_img.max() - gt_img.min() + 1e-8)
        recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min() + 1e-8)

        # Binarize
        gt_bin = (gt_img > threshold).astype(np.uint8).flatten()
        recon_bin = (recon_img > threshold).astype(np.uint8).flatten()

        # Store
        all_gt.append(gt_bin)
        all_pred.append(recon_bin)

    except Exception as e:
        print(f"[ERROR] Processing {gt_file}: {e}")
        continue

# === STACK AND GENERATE CONFUSION MATRIX ===
all_gt = np.concatenate(all_gt)
all_pred = np.concatenate(all_pred)

print("Calculating confusion matrix...")
conf_matrix = confusion_matrix(all_gt, all_pred)
print(conf_matrix)


In [None]:
!ls "/content/drive/MyDrive/CALAMITI_Project/generated_reconstructions" | wc -l


In [None]:
!ls "/content/drive/MyDrive/CALAMITI_Project/generated_reconstructions" | head -10


In [None]:
import os

folder = "/content/drive/MyDrive/CALAMITI_Project/generated_reconstructions"

for filename in os.listdir(folder):
    if filename.endswith("_axial_recon_axial_recon_axial_recon.nii.gz"):
        new_filename = filename.replace("_axial_recon_axial_recon_axial_recon", "_axial_recon_axial_recon")
        os.rename(os.path.join(folder, filename), os.path.join(folder, new_filename))

print("✅ Renamed all files successfully.")


In [None]:
for gt_file in tqdm(gt_files, desc="Generating confusion data"):
    try:
        gt_path = os.path.join(ground_truth_dir, gt_file)

        base_name = gt_file.replace("_axial_recon.nii.gz", "")
        recon_name = f"{base_name}_axial_recon_axial_recon.nii.gz"
        recon_path = os.path.join(reconstructed_dir, recon_name)

        if not os.path.exists(recon_path):
            print(f"[WARNING] Missing reconstruction for {gt_file}")
            continue

        # Load and normalize images
        gt_img = nib.load(gt_path).get_fdata()
        recon_img = nib.load(recon_path).get_fdata()

        gt_img = (gt_img - gt_img.min()) / (gt_img.max() - gt_img.min() + 1e-8)
        recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min() + 1e-8)

        # Binarize
        gt_bin = (gt_img > threshold).astype(np.uint8).flatten()
        recon_bin = (recon_img > threshold).astype(np.uint8).flatten()

        # ✅ Only add if both arrays are non-empty
        if gt_bin.size > 0 and recon_bin.size > 0:
            all_gt.append(gt_bin)
            all_pred.append(recon_bin)

    except Exception as e:
        print(f"[ERROR] Processing {gt_file}: {e}")
        continue


In [None]:
if gt_img.shape != recon_img.shape:
    print(f"[WARNING] Shape mismatch for {gt_file}, skipping...")
    continue


In [None]:
# After collecting recon_images and gt_images
all_pred_prob = np.concatenate([img.flatten() for img in recon_images])
all_gt_flat = np.concatenate([img.flatten() for img in gt_images])

# Binarize ground truth (VERY IMPORTANT!!)
threshold_gt = 0.5
all_gt_flat = (all_gt_flat > threshold_gt).astype(np.uint8)

# Now compute PR curve
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

precision, recall, thresholds = precision_recall_curve(all_gt_flat, all_pred_prob)

plt.plot(recall, precision, marker='.')
plt.title('Precision-Recall Curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.grid()
plt.show()


In [None]:
from sklearn.metrics import auc, f1_score

# 1. AUC Calculation
pr_auc = auc(recall, precision)
print(f"\n📈 AUC (Precision-Recall): {pr_auc:.4f}")

# 2. F1 Scores across thresholds
f1_scores = []
for t in thresholds:
    preds = (all_pred_prob >= t).astype(int)
    f1 = f1_score(all_gt_flat, preds)
    f1_scores.append(f1)

best_index = np.argmax(f1_scores)
best_threshold = thresholds[best_index]
best_f1 = f1_scores[best_index]

print(f"🏆 Best Threshold: {best_threshold:.4f}")
print(f"🔥 Best F1 Score: {best_f1:.4f}")

# Optional: Precision & Recall at Best Threshold
from sklearn.metrics import precision_score, recall_score
preds_best = (all_pred_prob >= best_threshold).astype(int)
p = precision_score(all_gt_flat, preds_best)
r = recall_score(all_gt_flat, preds_best)

print(f"✔️ Precision at Best Threshold: {p:.4f}")
print(f"✔️ Recall at Best Threshold:    {r:.4f}")


In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
from tqdm import tqdm
from sklearn.metrics import precision_recall_curve, confusion_matrix, ConfusionMatrixDisplay, f1_score, precision_score, recall_score, auc, roc_curve, roc_auc_score


In [None]:
# === CONFIGURATION ===
ground_truth_dir = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default/train/T1w"
reconstructed_dir = "/content/drive/MyDrive/CALAMITI_Project/generated_reconstructions"

# === LOAD FILES ===
gt_files = sorted([f for f in os.listdir(ground_truth_dir) if f.endswith("_axial_recon.nii.gz")])
recon_images = []
gt_images = []

for gt_file in tqdm(gt_files, desc="Loading GT and recon"):
    try:
        gt_path = os.path.join(ground_truth_dir, gt_file)
        base_name = gt_file.replace("_axial_recon.nii.gz", "")
        recon_name = f"{base_name}_axial_recon_axial_recon.nii.gz"
        recon_path = os.path.join(reconstructed_dir, recon_name)

        if not os.path.exists(recon_path):
            continue

        gt_img = nib.load(gt_path).get_fdata()
        recon_img = nib.load(recon_path).get_fdata()

        if recon_img.shape[0] == 3:
            recon_img = recon_img[0, :, :]  # use axial slice

        # Normalize
        gt_img = (gt_img - gt_img.min()) / (gt_img.max() - gt_img.min() + 1e-8)
        recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min() + 1e-8)

        recon_images.append(recon_img)
        gt_images.append(gt_img)
    except:
        continue


In [None]:
from sklearn.metrics import precision_recall_curve, confusion_matrix, ConfusionMatrixDisplay, f1_score, precision_score, recall_score


In [None]:
from sklearn.metrics import auc, roc_curve, roc_auc_score


In [None]:
# === Subsample full images before flattening ===
np.random.seed(42)
sample_indices = np.random.choice(len(recon_images), size=20, replace=False)

recon_sampled = [recon_images[i] for i in sample_indices]
gt_sampled = [gt_images[i] for i in sample_indices]

# Now flatten
all_pred_prob = np.concatenate([img.flatten() for img in recon_sampled])
all_gt_flat = np.concatenate([img.flatten() for img in gt_sampled])

# Subsample 10% of all pixels from these 20 images
subsample_size = int(0.1 * len(all_gt_flat))
pixel_indices = np.random.choice(len(all_gt_flat), size=subsample_size, replace=False)

all_pred_prob = all_pred_prob[pixel_indices]
all_gt_flat = all_gt_flat[pixel_indices]
all_gt_flat = (all_gt_flat > 0.5).astype(np.uint8)


In [None]:
# === PR Curve + ROC + Confusion Matrix ===
from sklearn.metrics import (
    precision_recall_curve, confusion_matrix, ConfusionMatrixDisplay,
    f1_score, roc_curve, roc_auc_score, auc
)
import matplotlib.pyplot as plt

precision, recall, thresholds = precision_recall_curve(all_gt_flat, all_pred_prob)
pr_auc = auc(recall, precision)

f1_scores = [(f1_score(all_gt_flat, (all_pred_prob >= t).astype(int))) for t in thresholds]
best_index = np.argmax(f1_scores)
best_threshold = thresholds[best_index]

preds_best = (all_pred_prob >= best_threshold).astype(int)
cm = confusion_matrix(all_gt_flat, preds_best)

# === Plot Confusion Matrix ===
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Background", "Region"])
fig, ax = plt.subplots(figsize=(5, 5))
disp.plot(ax=ax, cmap='Blues', values_format='d')
plt.title(f"Confusion Matrix at Best Threshold ({best_threshold:.2f})")
plt.grid(False)
plt.show()

# === Plot ROC Curve ===
fpr, tpr, _ = roc_curve(all_gt_flat, all_pred_prob)
roc_auc = roc_auc_score(all_gt_flat, all_pred_prob)

plt.figure(figsize=(7, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

# === Print Summary ===
print(f"✅ AUC (PR): {pr_auc:.4f}")
print(f"✅ AUC (ROC): {roc_auc:.4f}")
print(f"✅ Best Threshold (max F1): {best_threshold:.4f}")


In [None]:
import os
import numpy as np
import nibabel as nib
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import (
    precision_recall_curve, auc, f1_score, precision_score, recall_score,
    confusion_matrix, ConfusionMatrixDisplay
)

# === CONFIGURATION ===
ground_truth_dir = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default/train/T1w"
reconstructed_dir = "/content/drive/MyDrive/CALAMITI_Project/generated_reconstructions"

# === LOAD SAMPLE FILES (subset for speed) ===
gt_files = sorted([f for f in os.listdir(ground_truth_dir) if f.endswith("_axial_recon.nii.gz")])
np.random.seed(42)
gt_files = np.random.choice(gt_files, size=20, replace=False)

recon_images = []
gt_images = []

for gt_file in tqdm(gt_files, desc="Loading 20 images"):
    try:
        gt_path = os.path.join(ground_truth_dir, gt_file)
        base_name = gt_file.replace("_axial_recon.nii.gz", "")
        recon_name = f"{base_name}_axial_recon_axial_recon.nii.gz"
        recon_path = os.path.join(reconstructed_dir, recon_name)
        if not os.path.exists(recon_path): continue

        gt_img = nib.load(gt_path).get_fdata()
        recon_img = nib.load(recon_path).get_fdata()

        if recon_img.shape[0] == 3:
            recon_img = recon_img[0, :, :]

        # Normalize
        gt_img = (gt_img - gt_img.min()) / (gt_img.max() - gt_img.min() + 1e-8)
        recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min() + 1e-8)

        recon_images.append(recon_img)
        gt_images.append(gt_img)
    except:
        continue

# === FLATTEN & SUBSAMPLE 10% Pixels ===
all_pred_prob = np.concatenate([img.flatten() for img in recon_images])
all_gt_flat = np.concatenate([img.flatten() for img in gt_images])

indices = np.random.choice(len(all_gt_flat), size=int(0.1 * len(all_gt_flat)), replace=False)
all_pred_prob = all_pred_prob[indices]
all_gt_flat = (all_gt_flat[indices] > 0.5).astype(np.uint8)

# === PR CURVE + BEST THRESHOLD ===
precision, recall, thresholds = precision_recall_curve(all_gt_flat, all_pred_prob)
pr_auc = auc(recall, precision)

f1_scores = [(f1_score(all_gt_flat, (all_pred_prob >= t).astype(int))) for t in thresholds]
best_index = np.argmax(f1_scores)
best_threshold = thresholds[best_index]

# === CONFUSION MATRIX ===
preds_best = (all_pred_prob >= best_threshold).astype(int)
cm = confusion_matrix(all_gt_flat, preds_best)

# === DISPLAY CONFUSION MATRIX ===
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Background", "Region"])
fig, ax = plt.subplots(figsize=(5, 5))
disp.plot(ax=ax, cmap='Blues', values_format='d')
plt.title(f"Confusion Matrix at Best Threshold ({best_threshold:.4f})")
plt.grid(False)
plt.show()

# === PRINT METRICS ===
f1 = f1_score(all_gt_flat, preds_best)
prec = precision_score(all_gt_flat, preds_best)
rec = recall_score(all_gt_flat, preds_best)

print(f"📊 Metrics @ Best Threshold ({best_threshold:.4f}):")
print(f"🔹 F1 Score   : {f1:.4f}")
print(f"🔹 Precision  : {prec:.4f}")
print(f"🔹 Recall     : {rec:.4f}")
print(f"🔹 PR AUC     : {pr_auc:.4f}")


In [None]:
# === FLATTEN FOR EVALUATION ===
all_pred_prob = np.concatenate([img.flatten() for img in recon_images])
all_gt_flat = np.concatenate([img.flatten() for img in gt_images])

# 🛑 Subsample
np.random.seed(42)
indices = np.random.choice(len(all_gt_flat), size=int(0.1 * len(all_gt_flat)), replace=False)

all_pred_prob = all_pred_prob[indices]
all_gt_flat = all_gt_flat[indices]

all_gt_flat = (all_gt_flat > 0.5).astype(np.uint8)

# === PR Curve + ROC + Confusion Matrix ===
precision, recall, thresholds = precision_recall_curve(all_gt_flat, all_pred_prob)
pr_auc = auc(recall, precision)

f1_scores = [(f1_score(all_gt_flat, (all_pred_prob >= t).astype(int))) for t in thresholds]
best_index = np.argmax(f1_scores)
best_threshold = thresholds[best_index]

preds_best = (all_pred_prob >= best_threshold).astype(int)
cm = confusion_matrix(all_gt_flat, preds_best)

# === Plot confusion matrix ===
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Background", "Region"])
fig, ax = plt.subplots(figsize=(5, 5))
disp.plot(ax=ax, cmap='Blues', values_format='d')
plt.title(f"Confusion Matrix at Best Threshold ({best_threshold:.2f})")
plt.grid(False)
plt.show()

# === Plot ROC Curve ===
fpr, tpr, _ = roc_curve(all_gt_flat, all_pred_prob)
roc_auc = roc_auc_score(all_gt_flat, all_pred_prob)

plt.figure(figsize=(7, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(True)
plt.show()


In [None]:
import seaborn as sns
import pandas as pd

labels = ["Background", "Target"]
df_cm = pd.DataFrame(cm, index=labels, columns=labels)

sns.heatmap(df_cm, annot=True, fmt='d', cmap="Blues")
plt.title(f"Normalized Confusion Matrix at Threshold {best_threshold:.2f}")
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()


In [None]:
import os
import numpy as np
import nibabel as nib
from tqdm import tqdm

ground_truth_dir = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default/train/T1w"
reconstructed_dir = "/content/drive/MyDrive/CALAMITI_Project/generated_reconstructions"

gt_files = sorted([f for f in os.listdir(ground_truth_dir) if f.endswith("_axial_recon.nii.gz")])

recon_images = []
gt_images = []

for gt_file in tqdm(gt_files):
    try:
        gt_path = os.path.join(ground_truth_dir, gt_file)
        base_name = gt_file.replace("_axial_recon.nii.gz", "")
        recon_name = f"{base_name}_axial_recon_axial_recon.nii.gz"
        recon_path = os.path.join(reconstructed_dir, recon_name)

        if not os.path.exists(recon_path):
            continue

        gt_img = nib.load(gt_path).get_fdata()
        recon_img = nib.load(recon_path).get_fdata()

        if recon_img.shape[0] == 3:
            recon_img = recon_img[0, :, :]  # 👈 Take axial only

        gt_img = (gt_img - gt_img.min()) / (gt_img.max() - gt_img.min() + 1e-8)
        recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min() + 1e-8)

        recon_images.append(recon_img)
        gt_images.append(gt_img)

    except Exception as e:
        continue

# Now for precision-recall curve:
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

all_pred_prob = np.concatenate([img.flatten() for img in recon_images])
all_gt_flat = np.concatenate([img.flatten() for img in gt_images])

precision, recall, thresholds = precision_recall_curve(all_gt_flat, all_pred_prob)

plt.plot(recall, precision, marker='.')
plt.title('Precision-Recall Curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.grid()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, ConfusionMatrixDisplay

# === METRICS ===
accuracy = accuracy_score(all_gt, all_pred)
precision = precision_score(all_gt, all_pred)
recall = recall_score(all_gt, all_pred)
f1 = f1_score(all_gt, all_pred)

print("\n📊 Evaluation Metrics:")
print(f"Accuracy : {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall   : {recall:.4f}")
print(f"F1 Score : {f1:.4f}")

# === VISUALIZATION ===
plt.figure(figsize=(6, 5))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap="Blues", cbar=False,
            xticklabels=["Pred 0", "Pred 1"], yticklabels=["Actual 0", "Actual 1"])
plt.title("Confusion Matrix Heatmap")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np

# === CONFIG ===
gt_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default/train/T1w/SAMPLE_T1w_SUB001_AXIAL_SLICE010_axial_recon.nii.gz"
recon_path = "/content/drive/MyDrive/CALAMITI_Project/generated_reconstructions/SAMPLE_T1w_SUB001_AXIAL_SLICE010_axial_recon_axial_recon.nii.gz"

# === LOAD ===
gt_img = nib.load(gt_path).get_fdata()
recon_img = nib.load(recon_path).get_fdata()

# === Fix recon image if it's 3-channel ===
if recon_img.shape[0] == 3:
    recon_img = np.mean(recon_img, axis=0)

# === Normalize both ===
gt_img = (gt_img - gt_img.min()) / (gt_img.max() - gt_img.min() + 1e-8)
recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min() + 1e-8)

# === Squeeze if needed ===
gt_img = np.squeeze(gt_img)
recon_img = np.squeeze(recon_img)

# === Resize if shape mismatch ===
if gt_img.shape != recon_img.shape:
    from scipy.ndimage import zoom
    zoom_factors = (
        gt_img.shape[0] / recon_img.shape[0],
        gt_img.shape[1] / recon_img.shape[1]
    )
    recon_img = zoom(recon_img, zoom_factors, order=1)

# === PLOT ===
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.imshow(gt_img, cmap='gray')
plt.title("Ground Truth")
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(recon_img, cmap='gray')
plt.title("Reconstructed (CycleGAN)")
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(np.abs(gt_img - recon_img), cmap='hot')
plt.title("Difference Map")
plt.axis('off')

plt.suptitle("CycleGAN Output Visualization", fontsize=14)
plt.tight_layout()
plt.show()


In [None]:
!pip install albumentations


In [None]:
import cv2

# Post-process recon_bin with dilation
recon_reshaped = recon_bin.reshape(gt_img.shape).astype(np.uint8)
kernel = np.ones((3, 3), np.uint8)
recon_dilated = cv2.dilate(recon_reshaped, kernel, iterations=1)

# Flatten again
recon_bin = recon_dilated.flatten()


In [None]:
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

# Convert predictions to float scores (before thresholding)
all_pred_prob = np.concatenate([recon_img.flatten() for recon_img in recon_images])
all_gt_flat = np.concatenate([gt.flatten() for gt in gt_images])

precision, recall, thresholds = precision_recall_curve(all_gt_flat, all_pred_prob)

plt.plot(recall, precision, marker='.')
plt.title("Precision-Recall Curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.grid()
plt.show()


In [None]:
from multiorientation import MultiOrientationImages, custom_collate
from torch.utils.data import DataLoader

test_dataset = MultiOrientationImages(
    dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    data_name="T1w",
    mode="test"
)

test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate)


In [None]:
from multiorientation import MultiOrientationImages, custom_collate
from torch.utils.data import DataLoader


In [None]:
test_dataset = MultiOrientationImages(
    dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    data_name="T1w",
    mode="test"
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=custom_collate
)


In [None]:
model.test(model.test_loader, output_dir="./cyclegan_test_output")


In [None]:
model.test(test_loader, output_dir="./cyclegan_test_output")


In [None]:
cyclegan_model = model
cyclegan_model.test(test_loader, output_dir="./cyclegan_test_output")


In [None]:
cyclegan_model.test(test_loader, output_dir="./cyclegan_test_output")


In [None]:
import importlib.util
import sys

# Path to your cyclegan_fusion.py
cyclegan_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/cyclegan_fusion.py"

spec = importlib.util.spec_from_file_location("cyclegan_fusion", cyclegan_path)
cyclegan_fusion = importlib.util.module_from_spec(spec)
sys.modules["cyclegan_fusion"] = cyclegan_fusion
spec.loader.exec_module(cyclegan_fusion)


In [None]:
import importlib.util
import sys

path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/multiorientation.py"
spec = importlib.util.spec_from_file_location("multiorientation", path)
multiorientation = importlib.util.module_from_spec(spec)
sys.modules["multiorientation"] = multiorientation
spec.loader.exec_module(multiorientation)

from multiorientation import MultiOrientationImages, custom_collate


In [None]:
model = cyclegan_fusion.CycleGANFusionNetwork(
    dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    modality="T1w",
    batch_size=1,
    lr=0.0002,
    gpu=0
)

model.train(epochs=10)


In [None]:
model = cyclegan_fusion.CycleGANFusionNetwork(
    dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    modality="T1w",
    batch_size=1,
    lr=0.0002,
    gpu=0
)


**To Visualize the Output**

In [None]:
import sys
sys.path.append('/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules')  # ✅ this is the directory
from cyclegan_fusion import CycleGANFusionNetwork  # ✅ this is the filename without .py


In [None]:
dataset_dir = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default"  # ✅ Replace with your actual dataset path
modality = "T1w"

In [None]:
model = CycleGANFusionNetwork(dataset_dir=dataset_dir, modality=modality)


In [None]:
import matplotlib.pyplot as plt

n_samples = 5
test_loader = model.test_loader
plt.figure(figsize=(10, 2 * n_samples))

for i, batch in enumerate(test_loader):
    if i >= n_samples:
        break

    axial = batch['axial'].to(model.device)
    sagittal = batch['sagittal'].to(model.device)
    coronal = batch['coronal'].to(model.device)
    target = batch['target'].to(model.device)

    input_tensor = torch.cat([axial, sagittal, coronal], dim=1)
    with torch.no_grad():
        fused = model.G(input_tensor)

    plt.subplot(n_samples, 2, 2 * i + 1)
    plt.imshow(fused[0, 0].cpu().numpy(), cmap='gray')
    plt.title(f"Fused Output #{i+1}")
    plt.axis("off")

    plt.subplot(n_samples, 2, 2 * i + 2)
    plt.imshow(target[0, 0].cpu().numpy(), cmap='gray')
    plt.title(f"Ground Truth #{i+1}")
    plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:
plt.savefig("fusion_vs_target.png", dpi=300, bbox_inches='tight')


In [None]:
import matplotlib.pyplot as plt
import torch

# Step 1: Get a batch from test_loader
val_batch = next(iter(model.test_loader))

# Step 2: Send input slices to device
axial = val_batch['axial'].to(model.device)
sagittal = val_batch['sagittal'].to(model.device)
coronal = val_batch['coronal'].to(model.device)
target = val_batch['target'].to(model.device)

# Step 3: Forward pass through generator
input_tensor = torch.cat([axial, sagittal, coronal], dim=1)
with torch.no_grad():
    fused = model.G(input_tensor)

# Step 4: Convert outputs to numpy
axial_np = axial[0, 0].cpu().numpy()
sagittal_np = sagittal[0, 0].cpu().numpy()
coronal_np = coronal[0, 0].cpu().numpy()
fused_np = fused[0, 0].cpu().numpy()
target_np = target[0, 0].cpu().numpy()

# Step 5: Plot all views
plt.figure(figsize=(15, 4))

plt.subplot(1, 5, 1)
plt.imshow(axial_np, cmap='gray')
plt.title("Axial Input")
plt.axis("off")

plt.subplot(1, 5, 2)
plt.imshow(sagittal_np, cmap='gray')
plt.title("Sagittal Input")
plt.axis("off")

plt.subplot(1, 5, 3)
plt.imshow(coronal_np, cmap='gray')
plt.title("Coronal Input")
plt.axis("off")

plt.subplot(1, 5, 4)
plt.imshow(fused_np, cmap='gray')
plt.title("Fused Output")
plt.axis("off")

plt.subplot(1, 5, 5)
plt.imshow(target_np, cmap='gray')
plt.title("Ground Truth")
plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:
plt.figure(figsize=(15, 3))

plt.subplot(1, 4, 1)
plt.imshow(axial[0, 0].cpu().numpy(), cmap='gray')
plt.title("Axial Input")
plt.axis("off")

plt.subplot(1, 4, 2)
plt.imshow(sagittal[0, 0].cpu().numpy(), cmap='gray')
plt.title("Sagittal Input")
plt.axis("off")

plt.subplot(1, 4, 3)
plt.imshow(coronal[0, 0].cpu().numpy(), cmap='gray')
plt.title("Coronal Input")
plt.axis("off")

plt.subplot(1, 4, 4)
plt.imshow(fused_img, cmap='gray')
plt.title("Fused Output")
plt.axis("off")

plt.tight_layout()
plt.savefig("multi_input_fusion.png", dpi=300, bbox_inches='tight')
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Get one test batch
val_batch = next(iter(model.test_loader))

axial = val_batch['axial'].to(model.device)
sagittal = val_batch['sagittal'].to(model.device)
coronal = val_batch['coronal'].to(model.device)
target = val_batch['target'].to(model.device)

input_tensor = torch.cat([axial, sagittal, coronal], dim=1)

# Run inference
with torch.no_grad():
    fused = model.G(input_tensor)

# Convert to NumPy
fused_img = fused[0, 0].cpu().numpy()
target_img = target[0, 0].cpu().numpy()

# Plot side by side
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(fused_img, cmap='gray')
plt.title("CycleGAN Fused Output")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(target_img, cmap='gray')
plt.title("Ground Truth")
plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:
model.G.load_state_dict(torch.load('/path/to/fused_model.pth', map_location=model.device))
model.G.eval()


In [None]:
# -------------------- 1. Import Required Modules --------------------
from cyclegan_fusion import CycleGANFusionNetwork
import matplotlib.pyplot as plt
import torch

# -------------------- 2. Set Dataset Path and Modality --------------------
dataset_dir = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/cyclegan_fusion.py"  # ✅ Replace with your actual dataset path
modality = "T1w"  # ✅ Replace with the actual modality folder name (T1w, T2w, FLAIR, etc.)

# -------------------- 3. Initialize the CycleGAN Model --------------------
model = CycleGANFusionNetwork(dataset_dir=dataset_dir, modality=modality)

# -------------------- 4. Get a Batch for Visualization --------------------
val_batch = next(iter(model.test_loader))  # ✅ No more AttributeError

# -------------------- 5. Prepare Tensors and Run Model --------------------
axial = val_batch['axial'].to(model.device)
sagittal = val_batch['sagittal'].to(model.device)
coronal = val_batch['coronal'].to(model.device)
target = val_batch['target'].to(model.device)

input_tensor = torch.cat([axial, sagittal, coronal], dim=1)

with torch.no_grad():
    fused = model.G(input_tensor)

# -------------------- 6. Plot Results --------------------
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.imshow(fused[0, 0].cpu().numpy(), cmap='gray')
plt.title("CycleGAN Fused Output")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(target[0, 0].cpu().numpy(), cmap='gray')
plt.title("Ground Truth")
plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:
# 👇 Open and read the current content
with open("/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py", "r") as file:
    lines = file.readlines()

# 🛠 Modify lines here if needed OR print to check content
for i, line in enumerate(lines[-10:]):  # Last 10 lines
    print(f"{i+1}: {line.strip()}")

# OR reassign fixed lines and write back
# with open("path", "w") as file:
#     file.writelines(fixed_lines)


In [None]:
code = """
def run_training():
    class Args:
        dataset_dir = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default"
        modality = "T1w"
        output_dir = "./output"
        batch_size = 1
        epochs = 20
        lr = 0.0001
        gpu = 0 if torch.cuda.is_available() else -1
        checkpoint = None

    args = Args()

    network = FusionNetwork(
        pretrained_model=args.checkpoint,
        gpu=args.gpu,
        data_path=args.dataset_dir,
        data_name=args.modality,
        batch_size=args.batch_size
    )

    network.load_dataset(args.dataset_dir, args.modality, args.batch_size)
    network.initialize_training(args.output_dir, lr=args.lr)
    print(f"🚀 Starting training for {args.epochs} epochs...")
    network.train(epochs=args.epochs)
    print(f"✅ Training complete. Model saved to {args.output_dir}/models/")
    return network
"""

with open("/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py", "a") as f:
    f.write("\n" + code)


In [None]:
import importlib.util
import sys

script_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py"
spec = importlib.util.spec_from_file_location("fusion_sr", script_path)
fusion_sr = importlib.util.module_from_spec(spec)
sys.modules["fusion_sr"] = fusion_sr
spec.loader.exec_module(fusion_sr)

# ✅ Run the training now
network = fusion_sr.run_training()


In [None]:
import matplotlib.pyplot as plt
import torch


In [None]:
val_batch = next(iter(network.valid_loader))

axial = val_batch['axial'].to(network.device)
sagittal = val_batch['sagittal'].to(network.device)
coronal = val_batch['coronal'].to(network.device)
target = val_batch['target'].to(network.device)

input_tensor = torch.cat([axial, sagittal, coronal], dim=1)


In [None]:
network.fusion_net.eval()
with torch.no_grad():
    output = network.fusion_net(input_tensor)


In [None]:
plt.figure(figsize=(12, 4))

# Fused Output
plt.subplot(1, 2, 1)
plt.imshow(output[0, 0].cpu(), cmap='gray')
plt.title("Fused Output")
plt.axis("off")

# Ground Truth
plt.subplot(1, 2, 2)
plt.imshow(target[0, 0].cpu(), cmap='gray')
plt.title("Ground Truth")
plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:
import sys
sys.path.append("/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules")


In [None]:
test_dataset = MultiOrientationImages(
    dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    data_name="T1w",
    mode="test"
)


In [None]:
import os

test_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default/test/T1w"
count = 0

files = sorted(os.listdir(test_path))
axial_files = [f for f in files if "AXIAL_SLICE" in f and f.endswith(".nii.gz") and "_recon" not in f]

print(f"\n🔍 Found {len(axial_files)} axial slice base files.\n")

for f in axial_files:
    base = f.replace(".nii.gz", "")
    axial_recon = f"{base}_axial_recon.nii.gz"
    sagittal_recon = f"{base}_sagittal_recon.nii.gz"
    coronal_recon = f"{base}_coronal_recon.nii.gz"

    recon_files = [axial_recon, sagittal_recon, coronal_recon]
    missing = [r for r in recon_files if not os.path.exists(os.path.join(test_path, r))]

    if not missing:
        count += 1
    else:
        print(f"❌ Missing for {base}: {missing}")

print(f"\n✅ Total valid samples: {count}")


In [None]:
test_dataset = MultiOrientationImages(
    dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    data_name="T1w",
    mode="test"
)


In [None]:
import importlib.util

script_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py"
spec = importlib.util.spec_from_file_location("fusion", script_path)
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)


In [None]:
import importlib.util
import sys

# 🔁 Load your updated fusion.py
script_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py"
spec = importlib.util.spec_from_file_location("fusion", script_path)
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)

# ✅ Test using fine-tuned or final model
fusion.run_testing(
    checkpoint_path="./output/models/fusion_epoch020.pth",  # or best model path
    modality="T1w",
    dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    output_dir="./output_test"
)


In [None]:
network_t2 = fusion_sr.FusionNetwork(
    pretrained_model=None,  # or path to pretrained model if you want to fine-tune
    gpu=0 if torch.cuda.is_available() else -1,
    data_path="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    data_name="T2w",
    batch_size=1
)

network_t2.load_dataset(
    dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    data_name="T2w",
    batch_size=1
)

network_t2.initialize_training(
    out_dir="./output_T2w",
    lr=0.0001
)

# 🚀 Start training
network_t2.train(epochs=20)


In [None]:
!code /content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py


In [None]:
!pip install scikit-image


In [None]:
import importlib.util
spec = importlib.util.spec_from_file_location("multiorientation", "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/multiorientation.py")
multi = importlib.util.module_from_spec(spec)
spec.loader.exec_module(multi)


In [None]:
test_dataset = multi.MultiOrientationImages(dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default", data_name="T1w", mode="test")
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=multi.custom_collate)


In [None]:
# Step 1: Import your updated module
import importlib.util
import sys

module_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py"
spec = importlib.util.spec_from_file_location("fusion", module_path)
fusion = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)

# Step 2: Instantiate the model with checkpoint
network = fusion.FusionNetwork(
    pretrained_model="./output/models/fusion_epoch020.pth",
    gpu=0 if torch.cuda.is_available() else -1,
    data_path="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    data_name="T1w",
    batch_size=1
)

# Step 3: Run the test evaluation
network.run_testing(
    dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    modality="T1w",
    output_dir="./test_results"
)


In [None]:
from multiorientation import MultiOrientationImages, custom_collate
from torch.utils.data import DataLoader

# Load test dataset
test_dataset = MultiOrientationImages(
    dataset_dir="/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/sample_dataset/slices/default",
    data_name="T1w", mode="test"
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate)

# Load model
model = fusion.FusionNetwork(
    pretrained_model="./output/models/fusion_epoch020.pth",
    gpu=0
)

# Run test and compute metrics
model.test_on_loader(test_loader)


In [None]:
import importlib.util
import sys

script_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py"
spec = importlib.util.spec_from_file_location("fusion_sr", script_path)
fusion_sr = importlib.util.module_from_spec(spec)
sys.modules["fusion_sr"] = fusion_sr
spec.loader.exec_module(fusion_sr)

# 🚀 Now run training
network = fusion_sr.run_training()


In [None]:
import importlib.util
import sys

script_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py"
spec = importlib.util.spec_from_file_location("fusion_sr", script_path)
fusion_sr = importlib.util.module_from_spec(spec)
sys.modules["fusion_sr"] = fusion_sr
spec.loader.exec_module(fusion_sr)


In [None]:
import sys
import importlib.util

# ✅ Step 1: Add the module directory to sys.path
sys.path.append("/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules")

# ✅ Step 2: Load fusion.py as a module
script_path = "/content/drive/MyDrive/CALAMITI_Project/neuroimage_2021_calamiti/code/modules/fusion.py"
spec = importlib.util.spec_from_file_location("fusion", script_path)
fusion_sr = importlib.util.module_from_spec(spec)
sys.modules["fusion"] = fusion
spec.loader.exec_module(fusion)

# ✅ Step 3: Now run training (optional)
network = fusion.run_training()


Cleaning fo github part

In [None]:
pip install nbstripout


In [None]:
!nbstripout CALAMITI_NNK.ipynb -o CALAMITI_NNK_clean.ipynb