<a href="https://colab.research.google.com/github/melon771/IMG_TO_PXL/blob/main/medical_SOC_capstone_Proj.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("awsaf49/brats2020-training-data")

print("Path to dataset files:", path)

In [None]:
import torch
import torch.nn as nn

# Generator
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=1):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, nc=1, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np
from PIL import Image
import nibabel as nib
import warnings
import zipfile
import tempfile
import shutil
warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Hyperparameters
batch_size = 64
nz = 100
epochs = 50
lr = 0.0002
image_size = 128

# Generator Network
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=1):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size: (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size: (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size: (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size: (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, ngf // 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf // 2),
            nn.ReLU(True),
            # state size: (ngf//2) x 64 x 64
            nn.ConvTranspose2d(ngf // 2, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size: (nc) x 128 x 128
        )

    def forward(self, input):
        return self.main(input)

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, nc=1, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input is (nc) x 128 x 128
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf) x 64 x 64
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*2) x 32 x 32
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*4) x 16 x 16
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*8) x 8 x 8
            nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*16) x 4 x 4
            nn.Conv2d(ndf * 16, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

# Custom Dataset for BraTS2020 with NIfTI support
class BraTSDataset(Dataset):
    def __init__(self, root_dir, transform=None, slice_range=(50, 100), modality='flair'):
        self.image_paths = []
        self.transform = transform
        self.slice_range = slice_range
        self.modality = modality.lower()

        print(f"Searching for files in: {root_dir}")
        print(f"Looking for modality: {self.modality}")

        # Debug: Check if directory exists
        if not os.path.exists(root_dir):
            print(f"ERROR: Directory {root_dir} does not exist!")
            return

        # Debug: Explore directory structure more deeply
        def explore_directory(path, max_depth=4, current_depth=0):
            if current_depth >= max_depth:
                return

            try:
                items = os.listdir(path)
                indent = "  " * current_depth
                rel_path = os.path.relpath(path, root_dir)
                if rel_path == '.':
                    print(f"Contents of root directory:")
                else:
                    print(f"{indent}Contents of {rel_path}/:")

                dirs = []
                files = []
                for item in items:
                    item_path = os.path.join(path, item)
                    if os.path.isdir(item_path):
                        dirs.append(item)
                    else:
                        files.append(item)

                # Show files first
                for file in files[:5]:  # Show first 5 files
                    print(f"{indent}  FILE: {file}")
                if len(files) > 5:
                    print(f"{indent}  ... and {len(files) - 5} more files")

                # Show directories
                for dir_name in dirs[:5]:  # Show first 5 directories
                    print(f"{indent}  DIR: {dir_name}")
                if len(dirs) > 5:
                    print(f"{indent}  ... and {len(dirs) - 5} more directories")

                # Recursively explore directories
                for dir_name in dirs[:3]:  # Only explore first 3 dirs to avoid too much output
                    dir_path = os.path.join(path, dir_name)
                    explore_directory(dir_path, max_depth, current_depth + 1)

            except Exception as e:
                print(f"{indent}Error exploring {path}: {e}")

        explore_directory(root_dir)

        # Find all NIfTI files
        file_patterns = [
            f'_{self.modality}.nii.gz',
            f'_{self.modality}.nii',
            '.nii.gz',  # Fallback - any nifti file
            '.nii'      # Fallback - any nifti file
        ]

        files_found = 0
        for subdir, dirs, files in os.walk(root_dir):
            # Debug: Show directory being processed
            rel_path = os.path.relpath(subdir, root_dir)
            if rel_path != '.':
                print(f"Processing subdirectory: {rel_path}")

            for file in files:
                file_lower = file.lower()
                # Check for specific modality first, then any nifti
                for pattern in file_patterns:
                    if file_lower.endswith(pattern):
                        full_path = os.path.join(subdir, file)
                        self.image_paths.append(full_path)
                        files_found += 1
                        if files_found <= 5:  # Show first few files found
                            print(f"  Found: {file}")
                        break

                # Also support regular images as fallback
                if file_lower.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                    self.image_paths.append(os.path.join(subdir, file))
                    files_found += 1

        print(f"Total files found: {len(self.image_paths)}")

        if len(self.image_paths) == 0:
            print("\nTroubleshooting:")
            print("1. Check if the dataset path is correct")
            print("2. The BraTS2020 dataset should have structure like:")
            print("   BraTS20_Training_XXX/")
            print("   ├── BraTS20_Training_XXX_flair.nii.gz")
            print("   ├── BraTS20_Training_XXX_t1.nii.gz")
            print("   ├── BraTS20_Training_XXX_t1ce.nii.gz")
            print("   └── BraTS20_Training_XXX_t2.nii.gz")
            print("3. Try changing modality to 't1', 't1ce', or 't2' if no FLAIR files exist")

    def __len__(self):
        # Each NIfTI file contains multiple slices, so multiply by approximate slice count
        return len(self.image_paths) * (self.slice_range[1] - self.slice_range[0])

    def __getitem__(self, idx):
        file_idx = idx // (self.slice_range[1] - self.slice_range[0])
        slice_idx = idx % (self.slice_range[1] - self.slice_range[0]) + self.slice_range[0]

        if file_idx >= len(self.image_paths):
            file_idx = file_idx % len(self.image_paths)

        file_path = self.image_paths[file_idx]

        try:
            if file_path.lower().endswith(('.nii.gz', '.nii')):
                # Load NIfTI file
                nii_img = nib.load(file_path)
                img_data = nii_img.get_fdata()

                # Extract a slice (typically axial view)
                if len(img_data.shape) == 3:
                    if slice_idx < img_data.shape[2]:
                        img_slice = img_data[:, :, slice_idx]
                    else:
                        img_slice = img_data[:, :, img_data.shape[2]//2]  # Middle slice if out of range
                elif len(img_data.shape) == 4:
                    img_slice = img_data[:, :, img_data.shape[2]//2, 0]  # For 4D data, take first timepoint
                else:
                    # Fallback for unexpected dimensions
                    img_slice = np.squeeze(img_data)
                    if len(img_slice.shape) > 2:
                        img_slice = img_slice[:, :, min(slice_idx, img_slice.shape[2]-1)]

                # Normalize the slice
                img_slice = np.nan_to_num(img_slice)  # Handle NaN values
                if img_slice.max() > img_slice.min():
                    img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min())
                img_slice = (img_slice * 255).astype(np.uint8)

                # Convert to PIL Image
                img = Image.fromarray(img_slice, mode='L')
            else:
                # Regular image loading
                img = Image.open(file_path).convert('L')

        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            # Return a blank image if loading fails
            img = Image.new('L', (image_size, image_size), 0)

        if self.transform:
            img = self.transform(img)

        return img

# Data loading with improved transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1] range
])

# Dataset path - Updated to the correct path with zip files
dataset_path = "/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions"

# Function to extract zip files and find NIfTI data
def extract_and_find_data(base_path):
    """Extract zip files and find NIfTI files"""
    print(f"Searching in: {base_path}")

    # Find all zip files
    zip_files = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith('.zip'):
                zip_path = os.path.join(root, file)
                zip_files.append(zip_path)
                print(f"Found zip file: {file}")

    if not zip_files:
        print("No zip files found!")
        return None

    # Create a temporary directory for extraction
    temp_dir = tempfile.mkdtemp()
    print(f"Created temporary directory: {temp_dir}")

    try:
        # Extract zip files
        for zip_path in zip_files:
            print(f"Extracting: {os.path.basename(zip_path)}")
            try:
                with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                    zip_ref.extractall(temp_dir)
            except Exception as e:
                print(f"Error extracting {zip_path}: {e}")
                continue

        # Look for NIfTI files in extracted content
        nifti_files = []
        for root, dirs, files in os.walk(temp_dir):
            for file in files:
                if file.lower().endswith(('.nii.gz', '.nii')):
                    nifti_files.append(os.path.join(root, file))

        print(f"Found {len(nifti_files)} NIfTI files after extraction")

        if nifti_files:
            # Show some example files
            for i, nifti_file in enumerate(nifti_files[:5]):
                print(f"  Example {i+1}: {os.path.basename(nifti_file)}")
            return temp_dir
        else:
            # Maybe the data needs to be downloaded differently
            print("No NIfTI files found after extraction.")
            print("The zip might contain metadata only.")

            # Let's see what was actually extracted
            print("Contents of extracted files:")
            for root, dirs, files in os.walk(temp_dir):
                level = root.replace(temp_dir, '').count(os.sep)
                indent = ' ' * 2 * level
                print(f"{indent}{os.path.basename(root)}/")
                subindent = ' ' * 2 * (level + 1)
                for file in files[:10]:  # Show first 10 files
                    print(f"{subindent}{file}")
                if len(files) > 10:
                    print(f"{subindent}... and {len(files) - 10} more files")

            return None

    except Exception as e:
        print(f"Error during extraction: {e}")
        return None

# Try to extract and find the data
extracted_path = extract_and_find_data(dataset_path)

if extracted_path is None:
    print("\n" + "="*60)
    print("DATASET SETUP ISSUE")
    print("="*60)
    print("It appears this dataset only contains metadata CSV files.")
    print("For the BraTS2020 dataset, you typically need to:")
    print("1. Register at https://www.med.upenn.edu/cbica/brats2020/registration.html")
    print("2. Download the actual training data (usually multiple GB)")
    print("3. The training data contains folders like BraTS20_Training_001/, BraTS20_Training_002/, etc.")
    print("4. Each folder contains 4 NIfTI files: *_flair.nii.gz, *_t1.nii.gz, *_t1ce.nii.gz, *_t2.nii.gz")
    print("\nFor now, let's create a demo with synthetic data...")

    # Create synthetic data for demonstration
    class SyntheticBrainDataset(Dataset):
        def __init__(self, size=1000, transform=None):
            self.size = size
            self.transform = transform

        def __len__(self):
            return self.size

        def __getitem__(self, idx):
            # Create a synthetic brain-like image
            np.random.seed(idx)  # For reproducibility

            # Create a circular brain-like structure
            img = np.zeros((128, 128), dtype=np.float32)

            # Add brain outline
            y, x = np.ogrid[:128, :128]
            center_y, center_x = 64, 64
            mask = (x - center_x)**2 + (y - center_y)**2 < 50**2

            # Add some texture inside the brain
            brain_region = np.random.normal(0.5, 0.1, (128, 128))
            brain_region = np.clip(brain_region, 0, 1)

            # Add some "tumor-like" bright spots randomly
            if np.random.random() > 0.7:  # 30% chance of bright spots
                tumor_y = np.random.randint(30, 98)
                tumor_x = np.random.randint(30, 98)
                tumor_mask = (x - tumor_x)**2 + (y - tumor_y)**2 < 5**2
                brain_region[tumor_mask] = np.random.uniform(0.8, 1.0)

            img[mask] = brain_region[mask]

            # Convert to PIL Image
            img = (img * 255).astype(np.uint8)
            img = Image.fromarray(img, mode='L')

            if self.transform:
                img = self.transform(img)

            return img

    print("Creating synthetic brain dataset for demonstration...")
    dataset = SyntheticBrainDataset(size=2000, transform=transform)
    actual_dataset_path = "synthetic_data"

else:
    actual_dataset_path = extracted_path

# Create dataset and dataloader
if extracted_path is None:
    # Use synthetic dataset
    print("Using synthetic dataset...")
else:
    # Use real extracted dataset
    dataset = BraTSDataset(actual_dataset_path, transform=transform, modality='flair')

# Check if dataset is empty and provide fallback (only for real data)
if extracted_path is not None and len(dataset) == 0:
    print("No files found! Trying with different modalities...")

    # Try different modalities
    for modality in ['t1', 't1ce', 't2', 'seg']:
        print(f"\nTrying modality: {modality}")
        dataset = BraTSDataset(actual_dataset_path, transform=transform, modality=modality)
        if len(dataset) > 0:
            print(f"Success! Found files with modality: {modality}")
            break

    # If still no files, try without specific modality (any .nii.gz file)
    if len(dataset) == 0:
        print("\nTrying to load any .nii.gz files...")
        dataset = BraTSDataset(actual_dataset_path, transform=transform, modality='')

# Final check before creating dataloader
if len(dataset) == 0:
    print("ERROR: No suitable image files found in the dataset!")
    print(f"Please check the dataset path: {actual_dataset_path}")
    print("The dataset should contain .nii.gz files or standard image formats")
    raise ValueError("Dataset is empty - no image files found")

print(f"\nFinal dataset size: {len(dataset)} samples")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)  # Set num_workers=0 for debugging

# Initialize models
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

G = Generator(nz=nz).to(device)
D = Discriminator().to(device)

# Apply weight initialization
G.apply(weights_init)
D.apply(weights_init)

# Loss and optimizers
criterion = nn.BCELoss()
optimizerD = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))

# Fixed noise for consistent generation tracking
fixed_noise = torch.randn(16, nz, 1, 1, device=device)

# Create output directory for generated images
os.makedirs("generated_images", exist_ok=True)

print("Starting training...")

# Training loop with improved logging
for epoch in range(epochs):
    for i, data in enumerate(dataloader):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        D.zero_grad()

        # Train with real batch
        real = data.to(device)
        b_size = real.size(0)
        label = torch.full((b_size,), 1., dtype=torch.float, device=device)

        output = D(real).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # Train with fake batch
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = G(noise)
        label.fill_(0.)

        output = D(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        G.zero_grad()
        label.fill_(1.)  # fake labels are real for generator cost

        output = D(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print(f'[{epoch}/{epochs}][{i}/{len(dataloader)}] '
                  f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                  f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')

    # Save generated images
    with torch.no_grad():
        fake = G(fixed_noise).detach().cpu()
        save_image(fake, f'generated_images/epoch_{epoch:03d}.png', normalize=True, nrow=4)

    # Save model checkpoints every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save({
            'generator_state_dict': G.state_dict(),
            'discriminator_state_dict': D.state_dict(),
            'optimizerG_state_dict': optimizerG.state_dict(),
            'optimizerD_state_dict': optimizerD.state_dict(),
            'epoch': epoch,
        }, f'generated_images/checkpoint_epoch_{epoch:03d}.pth')

print("Training completed!")