In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from transformers import ViTForImageClassification, AutoImageProcessor
from datasets import load_dataset
from torch.profiler import profile, ProfilerActivity, record_function
import cProfile
import pstats
import io
from collections import defaultdict

from faeyon.io import load
from faeyon import X

repo = "google/vit-base-patch16-224"

In [None]:
hf_model = ViTForImageClassification.from_pretrained(repo)
hf_model.eval()
hf_model.cuda()
pass

In [None]:
model = load("vit/vit-base-patch16-224", cache=True)
model.eval()
model.cuda()
pass

In [None]:
image_processor = AutoImageProcessor.from_pretrained(repo)
imagenet = load_dataset("ILSVRC/imagenet-1k", trust_remote_code=True)
inputs = image_processor(
    images=imagenet["train"][0]["image"],
    return_tensors="np"
)

In [None]:
img = torch.tensor(inputs["pixel_values"]).cuda()

In [None]:
%%timeit

hf_y = hf_model(img)


In [None]:
%%timeit
y = model(img)

In [None]:
3.4/3.28

In [None]:
hf_y = hf_model(img)
y = model(img)
torch.allclose(hf_y.logits,  y)

## Profiling

In [None]:
# Analyze and compare
def analyze_profile(pr, ndata = sorted([(stat.code, stat.callcount, stat.inlinetime) for stat in pr.getstats()], key=lambda x: x[1], reverse=True)ame):
    """Analyze a profile and return statistics"""
    stats = pr.getstats()
    total_time = sum(stat.totaltime for stat in stats)
    
    # Group by module
    by_module = defaultdict(lambda: {"time": 0.0, "calls": 0, "functions": []})
    
    for stat in stats:
        if stat.code and not isinstance(stat.code, str):
            filename = stat.code.co_filename
            func_name = stat.code.co_name
            module = filename.split('/')[-1] if '/' in filename else filename.split('\\')[-1]
            
            by_module[module]["time"] += stat.totaltime
            by_module[module]["calls"] += stat.callcount
            if stat.totaltime > 0.0001:  # Only track functions taking > 0.1ms
                by_module[module]["functions"].append({
                    "name": func_name,
                    "time": stat.totaltime,
                    "calls": stat.callcount
                })
    
    # Sort functions by time
    for module in by_module:
        by_module[module]["functions"].sort(key=lambda x: x["time"], reverse=True)
    
    return {
        "total_time": total_time,
        "by_module": dict(by_module)
    }

_ = model(img)
torch.cuda.synchronize()

# Profile Hugging Face model
with cProfile.Profile() as pr:
    y = model(img)

pr.print_stats(sort=0)

In [None]:
[1, 2, 3] >> X[X[0]]

In [None]:
data = sorted([(stat.code, stat.callcount, stat.inlinetime) for stat in pr.getstats()], key=lambda x: x[1], reverse=True)
data

In [None]:
from faeyon import X

In [None]:
def square(x):
    return x * x

10 >> square(X)

torch.mean(X)


In [None]:
[1, 2] >> [3, 4]

In [None]:
spells = [stat for stat in stats if not isinstance(stat.code, str) and "spells" in stat.code.co_filename]
faek = [stat for stat in stats if not isinstance(stat.code, str) and "faek" in stat.code.co_filename]
spells = sorted(spells, key=lambda x: x.inlinetime * x.reccallcount, reverse=True)

In [None]:
spells[0].calls#.code.co_firstlineno

In [None]:
# Find the top-level function (usually the model's __call__ or forward)
top_level = max(stats, key=lambda s: s.totaltime if s.code else 0)
total_time = top_level.totaltime  # âœ… This is the actual total time

In [None]:
from collections import Counter

lines  = []
for s in spells:
    lines.append(s.code.co_qualname)

Counter(lines)
