# Eventformer: Frame-Free Vision Transformer

Run on Google Colab with T4 GPU. Go to Runtime > Change runtime type > T4 GPU

In [None]:
!nvidia-smi
import torch
print(f'CUDA available: {torch.cuda.is_available()}')

In [None]:
!rm -rf Eventformer
!git clone https://github.com/jkinarthur/Eventformer.git
%cd Eventformer/code

In [None]:
!pip install -q einops timm h5py tensorboard seaborn

In [None]:
import sys
sys.path.insert(0, '.')

from models.ctpe import ContinuousTimePositionalEncoding
from models.paaa import PolarityAwareAsymmetricAttention
from models.asna import ASNABlock
from models import eventformer_tiny, EventformerForClassification

print('All imports successful!')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = eventformer_tiny().to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f'Parameters: {num_params:,}')

B, N = 4, 2048
coords = torch.rand(B, N, 2, device=device) * 346
times = torch.rand(B, N, device=device).sort(dim=1)[0]
pols = torch.randint(0, 2, (B, N), device=device).float() * 2 - 1

with torch.no_grad():
    out = model(coords, times, pols)
print(f'Output shape: {out[0].shape}')
print('Model test passed!')

In [None]:
from datasets import get_dataset
from torch.utils.data import DataLoader

!mkdir -p ../data/ncaltech101
train_data = get_dataset('ncaltech101', '../data/ncaltech101', split='train', num_events=2048)
print(f'Samples: {len(train_data)}, Classes: {train_data.num_classes}')

loader = DataLoader(train_data, batch_size=8, shuffle=True)
batch_inputs, batch_labels = next(iter(loader))
print(f'Batch coords shape: {batch_inputs["coords"].shape}')
print('Dataset test passed!')

In [None]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

cls_model = EventformerForClassification(
    num_classes=train_data.num_classes,
    embed_dim=32,
    depths=(2, 2, 4, 2),
    num_heads=(1, 2, 4, 8)
).to(device)

optimizer = optim.AdamW(cls_model.parameters(), lr=1e-3, weight_decay=0.05)
criterion = nn.CrossEntropyLoss()

for epoch in range(5):
    cls_model.train()
    total_loss = 0
    correct = 0
    total = 0
    for batch_inputs, batch_labels in tqdm(loader, desc=f'Epoch {epoch+1}'):
        coords = batch_inputs['coords'].to(device)
        times = batch_inputs['times'].to(device)
        pols = batch_inputs['polarities'].to(device)
        batch_labels = batch_labels.to(device)
        
        optimizer.zero_grad()
        logits = cls_model(coords, times, pols)
        loss = criterion(logits, batch_labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, pred = logits.max(1)
        total += batch_labels.size(0)
        correct += pred.eq(batch_labels).sum().item()
    acc = 100.0 * correct / total
    avg_loss = total_loss / len(loader)
    print(f'Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={acc:.1f}%')

print('Training complete!')

In [None]:
configs = {
    'Full': {'use_ctpe': True, 'use_paaa': True, 'use_asna': True},
    'No CTPE': {'use_ctpe': False, 'use_paaa': True, 'use_asna': True},
    'No PAAA': {'use_ctpe': True, 'use_paaa': False, 'use_asna': True},
    'No ASNA': {'use_ctpe': True, 'use_paaa': True, 'use_asna': False}
}

results = {}
for name, cfg in configs.items():
    m = EventformerForClassification(
        num_classes=train_data.num_classes, 
        embed_dim=32, 
        depths=(2,2,4,2), 
        num_heads=(1,2,4,8), 
        **cfg
    ).to(device)
    opt = optim.AdamW(m.parameters(), lr=1e-3)
    m.train()
    for _ in range(3):
        for inp, lab in loader:
            opt.zero_grad()
            out = m(inp['coords'].to(device), inp['times'].to(device), inp['polarities'].to(device))
            loss = criterion(out, lab.to(device))
            loss.backward()
            opt.step()
    m.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inp, lab in loader:
            pred = m(inp['coords'].to(device), inp['times'].to(device), inp['polarities'].to(device)).argmax(1)
            correct += pred.eq(lab.to(device)).sum().item()
            total += lab.size(0)
    results[name] = 100.0 * correct / total
    print(f'{name}: {results[name]:.1f}%')

print('Ablation complete!')

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 5))
colors = ['#2E86AB', '#E94F37', '#F39C12', '#27AE60']
bars = ax.bar(results.keys(), results.values(), color=colors)
ax.set_ylabel('Accuracy (%)')
ax.set_title('Ablation Study')
ax.set_ylim(0, 100)
for bar, acc in zip(bars, results.values()):
    height = bar.get_height()
    ax.annotate(f'{acc:.1f}%', 
                xy=(bar.get_x() + bar.get_width()/2, height), 
                xytext=(0, 5), 
                textcoords='offset points', 
                ha='center')
plt.tight_layout()
plt.savefig('../figures/ablation.png', dpi=150)
plt.show()
print('Figure saved!')

# Eventformer: Frame-Free Vision Transformer

Run on Google Colab with T4 GPU.

**Instructions:** Runtime -> Change runtime type -> T4 GPU

In [None]:
# Check GPU
!nvidia-smi
import torch
print(f'CUDA: {torch.cuda.is_available()}')

In [None]:
# Clone repository
!rm -rf Eventformer
!git clone https://github.com/jkinarthur/Eventformer.git
%cd Eventformer/code

In [None]:
# Install dependencies
!pip install -q einops timm h5py tensorboard seaborn

In [None]:
# Test imports
import sys
sys.path.insert(0, '.')

from models.ctpe import ContinuousTimePositionalEncoding
from models.paaa import PolarityAwareAsymmetricAttention
from models.asna import ASNABlock
from models import eventformer_tiny, EventformerForClassification

print('All imports successful!')

In [None]:
# Test model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = eventformer_tiny().to(device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

# Forward pass
B, N = 4, 2048
coords = torch.rand(B, N, 2, device=device) * 346
times = torch.rand(B, N, device=device).sort(dim=1)[0]
pols = torch.randint(0, 2, (B, N), device=device).float() * 2 - 1

with torch.no_grad():
    out = model(coords, times, pols)
print(f'Output shape: {out[0].shape}')
print('Model test passed!')

In [None]:
# Test dataset
from datasets import get_dataset
from torch.utils.data import DataLoader

!mkdir -p ../data/ncaltech101
train_data = get_dataset('ncaltech101', '../data/ncaltech101', split='train', num_events=2048)
print(f'Samples: {len(train_data)}, Classes: {train_data.num_classes}')

loader = DataLoader(train_data, batch_size=8, shuffle=True)
inputs, labels = next(iter(loader))
print(f'Batch coords: {inputs["coords"].shape}')
print('Dataset test passed!')

In [None]:
# Training demo
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

cls_model = EventformerForClassification(
    num_classes=train_data.num_classes,
    embed_dim=32,
    depths=(2, 2, 4, 2),
    num_heads=(1, 2, 4, 8)
).to(device)

optimizer = optim.AdamW(cls_model.parameters(), lr=1e-3, weight_decay=0.05)
criterion = nn.CrossEntropyLoss()

for epoch in range(5):
    cls_model.train()
    total_loss, correct, total = 0, 0, 0
    for inputs, labels in tqdm(loader, desc=f'Epoch {epoch+1}'):
        coords = inputs['coords'].to(device)
        times = inputs['times'].to(device)
        pols = inputs['polarities'].to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        logits = cls_model(coords, times, pols)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, pred = logits.max(1)
        total += labels.size(0)
        correct += pred.eq(labels).sum().item()
    print(f'Epoch {epoch+1}: Loss={total_loss/len(loader):.4f}, Acc={100.*correct/total:.1f}%')

print('Training complete!')

In [None]:
# Ablation study
configs = {
    'Full': {'use_ctpe': True, 'use_paaa': True, 'use_asna': True},
    'No CTPE': {'use_ctpe': False, 'use_paaa': True, 'use_asna': True},
    'No PAAA': {'use_ctpe': True, 'use_paaa': False, 'use_asna': True},
    'No ASNA': {'use_ctpe': True, 'use_paaa': True, 'use_asna': False}
}

results = {}
for name, cfg in configs.items():
    m = EventformerForClassification(num_classes=train_data.num_classes, embed_dim=32, depths=(2,2,4,2), num_heads=(1,2,4,8), **cfg).to(device)
    opt = optim.AdamW(m.parameters(), lr=1e-3)
    m.train()
    for _ in range(3):
        for inp, lab in loader:
            opt.zero_grad()
            loss = criterion(m(inp['coords'].to(device), inp['times'].to(device), inp['polarities'].to(device)), lab.to(device))
            loss.backward()
            opt.step()
    m.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inp, lab in loader:
            pred = m(inp['coords'].to(device), inp['times'].to(device), inp['polarities'].to(device)).argmax(1)
            correct += pred.eq(lab.to(device)).sum().item()
            total += lab.size(0)
    results[name] = 100.*correct/total
    print(f'{name}: {results[name]:.1f}%')

print('Ablation complete!')

In [None]:
# Plot results
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 5))
colors = ['#2E86AB', '#E94F37', '#F39C12', '#27AE60']
bars = ax.bar(results.keys(), results.values(), color=colors)
ax.set_ylabel('Accuracy (%)')
ax.set_title('Ablation Study')
ax.set_ylim(0, 100)
for bar, acc in zip(bars, results.values()):
    ax.annotate(f'{acc:.1f}%', xy=(bar.get_x() + bar.get_width()/2, bar.get_height()), xytext=(0, 5), textcoords='offset points', ha='center')
plt.tight_layout()
plt.savefig('../figures/ablation.png', dpi=150)
plt.show()
print('Figure saved!')