<a href="https://colab.research.google.com/github/digitaldaimyo/FastSAM/blob/main/ai_exp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Write and Zip

In [12]:

import os

# Create the directory if it doesn't exist
os.makedirs('ai_exp', exist_ok=True)

In [13]:

%%writefile ai_exp/tap.py

from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Set, Any, Optional, Literal, Iterable, Union, Mapping
import contextvars
from copy import deepcopy
import torch
import torch.nn as nn

# Global tap stack shared across all Tap.Modules (thread/Task-local via contextvars)
_TAP_STACK: contextvars.ContextVar[List["Tap._State"]] = contextvars.ContextVar("_TAP_STACK", default=[])


class Tap:
    """
    Unified forward-only tap system for model introspection.

    Key ideas:
      - Cheap capture by default (device='keep'): `tap()` only detaches.
      - Explicit materialization via CaptureHandle.snapshot('cpu'|'clone'|'keep'|'meta').
      - Global, contextvars-backed tap stack: no submodule rebinding; nesting composes.
    """

    @dataclass
    class _State:
        active: bool = False
        data: Dict[str, List[Any]] = field(default_factory=dict)
        allowed_keys: Optional[Set[str]] = None
        # What Module.tap does at write-time
        capture_device: Literal["keep", "cpu", "clone", "meta"] = "keep"

    class Context:
        class _CaptureHandle(Mapping):
            def __init__(self, data_ref: Dict[str, List[Any]]):
                self._data = data_ref

            def __getitem__(self, k): return self._data[k]
            def __iter__(self): return iter(self._data)
            def __len__(self): return len(self._data)

            def snapshot(self, tensor: Literal["keep","clone","cpu","meta"] = "keep") -> Dict[str, List[Any]]:
                """
                Take a deep copy for safe, post-context use.
                - 'keep': leave tensors as-is (already detached by tap()).
                - 'clone': clone tensors on current device.
                - 'cpu':   clone tensors to CPU.
                - 'meta':  replace tensors with {shape,dtype,device} dicts.
                """
                def tf(t: torch.Tensor):
                    if tensor == "keep":
                        return t
                    if tensor == "clone":
                        return t.clone()
                    if tensor == "cpu":
                        return t.detach().cpu().clone()
                    if tensor == "meta":
                        return {"shape": tuple(t.shape), "dtype": str(t.dtype), "device": str(t.device)}
                    raise ValueError("tensor must be one of {'keep','clone','cpu','meta'}")

                def walk(o):
                    if isinstance(o, torch.Tensor): return tf(o)
                    if isinstance(o, dict):  return {k: walk(v) for k, v in o.items()}
                    if isinstance(o, list):  return [walk(v) for v in o]
                    if isinstance(o, tuple): return tuple(walk(v) for v in o)
                    if isinstance(o, set):   return {walk(v) for v in o}
                    try: return deepcopy(o)
                    except Exception: return o

                return walk(self._data)

            __call__ = snapshot  # small convenience alias
            @property
            def live(self): return self._data

            def getone(self, key: str, idx: int = -1):
                lst = self._data.get(key, [])
                return lst[idx] if lst else None

        def __init__(
            self,
            keys: Union[str, Iterable[str], None] = None,
            device: Literal['keep','cpu','clone','meta'] = 'keep'
        ):
            self.keys = set(keys) if keys is not None and not isinstance(keys, str) else {keys} if isinstance(keys, str) else None
            self.device = device
            self.state = Tap._State(active=True, allowed_keys=self.keys, capture_device=self.device)
            self._token = None

        def __enter__(self) -> 'Tap.Context._CaptureHandle':
            stack = _TAP_STACK.get()
            new_stack = stack + [self.state]
            self._token = _TAP_STACK.set(new_stack)
            # Fresh buffer each entry
            self.state.data.clear()
            return Tap.Context._CaptureHandle(self.state.data)

        def __exit__(self, exc_type, exc_value, traceback):
            # Pop this state from the stack
            stack = _TAP_STACK.get()
            if stack and stack[-1] is self.state:
                _TAP_STACK.set(stack[:-1])
            else:
                # Defensive: remove if mis-nested
                _TAP_STACK.set([s for s in stack if s is not self.state])
            # Clean up
            self.state.active = False
            self.state.allowed_keys = None
            self.state.capture_device = 'keep'
            self.state.data.clear()

    @staticmethod
    def tap(*keys: str):
        if not keys:
            raise ValueError("Tap.tap requires at least one key")
        def decorator(fn):
            fn._tap_keys = list(keys)
            return fn
        return decorator

    class _KeyRecorder(dict):
        """Records keys written via __setitem__ without storing values."""
        def __init__(self):
            super().__init__()
            self._seen = set()
        def __setitem__(self, k, v):
            self._seen.add(k)
        def keys(self):
            return list(self._seen)

    class Module(nn.Module):
        def tap(self, key: str, value: Any):
            stack = _TAP_STACK.get()
            if not stack:
                return
            st: Tap._State = stack[-1]
            if not st.active or (st.allowed_keys is not None and key not in st.allowed_keys):
                return
            if key not in st.data:
                st.data[key] = []
            # Tensor handling per capture policy
            if isinstance(value, torch.Tensor):
                value = value.detach()
                if st.capture_device == 'cpu':
                    value = value.to('cpu', copy=True)
                elif st.capture_device == 'clone':
                    value = value.clone()
                elif st.capture_device == 'meta':
                    value = {"shape": tuple(value.shape), "dtype": str(value.dtype), "device": str(value.device)}
                # 'keep' => already detached; store reference
            st.data[key].append(value)

        def capture(
            self,
            keys: Union[str, Iterable[str], None] = None,
            device: Literal['keep','cpu','clone','meta'] = 'keep'
        ) -> 'Tap.Context':
            # Context is global; not bound to this module specifically.
            return Tap.Context(keys=keys, device=device)

        def list_taps(self, *args, **kwargs) -> List[str]:
            """
            Run a forward pass with a key recorder to discover emitted tap keys.
            Accepts arbitrary forward(*args, **kwargs).
            """
            # Push a temporary state using a KeyRecorder
            state = Tap._State(active=True, allowed_keys=None, capture_device='keep')
            recorder = Tap._KeyRecorder()
            state.data = recorder  # type: ignore[assignment]
            stack = _TAP_STACK.get()
            _TAP_STACK.set(stack + [state])
            try:
                self(*args, **kwargs)
                return list(recorder.keys())
            finally:
                # Pop the temporary state
                cur = _TAP_STACK.get()
                if cur and cur[-1] is state:
                    _TAP_STACK.set(cur[:-1])
                else:
                    _TAP_STACK.set([s for s in cur if s is not state])

        @classmethod
        def available_taps(cls) -> List[str]:
            seen, taps = set(), []
            for base in reversed(cls.__mro__):
                if not issubclass(base, Tap.Module): continue
                for obj in base.__dict__.values():
                    if callable(obj) and hasattr(obj, "_tap_keys"):
                        for key in obj._tap_keys:  # type: ignore[attr-defined]
                            if key not in seen:
                                seen.add(key); taps.append(key)
            return taps

Writing ai_exp/tap.py


In [14]:
%%writefile ai_exp/utils.py

from __future__ import annotations
import os, json, hashlib, torch

from typing import Dict, Any

def stable_hash(obj: Dict[str, Any]) -> str:
    canon = json.dumps(obj, sort_keys=True, separators=(",", ":"))
    return hashlib.sha1(canon.encode()).hexdigest()[:10]

def save_ckpt_atomic(path: str, payload: Dict[str, Any]):
    """
    Durable atomic write:
      - write to tmp
      - fsync file
      - os.replace
      - fsync directory
    """
    os.makedirs(os.path.dirname(path), exist_ok=True)
    tmp = path + ".tmp"
    with open(tmp, "wb") as f:
        torch.save(payload, f)
        f.flush()
        os.fsync(f.fileno())
    os.replace(tmp, path)
    dir_fd = os.open(os.path.dirname(path), os.O_DIRECTORY)
    try:
        os.fsync(dir_fd)
    finally:
        os.close(dir_fd)

Writing ai_exp/utils.py


In [15]:

%%writefile ai_exp/policy.py
from __future__ import annotations
from typing import Callable, List, Dict, Any
from dataclasses import dataclass
import traceback as _tb

@dataclass
class Context:
    epoch: int
    step: int
    model: Any
    device: Any
    cfg: Any
    train_loader: Any
    eval_loader: Any
    run: Any
    logger: Any

class Policy:
    _REGISTRY: Dict[str, Callable] = {}

    @classmethod
    def register(cls, name: str):
        def wrap(fn: Callable):
            if name in cls._REGISTRY:
                raise ValueError(f"Collector '{name}' already registered.")
            cls._REGISTRY[name] = fn
            fn._collector_name = name  # type: ignore[attr-defined]
            return fn
        return wrap

    @classmethod
    def get(cls, name: str) -> Callable:
        if name not in cls._REGISTRY:
            raise KeyError(f"Collector '{name}' not found.")
        return cls._REGISTRY[name]

    @classmethod
    def list_collectors(cls) -> List[str]:
        return sorted(cls._REGISTRY.keys())

    def __init__(self, name: str = "default"):
        self.name = name
        self._config: Dict[str, Any] = {"every_epoch": [], "every_k_epochs": []}

    def _normalize_every_epoch(self, items: List[Any]) -> List[Dict[str, Any]]:
        norm = []
        for it in items:
            if isinstance(it, str):
                self.get(it)
                norm.append({"name": it, "params": {}})
            elif isinstance(it, dict) and "name" in it:
                self.get(it["name"])
                norm.append({"name": str(it["name"]), "params": dict(it.get("params", {}))})
            else:
                raise ValueError("Invalid every_epoch spec")
        return norm

    def _normalize_every_k(self, items: List[Any]) -> List[Dict[str, Any]]:
        norm = []
        for it in items:
            if not (isinstance(it, dict) and "name" in it and "k" in it):
                raise ValueError("every_k_epochs must have {'name','k',...}")
            self.get(it["name"])
            norm.append({"name": str(it["name"]), "k": int(it["k"]), "params": dict(it.get("params", {}))})
        return norm

    def configure(self, *, every_epoch=None, every_k_epochs=None):
        if every_epoch is not None:
            self._config["every_epoch"] = self._normalize_every_epoch(every_epoch)
        if every_k_epochs is not None:
            self._config["every_k_epochs"] = self._normalize_every_k(every_k_epochs)
        return self  # chaining

    def to_manifest(self) -> Dict[str, Any]:
        return {
            "name": self.name,
            "every_epoch": [{"name": e["name"], "params": e.get("params", {})} for e in self._config["every_epoch"]],
            "every_k_epochs": [{"name": e["name"], "k": e["k"], "params": e.get("params", {})} for e in self._config["every_k_epochs"]],
            "available_collectors": self.list_collectors(),
        }

    def _safe_call(self, ctx: Context, name: str, **params):
        try:
            fn = self.get(name)
            fn(ctx, **params)
        except Exception as e:
            ctx.logger.log_event(
                "collector_error",
                epoch=ctx.epoch, name=name,
                msg=str(e),
                tb="".join(_tb.format_exc(limit=5))
            )

    def run_collectors(self, ctx: Context, edge_epoch_tick: bool = True):
        for spec in self._config["every_epoch"]:
            self._safe_call(ctx, spec["name"], **spec.get("params", {}))

        for spec in self._config["every_k_epochs"]:
            if ctx.epoch % spec["k"] == 0:
                self._safe_call(ctx, spec["name"], **spec.get("params", {}))

        if edge_epoch_tick and ctx.epoch + 1 == ctx.cfg.train.max_epochs:
            ctx.logger.log_event("run_complete", epoch=ctx.epoch)

Writing ai_exp/policy.py


In [16]:

%%writefile ai_exp/logger.py
from __future__ import annotations
import os, json, time, numpy as np
import torch

class RunLogger:
    def __init__(self, run):
        self.run_dir = run.run_dir
        self.art_dir = os.path.join(run.run_dir, "artifacts")
        self.log_dir = os.path.join(run.run_dir, "logs")
        os.makedirs(self.art_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)

        self._epoch_f = open(os.path.join(self.log_dir, "epoch.jsonl"), "a", buffering=1)
        self._events_f = open(os.path.join(self.log_dir, "events.jsonl"), "a", buffering=1)
        self._lock_path = os.path.join(run.run_dir, "_LOCK")
        # Write PID to help detect stale locks
        with open(self._lock_path, "w") as f:
            try:
                f.write(str(os.getpid()))
            except Exception:
                f.write("unknown")

    def _print(self, msg: str):
        try:
            print(msg)
        except Exception:
            pass

    def log_epoch(self, to_console: bool = False, **kv):
        rec = {"t": time.time(), **kv}
        self._epoch_f.write(json.dumps(rec) + "\n")
        if to_console:
            ep = rec.get("epoch")
            if isinstance(ep, int):
                msg = f"[epoch {ep:03d}]"
            else:
                msg = "[epoch ?]"
            lr = rec.get("lr")
            tl = rec.get("train_loss")
            l1 = rec.get("eval_l1")
            if isinstance(lr, (int, float)): msg += f" lr={lr:.2e}"
            if isinstance(tl, (int, float)): msg += f" train={tl:.4f}"
            if isinstance(l1, (int, float)): msg += f" evalL1={l1:.4f}"
            self._print(msg)

    def log_event(self, kind: str, to_console: bool = False, **kv):
        rec = {"t": time.time(), "kind": kind, **kv}
        self._events_f.write(json.dumps(rec) + "\n")
        if to_console:
            parts = [f"{k}={v}" for k, v in sorted(kv.items())]
            self._print(f"[event:{kind}] {' '.join(parts)}")

    def save_npz(self, name: str, to_console: bool = False, **arrays):
        path = os.path.join(self.art_dir, name)
        np.savez_compressed(path, **{
            k: (v.detach().cpu().numpy() if torch.is_tensor(v) else v)
            for k, v in arrays.items()
        })
        if to_console:
            self._print(f"[artifact] saved {name}")
        return path

    def close(self):
        try:
            self._epoch_f.close()
            self._events_f.close()
        except Exception:
            pass
        try:
            os.remove(self._lock_path)
        except Exception:
            pass

Writing ai_exp/logger.py


In [17]:

%%writefile ai_exp/run.py
from __future__ import annotations
import os, time, json, torch, random, numpy as np
from typing import Optional
from dataclasses import asdict
from .utils import stable_hash, save_ckpt_atomic
from .logger import RunLogger
from .policy import Policy, Context

class Run:
    def __init__(self, cfg, tag: Optional[str] = None, policy: Optional[Policy] = None):
        self.cfg = cfg
        self.tag = (tag or "").strip()
        self.policy = policy
        self.logger: Optional[RunLogger] = None
        self.started = False
        self.device: Optional[torch.device] = None
        self.dtype: Optional[torch.dtype] = None
        self.timestamp: Optional[str] = None
        self.requested_device: Optional[str] = None  # preserve original intent

    def _resolve_device(self) -> torch.device:
        req = getattr(self.cfg.train, "device", "auto")
        self.requested_device = req
        if req == "auto":
            dev = "cuda" if torch.cuda.is_available() else "cpu"
            resolved = dev
        else:
            resolved = req
        # record the resolved choice but keep the requested for manifest
        self.cfg.train.device = resolved
        return torch.device(resolved)

    def _resolve_dtype(self) -> torch.dtype:
        p = self.cfg.train.precision.lower()
        if p == "fp32": return torch.float32
        if p == "fp16": return torch.float16
        if p == "bf16": return torch.bfloat16
        raise ValueError(f"Unknown precision: {self.cfg.train.precision}")

    def _seed_everything(self):
        sd = self.cfg.train.seed
        random.seed(sd); np.random.seed(sd); torch.manual_seed(sd)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(sd)
        # Determinism knobs
        torch.use_deterministic_algorithms(self.cfg.run.deterministic)
        # Optional: TF32 off for stricter determinism
        try:
            torch.backends.cuda.matmul.allow_tf32 = False
            torch.backends.cudnn.allow_tf32 = False
        except Exception:
            pass
        # Optional CUDA workspace for full determinism (PyTorch will warn if needed)
        os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")

    def _is_writable(self, path: str) -> bool:
        try:
            os.makedirs(path, exist_ok=True)
            p = os.path.join(path, ".touch")
            with open(p, "w") as f: f.write("ok")
            os.remove(p)
            return True
        except Exception:
            return False

    def _resolve_paths(self):
        root = self.cfg.run.log_root if self._is_writable(self.cfg.run.log_root) else self.cfg.run.local_fallback
        self.timestamp = time.strftime("%Y%m%d-%H%M%S")
        if not hasattr(self.cfg, "exp_hash") or not self.cfg.exp_hash:
            self.cfg.finalize()
        tag = self.tag or "default"
        rd = os.path.join(root, tag, self.cfg.exp_hash[:8], self.timestamp)
        for sub in ("", "artifacts", "logs", "checkpoints"):
            os.makedirs(os.path.join(rd, sub), exist_ok=True)
        self.cfg.run.run_dir = rd
        return rd

    def _write_manifest(self):
        cuda_info = {}
        if torch.cuda.is_available():
            try:
                cuda_info = {
                    "device_name": torch.cuda.get_device_name(0),
                    "cc": ".".join(map(str, torch.cuda.get_device_capability(0))),
                    "cuda": torch.version.cuda,
                }
            except Exception:
                cuda_info = {}
        man = {
            "schema": 1,
            "created_at": time.time(),
            "timestamp": self.timestamp,
            "tag": self.tag,
            "exp_hash": self.cfg.exp_hash,
            "cfg": asdict(self.cfg),
            "policy": (self.policy.to_manifest() if self.policy else None),
            "paths": {
                "run_dir": self.cfg.run.run_dir,
                "artifacts": os.path.join(self.cfg.run.run_dir, "artifacts"),
                "logs": os.path.join(self.cfg.run.run_dir, "logs"),
                "checkpoints": os.path.join(self.cfg.run.run_dir, "checkpoints"),
            },
            "env": {"torch": torch.__version__, **cuda_info},
            "requested_device": self.requested_device,
            "resolved_device": self.cfg.train.device,
        }
        with open(os.path.join(self.cfg.run.run_dir, "manifest.json"), "w") as f:
            json.dump(man, f, indent=2, default=str)

    def _print_header(self):
        header = [
            f"Run started",
            f" tag: {self.tag}",
            f" device: {self.device.type}",
            f" precision: {str(self.dtype).split('.')[-1]}",
            f" seed: {self.cfg.train.seed}",
            f" exp_hash: {self.cfg.exp_hash[:8]}",
            f" timestamp: {self.timestamp}",
            f" run_dir: {self.cfg.run.run_dir}",
            f" checkpoints: {os.path.join(self.cfg.run.run_dir, 'checkpoints')}",
            f" artifacts: {os.path.join(self.cfg.run.run_dir, 'artifacts')}",
        ]
        print("\n".join(header))

    def start(self):
        assert not self.started, "Run already started"
        self.cfg.finalize()
        self.device = self._resolve_device()
        self.dtype = self._resolve_dtype()
        self._seed_everything()
        self._resolve_paths()
        self._write_manifest()
        self.logger = RunLogger(self)
        self._print_header()
        self.started = True
        return self.device, self.logger

    def finish(self):
        if self.logger: self.logger.close()
        self.started = False

    def apply_policy(self, ctx: Context, edge_epoch_tick: bool = True):
        if self.policy is not None:
            self.policy.run_collectors(ctx, edge_epoch_tick)

    @property
    def run_dir(self) -> str:
        assert self.cfg.run.run_dir
        return self.cfg.run.run_dir

    @property
    def checkpoints_dir(self) -> str:
        return os.path.join(self.run_dir, "checkpoints")

    @property
    def artifacts_dir(self) -> str:
        return os.path.join(self.run_dir, "artifacts")

Writing ai_exp/run.py


In [18]:

%%writefile ai_exp/collectors.py
from __future__ import annotations
import os
import io
from typing import Dict, Any, Iterable, Optional, Tuple
import numpy as np
import torch

from .policy import Policy, Context
from .tap import Tap


def _flatten_named_params(model: torch.nn.Module, names: Optional[Iterable[str]] = None):
    """
    Yields (name, param) pairs. If `names` is provided, filters to those names (exact match or suffix match).
    """
    wanted = set(names) if names is not None else None
    for n, p in model.named_parameters():
        if wanted is None:
            yield n, p
        else:
            if n in wanted or any(n.endswith("/" + w) or n.endswith("." + w) for w in wanted):
                yield n, p


def _safe_tensor_to_np(t: torch.Tensor):
    try:
        return t.detach().cpu().numpy()
    except Exception:
        return None


def _summarize_model_text(model: torch.nn.Module) -> str:
    # Param counts
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable = total - trainable

    # Small one-line-per-submodule tree
    lines = [f"Total params: {total:,} (trainable={trainable:,}, frozen={non_trainable:,})"]
    for name, module in model.named_modules():
        if name == "":
            continue
        params = sum(p.numel() for p in module.parameters(recurse=False))
        lines.append(f"{name}: {module.__class__.__name__} (params={params:,})")
    return "\n".join(lines)


def _ensure_dir(path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)


def _write_text_artifact(ctx: Context, name: str, text: str) -> str:
    path = os.path.join(ctx.run.artifacts_dir, name)
    _ensure_dir(path)
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)
    ctx.logger.log_event("artifact_text", name=name)
    return path


def _stack_or_meta(seq):
    """
    Try to stack a list of tensors. If incompatible, return a meta dict.
    """
    if not seq:
        return None
    if all(isinstance(x, torch.Tensor) for x in seq):
        shapes = [tuple(x.shape) for x in seq]
        dtypes = [str(x.dtype) for x in seq]
        if len(set(shapes)) == 1 and len(set(dtypes)) == 1:
            try:
                return torch.stack(seq, dim=0)
            except Exception:
                pass
        # fallthrough: meta
        return {
            "shapes": shapes,
            "dtypes": dtypes,
            "count": len(seq),
        }
    # Mixed content: return counts/meta
    return {"types": [type(x).__name__ for x in seq], "count": len(seq)}


# ----------------------
# Collectors
# ----------------------

@Policy.register("grad_norms")
def collect_grad_norms(ctx: Context, group_by: str = "global"):
    """
    Saves gradient L2 norms.
      group_by: "global" (single number) or "per_param"
    """
    norms = []
    per_param = {}
    for n, p in ctx.model.named_parameters():
        if p.grad is None:
            continue
        g = p.grad.detach()
        val = g.norm(2)
        norms.append(val)
        if group_by == "per_param":
            per_param[n] = val

    if norms:
        gsum = torch.stack(norms).norm(2)
        arrs: Dict[str, Any] = {"grad_norm_global": gsum}
        if group_by == "per_param":
            # Convert to a structured array (names truncated for npz keys)
            for k, v in per_param.items():
                arrs[f"grad_norm__{k}"] = v
        ctx.logger.save_npz(f"grad_norms_epoch{ctx.epoch:04d}.npz", **arrs)
        ctx.logger.log_event("grad_norms", epoch=ctx.epoch, global_l2=float(gsum.detach().cpu()))
    else:
        ctx.logger.log_event("grad_norms", epoch=ctx.epoch, note="no_grads")


@Policy.register("param_hists")
def collect_param_histograms(ctx: Context, names: Optional[Iterable[str]] = None, bins: int = 64, max_params: int = 256):
    """
    Saves compact histograms for parameters. If `names` is None, samples up to `max_params` parameters.
    """
    arrs = {}
    count = 0
    for n, p in _flatten_named_params(ctx.model, names):
        if p.numel() == 0:
            continue
        # Sample to a budget to avoid giant artifacts
        if names is None and count >= max_params:
            break
        x = _safe_tensor_to_np(p)
        if x is None:
            continue
        # Flatten and histogram
        x = x.ravel()
        try:
            h, edges = np.histogram(x, bins=bins)
            arrs[f"{n}__hist"] = h
            arrs[f"{n}__edges"] = edges
            count += 1
        except Exception:
            continue

    if arrs:
        ctx.logger.save_npz(f"param_hists_epoch{ctx.epoch:04d}.npz", **arrs)
        ctx.logger.log_event("param_hists", epoch=ctx.epoch, params=count)
    else:
        ctx.logger.log_event("param_hists", epoch=ctx.epoch, note="no_params_or_failed")


@Policy.register("gpu_mem")
def collect_gpu_memory(ctx: Context):
    """
    Logs CUDA memory (in MB) if available.
    """
    if torch.cuda.is_available():
        try:
            alloc = torch.cuda.memory_allocated() / (1024 ** 2)
            reserv = torch.cuda.memory_reserved() / (1024 ** 2)
            ctx.logger.log_event("gpu_mem", epoch=ctx.epoch, allocated_mb=round(alloc, 2), reserved_mb=round(reserv, 2))
        except Exception as e:
            ctx.logger.log_event("gpu_mem_error", epoch=ctx.epoch, msg=str(e))
    else:
        ctx.logger.log_event("gpu_mem", epoch=ctx.epoch, note="cpu_only")


@Policy.register("model_summary")
def collect_model_summary(ctx: Context, filename: str = "model_summary.txt"):
    """
    Writes a simple text summary + param counts.
    """
    try:
        text = _summarize_model_text(ctx.model)
        _write_text_artifact(ctx, filename, text)
        ctx.logger.log_event("model_summary", epoch=ctx.epoch, file=filename)
    except Exception as e:
        ctx.logger.log_event("model_summary_error", epoch=ctx.epoch, msg=str(e))


@Policy.register("eval_sample")
def collect_eval_sample(ctx: Context, batches: int = 1, filename: str = "eval_sample_epoch{epoch:04d}.npz"):
    """
    Runs a few eval batches and saves outputs (and targets if present) to NPZ.
    Tries to be generic: supports dataloaders yielding:
      - (inputs, targets, *rest), or
      - dict with 'inputs'/'targets' keys, or
      - just inputs.
    """
    if ctx.eval_loader is None:
        ctx.logger.log_event("eval_sample", epoch=ctx.epoch, note="no_eval_loader")
        return

    model = ctx.model
    model_device = next((p.device for p in model.parameters() if p is not None), ctx.device)
    saved = False
    arrays: Dict[str, Any] = {}
    k = 0

    def _split_batch(b) -> Tuple[torch.Tensor, Optional[Any]]:
        if isinstance(b, dict):
            x = b.get("inputs", b.get("x", b.get("data")))
            y = b.get("targets", b.get("y", None))
            return x, y
        if isinstance(b, (tuple, list)) and len(b) >= 1:
            x = b[0]
            y = b[1] if len(b) >= 2 else None
            return x, y
        return b, None

    model.eval()
    with torch.no_grad():
        for b in ctx.eval_loader:
            x, y = _split_batch(b)
            try:
                if isinstance(x, torch.Tensor):
                    x = x.to(model_device, non_blocking=True)
                out = model(x)
                # Save a small slice to keep artifact sizes tame
                if isinstance(out, torch.Tensor):
                    arrays[f"out_{k}"] = out.detach().cpu()
                elif isinstance(out, (list, tuple)) and out and isinstance(out[0], torch.Tensor):
                    arrays[f"out_{k}"] = out[0].detach().cpu()
                if isinstance(x, torch.Tensor):
                    arrays[f"in_{k}"] = x.detach().cpu()
                if isinstance(y, torch.Tensor):
                    arrays[f"tgt_{k}"] = y.detach().cpu()
                saved = True
            except Exception as e:
                ctx.logger.log_event("eval_sample_error", epoch=ctx.epoch, msg=str(e))
            k += 1
            if k >= batches:
                break

    if saved:
        fname = filename.format(epoch=ctx.epoch)
        ctx.logger.save_npz(fname, **arrays)
        ctx.logger.log_event("eval_sample", epoch=ctx.epoch, batches=k, file=fname)
    else:
        ctx.logger.log_event("eval_sample", epoch=ctx.epoch, note="nothing_saved")


@Policy.register("tap_eval")
def collect_tap_eval(ctx: Context, keys: Optional[Iterable[str]] = None, filename: str = "tap_eval_epoch{epoch:04d}.npz", device: str = "keep"):
    """
    If the model inherits from Tap.Module, run one eval batch under a capture and persist tap values.
    - Tries to stack per-key lists of tensors; if not possible, saves meta.
    - `device`: 'keep'|'cpu'|'clone'|'meta' for capture-time write behavior.
    """
    if not isinstance(ctx.model, Tap.Module):
        ctx.logger.log_event("tap_eval", epoch=ctx.epoch, note="model_not_TapModule")
        return
    if ctx.eval_loader is None:
        ctx.logger.log_event("tap_eval", epoch=ctx.epoch, note="no_eval_loader")
        return

    model: Tap.Module = ctx.model
    model_device = next((p.device for p in model.parameters() if p is not None), ctx.device)

    # Get one batch
    try:
        batch = next(iter(ctx.eval_loader))
    except Exception as e:
        ctx.logger.log_event("tap_eval_error", epoch=ctx.epoch, msg=f"loader_iter: {e}")
        return

    def _get_inputs(b):
        if isinstance(b, dict):
            return b.get("inputs", b.get("x", b.get("data", b)))
        if isinstance(b, (tuple, list)) and len(b) >= 1:
            return b[0]
        return b

    x = _get_inputs(batch)
    if isinstance(x, torch.Tensor):
        x = x.to(model_device, non_blocking=True)

    arrays: Dict[str, Any] = {}
    meta: Dict[str, Any] = {}

    model.eval()
    with torch.no_grad():
        with model.capture(keys=keys, device=device) as cap:
            try:
                _ = model(x)
            except Exception as e:
                ctx.logger.log_event("tap_eval_error", epoch=ctx.epoch, msg=f"forward: {e}")
                return
            snap = cap.snapshot("cpu")  # materialize to CPU for saving

    # Try to convert per-key lists into stackable arrays
    for k, seq in snap.items():
        st = _stack_or_meta(seq)
        if st is None:
            continue
        if isinstance(st, torch.Tensor):
            arrays[f"{k}"] = st  # tensor will be moved to np by save_npz
        else:
            meta[k] = st

    fname = filename.format(epoch=ctx.epoch)
    if arrays:
        ctx.logger.save_npz(fname, **arrays)
    # Save meta (if any) as a sidecar JSON
    if meta:
        sidecar = fname.replace(".npz", ".meta.json")
        path = os.path.join(ctx.run.artifacts_dir, sidecar)
        os.makedirs(os.path.dirname(path), exist_ok=True)
        import json
        with open(path, "w", encoding="utf-8") as f:
            json.dump(meta, f, indent=2)
    ctx.logger.log_event("tap_eval", epoch=ctx.epoch, keys=list(snap.keys()), file=fname, meta=("yes" if meta else "no"))

Writing ai_exp/collectors.py


In [20]:
import os

init_file = 'ai_exp/__init__.py'
with open(init_file, 'w') as f:
    pass
print(f"Created {init_file}")

Created ai_exp/__init__.py


In [21]:

package_name = 'ai_exp'
requirements = ['numpy', 'torch']
requirements_file_path = os.path.join(package_name, "requirements.txt")

import os

init_file = requirements_file_path
with open(requirements_file_path, 'w') as f:
    for req in requirements:
        f.write(f"{req}\n")

print(f"Created {requirements_file_path}")

Created ai_exp/requirements.txt


In [23]:
import shutil
# Zip the 'ai_exp' directory
shutil.make_archive(package_name, 'zip', root_dir='.', base_dir=package_name)
print(f"Created {package_name}.zip")

Created ai_exp.zip


# Extract

In [2]:
import zipfile
import os

zip_file_path = 'ai_exp.zip' # Replace with the path to your zip file
# Extract to a subdirectory named 'ai_exp_extracted' inside the current directory
extract_dir = 'ai_exp_extracted/ai_exp_files' # Modify this to your desired extraction path

if os.path.exists(zip_file_path):
    # Create the extraction directory if it doesn't exist
    os.makedirs(extract_dir, exist_ok=True)
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        # Extract all contents to the specified directory
        zip_ref.extractall(extract_dir)
    print(f"Extracted {zip_file_path} to {extract_dir}")
else:
    print(f"Error: {zip_file_path} not found.")

Extracted ai_exp.zip to ai_exp


# Test

In [24]:

# --- cell 1: setup & config ---
from __future__ import annotations
import os, json, time
from dataclasses import dataclass, asdict
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# your framework (ensure xai_exp/ is importable)
from ai_exp.tap import Tap
from ai_exp.policy import Policy, Context
from ai_exp.collectors import *  # registers collectors
from ai_exp.run import Run
from ai_exp.utils import stable_hash, save_ckpt_atomic

@dataclass
class TrainCfg:
    device: str = "auto"            # "auto" | "cpu" | "cuda"
    precision: str = "fp32"         # "fp32" | "fp16" | "bf16"
    seed: int = 1337
    batch_size: int = 256
    max_epochs: int = 5
    lr: float = 0.2
    momentum: float = 0.9
    weight_decay: float = 5e-4
    amp: bool = True                # use autocast if fp16/bf16 on cuda

@dataclass
class RunCfg:
    log_root: str = "/content/xai_runs"
    local_fallback: str = "/content/xai_runs"
    deterministic: bool = True
    run_dir: Optional[str] = None

@dataclass
class DataCfg:
    num_workers: int = 2

@dataclass
class Cfg:
    train: TrainCfg
    run: RunCfg
    data: DataCfg
    exp_hash: str = ""

    def finalize(self):
        canon = {
            "train": asdict(self.train),
            "run": {k:v for k,v in asdict(self.run).items() if k != "run_dir"},
            "data": asdict(self.data),
        }
        self.exp_hash = stable_hash(canon)

cfg = Cfg(train=TrainCfg(), run=RunCfg(), data=DataCfg())

In [25]:

# --- cell 2: dataset & loaders ---
import torchvision
import torchvision.transforms as T

def make_loaders(cfg: Cfg):
    # Standard CIFAR-10 stats
    MEAN = (0.4914, 0.4822, 0.4465)
    STD  = (0.2470, 0.2435, 0.2616)

    train_tf = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(MEAN, STD),
    ])
    eval_tf = T.Compose([
        T.ToTensor(),
        T.Normalize(MEAN, STD),
    ])

    train_set = torchvision.datasets.CIFAR10(root="/content/data", train=True, download=True, transform=train_tf)
    eval_set  = torchvision.datasets.CIFAR10(root="/content/data", train=False, download=True, transform=eval_tf)

    train_loader = DataLoader(train_set, batch_size=cfg.train.batch_size, shuffle=True,
                              num_workers=cfg.data.num_workers, pin_memory=True)
    eval_loader  = DataLoader(eval_set, batch_size=512, shuffle=False,
                              num_workers=cfg.data.num_workers, pin_memory=True)
    return train_loader, eval_loader

train_loader, eval_loader = make_loaders(cfg)

100%|██████████| 170M/170M [00:12<00:00, 14.1MB/s]


In [26]:

# --- cell 3: model with taps ---
class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.down  = None
        if stride != 1 or in_ch != out_ch:
            self.down = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        idt = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        if self.down is not None:
            idt = self.down(idt)
        x = F.relu(x + idt)
        return x

class CIFARNet(Tap.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )
        self.layer1 = BasicBlock(32, 64, stride=2)
        self.layer2 = BasicBlock(64, 128, stride=2)
        self.layer3 = BasicBlock(128, 128, stride=1)
        self.pool   = nn.AdaptiveAvgPool2d((1,1))
        self.fc     = nn.Linear(128, num_classes)

        # declare tap keys for discoverability
        @Tap.tap("stem_in","stem_out","l1_out","l2_out","l3_out","head_in","logits")
        def _taps(): pass
        self._taps = _taps

    def forward(self, x):
        self.tap("stem_in", x)
        x = self.stem(x)
        self.tap("stem_out", x)
        x = self.layer1(x)
        self.tap("l1_out", x)
        x = self.layer2(x)
        self.tap("l2_out", x)
        x = self.layer3(x)
        self.tap("l3_out", x)
        h = self.pool(x).flatten(1)
        self.tap("head_in", h)
        logits = self.fc(h)
        self.tap("logits", logits)
        return logits

In [27]:

# --- cell 4: policy, run, training/eval ---
# Collectors policy exercising all mechanisms
policy = Policy("cifar10_demo").configure(
    every_epoch=[
        "gpu_mem",
        {"name":"grad_norms","params":{"group_by":"global"}},
        {"name":"param_hists","params":{"bins":64,"max_params":64}},
        {"name":"eval_sample","params":{"batches":1}},
        "model_summary",
        {"name":"tap_eval","params":{"device":"keep"}},  # taps via cheap capture
    ],
    every_k_epochs=[
        {"name":"param_hists","k":3,"params":{"bins":128,"max_params":128}},
    ]
)

model = CIFARNet(num_classes=10)
criterion = nn.CrossEntropyLoss()

run = Run(cfg=cfg, tag="cifar10_tap_demo", policy=policy)
device, logger = run.start()
model.to(device)

# Optimizer & simple cosine schedule
optimizer = torch.optim.SGD(model.parameters(), lr=cfg.train.lr,
                            momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.train.max_epochs)

# AMP setup (optional)
use_amp = (cfg.train.amp and device.type == "cuda" and cfg.train.precision.lower() in {"fp16","bf16"})
autocast_dtype = torch.float16 if cfg.train.precision.lower()=="fp16" else torch.bfloat16
scaler = torch.cuda.amp.GradScaler(enabled=use_amp and autocast_dtype==torch.float16)

def evaluate(model, loader, device):
    model.eval()
    correct = total = 0
    l1 = 0.0
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=use_amp, dtype=autocast_dtype):
        for x, y in loader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()
            l1 += logits.abs().mean().item()
    return correct / max(1,total), l1 / max(1,len(loader))

global_step = 0
for epoch in range(cfg.train.max_epochs):
    model.train()
    running_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=use_amp, dtype=autocast_dtype):
            logits = model(x)
            loss = criterion(logits, y)
        if scaler.is_enabled():
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        running_loss += float(loss.item())
        global_step += 1

    scheduler.step()
    epoch_loss = running_loss / max(1, len(train_loader))
    acc, l1 = evaluate(model, eval_loader, device)

    logger.log_epoch(to_console=True,
                     epoch=epoch,
                     lr=scheduler.get_last_lr()[0],
                     train_loss=epoch_loss,
                     eval_acc=acc,
                     eval_l1=l1)

    # Run collectors (drops artifacts incl. tap snapshots)
    ctx = Context(epoch=epoch, step=global_step, model=model, device=device,
                  cfg=cfg, train_loader=train_loader, eval_loader=eval_loader,
                  run=run, logger=logger)
    run.apply_policy(ctx, edge_epoch_tick=True)

    # Save checkpoint atomically
    ckpt = {
        "epoch": epoch,
        "model": model.state_dict(),
        "optim": optimizer.state_dict(),
        "sched": scheduler.state_dict(),
        "cfg": asdict(cfg),
    }
    save_ckpt_atomic(os.path.join(run.checkpoints_dir, f"epoch_{epoch:04d}.pt"), ckpt)

run.finish()
print("Run dir:", cfg.run.run_dir)
print("Artifacts:", os.listdir(os.path.join(cfg.run.run_dir, "artifacts")))
print("Logs:", os.listdir(os.path.join(cfg.run.run_dir, "logs")))

Run started
 tag: cifar10_tap_demo
 device: cuda
 precision: float32
 seed: 1337
 exp_hash: fc1c5f3c
 timestamp: 20251031-200645
 run_dir: /content/xai_runs/cifar10_tap_demo/fc1c5f3c/20251031-200645
 checkpoints: /content/xai_runs/cifar10_tap_demo/fc1c5f3c/20251031-200645/checkpoints
 artifacts: /content/xai_runs/cifar10_tap_demo/fc1c5f3c/20251031-200645/artifacts


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp and autocast_dtype==torch.float16)
  with torch.cuda.amp.autocast(enabled=use_amp, dtype=autocast_dtype):
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=use_amp, dtype=autocast_dtype):


[epoch 000] lr=1.81e-01 train=1.6608 evalL1=1.7620
[epoch 001] lr=1.31e-01 train=1.2030 evalL1=2.3557
[epoch 002] lr=6.91e-02 train=0.9396 evalL1=2.3013
[epoch 003] lr=1.91e-02 train=0.7401 evalL1=2.8518
[epoch 004] lr=0.00e+00 train=0.5997 evalL1=2.9267
Run dir: /content/xai_runs/cifar10_tap_demo/fc1c5f3c/20251031-200645
Artifacts: ['param_hists_epoch0002.npz', 'eval_sample_epoch0003.npz', 'tap_eval_epoch0001.npz', 'eval_sample_epoch0000.npz', 'model_summary.txt', 'param_hists_epoch0003.npz', 'tap_eval_epoch0000.npz', 'grad_norms_epoch0003.npz', 'param_hists_epoch0001.npz', 'eval_sample_epoch0001.npz', 'param_hists_epoch0000.npz', 'eval_sample_epoch0004.npz', 'grad_norms_epoch0004.npz', 'grad_norms_epoch0001.npz', 'tap_eval_epoch0003.npz', 'grad_norms_epoch0002.npz', 'grad_norms_epoch0000.npz', 'param_hists_epoch0004.npz', 'eval_sample_epoch0002.npz', 'tap_eval_epoch0004.npz', 'tap_eval_epoch0002.npz']
Logs: ['events.jsonl', 'epoch.jsonl']


In [None]:

# --- cell 5: analysis utilities ---
from glob import glob
import numpy as np

print("Available taps (class):", CIFARNet.available_taps())

# list_taps via a dry forward
dummy = torch.randn(2,3,32,32).to(device)
print("list_taps(*):", model.list_taps(dummy))

# Cheap capture during eval, then explicit CPU snapshot
model.eval()
with torch.no_grad():
    xb, yb = next(iter(eval_loader))
    xb = xb.to(device)
    with model.capture(device='keep') as cap:
        _ = model(xb)
    snap = cap.snapshot('cpu')

print("Captured keys:", list(snap.keys()))
for k, seq in list(snap.items())[:4]:
    if isinstance(seq, list) and seq and torch.is_tensor(seq[0]):
        print(f"  {k:10s}: {len(seq)} tensors; first shape={tuple(seq[0].shape)}")

# Inspect tap_eval artifact
tap_npzs = sorted(glob(os.path.join(cfg.run.run_dir, "artifacts", "tap_eval_epoch*.npz")))
print("tap_eval artifacts:", [os.path.basename(p) for p in tap_npzs][-3:])
if tap_npzs:
    npz = np.load(tap_npzs[-1], allow_pickle=True)
    print("npz keys:", list(npz.keys())[:8])
    k0 = list(npz.keys())[0]
    print(k0, "->", npz[k0].shape, npz[k0].dtype)

# Look at a few events
with open(os.path.join(cfg.run.run_dir, "logs", "events.jsonl"), "r") as f:
    lines = f.readlines()[-10:]
print("\nLast 10 events:")
for ln in lines:
    print(ln.strip())

# Show a couple histogram keys
hist_npzs = sorted(glob(os.path.join(cfg.run.run_dir, "artifacts", "param_hists_epoch*.npz")))
if hist_npzs:
    npz = np.load(hist_npzs[-1], allow_pickle=True)
    some = [k for k in npz.files if k.endswith("__hist")][:5]
    print("\nHistogram keys:", some)