# NARadio vs RADIO Comparison

This notebook compares the feature maps produced by the standard RADIO encoder and the NARadio encoder (which uses Gaussian attention to improve spatial structure).

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import requests
from io import BytesIO

from vlm_eval import EncoderRegistry

# Check for timm
try:
    import timm
except ImportError:
    print("Installing timm...")
    !pip install timm

In [None]:
# Load an example image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
image = image.resize((512, 512))
plt.imshow(image)
plt.axis("off")
plt.show()

In [None]:
# Prepare input tensor
input_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
input_tensor = input_tensor.unsqueeze(0).cuda()

print(f"Input shape: {input_tensor.shape}")

In [None]:
# Initialize encoders
print("Loading RADIO...")
radio = EncoderRegistry.get("radio", variant="base", pretrained=True, input_size=512).cuda()

print("Loading NARadio...")
naradio = EncoderRegistry.get("naradio", model_version="radio_v2.5-b", lang_model="siglip", input_size=512).cuda()


In [None]:
# Run inference
with torch.no_grad():
    radio_features = radio(input_tensor)
    naradio_features = naradio(input_tensor)

print(f"RADIO features shape: {radio_features.shape}")
print(f"NARadio features shape: {naradio_features.shape}")

In [None]:
# Visualize PCA of features
from sklearn.decomposition import PCA

def visualize_pca(features, title):
    # features: (B, C, H, W)
    B, C, H, W = features.shape
    features_flat = features.permute(0, 2, 3, 1).reshape(-1, C).cpu().numpy()
    
    pca = PCA(n_components=3)
    pca_features = pca.fit_transform(features_flat)
    
    # Normalize to [0, 1] for visualization
    pca_features = (pca_features - pca_features.min(0)) / (pca_features.max(0) - pca_features.min(0))
    
    pca_img = pca_features.reshape(H, W, 3)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(pca_img)
    plt.title(title)
    plt.axis("off")
    plt.show()

visualize_pca(radio_features, "RADIO Features PCA")
visualize_pca(naradio_features, "NARadio Features PCA")