# Training New Model Architecture on Pavia University Dataset

This notebook trains the new model architecture from `model.py` on the Pavia University hyperspectral dataset.

## Model Architecture
- **newFastViT**: A novel Fast Vision Transformer with:
  - Efficient Attention mechanism
  - Spectral Attention module
  - Transformer blocks with residual connections
  - Optimized for hyperspectral image classification

## Dataset
- **Pavia University**: 610×340 pixels, 103 spectral bands
- **Classes**: 9 land-cover categories
- **Window Size**: 5×5 spatial patches


In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import scipy.io
from torch.utils.data import Dataset
from transformers import TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report, cohen_kappa_score, precision_score, recall_score, f1_score
import time
from einops import rearrange
from thop import profile

# Import custom modules
from data_loader import load_pavia_university, preprocess_data, PaviaUniversityDataset
from model import newFastViT
from utils import (
    calculate_latency_per_image,
    calculate_throughput,
    overall_accuracy,
    average_accuracy,
    kappa_coefficient,
    calculate_f1_precision_recall,
    count_model_parameters,
    calculate_gflops
)

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


## Step 1: Load Dataset

Load the Pavia University hyperspectral dataset. Update the file paths according to your dataset location.


In [None]:
# Update these paths to point to your dataset files
image_file = "/content/PaviaU.mat"  # Update this path
gt_file = "/content/PaviaU_gt.mat"  # Update this path

# Load the dataset
image_data, ground_truth = load_pavia_university(image_file, gt_file)
print(f"\nDataset loaded successfully!")
print(f"Image data shape: {image_data.shape}")
print(f"Ground truth shape: {ground_truth.shape}")


## Step 2: Preprocess Data

Preprocess the data to extract spatial-spectral patches and prepare labels.


In [None]:
# Preprocess data with 5x5 window size
window_size = 5
spatial_spectral_data, y, label_encoder = preprocess_data(image_data, ground_truth, window_size=window_size)

print(f"\nData preprocessed successfully!")
print(f"Spatial-spectral data shape: {spatial_spectral_data.shape}")
print(f"Labels shape: {y.shape}")
print(f"Number of classes: {len(np.unique(y))}")
print(f"Class distribution: {np.bincount(y)}")


## Step 3: Split Dataset

Split the dataset into training and testing sets with stratification to maintain class distribution.


In [None]:
# Split dataset: 80% train, 20% test
train_indices, test_indices = train_test_split(
    np.arange(len(y)), 
    test_size=0.2, 
    stratify=y, 
    random_state=42
)

train_dataset = PaviaUniversityDataset(spatial_spectral_data[train_indices], y[train_indices])
test_dataset = PaviaUniversityDataset(spatial_spectral_data[test_indices], y[test_indices])

print(f"Training samples: {len(train_dataset)}")
print(f"Testing samples: {len(test_dataset)}")
print(f"Train class distribution: {np.bincount(y[train_indices])}")
print(f"Test class distribution: {np.bincount(y[test_indices])}")


## Step 4: Initialize Model

Initialize the new model architecture from `model.py`.


In [None]:
# Initialize the new model
num_classes = len(np.unique(y))
num_channels = spatial_spectral_data.shape[-1]  # 103 spectral bands

# IMPORTANT: embed_dim must be divisible by num_heads
# For patch_size=4 with image_size=5, the number of patches will be:
# floor((image_size - patch_size) / patch_size) + 1 = floor((5-4)/4) + 1 = 1
# So we get 1 patch total

# Model configuration
patch_size = 4
embed_dim = 192  # Must be divisible by num_heads (192 / 4 = 48)
num_heads = 4
depth = 4

# Validate configuration
if embed_dim % num_heads != 0:
    raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")

# Calculate actual number of patches that will be created
# Conv2d with kernel_size=patch_size, stride=patch_size on image_size x image_size
# Output size = floor((image_size - patch_size) / patch_size) + 1
actual_patches_h = (window_size - patch_size) // patch_size + 1
actual_patches_w = (window_size - patch_size) // patch_size + 1
actual_num_patches = actual_patches_h * actual_patches_w

print(f"Model Configuration:")
print(f"  Image size: {window_size}x{window_size}")
print(f"  Patch size: {patch_size}x{patch_size}")
print(f"  Actual patches: {actual_patches_h}x{actual_patches_w} = {actual_num_patches}")
print(f"  Embed dim: {embed_dim} (divisible by num_heads={num_heads} ✓)")
print(f"  Head dim: {embed_dim // num_heads}")
print(f"  Depth: {depth}")

model = newFastViT(
    image_size=window_size,
    patch_size=patch_size,
    num_channels=num_channels,
    num_classes=num_classes,
    embed_dim=embed_dim,
    depth=depth,
    num_heads=num_heads,
    mlp_ratio=4.0
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(f"\nModel initialized successfully!")
print(f"Model device: {device}")
print(f"Number of parameters: {count_model_parameters(model):.2f} M")

# Calculate GFLOPs
try:
    gflops = calculate_gflops(model, train_dataset, device)
    print(f"GFLOPs: {gflops:.2f}")
except Exception as e:
    print(f"Warning: Could not calculate GFLOPs: {e}")
    print("This is okay, you can proceed with training.")


## Step 5: Setup Training

Configure training arguments and create the trainer.


In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./results_new_model",
    num_train_epochs=20,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_steps=100,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none",
    save_total_limit=3,
    metric_for_best_model="eval_loss",
    greater_is_better=False
)

# Data collator
def data_collator(data):
    return {
        'x': torch.stack([d['x'] for d in data]),
        'labels': torch.stack([d['labels'] for d in data])
    }

# Compute metrics function
def compute_metrics(p):
    predictions = p.predictions.argmax(-1)
    labels = p.label_ids
    accuracy = (predictions == labels).mean()
    return {"accuracy": accuracy}

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    data_collator=data_collator
)

print("Trainer setup complete!")


## Step 6: Train Model

Train the model on the training dataset.


In [None]:
# Train the model
print("Starting training...")
trainer.train()
print("Training completed!")


## Step 7: Evaluate Model

Evaluate the trained model on the test dataset.


In [None]:
# Evaluate the model
eval_results = trainer.evaluate()
print("\nEvaluation Results:")
for key, value in eval_results.items():
    print(f"{key}: {value:.4f}")


## Step 8: Generate Predictions and Calculate Metrics

Generate predictions and calculate comprehensive classification metrics.


In [None]:
# Generate predictions
predictions = trainer.predict(test_dataset)
y_pred = np.argmax(predictions.predictions, axis=1)
y_true = y[test_indices]

# Calculate metrics
oa = overall_accuracy(y_true, y_pred)
aa = average_accuracy(y_true, y_pred)
kappa = kappa_coefficient(y_true, y_pred)
f1, precision, recall = calculate_f1_precision_recall(y_true, y_pred)

print("\n" + "="*50)
print("Classification Metrics")
print("="*50)
print(f"Overall Accuracy (OA):     {oa:.4f}")
print(f"Average Accuracy (AA):      {aa:.4f}")
print(f"Kappa Coefficient:          {kappa:.4f}")
print(f"F1 Score (weighted):         {f1:.4f}")
print(f"Precision (weighted):        {precision:.4f}")
print(f"Recall (weighted):           {recall:.4f}")
print("="*50)


## Step 9: Performance Metrics

Calculate model performance metrics including latency, throughput, and parameter count.


In [None]:
# Create test data loader for performance metrics
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# Calculate performance metrics
latency = calculate_latency_per_image(model, test_loader, device)
throughput = calculate_throughput(model, test_loader, device)
params = count_model_parameters(model)

print("\n" + "="*50)
print("Performance Metrics")
print("="*50)
print(f"Latency per image:          {latency:.4f} ms")
print(f"Throughput:                  {throughput:.2f} samples/sec")
print(f"Model Parameters:            {params:.2f} M")
print(f"GFLOPs:                      {gflops:.2f}")
print("="*50)


## Step 10: Confusion Matrix

Visualize the confusion matrix to understand classification performance per class.


In [None]:
# Generate confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar_kws={'label': 'Count'})
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Confusion Matrix - New Model Architecture', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Print classification report
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))


## Step 11: Per-Class Accuracy

Visualize per-class accuracy to identify which classes are performing well.


In [None]:
# Calculate per-class accuracy
class_accuracies = cm.diagonal() / cm.sum(axis=1)
class_names = [f"Class {i}" for i in range(num_classes)]

# Plot per-class accuracy
plt.figure(figsize=(10, 6))
bars = plt.bar(range(num_classes), class_accuracies, color='steelblue', alpha=0.7)
plt.xlabel('Class', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.title('Per-Class Accuracy - New Model Architecture', fontsize=14, fontweight='bold')
plt.xticks(range(num_classes), class_names, rotation=45, ha='right')
plt.ylim([0, 1])
plt.grid(axis='y', alpha=0.3)

# Add value labels on bars
for i, (bar, acc) in enumerate(zip(bars, class_accuracies)):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
             f'{acc:.3f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

print("\nPer-Class Accuracies:")
for i, acc in enumerate(class_accuracies):
    print(f"  {class_names[i]}: {acc:.4f}")


## Summary

This notebook successfully trained the new model architecture (`newFastViT`) from `model.py` on the Pavia University hyperspectral dataset. The model uses:

- **Efficient Attention**: Optimized attention mechanism for faster computation
- **Spectral Attention**: Specialized attention for spectral band processing
- **Transformer Blocks**: 6 layers with residual connections
- **Spatial-Spectral Processing**: 5×5 patches with 103 spectral bands

### Key Results:
- Overall Accuracy, Average Accuracy, and Kappa Coefficient are displayed above
- Confusion matrix shows per-class performance
- Performance metrics include latency, throughput, and model size

The trained model is saved in `./results_new_model/` directory.
