# Tumor classification

Classification code is done by Chitipotu Kushwanth

In [None]:
import os
import torch
import nibabel as nib

# Function to classify tumor grade based on tumor subregion volumes
def classify_tumor_grade(segmented_image):
    """
    Classifies a tumor as high-grade (1) or low-grade (0) based on the volume of tumor subregions.
    """
    # Tumor subregion labels
    LABEL_NECROTIC = 1
    LABEL_EDEMA = 2

    # Calculate volumes (number of voxels) for each subregion
    necrotic_volume = torch.sum(segmented_image == LABEL_NECROTIC).item()
    edema_volume = torch.sum(segmented_image == LABEL_EDEMA).item()

    # Total tumor volume (sum of all subregions)
    total_tumor_volume = necrotic_volume + edema_volume

    # Avoid division by zero in case there is no tumor
    if total_tumor_volume == 0:
        return 0  # Default to low-grade if no tumor is present

    # Calculate proportions of necrotic regions relative to total tumor volume
    necrotic_proportion = necrotic_volume / total_tumor_volume

    # Define thresholds for classification
    NECROTIC_THRESHOLD = 0.2  # High-grade tumors tend to have more necrosis

    # Classification rule
    if necrotic_proportion > NECROTIC_THRESHOLD:
        return 1  # High-grade
    else:
        return 0  # Low-grade

# Function to process and classify tumor grade, then save the result in the same folder
def process_and_classify(patient_folder):
    """
    Classifies tumor grade based on the segmentation file in the patient's folder and saves the result.
    """
    patient_id = os.path.basename(patient_folder)
    seg_file = os.path.join(patient_folder, f"{patient_id}_seg.nii")

    if os.path.exists(seg_file):
        # Load segmentation file
        seg_img = nib.load(seg_file)
        seg_data = seg_img.get_fdata()

        # Convert segmentation data to a PyTorch tensor
        segmented_image = torch.tensor(seg_data, dtype=torch.float32)

        # Classify tumor grade
        grade = classify_tumor_grade(segmented_image)
        grade_str = "High-grade" if grade == 1 else "Low-grade"

        # Save classification result in the same folder
        classification_file = os.path.join(patient_folder, "classification.txt")
        with open(classification_file, "w") as f:
            f.write(f"Patient ID: {patient_id}\n")
            f.write(f"Tumor grade: {grade_str}\n")

        print(f"Classification saved for patient {patient_id}: {grade_str}")
    else:
        print(f"Segmentation file not found for patient {patient_id}")

# Mount Google Drive if needed
from google.colab import drive
drive.mount('/content/drive')

# Main directory containing patient folders
input_dir = "/content/drive/MyDrive/Processed_BraTS2021/val"
# Process each patient folder
for patient_folder in os.listdir(input_dir):
    full_path = os.path.join(input_dir, patient_folder)
    if os.path.isdir(full_path):  # Check if it's a folder
        process_and_classify(full_path)


# EfficientNEt3D

Rest of the code is from   Sai satwik clasifier code

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EfficientNet3D(nn.Module):
    def __init__(self, input_channels=1, num_classes=2):  # Default input channels to 1
        super(EfficientNet3D, self).__init__()

        # Initial Conv3D layer with dynamic input channels (e.g., 1 for grayscale or 3 for multi-channel)
        self.conv1 = nn.Conv3d(input_channels, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm3d(32)
        self.relu1 = nn.ReLU()

        # Efficient 3D blocks (a few layers for simplicity)
        self.block1 = self._make_block(32, 64)
        self.block2 = self._make_block(64, 128)
        self.block3 = self._make_block(128, 256)

        # Global average pooling layer (to handle variable input size)
        self.global_avg_pool = nn.AdaptiveAvgPool3d(1)

        # Fully connected layers for classification
        self.fc1 = nn.Linear(256, 512)
        self.fc2 = nn.Linear(512, num_classes)  # Output layer changed to 2 classes

        # Dropout for regularization
        self.dropout = nn.Dropout(0.5)

    def _make_block(self, in_channels, out_channels):
        # A helper function to create 3D convolution blocks
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
            nn.MaxPool3d(2)
        )

    def forward(self, x):
        # Pass through the initial convolution and the blocks
        x = self.relu1(self.bn1(self.conv1(x)))

        # Block 1, Block 2, Block 3
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        # Global average pooling to ensure size consistency
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)  # Flatten to (batch_size, num_features)

        # Fully connected layers
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)  # No sigmoid here because we'll use softmax during loss computation

        return x



# Dataset

Rest is by Chitipotu Kushwanth

In [None]:
import os
import torch
import nibabel as nib
from torch.utils.data import Dataset
from torchvision import transforms

class BRATS2021Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Root directory where each patient folder contains segmentation (.nii) and classification.txt.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.patients = sorted(os.listdir(root_dir))  # List of patient folders
        self.transform = transform

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        patient_folder = self.patients[idx]
        patient_dir = os.path.join(self.root_dir, patient_folder)

        # Extract patient ID (assuming folder name format is like "BraTS2021_00005")
        patient_id = patient_folder

        # Define the segmentation file path
        segmentation_path = os.path.join(patient_dir, f"{patient_id}_seg.nii")
        if not os.path.exists(segmentation_path):
            raise FileNotFoundError(f"Segmentation file not found: {segmentation_path}")

        # Load the NIfTI file using nibabel
        nii = nib.load(segmentation_path)
        segmentation = nii.get_fdata()  # Get voxel data (numpy array)

        # Convert segmentation to a torch tensor
        segmentation = torch.tensor(segmentation, dtype=torch.float32)

        # Parse the classification.txt file to get the tumor grade
        classification_file = os.path.join(patient_dir, "classification.txt")
        if not os.path.exists(classification_file):
            raise FileNotFoundError(f"Classification file not found: {classification_file}")

        with open(classification_file, 'r') as f:
            lines = f.readlines()
            tumor_grade = lines[1].strip()  # Assuming second line contains the tumor grade

        # Remove the prefix "Tumor grade: " if present
        tumor_grade = tumor_grade.replace("Tumor grade: ", "").strip()

        # Convert tumor grade to a numeric label (e.g., 0 for low-grade, 1 for high-grade)
        if tumor_grade == "Low-grade":
            tumor_label = 0
        elif tumor_grade == "High-grade":
            tumor_label = 1
        else:
            raise ValueError(f"Unknown tumor grade: {tumor_grade}")

        # Apply transformations if any
        if self.transform:
            segmentation = self.transform(segmentation)

        # Return the segmentation and corresponding tumor label
        return segmentation, tumor_label


# Example transformation (e.g., normalization, resizing)
transform = transforms.Compose([
    # Add any transformations needed, for example:
    transforms.Normalize(mean=[0.5], std=[0.5]),  # Example: Normalize the data
])




In [None]:
# Example usage
train_dir = r"C:\Users\intel5\Downloads\Processed_BraTS2021-20241121T100710Z-001\Processed_BraTS2021\train"   # Replace with the actual path
dataset = BRATS2021Dataset(root_dir=train_dir, transform=transform)

# Example DataLoader
train_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

val_dir = r"C:\Users\intel5\Downloads\Processed_BraTS2021-20241121T100710Z-001\Processed_BraTS2021\val"   # Replace with the actual path
dataset = BRATS2021Dataset(root_dir=val_dir, transform=transform)

# Example DataLoader
val_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

test_dir = r"C:\Users\intel5\Downloads\Processed_BraTS2021-20241121T100710Z-001\Processed_BraTS2021\test"   # Replace with the actual path
dataset = BRATS2021Dataset(root_dir=test_dir, transform=transform)

# Example DataLoader
test_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

# Training Loop

In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import os

# Assuming your dataset is set up with DataLoader as `train_loader` and `test_loader`

# Initialize the model
model = EfficientNet3D(input_channels=1, num_classes=2)  # Number of classes is 2

# Check if GPU is available and move the model to GPU if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()  # For classification tasks
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Directory to save checkpoints
checkpoint_dir = r'C:\Users\intel5\Downloads\Processed_BraTS2021-20241121T100710Z-001\Processed_BraTS2021\checkpoints'

# Function to save checkpoints
def save_checkpoint(model, optimizer, epoch, loss, checkpoint_dir):
    filename = f"checkpoint_epoch_{epoch+1}.pth"  # Save checkpoint with epoch number
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    # Ensure the checkpoint directory exists
    os.makedirs(checkpoint_dir, exist_ok=True)

    checkpoint_path = os.path.join(checkpoint_dir, filename)
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch+1}, filename: {filename}")

# Function to load checkpoints
def load_checkpoint(model, optimizer, checkpoint_dir, filename="checkpoint_epoch_1.pth"):
    checkpoint_path = os.path.join(checkpoint_dir, filename)
    if os.path.isfile(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1  # Resume from the next epoch
        loss = checkpoint['loss']
        print(f"Checkpoint loaded. Starting from epoch {start_epoch}. Last loss: {loss}")
        return model, optimizer, start_epoch, loss
    else:
        print("No checkpoint found, starting from scratch.")
        return model, optimizer, 0, None

# Training function with tqdm and checkpoint saving
def train_model(model, train_loader, criterion, optimizer, device, epochs=10, checkpoint_dir=None):
    model.train()  # Set the model to training mode
    start_epoch = 0  # Default start epoch
    if checkpoint_dir:
        model, optimizer, start_epoch, _ = load_checkpoint(model, optimizer, checkpoint_dir)

    for epoch in range(start_epoch, epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        # Wrap the DataLoader with tqdm to show progress
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch"):
            inputs = inputs.to(device)  # Move to GPU if available
            labels = labels.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)

            # Compute loss
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Track loss and accuracy
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = 100 * correct / total

        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")

        # Save checkpoint after every epoch with the epoch number in the filename
        if checkpoint_dir:
            save_checkpoint(model, optimizer, epoch, epoch_loss, checkpoint_dir)

# Evaluation function with tqdm
def evaluate_model(model, test_loader, device):
    model.eval()  # Set the model to evaluation mode
    all_preds = []
    all_labels = []

    with torch.no_grad():  # Disable gradient computation for evaluation
        # Wrap the DataLoader with tqdm to show progress
        for inputs, labels in tqdm(test_loader, desc="Evaluating", unit="batch"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)

            # Get predictions
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy: {accuracy * 100:.2f}%")
    return accuracy

# Train the model
train_model(model, train_loader, criterion, optimizer, device, epochs=10, checkpoint_dir=checkpoint_dir)

# Evaluate the model
evaluate_model(model, val_loader, device)


  checkpoint = torch.load(checkpoint_path)


Loading checkpoint from C:\Users\intel5\Downloads\Processed_BraTS2021-20241121T100710Z-001\Processed_BraTS2021\checkpoints\checkpoint_epoch_1.pth
Checkpoint loaded. Starting from epoch 1. Last loss: 0.6794873289253613


Epoch 2/10: 100%|█████████████████████████████████████████████████████████████████| 438/438 [52:19<00:00,  7.17s/batch]


Epoch 2/10, Loss: 0.6634, Accuracy: 62.29%
Checkpoint saved at epoch 2, filename: checkpoint_epoch_2.pth


Epoch 3/10: 100%|█████████████████████████████████████████████████████████████████| 438/438 [51:42<00:00,  7.08s/batch]


Epoch 3/10, Loss: 0.6675, Accuracy: 63.20%
Checkpoint saved at epoch 3, filename: checkpoint_epoch_3.pth


Epoch 4/10:   2%|█                                                                  | 7/438 [00:50<52:04,  7.25s/batch]