In [2]:
import os
import re
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
from datetime import datetime

# === Dataset to load .npy volumes ===
class BraTSDataset(Dataset):
    def __init__(self, npy_dir):
        self.npy_dir = npy_dir
        self.npy_files = [f for f in os.listdir(npy_dir) if f.endswith('.npy')]
        self.npy_files.sort()
    
    def __len__(self):
        return len(self.npy_files)
    
    def __getitem__(self, idx):
        npy_path = os.path.join(self.npy_dir, self.npy_files[idx])
        volume = np.load(npy_path).astype(np.float32)

        # If volume has 4 dims (e.g., 128x128x128x3), convert to grayscale
        if volume.ndim == 4 and volume.shape[-1] == 3:
            volume = np.mean(volume, axis=-1)

        # Normalize and add channel dimension
        volume = (volume - volume.min()) / (volume.max() - volume.min() + 1e-8)
        volume = torch.tensor(volume).unsqueeze(0)  # shape: (1, D, H, W)
        return volume, self.npy_files[idx]

# === Generator Network ===
class Generator3D(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()
        self.latent_dim = latent_dim
        self.fc = nn.Linear(latent_dim, 64 * 16 * 16 * 16)
        self.gen = nn.Sequential(
            nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(True),
            nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(True),
            nn.ConvTranspose3d(16, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, 64, 16, 16, 16)
        return self.gen(x)

# === Training Loop ===
def train_gon(generator, dataloader, epochs=10, lr=1e-4, device='cuda'):
    optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
    criterion = nn.MSELoss()
    generator.to(device)

    for epoch in range(epochs):
        generator.train()
        total_loss = 0
        for real_volumes, _ in dataloader:
            real_volumes = real_volumes.to(device)
            z = torch.randn(real_volumes.size(0), generator.latent_dim).to(device)
            fake_volumes = generator(z)
            loss = criterion(fake_volumes, real_volumes)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * real_volumes.size(0)

        avg_loss = total_loss / len(dataloader.dataset)
        print(f" Epoch [{epoch+1}/{epochs}] | Loss: {avg_loss:.6f}")

# === Save Generated .nii Volume ===
def save_generated_nii(volume_tensor, filename, npy_filename, nii_root_dir):
    volume_np = volume_tensor.squeeze().cpu().detach().numpy()

    match = re.search(r'image_(\d+)\.npy', npy_filename)
    if not match:
        print(f" Could not extract index from filename: {npy_filename}")
        return
    
    index = match.group(1)
    subject_folder = f"BraTS20_Training_{index}"
    subject_path = os.path.join(nii_root_dir, subject_folder)

    nii_file = None
    for modality in ['_t1.nii.gz', '_t1.nii', '_flair.nii.gz', '_flair.nii']:
        possible_path = os.path.join(subject_path, subject_folder + modality)
        if os.path.exists(possible_path):
            nii_file = possible_path
            break

    if nii_file:
        ref_nii = nib.load(nii_file)
        new_nii = nib.Nifti1Image(volume_np, affine=ref_nii.affine, header=ref_nii.header)
        print(f" Using affine from: {nii_file}")
    else:
        print(f" No reference NIfTI found in {subject_path}. Using identity affine.")
        new_nii = nib.Nifti1Image(volume_np, affine=np.eye(4))

    nib.save(new_nii, filename)
    print(f" Saved generated volume as: {filename}")

# === Main ===
if __name__ == "__main__":
    npy_dir = r"C:\Users\prajw\Downloads\MP_Dataset\BraTS2020_TrainingData\input_data_128\train\images"
    nii_root_dir = r"C:\Users\prajw\Downloads\MP_Dataset\BraTS2020_TrainingData\MICCAI_BraTS2020_TrainingData"

    dataset = BraTSDataset(npy_dir)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f" Using device: {device}")
    generator = Generator3D(latent_dim=128)

    #  Train
    train_gon(generator, dataloader, epochs=10, lr=1e-4, device=device)

    #  Generate & Save
    generator.eval()
    z = torch.randn(1, 128).to(device)
    fake_volume = generator(z)

    _, npy_filename = dataset[0]
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_nii_path = f"generated_brain_{timestamp}.nii"
    save_generated_nii(fake_volume, output_nii_path, npy_filename, nii_root_dir)

🖥️ Using device: cpu
📘 Epoch [1/10] | Loss: 0.095786
📘 Epoch [2/10] | Loss: 0.071323
📘 Epoch [3/10] | Loss: 0.052491
📘 Epoch [4/10] | Loss: 0.034203
📘 Epoch [5/10] | Loss: 0.023184
📘 Epoch [6/10] | Loss: 0.018922
📘 Epoch [7/10] | Loss: 0.016861
📘 Epoch [8/10] | Loss: 0.015597
📘 Epoch [9/10] | Loss: 0.014848
📘 Epoch [10/10] | Loss: 0.014398
⚠️ No reference NIfTI found in C:\Users\prajw\Downloads\MP_Dataset\BraTS2020_TrainingData\MICCAI_BraTS2020_TrainingData\BraTS20_Training_0. Using identity affine.
✅ Saved generated volume as: generated_brain_20250608_152113.nii


In [4]:
import napari
import nibabel as nib
import numpy as np

# Path to your generated .nii file
nii_path = r"generated_brain_20250608_152113.nii"

# Load the volume
nii_img = nib.load(nii_path)
volume = nii_img.get_fdata()

# Optional: normalize for viewing
volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume) + 1e-8)

# Launch napari viewer
viewer = napari.Viewer()
viewer.add_image(volume, name='Generated Brain', colormap='gray')
napari.run()