In [1]:
import os
import numpy as np
import tifffile as tiff
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights
import torchmetrics
from torchmetrics import Accuracy, Precision, Recall, F1Score, AUROC, ROC, ConfusionMatrix
from tqdm import tqdm
from PIL import Image
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.figure_factory as ff
import random
import warnings
warnings.filterwarnings('ignore')

In [2]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [3]:
# Custom color palette for 6 classes
label_colors = [
    '#01BEFE',  # Class 0
    '#FFDD00',  # Class 1
    '#FF7D00',  # Class 2
    '#FF006D',  # Class 3
    '#ADFF02',  # Class 4
    '#8F00FF'   # Class 5
]

In [4]:
class SeismicPatchDataset(Dataset):
    """Dataset for patch-based seismic facies classification"""
    
    def __init__(self, images_dir, annotations_dir, patch_size=224, 
                 stride=112, transform=None, min_class_fraction=0.1):
        """
        Args:
            patch_size: Size of patches to extract
            stride: Step size when sliding window
            min_class_fraction: Minimum fraction of patch needed to count as class presence
        """
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.patch_size = patch_size
        self.stride = stride
        self.transform = transform
        self.min_class_fraction = min_class_fraction
        
        # Collect image files
        self.image_files = [f for f in os.listdir(images_dir) 
                           if f.lower().endswith(('.tiff', '.tif'))]
        
        # Precompute all valid patches
        self.patches = []
        print("Preprocessing patches...")
        for img_file in tqdm(self.image_files):
            img_path = os.path.join(images_dir, img_file)
            base_name = os.path.splitext(img_file)[0]
            ann_path = os.path.join(annotations_dir, base_name + '.png')
            
            # Load image and annotation
            img = tiff.imread(img_path)
            ann = np.array(Image.open(ann_path))
            
            # Generate patches
            H, W = ann.shape
            for y in range(0, H - patch_size + 1, stride):
                for x in range(0, W - patch_size + 1, stride):
                    patch_ann = ann[y:y+patch_size, x:x+patch_size]
                    
                    # Calculate class presence
                    class_presence = np.zeros(6)
                    total_pixels = patch_size * patch_size
                    for class_id in range(6):
                        class_count = np.sum(patch_ann == class_id)
                        class_presence[class_id] = class_count / total_pixels
                    
                    # Only keep patches with significant features
                    if np.max(class_presence) > min_class_fraction:
                        self.patches.append((img_path, ann_path, y, x, class_presence))
        
        print(f"Generated {len(self.patches)} valid patches")
    
    def __len__(self):
        return len(self.patches)
    
    def __getitem__(self, idx):
        img_path, ann_path, y, x, class_presence = self.patches[idx]
        
        # Load image patch
        img = tiff.imread(img_path)
        patch_img = img[y:y+self.patch_size, x:x+self.patch_size]
        
        # Convert to uint8 if needed
        if patch_img.dtype != np.uint8:
            patch_img = (patch_img - patch_img.min()) / (patch_img.max() - patch_img.min()) * 255
            patch_img = patch_img.astype(np.uint8)
        
        # Convert to PIL for transformations
        patch_img = Image.fromarray(patch_img)
        
        if self.transform:
            patch_img = self.transform(patch_img)
        
        # Convert to multi-label format (1 if class present, 0 otherwise)
        class_labels = (class_presence > self.min_class_fraction).astype(np.float32)
        
        return patch_img, torch.tensor(class_labels, dtype=torch.float)

In [5]:
def create_model(num_classes=6, pretrained=True):
    """Create a ResNet18 model for multi-label classification"""
    if pretrained:
        model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    else:
        model = resnet18(weights=None)
    
    # Modify the final layer for multi-label classification
    # model.fc = nn.Linear(model.fc.in_features, num_classes)
    # In model creation
    model.fc = nn.Sequential(
        nn.Dropout(0.5),  # Add dropout
        nn.Linear(model.fc.in_features, num_classes)
    )
    
    print(f"Model configured for {num_classes}-class multi-label classification")
    
    return model

In [6]:
def train_multilabel_model(model, train_loader, val_loader, num_epochs=25, learning_rate=0.001, device='cuda'):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    model.to(device)
    
    # Initialize metrics
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    best_val_loss = float('inf')
    best_model_state = None
    
    print("Starting training...")
    
    epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", unit="epoch")
    
    for epoch in epoch_pbar:
        # Training phase
        model.train()
        running_loss = 0.0
        running_correct = 0
        total_samples = 0
        
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training", 
                         leave=False, unit="batch")
        
        for images, labels in train_pbar:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            total_samples += images.size(0)
            
            # Calculate accuracy (considering multi-label)
            preds = torch.sigmoid(outputs) > 0.5
            running_correct += torch.sum(preds == labels.byte()).item() / labels.size(1)
            
            train_pbar.set_postfix(loss=f"{loss.item():.4f}")
        
        train_loss = running_loss / total_samples
        train_acc = running_correct / total_samples
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_samples = 0
        
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation", 
                       leave=False, unit="batch")
        
        with torch.no_grad():
            for images, labels in val_pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * images.size(0)
                val_samples += images.size(0)
                
                # Calculate accuracy
                preds = torch.sigmoid(outputs) > 0.5
                val_correct += torch.sum(preds == labels.byte()).item() / labels.size(1)
                
                val_pbar.set_postfix(loss=f"{loss.item():.4f}")
        
        val_loss /= val_samples
        val_acc = val_correct / val_samples
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
        
        # Store metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        
        scheduler.step()
        
        epoch_pbar.set_postfix({
            'Train Loss': f'{train_loss:.4f}',
            'Train Acc': f'{train_acc:.4f}',
            'Val Loss': f'{val_loss:.4f}',
            'Val Acc': f'{val_acc:.4f}',
            'Best Val Loss': f'{best_val_loss:.4f}'
        })
    
    # Load best model
    if best_model_state:
        model.load_state_dict(best_model_state)
    
    training_history = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accuracies': train_accuracies,
        'val_accuracies': val_accuracies,
        'best_val_loss': best_val_loss
    }
    
    print(f"\nTraining completed! Best validation loss: {best_val_loss:.4f}")
    
    return model, training_history

In [7]:
def evaluate_multilabel_model(model, test_loader, device='cuda'):
    """Evaluate multi-label model with metrics"""
    
    accuracy = Accuracy(task="multilabel", num_labels=6)
    precision = Precision(task="multilabel", num_labels=6, average='macro')
    recall = Recall(task="multilabel", num_labels=6, average='macro')
    f1 = F1Score(task="multilabel", num_labels=6, average='macro')
    
    # Per-class metrics
    precision_per_class = Precision(task="multilabel", num_labels=6, average=None)
    recall_per_class = Recall(task="multilabel", num_labels=6, average=None)
    f1_per_class = F1Score(task="multilabel", num_labels=6, average=None)
    
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.cpu()
            outputs = model(images)
            
            # Get predictions and probabilities
            probs = torch.sigmoid(outputs).cpu()
            preds = (probs > 0.5)
            
            all_preds.append(preds)
            all_labels.append(labels)
            all_probs.append(probs)
    
    # Concatenate all predictions and labels
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    all_probs = torch.cat(all_probs)
    
    # Calculate metrics
    acc = accuracy(all_preds, all_labels)
    prec = precision(all_preds, all_labels)
    rec = recall(all_preds, all_labels)
    f1_score = f1(all_preds, all_labels)
    
    # Per-class metrics
    prec_per_class = precision_per_class(all_preds, all_labels)
    rec_per_class = recall_per_class(all_preds, all_labels)
    f1_per_class_scores = f1_per_class(all_preds, all_labels)
    
    return {
        'accuracy': acc.item(),
        'precision': prec.item(),
        'recall': rec.item(),
        'f1_score': f1_score.item(),
        'precision_per_class': prec_per_class.numpy(),
        'recall_per_class': rec_per_class.numpy(),
        'f1_per_class': f1_per_class_scores.numpy(),
        'predictions': all_preds.numpy(),
        'true_labels': all_labels.numpy(),
        'probabilities': all_probs.numpy()
    }

In [8]:
def visualize_sample_patches(dataset, num_samples=12):
    """Visualize sample patches with their annotations and labels"""
    
    indices = np.random.choice(len(dataset), size=min(num_samples, len(dataset)), replace=False)
    
    rows = num_samples // 3
    cols = 3
    
    fig = make_subplots(
        rows=rows, cols=cols,
        subplot_titles=[f'Patch {i+1}' for i in range(num_samples)],
        vertical_spacing=0.08
    )
    
    class_names = [f'Class {i}' for i in range(6)]
    
    for i, idx in enumerate(indices):
        row = i // 3 + 1
        col = i % 3 + 1
        
        # Get patch data
        patch_img, labels = dataset[idx]
        
        # Convert tensor to numpy
        if isinstance(patch_img, torch.Tensor):
            # Denormalize if needed
            patch_img = patch_img.numpy()
            if patch_img.shape[0] == 3:  # CHW format
                patch_img = np.transpose(patch_img, (1, 2, 0))
            # Rescale to 0-255 range
            patch_img = ((patch_img - patch_img.min()) / (patch_img.max() - patch_img.min()) * 255).astype(np.uint8)
        
        # Add image
        fig.add_trace(
            go.Image(z=patch_img),
            row=row, col=col
        )
        
        # Create label text
        active_classes = [class_names[j] for j in range(6) if labels[j] > 0.5]
        label_text = ', '.join(active_classes) if active_classes else 'No significant class'
        
        # Update subplot title
        fig.layout.annotations[i].text = f'Patch {i+1}<br>{label_text}'
    
    fig.update_layout(
        title='Sample Patches with Multi-Label Classifications',
        height=200 * rows,
        showlegend=False
    )
    
    # Remove axes
    for i in range(1, rows + 1):
        for j in range(1, cols + 1):
            fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=i, col=j)
            fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=i, col=j)
    
    fig.show()

In [9]:
def plot_class_distribution(labels, class_names):
    """Plot distribution of classes"""

    class_counts = labels.sum(axis=0)
    
    fig = go.Figure()
    
    fig.add_trace(go.Bar(
        x=class_names,
        y=class_counts,
        marker_color=label_colors,
        text=class_counts,
        textposition='outside',
        name='Class Count'
    ))
    
    fig.update_layout(
        title='Class Distribution in Dataset (Multi-Label)',
        xaxis_title='Classes',
        yaxis_title='Number of Patches',
        template='plotly_white',
        height=500
    )
    
    fig.show()
    
    # Create pie chart for class proportions
    total_labels = labels.sum()
    class_proportions = class_counts / total_labels
    
    fig_pie = go.Figure(data=[go.Pie(
        labels=class_names,
        values=class_proportions,
        marker_colors=label_colors,
        textinfo='label+percent',
        textposition='outside'
    )])
    
    fig_pie.update_layout(
        title='Class Distribution Proportions',
        height=500
    )
    
    fig_pie.show()

In [10]:
def plot_training_history(history):
    """Plot training and validation curves"""
    
    epochs = list(range(1, len(history['train_losses']) + 1))
    
    # Create subplots
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=('Training and Validation Loss', 'Training and Validation Accuracy'),
        specs=[[{"secondary_y": False}, {"secondary_y": False}]]
    )
    
    # Loss plot
    fig.add_trace(
        go.Scatter(x=epochs, y=history['train_losses'], mode='lines+markers', 
                  name='Train Loss', line=dict(color='blue', width=2)),
        row=1, col=1
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=history['val_losses'], mode='lines+markers', 
                  name='Val Loss', line=dict(color='red', width=2)),
        row=1, col=1
    )
    
    # Accuracy plot
    fig.add_trace(
        go.Scatter(x=epochs, y=history['train_accuracies'], mode='lines+markers', 
                  name='Train Accuracy', line=dict(color='green', width=2)),
        row=1, col=2
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=history['val_accuracies'], mode='lines+markers', 
                  name='Val Accuracy', line=dict(color='orange', width=2)),
        row=1, col=2
    )
    
    # Add best validation loss marker
    best_epoch = np.argmin(history['val_losses']) + 1
    best_val_loss = min(history['val_losses'])
    
    fig.add_trace(
        go.Scatter(x=[best_epoch], y=[best_val_loss], mode='markers',
                  marker=dict(size=12, color='red', symbol='star'),
                  name=f'Best Val Loss: {best_val_loss:.4f}'),
        row=1, col=1
    )
    
    # Update layout
    fig.update_xaxes(title_text="Epoch", row=1, col=1)
    fig.update_xaxes(title_text="Epoch", row=1, col=2)
    fig.update_yaxes(title_text="Loss", row=1, col=1)
    fig.update_yaxes(title_text="Accuracy", row=1, col=2)
    
    fig.update_layout(
        height=500,
        template='plotly_white',
        title_text="Training History",
        showlegend=True
    )
    
    fig.show()

In [11]:
def plot_confusion_matrix_multilabel(metrics, class_names):
    """Plot confusion matrix for multi-label classification"""
    
    predictions = metrics['predictions']
    true_labels = metrics['true_labels']
    
    # For multi-label, we'll show per-class confusion matrices
    fig = make_subplots(
        rows=2, cols=3,
        subplot_titles=[f'Class {i}: {class_names[i]}' for i in range(6)],
        specs=[[{"type": "heatmap"}, {"type": "heatmap"}, {"type": "heatmap"}],
               [{"type": "heatmap"}, {"type": "heatmap"}, {"type": "heatmap"}]]
    )
    
    for i in range(6):
        row = i // 3 + 1
        col = i % 3 + 1
        
        # Binary confusion matrix for each class
        y_true = true_labels[:, i]
        y_pred = predictions[:, i]
        
        # Calculate confusion matrix
        tn = np.sum((y_true == 0) & (y_pred == 0))
        fp = np.sum((y_true == 0) & (y_pred == 1))
        fn = np.sum((y_true == 1) & (y_pred == 0))
        tp = np.sum((y_true == 1) & (y_pred == 1))
        
        cm = np.array([[tn, fp], [fn, tp]])
        
        fig.add_trace(
            go.Heatmap(
                z=cm,
                x=['Predicted 0', 'Predicted 1'],
                y=['Actual 0', 'Actual 1'],
                colorscale='Blues',
                showscale=False,
                text=cm,
                texttemplate="%{text}",
                textfont={"size": 12}
            ),
            row=row, col=col
        )
    
    fig.update_layout(
        title='Confusion Matrices per Class (Multi-Label)',
        height=600,
        template='plotly_white'
    )
    
    fig.show()

In [12]:
def plot_performance_metrics(metrics, class_names):
    """Plot performance metrics"""
    
    # Overall metrics
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=('Overall Performance Metrics', 'Per-Class Precision', 
                       'Per-Class Recall', 'Per-Class F1-Score'),
        specs=[[{"type": "bar"}, {"type": "bar"}],
               [{"type": "bar"}, {"type": "bar"}]]
    )
    
    # Overall metrics
    overall_metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
    overall_values = [metrics['accuracy'], metrics['precision'], 
                     metrics['recall'], metrics['f1_score']]
    colors = ['skyblue', 'lightcoral', 'lightgreen', 'gold']
    
    fig.add_trace(
        go.Bar(x=overall_metrics, y=overall_values, 
               marker_color=colors, showlegend=False,
               text=[f'{v:.3f}' for v in overall_values],
               textposition='outside'),
        row=1, col=1
    )
    
    # Per-class precision
    fig.add_trace(
        go.Bar(x=class_names, y=metrics['precision_per_class'], 
               marker_color=label_colors, showlegend=False,
               text=[f'{v:.3f}' for v in metrics['precision_per_class']],
               textposition='outside'),
        row=1, col=2
    )
    
    # Per-class recall
    fig.add_trace(
        go.Bar(x=class_names, y=metrics['recall_per_class'], 
               marker_color=label_colors, showlegend=False,
               text=[f'{v:.3f}' for v in metrics['recall_per_class']],
               textposition='outside'),
        row=2, col=1
    )
    
    # Per-class F1-score
    fig.add_trace(
        go.Bar(x=class_names, y=metrics['f1_per_class'], 
               marker_color=label_colors, showlegend=False,
               text=[f'{v:.3f}' for v in metrics['f1_per_class']],
               textposition='outside'),
        row=2, col=2
    )
    
    # Update layout
    fig.update_yaxes(range=[0, 1.1], row=1, col=1)
    fig.update_yaxes(range=[0, 1.1], row=1, col=2)
    fig.update_yaxes(range=[0, 1.1], row=2, col=1)
    fig.update_yaxes(range=[0, 1.1], row=2, col=2)
    
    fig.update_layout(
        height=800,
        template='plotly_white',
        title_text="Comprehensive Performance Metrics"
    )
    
    fig.show()

In [13]:
def plot_roc_curves_multilabel(metrics, class_names):
    """Plot ROC curves for each class"""
    
    fig = go.Figure()
    
    probabilities = metrics['probabilities']
    true_labels = metrics['true_labels']
    
    # Plot ROC curve for each class
    for i in range(6):
        try:
            # Convert to tensors for ROC computation
            probs = torch.tensor(probabilities[:, i], dtype=torch.float32)
            labels = torch.tensor(true_labels[:, i], dtype=torch.float32).long()
            
            # Compute ROC curve
            roc = ROC(task='binary')
            fpr, tpr, _ = roc(probs, labels)
            
            # Calculate AUC
            auroc = AUROC(task='binary')
            auc_score = auroc(probs, labels)
            
            fig.add_trace(
                go.Scatter(
                    x=fpr.numpy(), 
                    y=tpr.numpy(), 
                    mode='lines',
                    name=f'{class_names[i]} (AUC={auc_score:.3f})',
                    line=dict(color=label_colors[i], width=2)
                )
            )
        except Exception as e:
            print(f"Warning: Could not compute ROC for {class_names[i]}: {e}")
            continue
    
    # Add diagonal line
    fig.add_trace(
        go.Scatter(x=[0, 1], y=[0, 1], mode='lines', 
                  name='Random', line=dict(dash='dash', color='black', width=1))
    )
    
    fig.update_layout(
        title='ROC Curves for Multi-Label Classification',
        xaxis_title='False Positive Rate',
        yaxis_title='True Positive Rate',
        template='plotly_white',
        height=600,
        legend=dict(x=0.6, y=0.1)
    )
    
    fig.show()

In [14]:
def create_report(metrics, history, class_names):
    """Create report with metrics"""
    
    print("MODEL EVALUATION REPORT")
    print("="*80)
    
    # 1. Training History
    print("\n1. Training History")
    plot_training_history(history)
    
    # 2. Performance Metrics
    print("\n2. Performance Metrics")
    plot_performance_metrics(metrics, class_names)
    
    # 3. Confusion Matrices
    print("\n3. Confusion Matrices")
    plot_confusion_matrix_multilabel(metrics, class_names)
    
    # 4. ROC Curves
    print("\n4. ROC Curves")
    plot_roc_curves_multilabel(metrics, class_names)
    
    # 5. Summary Statistics
    print("\n5. Summary Statistics")
    print(f"Overall Accuracy: {metrics['accuracy']:.4f}")
    print(f"Overall Precision: {metrics['precision']:.4f}")
    print(f"Overall Recall: {metrics['recall']:.4f}")
    print(f"Overall F1-Score: {metrics['f1_score']:.4f}")
    
    print("\nPer-Class Performance:")
    for i, class_name in enumerate(class_names):
        print(f"  {class_name}:")
        print(f"    Precision: {metrics['precision_per_class'][i]:.4f}")
        print(f"    Recall: {metrics['recall_per_class'][i]:.4f}")
        print(f"    F1-Score: {metrics['f1_per_class'][i]:.4f}")

In [15]:
def main():
    
    # Update to your dataset location
    F3_ROOT_DIR = "/workspaces/Minerva-Dev/shared_data/seismic/f3_segmentation_N"
    
    # Training parameters
    BATCH_SIZE = 32
    NUM_EPOCHS = 25
    LEARNING_RATE = 0.001
    PATCH_SIZE = 224
    STRIDE = 112
    MIN_CLASS_FRACTION = 0.1
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Define transformations
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = SeismicPatchDataset(
        images_dir=os.path.join(F3_ROOT_DIR, "images/train"),
        annotations_dir=os.path.join(F3_ROOT_DIR, "annotations/train"),
        patch_size=PATCH_SIZE,
        stride=STRIDE,
        transform=train_transform,
        min_class_fraction=MIN_CLASS_FRACTION
    )
    
    val_dataset = SeismicPatchDataset(
        images_dir=os.path.join(F3_ROOT_DIR, "images/val"),
        annotations_dir=os.path.join(F3_ROOT_DIR, "annotations/val"),
        patch_size=PATCH_SIZE,
        stride=STRIDE,
        transform=val_transform,
        min_class_fraction=MIN_CLASS_FRACTION
    )
    
    test_dataset = SeismicPatchDataset(
        images_dir=os.path.join(F3_ROOT_DIR, "images/test"),
        annotations_dir=os.path.join(F3_ROOT_DIR, "annotations/test"),
        patch_size=PATCH_SIZE,
        stride=STRIDE,
        transform=val_transform,
        min_class_fraction=MIN_CLASS_FRACTION
    )
    
    print(f"\nDataset sizes:")
    print(f"Train: {len(train_dataset)} patches")
    print(f"Validation: {len(val_dataset)} patches")
    print(f"Test: {len(test_dataset)} patches")
    
    # Visualize sample patches BEFORE training
    print("\nVisualizing sample patches from test set...")
    visualize_sample_patches(test_dataset)
    
    # Collect labels for class distribution
    print("\nAnalyzing class distribution...")
    all_labels = []
    for _, labels in DataLoader(train_dataset, batch_size=128):
        all_labels.append(labels)
    all_labels = torch.cat(all_labels).numpy()
    
    class_names = [f'Class {i}' for i in range(6)]
    plot_class_distribution(all_labels, class_names)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # Create and train model
    model = create_model(num_classes=6, pretrained=True)
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    trained_model, history = train_multilabel_model(
        model, train_loader, val_loader, 
        num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE, device=device
    )
    
    print("\nEvaluating on test set...")
    test_metrics = evaluate_multilabel_model(trained_model, test_loader, device=device)
    
    print("\n" + "="*50)
    print("FINAL TEST RESULTS")
    print("="*50)
    print(f"Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Precision: {test_metrics['precision']:.4f}")
    print(f"Recall: {test_metrics['recall']:.4f}")
    print(f"F1-Score: {test_metrics['f1_score']:.4f}")
    
    # Generate report
    print("\nGenerating report...")
    create_report(test_metrics, history, class_names)
    
    torch.save(trained_model.state_dict(), 'f3_seismic_patch_classifier.pth')
    print("\nModel saved as 'f3_seismic_patch_classifier.pth'")
    
    return trained_model, test_metrics, history

In [16]:
if __name__ == "__main__":
    print("F3 Seismic Patch-based Classification")
    model, metrics, history = main()

F3 Seismic Patch-based Classification
Using device: cuda
Preprocessing patches...


100%|██████████| 992/992 [00:03<00:00, 249.27it/s]


Generated 3067 valid patches
Preprocessing patches...


100%|██████████| 110/110 [00:00<00:00, 231.76it/s]


Generated 340 valid patches
Preprocessing patches...


100%|██████████| 400/400 [00:01<00:00, 217.96it/s]


Generated 1800 valid patches

Dataset sizes:
Train: 3067 patches
Validation: 340 patches
Test: 1800 patches

Visualizing sample patches from test set...



Analyzing class distribution...


Model configured for 6-class multi-label classification
Total parameters: 11,179,590
Starting training...


Training Progress: 100%|██████████| 25/25 [13:58<00:00, 33.56s/epoch, Train Loss=0.0034, Train Acc=0.9990, Val Loss=0.0036, Val Acc=0.9990, Best Val Loss=0.0025]



Training completed! Best validation loss: 0.0025

Evaluating on test set...


Evaluating: 100%|██████████| 57/57 [00:06<00:00,  8.29it/s]



FINAL TEST RESULTS
Accuracy: 0.9548
Precision: 0.7851
Recall: 0.7677
F1-Score: 0.7713

Generating report...
MODEL EVALUATION REPORT

1. Training History



2. Performance Metrics



3. Confusion Matrices



4. ROC Curves



5. Summary Statistics
Overall Accuracy: 0.9548
Overall Precision: 0.7851
Overall Recall: 0.7677
Overall F1-Score: 0.7713

Per-Class Performance:
  Class 0:
    Precision: 1.0000
    Recall: 1.0000
    F1-Score: 1.0000
  Class 1:
    Precision: 0.8034
    Recall: 0.9696
    F1-Score: 0.8787
  Class 2:
    Precision: 1.0000
    Recall: 1.0000
    F1-Score: 1.0000
  Class 3:
    Precision: 0.9072
    Recall: 0.9167
    F1-Score: 0.9119
  Class 4:
    Precision: 1.0000
    Recall: 0.7199
    F1-Score: 0.8372
  Class 5:
    Precision: 0.0000
    Recall: 0.0000
    F1-Score: 0.0000

Model saved as 'f3_seismic_patch_classifier.pth'
