In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
# Import MobileNetV2 specific weights enum
from torchvision.models import MobileNet_V2_Weights
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast # For Mixed Precision
import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt

# --- Configuration ---
# <<< IMPORTANT: UPDATE these paths to your binary dataset location >>>
data_base_dir = '../dataset2'
train_dir = os.path.join(data_base_dir, 'train')
val_dir = os.path.join(data_base_dir, 'validation')

# Model parameters
num_classes = 2 # Binary classification: brain_mri vs other_image
batch_size = 64 # MobileNetV2 is lighter, might allow larger batch size than ResNet50
num_epochs = 20 # Adjust as needed, monitor validation accuracy
learning_rate = 0.001

# Image transformations parameters
img_size = 224 # MobileNetV2 typically uses 224x224

# Define the device (use GPU if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Enable cuDNN Benchmark if using GPU
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    print("cuDNN Benchmark enabled.")

# Define where to save the trained binary classifier model
binary_model_save_path = 'mobilenet_binary_classifier_weights.pth'

In [None]:
# Define the class names for this binary task
# IMPORTANT: These should match your folder names ('brain_mri', 'other_images')
class_names = ['brain_mri', 'other_image']
print(f"Binary classification classes: {class_names}")

In [None]:
# Define transformations for the training and validation data
# Use standard ImageNet normalization as we use a pre-trained model
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(imagenet_mean, imagenet_std)
    ]),
    'val': transforms.Compose([
        transforms.Resize(img_size + 32), # Resize slightly larger for center crop
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(imagenet_mean, imagenet_std)
    ])
}
print("Data transforms defined.")

In [None]:
# --- Modified Cell 4: Load Data and Split in Code ---

print(f"Loading data from base directory: {data_base_dir}")
print("Applying TRAIN transforms during initial load for splitting reference...")
print("Applying VAL transforms during second initial load for validation subset...")


# --- Create TWO ImageFolder datasets pointing to the SAME base directory ---
# This is the recommended way to handle applying different transforms to train/val splits.
# One instance uses training transforms, the other uses validation transforms.
try:
    # Instance applying training transforms (used for training subset)
    full_dataset_train_transforms = datasets.ImageFolder(data_base_dir, transform=data_transforms['train'])

    # Instance applying validation transforms (used for validation subset)
    full_dataset_val_transforms = datasets.ImageFolder(data_base_dir, transform=data_transforms['val'])

    print(f"Successfully loaded dataset structure. Total images: {len(full_dataset_train_transforms)}")

    # Basic check: Ensure both instances found the same files/classes
    if len(full_dataset_train_transforms) != len(full_dataset_val_transforms):
         print("Warning: Dataset length mismatch between train/val transform instances. Check data loading.")
    if full_dataset_train_transforms.classes != full_dataset_val_transforms.classes:
         print("Warning: Detected class mismatch between train/val transform instances.")

except FileNotFoundError:
    print(f"ERROR: Base dataset folder not found at {data_base_dir}.")
    print("Please ensure this path points to the directory containing 'brain_mri' and 'other_image' subfolders.")
    raise # Stop execution if data isn't found
except Exception as e:
    print(f"An error occurred loading dataset: {e}")
    raise

# Get class names and mapping (should be same for both instances)
detected_classes = full_dataset_train_transforms.classes
class_to_idx = full_dataset_train_transforms.class_to_idx
print(f"Classes detected: {detected_classes}")
print(f"Class to index mapping: {class_to_idx}")

# Verify detected classes match expected classes
if detected_classes != class_names: # class_names defined in Cell 2
    print(f"Warning: Detected classes {detected_classes} do not match expected {class_names}. Check folder names.")

# --- Define the Split Ratio ---
val_split = 0.2  # e.g., 20% for validation
dataset_size = len(full_dataset_train_transforms)
val_size = int(val_split * dataset_size)
train_size = dataset_size - val_size
print(f"Splitting dataset: {train_size} training samples, {val_size} validation samples")

# --- Perform the Random Split ---
# Important: Split based on indices, then create Subsets using the appropriate transform dataset

# Generate indices and shuffle them
indices = list(range(dataset_size))
np.random.seed(42) # Optional: for reproducible splits
np.random.shuffle(indices)
train_indices, val_indices = indices[:train_size], indices[train_size:]

# Create Subset datasets using the shuffled indices and the corresponding transform dataset
train_dataset_subset = torch.utils.data.Subset(full_dataset_train_transforms, train_indices)
val_dataset_subset = torch.utils.data.Subset(full_dataset_val_transforms, val_indices)

x
# --- Create DataLoaders from the Subset datasets ---
dataloaders = {
    'train': DataLoader(train_dataset_subset, batch_size=batch_size, shuffle=True, num_workers=4),
    'val': DataLoader(val_dataset_subset, batch_size=batch_size, shuffle=False, num_workers=4) # No shuffle for validation
}

dataset_sizes = {'train': len(train_dataset_subset), 'val': len(val_dataset_subset)}
print("DataLoaders created using random split from the base dataset.")

In [None]:
# Load pre-trained MobileNetV2 using recommended 'weights' parameter
try:
    model = models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
    print("Loaded pre-trained MobileNetV2 weights (DEFAULT).")
except Exception as e:
    print(f"Could not download pre-trained weights: {e}. Initializing random weights.")
    model = models.mobilenet_v2(weights=None)

# MobileNetV2 has a 'classifier' layer which is a Sequential module.
# The actual Linear layer is the last element (index 1).
num_ftrs = model.classifier[1].in_features # Get features into the final linear layer

# Replace the classifier's final layer for binary classification (num_classes = 2)
model.classifier[1] = nn.Linear(num_ftrs, num_classes)

# Move the model to the specified device
model = model.to(device)

print("MobileNetV2 model loaded and final layer modified for binary classification.")
# print(model) # Optional: print model architecture

In [None]:
# Define the loss function (CrossEntropyLoss works for 2 classes)
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Initialize GradScaler for Mixed Precision
scaler = GradScaler(enabled=torch.cuda.is_available())

print("Loss function, optimizer, and GradScaler defined.")

In [None]:
print("Starting training for binary classifier...")
start_time = time.time()

# Variables to track best model based on validation accuracy
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}')
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        for i, (inputs, labels) in enumerate(dataloaders[phase]):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            # Track history only in train phase for optimizer step
            # Use autocast for mixed precision
            with torch.set_grad_enabled(phase == 'train'):
                with autocast(enabled=(device.type == 'cuda')): # Enable AMP only on GPU
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) # Get predictions
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        if torch.isnan(loss):
                            print(f"WARNING: NaN loss detected at epoch {epoch+1}, batch {i+1} ({phase}). Skipping batch.")
                            continue
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

            # Print progress periodically during training
            # if phase == 'train' and (i + 1) % 100 == 0:
            #      print(f'  Batch {i+1}/{len(dataloaders[phase])}, Current Batch Loss: {loss.item():.4f}')


        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]

        print(f'{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Deep copy the model if validation accuracy improves
        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            print(f'  -> Best Validation Acc: {best_acc:.4f} (saved model weights)')

# --- Training Complete ---
time_elapsed = time.time() - start_time
print(f'\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best Validation Accuracy: {best_acc:.4f}')

# Load best model weights before saving
model.load_state_dict(best_model_wts)

In [None]:
# Save the state dictionary of the best model
torch.save(model.state_dict(), binary_model_save_path)
print(f"Best binary classifier model weights saved to {binary_model_save_path}")
# This is the file you'll load in your Flask app for the initial check.