# Fruits-360 Classification with TverskyReduceBackbone

This notebook demonstrates parameter-efficient fruit classification using **TverskyReduceBackbone** with the **GlobalFeature bank** for feature sharing. The model combines CNN feature extraction with Tversky similarity-based projections that share feature matrices across layers, significantly reducing the number of trainable parameters.

<a href="https://colab.research.google.com/github/madch3m/tverskysimilaritygrad/blob/main/tverskycv/notebooks/Classification_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


## 1. Setup and Installation

Install required packages and set up the environment.

In [None]:
# Clone the repository (for Colab users)
# 
# INSTRUCTIONS FOR COLAB:
# 1. Uncomment the git clone line below
# 2. The repository will be cloned to /content/tverskysimilaritygrad
# 3. Run this cell to clone the repository
# 4. The import cell will automatically detect and use the cloned repository

# Uncomment the line below:
# !git clone https://github.com/madch3m/tverskysimilaritygrad.git

# For local development, skip this cell
import os
import sys

# Check if we're in Colab
IN_COLAB = os.path.exists('/content')

if IN_COLAB:
    repo_path = '/content/tverskysimilaritygrad'
    if os.path.exists(repo_path):
        print(f"✓ Repository found at {repo_path}")
        os.chdir(repo_path)
        # Add to sys.path immediately
        if repo_path not in sys.path:
            sys.path.insert(0, repo_path)
        print(f"✓ Changed to: {os.getcwd()}")
        print(f"✓ Added to sys.path")
        
        # Verify tverskycv exists
        if os.path.exists(os.path.join(repo_path, 'tverskycv')):
            print(f"✓ tverskycv folder found")
        else:
            print(f"⚠ tverskycv folder not found at {repo_path}/tverskycv")
    else:
        print("⚠ Repository not found.")
        print("  Please uncomment and run the git clone command above.")
        print(f"  Expected location: {repo_path}")
else:
    print("✓ Running locally - repository should already be available")
    # For local, try to find project root
    current = os.getcwd()
    if os.path.exists(os.path.join(current, 'tverskycv')):
        print(f"✓ Found tverskycv in current directory: {current}")
    elif os.path.exists(os.path.join(current, '..', 'tverskycv')):
        parent = os.path.abspath('..')
        print(f"✓ Found tverskycv in parent directory: {parent}")
        if parent not in sys.path:
            sys.path.insert(0, parent)


In [None]:
# Install required packages
%pip install torch torchvision transformers numpy matplotlib seaborn tqdm datasets Pillow

## 2. Import Libraries and Setup Path

Import necessary modules and set up the project path.

In [None]:
# Import Libraries and Setup Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from datasets import load_dataset
import sys
import os
from collections import Counter

# Automatic path setup for both Colab and local development
print("Setting up paths...")
_current_dir = os.getcwd()
print(f"Current directory: {_current_dir}")

_project_root = None

# Try to find project root (directory containing tverskycv folder)
# Check multiple possible locations
_search_paths = [
    _current_dir,  # Current directory
    os.path.join(_current_dir, 'tverskysimilaritygrad'),  # If we're in /content
    '/content/tverskysimilaritygrad',  # Colab default after clone
    os.path.join(_current_dir, '..'),
    os.path.join(_current_dir, '..', '..'),
    os.path.join(_current_dir, '..', '..', '..'),
    os.path.abspath('.'),  # Absolute current
    os.path.abspath('..'),  # Parent
]

print("\nSearching for project root...")
for path in _search_paths:
    abs_path = os.path.abspath(path)
    tverskycv_path = os.path.join(abs_path, 'tverskycv')
    if os.path.exists(tverskycv_path) and os.path.isdir(tverskycv_path):
        _project_root = abs_path
        print(f"  ✓ Found tverskycv at: {tverskycv_path}")
        break
    else:
        print(f"  ✗ Not found: {tverskycv_path}")

# Change to project root if found
if _project_root:
    os.chdir(_project_root)
    if _project_root not in sys.path:
        sys.path.insert(0, _project_root)
    print(f"\n✓ Project root: {_project_root}")
    print(f"✓ Working directory: {os.getcwd()}")
    print(f"✓ Added to sys.path")
else:
    # Fallback: try adding current directory and common paths
    print("\n⚠ Project root not found. Trying fallback paths...")
    _fallback_paths = [
        os.path.abspath('.'),
        os.path.abspath('..'),
        os.path.abspath('../..'),
        '/content',
        '/content/tverskysimilaritygrad',
    ]
    for path in _fallback_paths:
        abs_path = os.path.abspath(path)
        if abs_path not in sys.path and os.path.exists(abs_path):
            sys.path.insert(0, abs_path)
            print(f"  Added to sys.path: {abs_path}")
    print(f"\n⚠ Current directory: {os.getcwd()}")
    print(f"⚠ If imports fail, make sure you've:")
    print(f"   1. Run the git clone cell above")
    print(f"   2. In Colab, the repo should be at: /content/tverskysimilaritygrad")

# Verify tverskycv can be found
print("\nVerifying tverskycv module location...")
try:
    import tverskycv
    print(f"✓ tverskycv found at: {tverskycv.__file__}")
except ImportError:
    print("✗ tverskycv module not found in Python path")
    print(f"  sys.path entries: {sys.path[:10]}")

# Now import Tversky modules
print("\nImporting Tversky modules...")
try:
    from tverskycv.models.backbones.shared_tversky import GlobalFeature
    print("  ✓ GlobalFeature imported")
    from tverskycv.models.backbones.tversky_reduce_backbone import (
        TverskyReduceBackbone,
        SharedTverskyCompact,
        SharedTverskyInterpretable
    )
    print("  ✓ TverskyReduceBackbone components imported")
    print("\n✓ All imports successful!")
except ImportError as e:
    print(f"\n✗ Import error: {e}")
    print("\nTroubleshooting steps:")
    print("  1. Make sure you've run the git clone cell (Cell 2) above")
    print("  2. In Colab, uncomment and run: !git clone https://github.com/madch3m/tverskysimilaritygrad.git")
    print("  3. After cloning, the repo should be at: /content/tverskysimilaritygrad")
    print(f"  4. Current working directory: {os.getcwd()}")
    print(f"  5. Check if tverskycv exists: {os.path.exists('tverskycv')}")
    print(f"  6. sys.path entries: {sys.path[:10]}")
    print("\nIf still failing, try:")
    print("  - Restart the runtime after cloning")
    print("  - Manually add: sys.path.insert(0, '/content/tverskysimilaritygrad')")


## 3. Understanding TverskyReduceBackbone

**TverskyReduceBackbone** combines CNN feature extraction with Tversky projection layers that use the **GlobalFeature bank** for parameter sharing:

- **CNN Feature Extraction**: Standard convolutional layers extract spatial features from images
- **Tversky Projection**: Compact or interpretable Tversky layers compute psychologically plausible similarity
- **Feature Matrix Sharing**: Multiple layers share the same feature transformation matrix via GlobalFeature
- **Parameter Reduction**: Significantly fewer trainable parameters compared to standard architectures

### Key Benefits:
- **Parameter Efficiency**: Shared features reduce total parameters by 50-90%
- **Memory Efficient**: Lower memory footprint during training and inference
- **Maintains Performance**: Shared features can still learn effective representations
- **Two Variants**: Compact (minimal params) or Interpretable (visualization-friendly)


## 4. Load Fruits-360 Dataset

Load the fruits-360 dataset from Hugging Face and prepare it for training.

In [None]:
# Colab-specific optimizations
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['OMP_NUM_THREADS'] = '1'

# Load fruits-360 dataset from Hugging Face
print("Loading fruits-360 dataset...")
ds = load_dataset("PedroSampaio/fruits-360")

print(f"\nDataset splits: {list(ds.keys())}")
print(f"Train samples: {len(ds['train'])}")
if 'test' in ds:
    print(f"Test samples: {len(ds['test'])}")
if 'validation' in ds:
    print(f"Validation samples: {len(ds['validation'])}")

# Get number of classes
if 'train' in ds:
    labels = ds['train']['label']
    num_classes = len(set(labels))
    print(f"\nNumber of classes: {num_classes}")
    print(f"Sample labels: {sorted(set(labels))[:10]}...")

In [None]:
# Prepare data transformations
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize to 64x64 for faster training
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

def transform_dataset(examples):
    """Transform images in the dataset."""
    images = examples['image']
    transformed_images = []
    
    for img in images:
        if isinstance(img, Image.Image):
            img_tensor = transform(img)
        elif isinstance(img, np.ndarray):
            img_pil = Image.fromarray(img)
            img_tensor = transform(img_pil)
        else:
            img_tensor = torch.tensor(img) if not isinstance(img, torch.Tensor) else img
        transformed_images.append(img_tensor.numpy())
    
    examples['image'] = transformed_images
    return examples

# Apply transformations
print("Transforming dataset...")
train_data_transformed = ds['train'].map(
    transform_dataset, 
    batched=True, 
    batch_size=100,
    remove_columns=[col for col in ds['train'].column_names if col not in ['image', 'label']]
)

if 'test' in ds:
    test_data_transformed = ds['test'].map(
        transform_dataset,
        batched=True,
        batch_size=100,
        remove_columns=[col for col in ds['test'].column_names if col not in ['image', 'label']]
    )
else:
    # Use validation as test if test split doesn't exist
    test_data_transformed = ds['validation'].map(
        transform_dataset,
        batched=True,
        batch_size=100,
        remove_columns=[col for col in ds['validation'].column_names if col not in ['image', 'label']]
    )

print("✓ Dataset transformation complete!")

In [None]:
# Create PyTorch Dataset wrapper
class Fruits360Dataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        image = item['image']
        label = item['label']
        
        # Convert to tensor if needed
        if not isinstance(image, torch.Tensor):
            if isinstance(image, list):
                image = torch.tensor(image)
            elif isinstance(image, np.ndarray):
                image = torch.from_numpy(image)
            else:
                image = torch.tensor(image)
        
        # Ensure image is in (C, H, W) format for backbone
        if len(image.shape) == 1:
            # Flattened image, reshape to (C, H, W)
            image = image.view(3, 64, 64)
        elif len(image.shape) == 3 and image.shape[0] != 3:
            # Might be (H, W, C), transpose to (C, H, W)
            if image.shape[2] == 3:
                image = image.permute(2, 0, 1)
        
        return image, label

# Create datasets
train_dataset = Fruits360Dataset(train_data_transformed)
test_dataset = Fruits360Dataset(test_data_transformed)

# Create data loaders
batch_size = 32
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=0,  # Set to 0 for Colab compatibility
    pin_memory=False
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=0,
    pin_memory=False
)

print(f"✓ Data loaders created!")
print(f"  Batch size: {batch_size}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Test batches: {len(test_loader)}")

# Get image dimensions
sample_image, _ = train_dataset[0]
print(f"  Image shape: {sample_image.shape}")
print(f"  Image size: {sample_image.shape[1]}x{sample_image.shape[2]}")


## 5. Create Model with TverskyReduceBackbone

Create a classification model using TverskyReduceBackbone with feature sharing.


In [None]:
class FruitsClassifierWithTverskyBackbone(nn.Module):
    """Fruit classifier using TverskyReduceBackbone with feature sharing."""
    
    def __init__(
        self,
        num_classes,
        in_channels=3,
        img_size=64,
        variant='compact',  # 'compact' or 'interpretable'
        out_dim=128,
        n_features=64,  # For interpretable variant
        feature_key='fruits',
        share_features=True,
        alpha=1.0,
        beta=1.0
    ):
        super().__init__()
        
        # TverskyReduceBackbone extracts features
        self.backbone = TverskyReduceBackbone(
            out_dim=out_dim,
            in_channels=in_channels,
            img_size=img_size,
            variant=variant,
            n_features=n_features,
            feature_key=feature_key,
            share_features=share_features,
            alpha=alpha,
            beta=beta
        )
        
        # Classification head (simple linear layer)
        self.classifier = nn.Linear(out_dim, num_classes)
    
    def forward(self, x):
        # Extract features using Tversky backbone
        features = self.backbone(x)  # (B, out_dim)
        
        # Classify
        logits = self.classifier(features)  # (B, num_classes)
        return logits

# Model configuration
img_size = 64
in_channels = 3
variant = 'compact'  # Use 'compact' for efficiency or 'interpretable' for visualization
out_dim = 128  # Output dimension from backbone
feature_key = 'fruits_shared'  # Shared feature key

# Clear GlobalFeature bank before creating model
gf = GlobalFeature()
gf.clear()

# Create model with feature sharing
model_shared = FruitsClassifierWithTverskyBackbone(
    num_classes=num_classes,
    in_channels=in_channels,
    img_size=img_size,
    variant=variant,
    out_dim=out_dim,
    feature_key=feature_key,
    share_features=True  # Enable feature sharing
)

print(f"✓ Model with TverskyReduceBackbone created!")
print(f"  Variant: {variant}")
print(f"  Image size: {img_size}x{img_size}")
print(f"  Input channels: {in_channels}")
print(f"  Backbone output dim: {out_dim}")
print(f"  Number of classes: {num_classes}")

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

params_shared = count_parameters(model_shared)
print(f"  Total parameters: {params_shared:,}")


## 6. Parameter Efficiency Analysis

Compare models with and without feature sharing to demonstrate parameter reduction.

In [None]:
# Create model WITHOUT feature sharing for comparison
gf.clear()  # Clear before creating non-shared model

model_no_sharing = FruitsClassifierWithTverskyBackbone(
    num_classes=num_classes,
    in_channels=in_channels,
    img_size=img_size,
    variant=variant,
    out_dim=out_dim,
    feature_key='fruits_no_sharing',
    share_features=False  # Disable feature sharing
)

params_no_sharing = count_parameters(model_no_sharing)

# Calculate reduction
reduction = ((params_no_sharing - params_shared) / params_no_sharing) * 100
params_saved = params_no_sharing - params_shared

print("Parameter Efficiency Comparison:")
print("=" * 70)
print(f"Model with feature sharing:    {params_shared:,} parameters")
print(f"Model without feature sharing: {params_no_sharing:,} parameters")
print(f"Parameters saved:              {params_saved:,}")
print(f"Reduction:                      {reduction:.2f}%")
print(f"Efficiency ratio:               {params_no_sharing / params_shared:.2f}x fewer parameters")
print("=" * 70)

# Analyze shared features
print("\nShared Feature Analysis:")
print("=" * 70)
# Re-create shared model to see features
gf.clear()
model_shared_check = FruitsClassifierWithTverskyBackbone(
    num_classes=num_classes,
    in_channels=in_channels,
    img_size=img_size,
    variant=variant,
    out_dim=out_dim,
    feature_key=feature_key,
    share_features=True
)
shared_features = gf._feature_matrices
total_shared_params = 0

for key, value in shared_features.items():
    if isinstance(value, nn.Parameter):
        param_count = value.numel()
        total_shared_params += param_count
        print(f"  {key}: {param_count:,} parameters")
    elif isinstance(value, dict):
        param_count = sum(p.numel() for p in value.values() if isinstance(p, nn.Parameter))
        total_shared_params += param_count
        print(f"  {key}: {param_count:,} parameters (Tversky params)")

print(f"\nTotal shared parameters: {total_shared_params:,}")
if params_shared > 0:
    print(f"Shared percentage: {(total_shared_params / params_shared * 100):.2f}%")
print("=" * 70)


### Compare with Traditional CNN

Now let's compare with a traditional convolutional neural network to see the parameter difference.

In [None]:
# Traditional CNN Model for comparison
class TraditionalCNN(nn.Module):
    """Traditional convolutional neural network for fruit classification."""
    
    def __init__(self, num_classes, img_size=64):
        super().__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        # Pooling layers
        self.pool = nn.MaxPool2d(2, 2)
        
        # Calculate flattened size after convolutions
        # After 3 pooling operations: 64 -> 32 -> 16 -> 8
        self.flattened_size = 128 * 8 * 8
        
        # Fully connected layers
        self.fc1 = nn.Linear(self.flattened_size, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, num_classes)
        
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        # Input should be in (B, C, H, W) format
        batch_size = x.size(0)
        
        # Handle both (B, C, H, W) and flattened formats
        if len(x.shape) == 2:
            # Flattened input, reshape to (B, C, H, W)
            x = x.view(batch_size, 3, 64, 64)
        
        # Assuming input is flattened (B, 64*64*3) = (B, 12288)
        batch_size = x.size(0)
        x = x.view(batch_size, 3, 64, 64)
        
        # Convolutional layers with pooling
        x = self.pool(F.relu(self.conv1(x)))  # 64x64 -> 32x32
        x = self.pool(F.relu(self.conv2(x)))  # 32x32 -> 16x16
        x = self.pool(F.relu(self.conv3(x)))  # 16x16 -> 8x8
        
        # Flatten
        x = x.view(batch_size, -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        
        return x

# Create traditional CNN model
model_cnn = TraditionalCNN(num_classes=num_classes, img_size=64)
params_cnn = count_parameters(model_cnn)

print("Traditional CNN Model:")
print(f"  Total parameters: {params_cnn:,}")
print(f"  Architecture: Conv2d layers + MaxPool + Fully Connected layers")
print(f"  Conv layers: 3x3 convs with 32, 64, 128 channels")
print(f"  FC layers: 512 -> 256 -> 128 -> {num_classes}")


In [None]:
# Three-way Parameter Comparison
print("=" * 80)
print("COMPREHENSIVE PARAMETER COMPARISON")
print("=" * 80)
print(f"\n1. TverskyReduceBackbone with Feature Sharing:")
print(f"   Parameters: {params_shared:,} ({params_shared/1e6:.2f}M)")
print(f"   Architecture: CNN + TverskyCompact/Interpretable with GlobalFeature bank")
print(f"   Feature: Parameter sharing across layers via GlobalFeature")

print(f"\n2. TverskyReduceBackbone without Feature Sharing:")
print(f"   Parameters: {params_no_sharing:,} ({params_no_sharing/1e6:.2f}M)")
print(f"   Architecture: CNN + TverskyCompact/Interpretable without sharing")
print(f"   Feature: Each layer has its own features")

print(f"\n3. Traditional CNN Model:")
print(f"   Parameters: {params_cnn:,} ({params_cnn/1e6:.2f}M)")
print(f"   Architecture: Conv2d + MaxPool + Linear layers")
print(f"   Feature: Standard convolutional neural network")

# Calculate reductions
reduction_vs_no_sharing = ((params_no_sharing - params_shared) / params_no_sharing) * 100
reduction_vs_cnn = ((params_cnn - params_shared) / params_cnn) * 100
reduction_cnn_vs_no_sharing = ((params_no_sharing - params_cnn) / params_no_sharing) * 100

print(f"\n" + "=" * 80)
print("PARAMETER REDUCTION ANALYSIS")
print("=" * 80)
print(f"\nFeature Sharing vs No Sharing:")
print(f"  Reduction: {reduction_vs_no_sharing:.2f}%")
print(f"  Parameters saved: {params_no_sharing - params_shared:,}")
print(f"  Efficiency: {params_no_sharing / params_shared:.2f}x fewer parameters")

print(f"\nFeature Sharing vs Traditional CNN:")
print(f"  Reduction: {reduction_vs_cnn:.2f}%")
print(f"  Parameters saved: {params_cnn - params_shared:,}")
print(f"  Efficiency: {params_cnn / params_shared:.2f}x fewer parameters")

print(f"\nTraditional CNN vs No Sharing:")
if params_cnn < params_no_sharing:
    print(f"  CNN has {params_no_sharing - params_cnn:,} fewer parameters ({reduction_cnn_vs_no_sharing:.2f}% reduction)")
else:
    print(f"  No Sharing model has {params_cnn - params_no_sharing:,} fewer parameters")

print("=" * 80)


In [None]:
# Visualize three-way parameter comparison
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Bar chart comparison (all three models)
ax1 = axes[0]
models = ['Tversky\n(Shared)', 'Tversky\n(No Share)', 'Traditional\nCNN']
params = [params_shared, params_no_sharing, params_cnn]
colors = ['#2ecc71', '#e74c3c', '#3498db']

bars = ax1.bar(models, params, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
ax1.set_ylabel('Number of Parameters', fontsize=12, fontweight='bold')
ax1.set_title('Parameter Count: All Model Types Comparison', fontsize=14, fontweight='bold')
ax1.grid(axis='y', alpha=0.3, linestyle='--')

# Add value labels
for bar, param in zip(bars, params):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
            f'{param:,}\n({param/1e6:.2f}M)',
            ha='center', va='bottom', fontsize=9, fontweight='bold')

# Reduction percentages comparison
ax2 = axes[1]
reductions = [reduction_vs_no_sharing, reduction_vs_cnn]
reduction_labels = ['vs No Sharing', 'vs Traditional CNN']
colors_reduction = ['#e74c3c', '#3498db']

bars2 = ax2.bar(reduction_labels, reductions, color=colors_reduction, alpha=0.7, edgecolor='black', linewidth=2)
ax2.set_ylabel('Parameter Reduction (%)', fontsize=12, fontweight='bold')
ax2.set_title('TverskyReduceBackbone: Parameter Reduction', fontsize=14, fontweight='bold')
ax2.grid(axis='y', alpha=0.3, linestyle='--')

# Add value labels
for bar, reduction in zip(bars2, reductions):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height,
            f'{reduction:.2f}%',
            ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

# Summary table
print("\n" + "=" * 80)
print("SUMMARY TABLE")
print("=" * 80)
print(f"{'Model Type':<30} {'Parameters':<20} {'Reduction vs CNN':<20}")
print("-" * 80)
print(f"{'TverskyReduce (Shared)':<30} {params_shared:>18,} ({params_shared/1e6:>5.2f}M) {reduction_vs_cnn:>18.2f}%")
print(f"{'TverskyReduce (No Share)':<30} {params_no_sharing:>18,} ({params_no_sharing/1e6:>5.2f}M) {'N/A':>20}")
print(f"{'Traditional CNN':<30} {params_cnn:>18,} ({params_cnn/1e6:>5.2f}M) {'Baseline':>20}")
print("=" * 80)


## 7. Training Setup

Set up training configuration and utilities.

In [None]:
# Training configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Use the shared model for training
gf.clear()
model = FruitsClassifierWithTverskyBackbone(
    num_classes=num_classes,
    in_channels=in_channels,
    img_size=img_size,
    variant=variant,
    out_dim=out_dim,
    feature_key=feature_key,
    share_features=True
).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

print(f"\n✓ Training setup complete!")
print(f"  Model: TverskyReduceBackbone ({variant} variant)")
print(f"  Loss function: CrossEntropyLoss")
print(f"  Optimizer: Adam (lr=0.001)")
print(f"  Scheduler: StepLR (step_size=5, gamma=0.5)")


## 8. Training Loop

Train the model with feature sharing.

In [None]:
# Training parameters
num_epochs = 10

# Track metrics
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

print(f"\n{'='*70}")
print(f"Starting Training for {num_epochs} epochs")
print(f"{'='*70}\n")

for epoch in range(num_epochs):
    # Training phase
    model.train()
    epoch_train_loss = 0.0
    epoch_train_correct = 0
    epoch_train_total = 0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        epoch_train_loss += loss.item()
        preds = torch.argmax(logits, dim=-1)
        epoch_train_correct += (preds == labels).sum().item()
        epoch_train_total += labels.size(0)
        
        # Print progress
        if (batch_idx + 1) % 100 == 0:
            current_loss = epoch_train_loss / (batch_idx + 1)
            current_acc = epoch_train_correct / epoch_train_total
            print(f"  Batch {batch_idx+1}/{len(train_loader)}: Loss={current_loss:.4f}, Acc={current_acc:.4f}")
    
    # Calculate average training metrics
    train_loss = epoch_train_loss / len(train_loader)
    train_acc = epoch_train_correct / epoch_train_total
    
    # Validation phase
    model.eval()
    epoch_val_loss = 0.0
    epoch_val_correct = 0
    epoch_val_total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images)
            loss = criterion(logits, labels)
            
            epoch_val_loss += loss.item()
            preds = torch.argmax(logits, dim=-1)
            epoch_val_correct += (preds == labels).sum().item()
            epoch_val_total += labels.size(0)
    
    # Calculate average validation metrics
    val_loss = epoch_val_loss / len(test_loader)
    val_acc = epoch_val_correct / epoch_val_total
    
    # Store metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    # Update learning rate
    scheduler.step()
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{num_epochs}:")
    print(f"  Train - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
    print(f"  Val   - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")
    print(f"  Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
    print(f"  {'-'*60}")

print(f"\n{'='*70}")
print(f"Training Complete!")
print(f"{'='*70}")
print(f"Final Training Accuracy: {train_accuracies[-1]:.4f}")
print(f"Final Validation Accuracy: {val_accuracies[-1]:.4f}")

## 9. Visualize Training Results

Plot training curves and analyze model performance.

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss plot
axes[0].plot(train_losses, label='Train Loss', marker='o', linewidth=2)
axes[0].plot(val_losses, label='Val Loss', marker='s', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(train_accuracies, label='Train Acc', marker='o', linewidth=2)
axes[1].plot(val_accuracies, label='Val Acc', marker='s', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 10. Final Evaluation

Evaluate the trained model on the test set.

In [None]:
# Final test evaluation
model.eval()
test_loss = 0.0
test_correct = 0
test_total = 0
all_preds = []
all_labels = []

print("Evaluating on test set...")
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        logits = model(images)
        loss = criterion(logits, labels)
        
        test_loss += loss.item()
        preds = torch.argmax(logits, dim=-1)
        test_correct += (preds == labels).sum().item()
        test_total += labels.size(0)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_loss = test_loss / len(test_loader)
test_acc = test_correct / test_total

print(f"\n{'='*70}")
print(f"Final Test Results")
print(f"{'='*70}")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f} ({test_correct}/{test_total})")
print(f"{'='*70}")

# Per-class accuracy (top 10)
correct_by_class = Counter()
total_by_class = Counter()

for pred, label in zip(all_preds, all_labels):
    total_by_class[label] += 1
    if pred == label:
        correct_by_class[label] += 1

print(f"\nPer-class accuracy (top 10 classes):")
class_accuracies = {cls: correct_by_class[cls] / total_by_class[cls] 
                    for cls in sorted(total_by_class.keys())[:10]}
for cls, acc in sorted(class_accuracies.items(), key=lambda x: x[1], reverse=True):
    print(f"  Class {cls}: {acc:.4f} ({correct_by_class[cls]}/{total_by_class[cls]})")

## 11. Summary

### Key Takeaways:

1. **Parameter Efficiency**: TverskyReduceBackbone with feature sharing significantly reduces the number of trainable parameters
2. **Performance**: The model maintains good classification performance despite fewer parameters
3. **Memory Efficiency**: Lower memory footprint enables training on resource-constrained devices
4. **Scalability**: Parameter reduction becomes more significant with larger models

### TverskyReduceBackbone Benefits:

- **CNN + Tversky Architecture**: Combines convolutional feature extraction with Tversky similarity
- **Shared Feature Matrices**: Multiple layers share the same feature transformation via GlobalFeature
- **Shared Tversky Parameters**: Alpha and beta parameters are shared across layers
- **Two Variants**: Compact (efficient) or Interpretable (visualizable)
- **GlobalFeature Bank**: Centralized parameter storage for efficient sharing

### Next Steps:

- Experiment with different variants (compact vs interpretable)
- Try different backbone output dimensions
- Compare with standard CNN architectures
- Fine-tune Tversky parameters (alpha, beta) for better performance
- Visualize learned prototypes (with interpretable variant)
