In [None]:
import pickle
import torch
import numpy as np
import gc
import torch, gc
from collections.abc import Mapping, Sequence, Set
import matplotlib.pyplot as plt

In [None]:
def move_to_cpu(obj, _visited=None, path="data"):
    """
    Recursively move any torch.Tensor (or object with .cpu()) in `obj` to CPU,
    and print the full path when .cpu() succeeds.
    """
    if _visited is None:
        _visited = set()
    obj_id = id(obj)
    if obj_id in _visited:
        return obj
    _visited.add(obj_id)

    # 1) Tensors
    if torch.is_tensor(obj):
        new = obj.cpu()
        print(f"Moved tensor at {path}")
        return new

    # 2) Anything with a .cpu() method
    cpu_m = getattr(obj, 'cpu', None)
    if callable(cpu_m):
        try:
            new = cpu_m()
        except Exception:
            pass
        else:
            print(f"Called .cpu() on object at {path}")
            return new

    # 3) dicts
    if isinstance(obj, Mapping):
        out = type(obj)()
        for k, v in obj.items():
            # build a nice key path if key is identifier-like, else repr()
            if isinstance(k, str) and k.isidentifier():
                child_path = f"{path}.{k}"
            else:
                child_path = f"{path}[{repr(k)}]"
            out[k] = move_to_cpu(v, _visited, child_path)
        return out

    # 4) lists/tuples
    if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes, bytearray)):
        if isinstance(obj, tuple):
            return tuple(
                move_to_cpu(v, _visited, f"{path}[{i}]")
                for i, v in enumerate(obj)
            )
        else:
            return type(obj)(
                move_to_cpu(v, _visited, f"{path}[{i}]")
                for i, v in enumerate(obj)
            )

    # 5) sets
    if isinstance(obj, Set):
        new_set = type(obj)()
        for v in obj:
            new_set.add(move_to_cpu(v, _visited, f"{path}{{elem}}"))
        return new_set

    # 6) custom objects with __dict__
    if hasattr(obj, '__dict__'):
        for name, val in vars(obj).items():
            child_path = f"{path}.{name}"
            new_val = move_to_cpu(val, _visited, child_path)
            if new_val is not val:
                setattr(obj, name, new_val)
        return obj

    # 7) custom objects with __slots__
    slots = getattr(type(obj), '__slots__', ())
    for slot in (slots if isinstance(slots, (list, tuple)) else (slots,)):
        if slot in ('__dict__', '__weakref__'):
            continue
        if hasattr(obj, slot):
            val = getattr(obj, slot)
            child_path = f"{path}.{slot}"
            new_val = move_to_cpu(val, _visited, child_path)
            if new_val is not val:
                setattr(obj, slot, new_val)
        return obj

    # 8) fallback
    return obj

In [None]:
data = pickle.load(open('./results/results.pkl', 'rb'))

In [None]:
data[6]['results']['results']

In [None]:
names = [
  "base",
  "+zp_int8",
  "+signed_kv",
  "+no_zp_clamp",
  "+zp_int8+signed_kv",
  "+zp_int8+no_zp_clamp",
  "+signed_kv+no_zp_clamp",
  "+zp_int8+signed_kv+no_zp_clamp",
  "+v_sym",
  "+v_sym+zp_int8",
  "+v_sym+signed_kv",
  "+v_sym+no_zp_clamp",
  "+v_sym+zp_int8+signed_kv",
  "+v_sym+zp_int8+no_zp_clamp",
  "+v_sym+signed_kv+no_zp_clamp",
  "+v_sym+zp_int8+signed_kv+no_zp_clamp"
]

metrics = [
  'acc_norm,none',
  'acc,none',
  'acc_norm,none',
  'acc_norm,none',
  'acc,none'
]

In [None]:
datasets = list(data[0]['results']['results'].keys())
datasets.pop(datasets.index('wikitext'))

In [None]:
results = {}
for i, name in enumerate(names):
  results[name] = {}
  curr_data = data[i]
  for j, dataset in enumerate(datasets): 
    result = curr_data['results']['results'][dataset][metrics[j]]
    results[name][dataset] = result

In [None]:
configs = list(results.keys())
metrics = list(next(iter(results.values())).keys())

# Plot one bar chart per metric
for metric in metrics:
    values = [results[config][metric] for config in configs]
    plt.figure()
    plt.bar(configs, values)
    plt.ylabel(metric)
    plt.xticks(rotation=90)
    plt.title(metric)
    plt.tight_layout()
    plt.show()