# Multilabel Classification of Tiger Beetles Using Deep Learning

This notebook implements taxonomic classification of tiger beetle specimens in the genus *Cicindela* using fastai and PyTorch. The approach uses a two-stage training strategy:

1. **Single-label training**: Train a model to predict the most specific taxonomic level available (species or subspecies) as a single concatenated label
2. **Multi-label transfer learning**: Transfer learned features to a multi-label model that can predict both species and subspecies separately

This methodology allows the model to learn hierarchical taxonomic relationships while handling cases where only species-level identification is possible. The approach follows methodologies established in taxonomic classification literature.

**Author**: B. de Medeiros, 2025

## Setup and Dependencies

Import required libraries and configure training parameters:

In [1]:
import os, gc
from pathlib import Path
import pandas as pd
from fastai.vision.all import *
from fastai.distributed import *
from accelerate import notebook_launcher
from timm.loss import AsymmetricLossMultiLabel
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_curve, auc, precision_score, recall_score
from collections import defaultdict, Counter
from IPython.display import display, HTML
import torch
from torch.utils.data import WeightedRandomSampler


export_path = Path("exported_fastai_models")
batch_size = 8
resize_dim = 448
n_workers = 4
p_memory = False
p_workers = False
arch = 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k'
#arch = 'resnet18'

os.makedirs(export_path, exist_ok=True)
%matplotlib inline

# Define the paths to your data
train_path = Path("images/train")
valid_path = Path("images/valid")
test_path = Path("images/test")

# Function to create labels from directory names
def extract_labels(dir_name):
    """
    Extract both species and subspecies labels from directory names.
    For example:
    - abbreviata_var_baliensis -> ['abbreviata', 'abbreviata_var_baliensis']
    - alba_ -> ['alba']
    """
    parts = dir_name.rstrip('_').split('_')
    species = parts[0]
    
    labels = [species]
    if len(parts) > 1:
        # Include the full name as a subspecies label
        subspecies = '_'.join(parts)
        labels.append(subspecies)
    
    return labels

# Function to create a dataframe with file paths and labels
def create_dataframe(data_path):
    rows = []
    
    for dir_name in os.listdir(data_path):
        # Skip any non-directory items or cache files
        if not os.path.isdir(data_path/dir_name) or dir_name == 'labels.cache':
            continue
            
        # Get the labels for this directory
        labels = extract_labels(dir_name)
        
        # Get all image files in the directory
        img_files = list(Path(data_path/dir_name).glob('*.*'))
        for img_file in img_files:
            if img_file.suffix.lower() in ['.jpg', '.jpeg', '.png']:
                rows.append({
                    'path': img_file,
                    'labels': labels
                })
    
    return pd.DataFrame(rows)

def remove_redundant_subspecies(df):
    """
    Process the dataframe to remove redundant subspecies labels.
    
    A subspecies is considered redundant if:
    1. All instances of a species that have a subspecies label have the SAME subspecies
    2. The species always has a subspecies (no records with just the species)
    
    Args:
        df: DataFrame with a 'labels' column containing semicolon-separated labels
        
    Returns:
        DataFrame with redundant subspecies labels removed
    """
    # Create mappings to track species and their subspecies
    species_to_subspecies = {}  # Map species to set of associated subspecies
    species_counts = {}         # Count total occurrences of each species
    species_with_subspecies = {}  # Count how often a species has a subspecies
    
    # Process each row to build the mappings
    for _, row in df.iterrows():
        labels = row['labels']
        species = labels[0]
        
        # Update total species count
        species_counts[species] = species_counts.get(species, 0) + 1
        
        # If there's a subspecies, track it
        if len(labels) > 1:
            subspecies = labels[1]
            
            # Update subspecies tracking
            if species not in species_to_subspecies:
                species_to_subspecies[species] = set()
            species_to_subspecies[species].add(subspecies)
            
            # Update count of species with subspecies
            species_with_subspecies[species] = species_with_subspecies.get(species, 0) + 1
    
    # Find truly redundant subspecies (single subspecies AND all instances have it)
    redundant_subspecies = {}
    for species, subspecies_set in species_to_subspecies.items():
        # Check if this species only has one subspecies
        if len(subspecies_set) == 1:
            # Check if all instances of this species have a subspecies
            if species_counts[species] == species_with_subspecies[species]:
                redundant_subspecies[species] = next(iter(subspecies_set))
    
    # Process the dataframe to remove redundant subspecies
    updated_df = df.copy()
    
    for idx, row in updated_df.iterrows():
        labels = row['labels']
        
        if len(labels) > 1:
            species = labels[0]
            subspecies = labels[1]
            
            # If this is a redundant subspecies, remove it
            if species in redundant_subspecies and subspecies == redundant_subspecies[species]:
                updated_df.at[idx, 'labels'] = [species]
    
    # Print summary of removed redundant subspecies
    if redundant_subspecies:
        print(f"Removed {len(redundant_subspecies)} redundant subspecies labels:")
        for species, subspecies in redundant_subspecies.items():
            print(f"  - {subspecies} (always associated with {species})")
    else:
        print("No redundant subspecies labels found.")
        
    return updated_df

## Data Organization and Label Processing

This section sets up the data processing pipeline:

1. **Directory structure**: Images are organized in subdirectories named by taxonomic labels (e.g., "repanda_repanda", "sexguttata")
2. **Label extraction**: The `extract_labels()` function creates hierarchical labels from directory names:
   - Single names (e.g., "sexguttata") → species only: ["sexguttata"]  
   - Compound names (e.g., "repanda_repanda") → species + subspecies: ["repanda", "repanda_repanda"]
3. **Redundancy removal**: `remove_redundant_subspecies()` eliminates subspecies labels that provide no additional information beyond the species name
4. **Data merging**: Training and validation sets are combined with an `is_valid` flag for splitting

## Dataset Analysis and Statistics

Analyze the dataset composition and class distribution:

In [2]:
# Create dataframes for train and validation sets
train_df = create_dataframe(train_path)
valid_df = create_dataframe(valid_path)

# Combine into a single dataframe with a 'is_valid' column for splitting
df = pd.concat([
    train_df.assign(is_valid=False),
    valid_df.assign(is_valid=True)
], ignore_index=True)

# Remove redundant subspecies
df = remove_redundant_subspecies(df)

# ----- Add these statistics printing sections -----

# Print basic dataset statistics
print(f"Total images: {len(df)}")
print(f"Training images: {len(df[~df['is_valid']])}")
print(f"Validation images: {len(df[df['is_valid']])}")

# Expand the labels column into a list for analysis
all_labels = []
for label_str in df['labels']:
    all_labels.extend(label_str)

# Count unique labels
unique_labels = pd.Series(all_labels).unique()
print(f"Total unique labels: {len(unique_labels)}")
print(f"Species labels: {len([l for l in unique_labels if '_' not in l])}")
print(f"Subspecies labels: {len([l for l in unique_labels if '_' in l])}")

# Show label distribution
label_counts = pd.Series(all_labels).value_counts()
print("\nTop 10 most common labels:")
print(label_counts.head(10))
print("\nBottom 10 least common labels:")
print(label_counts.tail(10))

# Print examples of labels for verification
print("\nSample of extracted labels for verification:")
sample_rows = df.sample(min(5, len(df)))
for _, row in sample_rows.iterrows():
    print(f"Path: {row['path'].name}, Labels: {row['labels']}")

# Check for extreme cases - species with many or few images
species_counts = pd.Series([l.split('_')[0] for l in all_labels]).value_counts()
print("\nSpecies with most images:")
print(species_counts.head(5))
print("\nSpecies with fewest images:")
print(species_counts.tail(5))

# ----- End of statistics section -----

Removed 8 redundant subspecies labels:
  - princeps_ducalis (always associated with princeps)
  - aurulenta_aurulenta (always associated with aurulenta)
  - lacteola_lacteola (always associated with lacteola)
  - roseiventris_mexicana (always associated with roseiventris)
  - coerulea_nitida (always associated with coerulea)
  - clypeata_clypeata (always associated with clypeata)
  - varians_lecontei (always associated with varians)
  - soluta_soluta (always associated with soluta)
Total images: 6279
Training images: 4883
Validation images: 1396
Total unique labels: 248
Species labels: 141
Subspecies labels: 107

Top 10 most common labels:
repanda             644
punctulata          642
repanda_repanda     461
sexguttata          421
formosa             355
scutellaris         355
oregona             343
oregona_oregona     281
ocellata            257
formosa_generosa    245
Name: count, dtype: int64

Bottom 10 least common labels:
dysenterica                    1
hydrophoba_atroreduct

## Single-Label DataBlock Configuration

Configure the data processing pipeline for single-label training:

- **CategoryBlock**: Uses single categorical labels (most specific taxonomic level)
- **Image preprocessing**: Resize to 448px with padding to maintain aspect ratio
- **Augmentation**: Geometric transformations (rotation, zoom, warp) and photometric changes (lighting)
- **Normalization**: Uses ImageNet statistics for pretrained model compatibility

Note: Although we plan to do multi-label learning, we first train a single-label model to establish good feature representations.

In [3]:
# Define the DataBlock
#dblock_sl = DataBlock(
#    blocks=(ImageBlock, CategoryBlock),
#    get_x=ColReader('path'),
#    get_y=ColReader('labels'),
#    splitter=ColSplitter('is_valid'),
#    item_tfms=[Resize(int(resize_dim*1.5), method='pad', pad_mode='border'),
#               DihedralItem(p=0.5)],
#    batch_tfms=[
#        *aug_transforms(do_flip=False, 
#                        max_rotate=45, 
#                        max_warp=0.1,
#                        max_lighting=0.3, 
#                        min_zoom = 0.7,
#                        max_zoom = 1,
#                        p_lighting=0.2,
#                        p_affine=0.2,
#                        pad_mode='border'),
#        RandomResizedCrop(size = resize_dim,min_scale=0.2, max_scale=2,p=0.7),
#        Normalize.from_stats(*imagenet_stats, cuda=False)
#    ]
#)

## Class Balancing Through Weighted Sampling

Address severe class imbalance using weighted sampling:

1. **Weight calculation**: Compute inverse frequency weights for each label using square root scaling to prevent extreme values
2. **Sample weighting**: For multi-label samples, use the mean weight of constituent labels
3. **Normalization**: Scale weights so the minimum weight equals 1.0

This approach up-samples rare species/subspecies during training while down-sampling common ones, helping the model learn from all taxonomic groups.

In [4]:
# Calculate multilabel-aware weights
def calculate_multilabel_weights_for_weighted_dl(df):
    # Get training data only
    train_df = df[df['is_valid'] == False].copy()
    
    # Count frequency of each individual label across all samples
    all_labels = []
    for label_list in train_df['labels']:
        for label in label_list:
            if label == 'a':
                print(label_list)
                break
        all_labels.extend(label_list)
        
    
    print("a" in all_labels)
    print(label_counts)
    total_samples = len(train_df)
    
    # Calculate inverse frequency weights for each label
    label_weights = {}
    for label, count in label_counts.items():
        label_weights[label] = np.sqrt(total_samples / count)
    
    # Normalize weights to prevent extreme values
    min_weight = min(label_weights.values())
    label_weights = {k: v/min_weight for k, v in label_weights.items()}
    
    print("Individual label weights:")
    for label, weight in sorted(label_weights.items(), key=lambda x: x[1]):
        count = label_counts[label]
        print(f"  {label}: {count} occurrences (weight: {weight:.3f})")
    
    # Calculate sample weights (mean of constituent label weights)
    sample_weights = []
    for label_list in train_df['labels']:
        sample_label_weights = [label_weights[label] for label in label_list]
        sample_weight = np.mean(sample_label_weights)
        sample_weights.append(sample_weight)
    
    print(f"\nSample weight stats:")
    print(f"  Min: {min(sample_weights):.3f}")
    print(f"  Max: {max(sample_weights):.3f}")
    print(f"  Mean: {np.mean(sample_weights):.3f}")
    print(f"  Number of weights: {len(sample_weights)}")
    
    return sample_weights

# Calculate weights based on original multilabel data
sample_weights = calculate_multilabel_weights_for_weighted_dl(df)



False
repanda                        644
punctulata                     642
repanda_repanda                461
sexguttata                     421
formosa                        355
                              ... 
carthagena_colossea              1
tranquebarica_parallelonota      1
parowana_wallisi                 1
klugii                           1
nigrocoerulea_bowditchi          1
Name: count, Length: 248, dtype: int64
Individual label weights:
  repanda: 644 occurrences (weight: 1.000)
  punctulata: 642 occurrences (weight: 1.002)
  repanda_repanda: 461 occurrences (weight: 1.182)
  sexguttata: 421 occurrences (weight: 1.237)
  formosa: 355 occurrences (weight: 1.347)
  scutellaris: 355 occurrences (weight: 1.347)
  oregona: 343 occurrences (weight: 1.370)
  oregona_oregona: 281 occurrences (weight: 1.514)
  ocellata: 257 occurrences (weight: 1.583)
  formosa_generosa: 245 occurrences (weight: 1.621)
  tranquebarica: 245 occurrences (weight: 1.621)
  trifasciata: 241 occurrence

In [5]:
# Prepare the processed dataframe for single-label training
df_processed = df.assign(labels=df['labels'].apply(lambda x: '_'.join(x)))

In [6]:
# Create CPU-only dataloaders for visualization and batch inspection
#dls_cpu = dblock_sl.dataloaders(
#    df_processed,
#    bs=batch_size,
#    num_workers=0,  # CPU only
#    device='cpu'
#)

#print(f"\nCreated CPU dataloaders for visualization with {len(df_processed[~df_processed['is_valid']])} training samples")

In [7]:
# Display a batch of training images with their labels
#dls_cpu.show_batch()

## Single-Label Model Training with Distributed Data Parallel

Create a distributed training function that uses fastai's `distrib_ctx()` for multi-GPU training:

1. **Distributed setup**: Use `distrib_ctx()` to enable distributed data parallel training
2. **Architecture**: Vision transformer (EVA02-Large) pretrained on ImageNet  
3. **Optimization**: Mixed precision training (`to_fp16()`) for efficiency
4. **Regularization**: MixUp augmentation to improve generalization
5. **Metrics**: Standard accuracy and balanced accuracy to handle class imbalance

This approach automatically handles multi-GPU coordination and gradient synchronization.

In [None]:
def train_single_label(df_processed, 
                       sample_weights, 
                       resize_dim, 
                       batch_size,
                       n_workers,
                       p_memory,
                       
                       arch):
    
    class DistributedNormalize(Normalize):
        """FastAI Normalize class fixed for distributed training"""
        def encodes(self, x: TensorImage):
            # Move mean and std to the same device as input tensor
            device = x.device
            mean = self.mean.to(device) if self.mean.device != device else self.mean
            std = self.std.to(device) if self.std.device != device else self.std
            return (x - mean) / std

    
    # DataBlock with manual normalization
    dblock_sl = DataBlock(
        blocks=(ImageBlock, CategoryBlock),
        get_x=ColReader('path'),
        get_y=ColReader('labels'),
        splitter=ColSplitter('is_valid'),
        item_tfms=[
            Resize(int(resize_dim*1.5), method='pad', pad_mode='border'),
            DihedralItem(p=0.5)
        ],
        batch_tfms=[
            *aug_transforms(
                do_flip=False, 
                max_rotate=45, 
                max_warp=0.1,
                max_lighting=0.3, 
                min_zoom=0.7,
                max_zoom=1,
                p_lighting=0.2,
                p_affine=0.2,
                pad_mode='border'
            ),
            RandomResizedCropGPU(size=resize_dim, min_scale=0.2, max_scale=2, p=0.7),
            DistributedNormalize.from_stats(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
) 
        ]
    )
    
    dls_sl = dblock_sl.weighted_dataloaders(
        df_processed,
        wgts=sample_weights,
        bs=batch_size,
        num_workers=n_workers,
        pin_memory=p_memory,
        persistent_workers=p_workers,
        prefetch_factor=6,
        drop_last=True
    )
    
    # Create learner 
    learn_sl = vision_learner(
        dls_sl, 
        arch,
        metrics=[accuracy, BalancedAccuracy()],
        normalize=False, 
        cbs=[MixUp()]
    ).to_fp16()
    
    with learn_sl.distrib_ctx(in_notebook=True, sync_bn=True):
        learn_sl.freeze()
        base_lr = learn_sl.lr_find()
        learn_sl.fine_tune(epochs=5, freeze_epochs=5, base_lr=base_lr.valley)
    
    learn_sl.export(f"model_fixed_{arch}.pkl")
    return learn_sl
    
# Run it
notebook_launcher(train_single_label, (df_processed, sample_weights, resize_dim, batch_size, n_workers, p_memory, arch), num_processes=4)

Launching training on 4 GPUs.
Training Learner...


  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()


## Multi-Label Model Setup and Transfer Learning with Distributed Training

Configure and initialize the multi-label model with distributed data parallel:

1. **Architecture change**: Switch from CategoryBlock to MultiCategoryBlock for multi-label classification
2. **Loss function**: Use AsymmetricLossMultiLabel optimized for imbalanced multi-label problems
3. **Metrics**: Micro-averaged precision, recall, and ROC-AUC for multi-label evaluation
4. **Transfer learning**: Load compatible weights from the single-label model to initialize features
5. **Weight filtering**: Only transfer layers with matching dimensions between models
6. **Distributed training**: Use `distrib_ctx()` for multi-GPU coordination

This leverages the hierarchical features learned in single-label training for multi-label prediction.

In [ ]:
def train_multi_label(df, 
                      sample_weights, 
                      resize_dim, 
                      batch_size,
                      n_workers,
                      p_memory,
                      arch):
    """Distributed training function for multi-label model"""
    
    class DistributedNormalize(Normalize):
        """FastAI Normalize class fixed for distributed training"""
        def encodes(self, x: TensorImage):
            # Move mean and std to the same device as input tensor
            device = x.device
            mean = self.mean.to(device) if self.mean.device != device else self.mean
            std = self.std.to(device) if self.std.device != device else self.std
            return (x - mean) / std
    
    # Define the multi-label DataBlock inside the function
    dblock_ml = DataBlock(
        blocks=(ImageBlock, MultiCategoryBlock),
        get_x=ColReader('path'),
        get_y=ColReader('labels'),
        splitter=ColSplitter('is_valid'),
        item_tfms=[
            Resize(int(resize_dim*1.5), method='pad', pad_mode='border'),
            DihedralItem(p=0.5)
        ],
        batch_tfms=[
            *aug_transforms(
                do_flip=False, 
                max_rotate=45, 
                max_warp=0.1,
                max_lighting=0.3, 
                min_zoom=0.7,
                max_zoom=1,
                p_lighting=0.2,
                p_affine=0.2,
                pad_mode='border'
            ),
            RandomResizedCropGPU(size=resize_dim, min_scale=0.2, max_scale=2, p=0.7),
            DistributedNormalize.from_stats(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ]
    )

    # Create weighted dataloaders inside the training function
    dls_ml = dblock_ml.weighted_dataloaders(
        df, 
        wgts=sample_weights,
        bs=batch_size, 
        num_workers=n_workers,
        pin_memory=p_memory,
        persistent_workers=p_workers,
        prefetch_factor=6,
        drop_last=True
    )

    # Create learner
    learn_ml = vision_learner(
        dls_ml, 
        arch,
        metrics=[PrecisionMulti(average='micro', thresh=0.5),
                RecallMulti(average='micro', thresh=0.5),
                RocAuc(average='micro')],
        loss_func=AsymmetricLossMultiLabel(gamma_neg=2,
                                          gamma_pos=2,
                                          clip=0.05),
        normalize=False,
        cbs=[MixUp()]
    ).to_fp16()

    # Transfer weights from single-label model
    learn_sl = load_learner(f"model_fixed_{arch}.pkl").detach_parallel()
    source_state_dict = learn_sl.model.state_dict()
    target_state_dict = learn_ml.model.state_dict()

    # Only load layers with matching shapes
    filtered_state_dict = {}
    for k, v in source_state_dict.items():
        if k in target_state_dict and target_state_dict[k].shape == v.shape:
            filtered_state_dict[k] = v

    learn_ml.model.load_state_dict(filtered_state_dict, strict=False)

    # Clean up single-label model
    del(learn_sl)
    gc.collect()
    torch.cuda.empty_cache()

    # Use distributed context manager for multi-GPU training
    with learn_ml.distrib_ctx(in_notebook=True, sync_bn=True):
        # Freeze the pretrained layers
        learn_ml.freeze()

        # Find optimal learning rate
        base_lr = learn_ml.lr_find()

        # Fine-tune the model
        learn_ml.fine_tune(epochs=10, freeze_epochs=3, base_lr=base_lr.valley)

    # Export the model after distributed training completes
    learn_ml.export(export_path/f"{arch}_ml.pkl")
    return learn_ml

# Launch distributed training for multi-label model with consistent parameters
notebook_launcher(train_multi_label, (df, sample_weights, resize_dim, batch_size, n_workers, p_memory, arch), num_processes=4)

## Multi-Label Model Training Complete

The multi-label model training is now complete and handled entirely within the distributed training function above. The model has been exported and is ready for evaluation.

## Test Set Evaluation

Evaluate the final model on the held-out test set:

1. **Model loading**: Load the trained multi-label model
2. **Test data preparation**: Process test images with same preprocessing pipeline
3. **Vocabulary filtering**: Ensure test labels match the model's vocabulary
4. **Test dataloader**: Create dataloader for test set with labels for evaluation

In [None]:
learn = load_learner(export_path/f"{arch}_ml.pkl", cpu=False)

test_df = create_dataframe(test_path)

vocab_list = learn.dls.vocab
def filter_labels(label_list, vocab):
    return [label for label in label_list if label in vocab]
test_df['labels'] = test_df['labels'].apply(lambda x: filter_labels(x, vocab_list))
test_df['is_valid'] = True

test_dl = learn.dls.test_dl(test_df, with_labels = True)

## Model Predictions

Generate predictions on the test set for detailed analysis:

In [None]:
# Get predictions and targets
preds, targs = learn.get_preds(dl = test_dl, with_loss=False, with_targs=True, act = torch.sigmoid)

# Detailed Performance Analysis by Taxonomic Class

Calculate comprehensive evaluation metrics for each taxonomic class:

1. **Binary classification metrics**: Convert multi-label predictions to binary decisions using 0.5 threshold
2. **Per-class evaluation**: Calculate precision, recall, true/false positives/negatives for each class
3. **Training set mapping**: Include training set counts to understand data availability
4. **Confusion analysis**: Identify which classes are commonly confused with each other
5. **Results formatting**: Create publication-ready table with performance metrics

This analysis reveals model performance across the taxonomic hierarchy and identifies challenging classification cases.

In [None]:
pred_threshold = 0.5
# Get vocabulary
vocab = learn.dls.vocab
if isinstance(vocab, list):
    vocab = vocab[-1]  # As in ClassificationInterpretation.__init__

# Get training set counts for each term from train_df
train_counts = {}

# Initialize counts for all classes
for i in range(len(vocab)):
    train_counts[i] = 0

# Count occurrences in training DataFrame
for labels_list in train_df['labels']:
    # Convert string representation to actual list if needed
    if isinstance(labels_list, str):
        # Remove brackets and quotes, then split
        labels_list = labels_list.strip('[]').replace("'", "").split(', ')
    
    # Count each label
    for label in labels_list:
        label = label.strip()  # Remove any whitespace
        # Use vocab.o2i to get the index instead of .index()
        if hasattr(vocab, 'o2i') and label in vocab.o2i:
            idx = vocab.o2i[label]
            train_counts[idx] += 1
        elif label in vocab:
            # Fallback if o2i doesn't exist
            for i, v in enumerate(vocab):
                if v == label:
                    train_counts[i] += 1
                    break

# Determine if this is multi-label or multi-class
is_multilabel = len(targs.shape) > 1

# Convert tensors to numpy
preds_np = preds.cpu().numpy()
targs_np = targs.cpu().numpy()

# For multi-label, threshold predictions
if is_multilabel:
    preds_binary = (preds_np > pred_threshold).astype(int)
else:
    # For multi-class, convert to one-hot encoding
    preds_binary = np.zeros_like(preds_np)
    for i in range(len(preds_np)):
        preds_binary[i, np.argmax(preds_np[i])] = 1

# Dictionary to track confusions
confusions = defaultdict(list)

# Calculate confusions - for each sample, identify false positives when true labels are present
for i in range(len(targs_np)):
    if is_multilabel:
        # For each sample, get true labels and predicted labels
        true_labels = np.where(targs_np[i] == 1)[0]
        pred_labels = np.where(preds_binary[i] == 1)[0]
        
        # For each true label, find false positives
        for true_idx in true_labels:
            false_pos_indices = [idx for idx in pred_labels if idx != true_idx and idx not in true_labels]
            if false_pos_indices:  # If there are false positives
                confusions[true_idx].extend(false_pos_indices)
    else:
        # For multi-class, the true label is confused with the predicted label if they differ
        true_label = np.argmax(targs_np[i]) if len(targs_np.shape) > 1 else targs_np[i]
        pred_label = np.argmax(preds_np[i])
        if true_label != pred_label:
            confusions[true_label].append(pred_label)

# Create a list to store results
results = []

# Calculate metrics for each class
for i in range(preds.shape[1]):
    # Get class name
    if i < len(vocab):
        class_name = vocab[i]
    else:
        class_name = f"Class_{i}"
    
    # Extract binary predictions and targets for this class
    if is_multilabel:
        pred_i = preds_binary[:, i]
        targ_i = targs_np[:, i]
    else:
        # For multi-class, create one-hot encoding
        pred_i = (np.argmax(preds_np, axis=1) == i).astype(int)
        targ_i = (targs_np == i).astype(int) if len(targs_np.shape) == 1 else (np.argmax(targs_np, axis=1) == i).astype(int)
    
    # Calculate N (number of samples with this label in test/validation set)
    N = targ_i.sum()
    
    # Get training count for this class
    N_train = int(train_counts.get(i, 0))
    
    # Skip if N is zero
    if N == 0:
        continue
    
    # Calculate metrics
    true_positives = np.sum((pred_i == 1) & (targ_i == 1))
    false_positives = np.sum((pred_i == 1) & (targ_i == 0))
    false_negatives = np.sum((pred_i == 0) & (targ_i == 1))
    
    precision = true_positives / (true_positives + false_positives) if true_positives + false_positives > 0 else 1
    recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives > 0 else 0
    
    # Process confusions for this class
    confused_with = []
    if i in confusions and confusions[i]:
        # Count occurrences of each confusion
        confusion_counts = {}
        for idx in confusions[i]:
            if idx < len(vocab):
                label = vocab[idx]
                confusion_counts[label] = confusion_counts.get(label, 0) + 1
            else:
                label = f"Class_{idx}"
                confusion_counts[label] = confusion_counts.get(label, 0) + 1
        
        # Sort by count and format as "label (count)"
        confused_with = [f"{label} ({count})" for label, count in 
                         sorted(confusion_counts.items(), key=lambda x: x[1], reverse=True)]
    
    # Add to results
    results.append({
        'Term': class_name,
        'N_train': N_train,
        'N': int(N),
        'micro_precision': precision,
        'micro_recall': recall,
        'true_positives': int(true_positives),
        'false_positives': int(false_positives),
        'false_negatives': int(false_negatives),
        'confused_with': ', '.join(confused_with) if confused_with else ''
    })

# Create DataFrame and sort by micro_precision first, then by N (both descending)
df = pd.DataFrame(results)
df = df.sort_values(['micro_precision', 'N'], ascending=[False, False])

# Format DataFrame for display
pd.set_option('display.float_format', '{:.4f}'.format)  # Format floats to 4 decimal places

# Format precision and recall as percentages
df['micro_precision'] = df['micro_precision'] * 100
df['micro_recall'] = df['micro_recall'] * 100

# Create a styled version of the DataFrame
from IPython.display import display, HTML

# Style the DataFrame
styled_df = df.style.format({
    'N_train': '{:,d}',
    'N': '{:,d}',
    'micro_precision': '{:.2f}%',
    'micro_recall': '{:.2f}%',
    'true_positives': '{:,d}',
    'false_positives': '{:,d}',
    'false_negatives': '{:,d}'
}).background_gradient(subset=['micro_precision', 'micro_recall'], cmap='RdYlGn')

# Add a title and display
display(HTML("<h2>Model Performance by Term</h2>"))
display(styled_df)

In [None]:
df.to_csv(f"test_results_{arch}.csv")

# Evaluation in unknown samples

To finalize, let's try to identify the unknown samples

In [None]:
unknown_df = pd.DataFrame({'path':get_image_files('unknowns'),'labels':None,'is_valid':True})
test_dl = learn.dls.test_dl(unknown_df, with_labels = False)
unknown_preds, _ = learn.get_preds(dl = test_dl, with_loss=False, with_targs=False, act=torch.sigmoid)

In [None]:
# Convert predictions to binary predictions
binary_preds = (unknown_preds > 0.5).int()

# Use fastai's decode method if available
try:
    pred_labels = [learn.dls.decode(row) for row in binary_preds]
    unknown_df['predictions'] = pred_labels
except:
    # Fallback to manual method
    pred_labels = []
    for row in binary_preds:
        indices = torch.nonzero(row, as_tuple=True)[0]
        labels = [learn.dls.vocab[i] for i in indices.tolist()]
        pred_labels.append(labels)
    unknown_df['predictions'] = pred_labels

In [None]:
unknown_df