# Backbone Feature Extraction Experiment

This notebook evaluates different CNN and ViT backbones as feature extractors for the EM axon dataset.

## Setup
Load dependencies and define paths.

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from src.data.datasets import EMAxonDataset
from src.data.transforms import get_em_transforms, get_backbone_image_size

# Path to your metadata CSV (expects columns: filepath, pathology, region, depth)
metadata_csv = 'data/metadata.csv'

# Load dataframe
df = pd.read_csv(metadata_csv)
print(f'Loaded {len(df)} samples')

## Helper functions

In [None]:
import timm

def extract_features(backbone_name, dataloader):
    """Return features tensor (N, F) and label vector for pathology task."""
    model = timm.create_model(backbone_name, pretrained=True, num_classes=0, global_pool='avg')
    model.eval()
    model.cuda()

    feats = []
    labels = []
    with torch.no_grad():
        for imgs, lbls in dataloader:
            imgs = imgs.cuda()
            out = model(imgs)
            feats.append(out.cpu())
            labels.append(lbls['pathology'])
    feats = torch.cat(feats).numpy()
    labels = torch.cat(labels).numpy()
    return feats, labels

## Feature extraction and evaluation

In [None]:
backbones = ['efficientnet_b4', 'resnet50', 'vit_base_patch16_224']
results = {}

for name in backbones:
    print(f'\nProcessing {name}')
    size = get_backbone_image_size(name)
    dataset = EMAxonDataset(df, transform=get_em_transforms(image_size=size, is_training=False))
    loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)
    X, y = extract_features(name, loader)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    clf = LogisticRegression(max_iter=1000)
    clf.fit(X_train, y_train)
    preds = clf.predict(X_test)
    acc = accuracy_score(y_test, preds)
    results[name] = acc
    print(f'Accuracy: {acc:.4f}')

print('\nSummary:')
for k, v in results.items():
    print(f'{k}: {v:.4f}')