In [1]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os
import glob
import random
from tqdm import tqdm

In [2]:
# Custom Dataset class
class BrainTumorDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, torch.tensor(label, dtype=torch.float32)

In [3]:
def load_balanced_data():
    """
    Load and balance validation data, and create augmented training data
    """
    # Define data directories
    base_dir = os.path.join(os.path.dirname(os.getcwd()), 'datasets/brain-tumor')
    train_img_dir = os.path.join(base_dir, 'train/images')
    train_label_dir = os.path.join(base_dir, 'train/labels')
    val_img_dir = os.path.join(base_dir, 'valid/images')
    val_label_dir = os.path.join(base_dir, 'valid/labels')
    
    print(f"Loading data from: {base_dir}")
    
    # Process validation data first and create balanced set
    val_image_paths = sorted(glob.glob(os.path.join(val_img_dir, '*')))
    val_tumor_paths = []
    val_non_tumor_paths = []
    
    # Separate validation images by class
    for img_path in val_image_paths:
        img_filename = os.path.basename(img_path)
        base_name = os.path.splitext(img_filename)[0]
        label_path = os.path.join(val_label_dir, f"{base_name}.txt")
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                line = f.readline().strip()
                if line:
                    parts = line.split()
                    class_id = int(parts[0])
                    if class_id == 1:
                        val_tumor_paths.append(img_path)
                    else:
                        val_non_tumor_paths.append(img_path)
    
    # Randomly sample non-tumor images to match tumor images count
    num_tumor = len(val_tumor_paths)
    val_non_tumor_paths_balanced = random.sample(val_non_tumor_paths, num_tumor)
    
    # Combine and shuffle balanced validation data
    val_image_paths_balanced = val_tumor_paths + val_non_tumor_paths_balanced
    val_labels_balanced = [1] * num_tumor + [0] * num_tumor
    combined = list(zip(val_image_paths_balanced, val_labels_balanced))
    random.shuffle(combined)
    val_image_paths_balanced, val_labels_balanced = zip(*combined)
    
    # Process training data and create augmented versions
    train_image_paths = sorted(glob.glob(os.path.join(train_img_dir, '*')))
    augmented_train_paths = []
    augmented_train_labels = []
    
    # Create directory for augmented images
    augmented_dir = "augmented_train_data"
    if not os.path.exists(augmented_dir):
        os.makedirs(augmented_dir)
    
    # Define augmentation transforms
    flip_transform = transforms.RandomHorizontalFlip(p=1.0)  # p=1.0 means always flip
    brightness_transform = transforms.ColorJitter(brightness=0.2)
    
    print("Creating augmented training dataset...")
    for img_path in tqdm(train_image_paths):
        # Get original label
        img_filename = os.path.basename(img_path)
        base_name = os.path.splitext(img_filename)[0]
        label_path = os.path.join(train_label_dir, f"{base_name}.txt")
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                line = f.readline().strip()
                if line:
                    parts = line.split()
                    label = int(parts[0])
                else:
                    label = 0
        else:
            print(f"Warning: No label found for {img_path}")
            label = 0
        
        # Load and process original image
        img = Image.open(img_path).convert('RGB')
        
        # Save original image and label
        original_save_path = os.path.join(augmented_dir, f"{base_name}_original.png")
        img.save(original_save_path)
        augmented_train_paths.append(original_save_path)
        augmented_train_labels.append(label)
        
        # Create and save flipped version
        flipped_img = flip_transform(img)
        flip_save_path = os.path.join(augmented_dir, f"{base_name}_flipped.png")
        flipped_img.save(flip_save_path)
        augmented_train_paths.append(flip_save_path)
        augmented_train_labels.append(label)
        
        # Create and save brightness adjusted version
        bright_img = brightness_transform(img)
        bright_save_path = os.path.join(augmented_dir, f"{base_name}_bright.png")
        bright_img.save(bright_save_path)
        augmented_train_paths.append(bright_save_path)
        augmented_train_labels.append(label)
    
    print(f"Original training dataset size: {len(train_image_paths)}")
    print(f"Augmented training dataset size: {len(augmented_train_paths)}")
    print(f"Balanced validation dataset size: {len(val_image_paths_balanced)} ({num_tumor} tumor, {num_tumor} non-tumor)")
    
    return augmented_train_paths, augmented_train_labels, list(val_image_paths_balanced), list(val_labels_balanced)

In [4]:
import torch
import torch.nn as nn
from torchvision.models import vgg16, VGG16_Weights

class VGG16Transfer(nn.Module):
    def __init__(self):
        super(VGG16Transfer, self).__init__()
        
        # Load pretrained VGG16 model with IMAGENET1K_V1 weights
        self.vgg16 = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
        
        # Freeze all layers in the VGG16 model
        for param in self.vgg16.parameters():
            param.requires_grad = False
            
        # Replace the classifier with our custom classifier
        self.vgg16.classifier = nn.Sequential(
            nn.Linear(25088, 256),  # VGG16's last conv layer outputs 25088 features
            nn.ReLU(),
            nn.Dropout(0.3),        # Dropout after dense layer and ReLU
            nn.Linear(256, 1),      # Binary classification (1 output neuron)
            nn.Sigmoid()            # Sigmoid activation for binary classification
        )
        
        # Unfreeze only the new dense layer (256 units)
        for param in self.vgg16.classifier[0].parameters():  # First Linear layer (256 units)
            param.requires_grad = True
            
    def forward(self, x):
        return self.vgg16(x)

# Get the transforms for preprocessing
data_transforms = VGG16_Weights.IMAGENET1K_V1.transforms()

# Create model instance
model = VGG16Transfer()

# Print model summary
print(model)

# Verify which layers are trainable
print("\nTrainable parameters:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.shape}")

VGG16Transfer(
  (vgg16): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
      

In [5]:
# Load and prepare the data
train_paths, train_labels, val_paths, val_labels = load_balanced_data()

# Create datasets using your existing BrainTumorDataset class and transforms
train_dataset = BrainTumorDataset(train_paths, train_labels, transform=data_transforms)
val_dataset = BrainTumorDataset(val_paths, val_labels, transform=data_transforms)
    
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Initialize model
model = VGG16Transfer()
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')

# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

Loading data from: E:\SDS-CP024-neurovision\submissions-team\andy-chen\model_training\datasets/brain-tumor
Creating augmented training dataset...


100%|██████████| 878/878 [00:51<00:00, 17.00it/s]


Original training dataset size: 878
Augmented training dataset size: 2634
Balanced validation dataset size: 162 (81 tumor, 81 non-tumor)


In [6]:
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import numpy as np

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=30):
    best_val_acc = 0.0
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_preds = []
        train_labels_all = []
        
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels.view(-1, 1))
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            predicted = (outputs > 0.5).float()
            train_preds.extend(predicted.cpu().numpy())
            train_labels_all.extend(labels.cpu().numpy())
        
        train_loss = train_loss / len(train_loader)
        train_preds = np.array(train_preds).flatten()
        train_labels_all = np.array(train_labels_all)
        
        # Calculate training metrics
        train_acc = 100 * np.mean(train_preds == train_labels_all)
        train_precision = precision_score(train_labels_all, train_preds)
        train_recall = recall_score(train_labels_all, train_preds)
        train_f1 = f1_score(train_labels_all, train_preds)
        train_cm = confusion_matrix(train_labels_all, train_preds)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_labels_all = []
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels.view(-1, 1))
                
                val_loss += loss.item()
                predicted = (outputs > 0.5).float()
                val_preds.extend(predicted.cpu().numpy())
                val_labels_all.extend(labels.cpu().numpy())
        
        val_loss = val_loss / len(val_loader)
        val_preds = np.array(val_preds).flatten()
        val_labels_all = np.array(val_labels_all)
        
        # Calculate validation metrics
        val_acc = 100 * np.mean(val_preds == val_labels_all)
        val_precision = precision_score(val_labels_all, val_preds)
        val_recall = recall_score(val_labels_all, val_preds)
        val_f1 = f1_score(val_labels_all, val_preds)
        val_cm = confusion_matrix(val_labels_all, val_preds)
        
        # Print statistics
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print('\nTraining Metrics:')
        print(f'Loss: {train_loss:.4f}')
        print(f'Accuracy: {train_acc:.2f}%')
        print(f'Precision: {train_precision:.4f}')
        print(f'Recall: {train_recall:.4f}')
        print(f'F1-Score: {train_f1:.4f}')
        print('Confusion Matrix:')
        print(train_cm)
        
        print('\nValidation Metrics:')
        print(f'Loss: {val_loss:.4f}')
        print(f'Accuracy: {val_acc:.2f}%')
        print(f'Precision: {val_precision:.4f}')
        print(f'Recall: {val_recall:.4f}')
        print(f'F1-Score: {val_f1:.4f}')
        print('Confusion Matrix:')
        print(val_cm)
        print('-' * 50)
        
        # Save best model based on validation accuracy
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_precision': val_precision,
                'val_recall': val_recall,
                'val_f1': val_f1,
            }, 'best_model.pth')

In [7]:
# Start training (epoch 30, lr: 0.001)
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=30)

Epoch 1/30:

Training Metrics:
Loss: 0.6725
Accuracy: 57.63%
Precision: 0.5788
Recall: 0.6964
F1-Score: 0.6322
Confusion Matrix:
[[559 698]
 [418 959]]

Validation Metrics:
Loss: 0.6621
Accuracy: 56.17%
Precision: 0.5391
Recall: 0.8519
F1-Score: 0.6603
Confusion Matrix:
[[22 59]
 [12 69]]
--------------------------------------------------
Epoch 2/30:

Training Metrics:
Loss: 0.6295
Accuracy: 67.92%
Precision: 0.6757
Recall: 0.7429
F1-Score: 0.7077
Confusion Matrix:
[[ 766  491]
 [ 354 1023]]

Validation Metrics:
Loss: 0.6672
Accuracy: 62.35%
Precision: 0.6471
Recall: 0.5432
F1-Score: 0.5906
Confusion Matrix:
[[57 24]
 [37 44]]
--------------------------------------------------
Epoch 3/30:

Training Metrics:
Loss: 0.5991
Accuracy: 71.91%
Precision: 0.7296
Recall: 0.7349
F1-Score: 0.7323
Confusion Matrix:
[[ 882  375]
 [ 365 1012]]

Validation Metrics:
Loss: 0.6543
Accuracy: 62.96%
Precision: 0.6400
Recall: 0.5926
F1-Score: 0.6154
Confusion Matrix:
[[54 27]
 [33 48]]
--------------------

KeyboardInterrupt: 

In [None]:
# new training starting here, changed metrics display, learning rate 0.01, epoch 20
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [10]:
!pip install seaborn

Collecting seaborn
  Using cached seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting pandas>=1.2 (from seaborn)
  Using cached pandas-2.2.3-cp312-cp312-win_amd64.whl.metadata (19 kB)
Collecting pytz>=2020.1 (from pandas>=1.2->seaborn)
  Using cached pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas>=1.2->seaborn)
  Using cached tzdata-2025.2-py2.py3-none-any.whl.metadata (1.4 kB)
Using cached seaborn-0.13.2-py3-none-any.whl (294 kB)
Using cached pandas-2.2.3-cp312-cp312-win_amd64.whl (11.5 MB)
Using cached pytz-2025.2-py2.py3-none-any.whl (509 kB)
Using cached tzdata-2025.2-py2.py3-none-any.whl (347 kB)
Installing collected packages: pytz, tzdata, pandas, seaborn
Successfully installed pandas-2.2.3 pytz-2025.2 seaborn-0.13.2 tzdata-2025.2



[notice] A new release of pip is available: 24.0 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [11]:
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20):
    best_val_acc = 0.0
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Lists to store metrics for plotting
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels.view(-1, 1))
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            predicted = (outputs > 0.5).float()
            train_total += labels.size(0)
            train_correct += (predicted.view(-1) == labels).sum().item()
        
        train_loss = train_loss / len(train_loader)
        train_acc = 100 * train_correct / train_total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_preds = []
        val_labels_all = []
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels.view(-1, 1))
                
                val_loss += loss.item()
                predicted = (outputs > 0.5).float()
                val_total += labels.size(0)
                val_correct += (predicted.view(-1) == labels).sum().item()
                
                val_preds.extend(predicted.cpu().numpy())
                val_labels_all.extend(labels.cpu().numpy())
        
        val_loss = val_loss / len(val_loader)
        val_acc = 100 * val_correct / val_total
        
        # Store metrics for plotting
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        
        # Print simple progress
        print(f'Epoch [{epoch+1}/{num_epochs}] - '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% - '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, 'best_model.pth')
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    # Plot Loss
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    
    # Plot Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Plot confusion matrix
    val_preds = np.array(val_preds).flatten()
    val_labels_all = np.array(val_labels_all)
    cm = confusion_matrix(val_labels_all, val_preds)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    # Return final metrics
    return {
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies,
        'best_val_acc': best_val_acc
    }

In [12]:
# Start training
history = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20)

Epoch [1/20] - Train Loss: 0.4240, Train Acc: 80.75% - Val Loss: 0.8198, Val Acc: 54.32%
Epoch [2/20] - Train Loss: 0.3270, Train Acc: 86.10% - Val Loss: 0.7513, Val Acc: 61.11%
Epoch [3/20] - Train Loss: 0.2475, Train Acc: 90.70% - Val Loss: 0.7597, Val Acc: 64.81%
Epoch [4/20] - Train Loss: 0.1873, Train Acc: 93.70% - Val Loss: 0.8597, Val Acc: 62.35%
Epoch [5/20] - Train Loss: 0.1705, Train Acc: 94.27% - Val Loss: 0.8932, Val Acc: 61.73%
Epoch [6/20] - Train Loss: 0.0955, Train Acc: 97.91% - Val Loss: 0.9459, Val Acc: 66.05%
Epoch [7/20] - Train Loss: 0.0826, Train Acc: 98.29% - Val Loss: 1.3320, Val Acc: 51.23%
Epoch [8/20] - Train Loss: 0.0584, Train Acc: 99.13% - Val Loss: 1.0623, Val Acc: 64.81%
Epoch [9/20] - Train Loss: 0.0570, Train Acc: 98.90% - Val Loss: 0.9798, Val Acc: 64.20%
Epoch [10/20] - Train Loss: 0.0397, Train Acc: 99.58% - Val Loss: 1.0493, Val Acc: 60.49%
Epoch [11/20] - Train Loss: 0.0308, Train Acc: 99.89% - Val Loss: 1.1214, Val Acc: 63.58%
Epoch [12/20] - Tra

KeyboardInterrupt: 

In [13]:
# third attempt with early stopping
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20, patience=5):
    best_val_acc = 0.0
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Lists to store metrics for plotting
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    
    # Early stopping variables
    patience_counter = 0
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels.view(-1, 1))
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            predicted = (outputs > 0.5).float()
            train_total += labels.size(0)
            train_correct += (predicted.view(-1) == labels).sum().item()
        
        train_loss = train_loss / len(train_loader)
        train_acc = 100 * train_correct / train_total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_preds = []
        val_labels_all = []
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels.view(-1, 1))
                
                val_loss += loss.item()
                predicted = (outputs > 0.5).float()
                val_total += labels.size(0)
                val_correct += (predicted.view(-1) == labels).sum().item()
                
                val_preds.extend(predicted.cpu().numpy())
                val_labels_all.extend(labels.cpu().numpy())
        
        val_loss = val_loss / len(val_loader)
        val_acc = 100 * val_correct / val_total
        
        # Store metrics for plotting
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        
        # Print progress
        print(f'Epoch [{epoch+1}/{num_epochs}] - '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% - '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_acc': val_acc,
                    'val_loss': val_loss,
                }, 'best_model.pth')
        else:
            patience_counter += 1
            print(f'Early stopping counter: {patience_counter}/{patience}')
            
        # If validation loss hasn't improved for {patience} epochs, stop training
        if patience_counter >= patience:
            print(f'\nEarly stopping triggered after epoch {epoch+1}')
            break
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    # Plot Loss
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    
    # Plot Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Plot confusion matrix
    val_preds = np.array(val_preds).flatten()
    val_labels_all = np.array(val_labels_all)
    cm = confusion_matrix(val_labels_all, val_preds)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    # Return final metrics
    return {
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies,
        'best_val_acc': best_val_acc,
        'stopped_epoch': epoch + 1
    }

In [14]:
# Start training with early stopping
history = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20, patience=5)

Epoch [1/20] - Train Loss: 0.0138, Train Acc: 100.00% - Val Loss: 1.2272, Val Acc: 62.35%


KeyboardInterrupt: 