In [2]:
import abc
import torch as t
import nnsight 

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
class Lens(abc.ABC, t.nn.Module):
    """Abstract base class for all Lens."""

    def __init__(
            self, 
            unembed: nnsight.module 
        ):
        """Create a Lens.

        Args:
            unembed: The unembed operation to use.
        """
        super().__init__()

        self.unembed = unembed

    @abc.abstractmethod
    def transform_hidden(self, h: t.Tensor, idx: int) -> t.Tensor:
        """Convert a hidden state to the final hidden just before the unembedding.

        Args:
            h: The hidden state to convert.
            idx: The layer of the transformer these hidden states come from.
        """
        ...

    @abc.abstractmethod
    def forward(self, h: t.Tensor, idx: int) -> t.Tensor:
        """Decode hidden states into logits."""
        ...

In [47]:
class LogitLens(Lens):
    """Unembeds the residual stream into logits."""

    def __init__(
        self,
        unembed: nnsight.module,
    ):
        """Create a Logit Lens.

        Args:
            unembed: The unembed operation to use.
        """
        super().__init__(unembed)

    @classmethod
    def from_model(
        cls,
        model: PreTrainedModel,
    ) -> "LogitLens":
        """Create a LogitLens from a pretrained model.

        Args:
            model: A pretrained model from the transformers library you wish to inspect.
        """
        unembed = Unembed(model)
        return cls(unembed)

    def transform_hidden(self, h: t.Tensor, idx: int) -> t.Tensor:
        """For the LogitLens, this is the identity function."""
        del idx
        return h

    def forward(self, h: t.Tensor, idx: int) -> t.Tensor:
        """Decode a hidden state into logits.

        Args:
            h: The hidden state to decode.
            idx: the layer of the transformer these hidden states come from.
        """
        del idx
        return self.unembed.forward(h)

In [8]:
"""Load lens artifacts from the hub or locally storage."""
import os
from pathlib import Path
from typing import Optional
from typing import Set, Tuple

from huggingface_hub import HfFileSystem, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError


def available_lens_artifacts(
    repo_id: str,
    repo_type: str,
    revision: str = "main",
    config_file: str = "config.json",
    ckpt_file: str = "params.pt",
    subfolder: str = "lens",
) -> Set[str]:
    """Get the available lens artifacts from the hub."""
    fs = HfFileSystem()

    repo_type = repo_type + "s" if not repo_type.endswith("s") else repo_type

    root = Path(repo_type, repo_id, subfolder)
    with_config = map(
        Path,
        fs.glob(
            (root / "**" / config_file).as_posix(), revision=revision  # type: ignore
        ),
    )

    with_pt = map(
        Path,
        fs.glob(
            (root / "**" / ckpt_file).as_posix(), revision=revision  # type: ignore
        ),
    )
    paths = {p.parent for p in with_pt}.intersection({p.parent for p in with_config})
    return {p.relative_to(root).as_posix() for p in paths}


def load_lens_artifacts(
    resource_id: str,
    repo_id: Optional[str] = None,
    repo_type: Optional[str] = None,
    revision: str = "main",
    config_file: str = "config.json",
    ckpt_file: str = "params.pt",
    subfolder: str = "lens",
    cache_dir: Optional[str] = None,
) -> Tuple[Path, Path]:
    """First checks for lens resource locally then tries to download it from the hub.

    Args:
        resource_id: The id of the lens resource.
        repo_id: The repository to download the lens from. Defaults to
            'AlignmentResearch/tuned-lens'. However, this default can be overridden by
            setting the TUNED_LENS_REPO_ID environment variable.
        repo_type: The type of repository to download the lens from. Defaults to
            'space'. However, this default can be overridden by setting the
            TUNED_LENS_REPO_TYPE environment variable.
        config_file: The name of the config file in the folder contain the lens.
        ckpt_file: The name of the checkpoint file in the folder contain the lens.
        revision: The revision of the lens to download.
        subfolder: The subfolder of the repository to download the lens from.
        cache_dir: The directory to cache the lens in.

    Returns:
        * The path to the config.json file
        * The path to the params.pt file

    Raises:
        ValueError: if the lens resource could not be found.
    """
    if repo_id is None:
        if os.environ.get("TUNED_LENS_REPO_ID"):
            repo_id = os.environ["TUNED_LENS_REPO_ID"]
        else:
            repo_id = "AlignmentResearch/tuned-lens"

    if repo_type is None:
        if os.environ.get("TUNED_LENS_REPO_TYPE"):
            repo_type = os.environ["TUNED_LENS_REPO_TYPE"]
        else:
            repo_type = "space"

    # Fist check if the resource id is a path to a folder that exists
    local_path = Path(resource_id)
    if (local_path / config_file).exists() and (local_path / ckpt_file).exists():
        return local_path / config_file, local_path / ckpt_file

    resource_folder = "/".join((subfolder, resource_id))
    try:
        params_path = hf_hub_download(
            filename=ckpt_file,
            repo_id=repo_id,
            repo_type=repo_type,
            revision=revision,
            subfolder=resource_folder,
            cache_dir=cache_dir,
        )

        config_path = hf_hub_download(
            filename=config_file,
            repo_id=repo_id,
            repo_type=repo_type,
            revision=revision,
            subfolder=resource_folder,
            cache_dir=cache_dir,
        )
    except EntryNotFoundError:
        available_lenses = available_lens_artifacts(
            repo_id=repo_id,
            repo_type=repo_type,
            revision=revision,
            config_file=config_file,
            ckpt_file=ckpt_file,
            subfolder=subfolder,
        )
        message = (
            f"Could not find lens at the specified resource id. Available lens"
            f"resources are: {', '.join(available_lenses)}"
        )
        raise ValueError(message)

    if config_path is not None and params_path is not None:
        return Path(config_path), Path(params_path)

    raise ValueError("Could not find lens resource locally or on the hf hub.")

In [48]:
config_path, ckpt_pat =load_lens_artifacts("gpt2")

import inspect
import logging

from typing import Dict, Optional
from dataclasses import dataclass, asdict
from copy import deepcopy

logger = logging.getLogger(__name__)

@dataclass
class TunedLensConfig:
    """A configuration for a TunedLens."""

    # The name of the base model this lens was tuned for.
    base_model_name_or_path: str
    # The hidden size of the base model.
    d_model: int
    # The number of layers in the base model.
    num_hidden_layers: int
    # whether to use a bias in the linear translators.
    bias: bool = True
    # The revision of the base model this lens was tuned for.
    base_model_revision: Optional[str] = None
    # The hash of the base's unembed model this lens was tuned for.
    unembed_hash: Optional[str] = None
    # The name of the lens type.
    lens_type: str = "linear_tuned_lens"

    def to_dict(self):
        """Convert this config to a dictionary."""
        return asdict(self)

    @classmethod
    def from_dict(cls, config_dict: Dict):
        """Create a config from a dictionary."""
        config_dict = deepcopy(config_dict)
        # Drop unrecognized config keys
        unrecognized = set(config_dict) - set(inspect.getfullargspec(cls).args)
        for key in unrecognized:
            logger.warning(f"Ignoring config key '{key}'")
            del config_dict[key]

        return cls(**config_dict)

In [61]:
class Unembed(t.nn.Module):
    pass

class TunedLens(Lens):
    """A tuned lens for decoding hidden states into logits."""

    def __init__(
        self,
        unembed: Unembed,
        config: TunedLensConfig,
    ):
        """Create a TunedLens.

        Args:
            unembed: The unembed operation to use.
            config: The configuration for this lens.
        """
        super().__init__(unembed)

        self.config = config
        # unembed_hash = unembed.unembedding_hash()
        # config.unembed_hash = unembed_hash

        # The unembedding might be int8 if we're using bitsandbytes
        # w = unembed.unembedding.weight
        # dtype = w.dtype if t.is_floating_point(w) else t.float16

        translator = t.nn.Linear(
            config.d_model, config.d_model, bias=config.bias, dtype=t.float16
        )
        translator.weight.data.zero_()
        translator.bias.data.zero_()

        # Don't include the final layer since it does not need a translator
        self.layer_translators = t.nn.ModuleList(
            [deepcopy(translator) for _ in range(self.config.num_hidden_layers)]
        )

    def forward(self, h: t.Tensor, idx: int) -> t.Tensor:
        """Transform and then decode the hidden states into logits."""
        h = self.transform_hidden(h, idx)
        return self.unembed.forward(h)
    
    def transform_hidden(self, h: t.Tensor, idx: int) -> t.Tensor:
        """Transform hidden state from layer `idx`."""
        # Note that we add the translator output residually, in contrast to the formula
        # in the paper. By parametrizing it this way we ensure that weight decay
        # regularizes the transform toward the identity, not the zero transformation.
        return h + self[idx](h)

In [62]:
import json
with open(config_path, "r") as f:
            config = TunedLensConfig.from_dict(json.load(f))
test = TunedLens(None, config)
test.config

TunedLensConfig(base_model_name_or_path='gpt2', d_model=768, num_hidden_layers=12, bias=True, base_model_revision='e7da7f221d5bf496a48136c0cd264e630fe9fcc8', unembed_hash=None, lens_type='linear_tuned_lens')

In [69]:
state = t.load(ckpt_pat)

test.layer_translators.load_state_dict(state)

<All keys matched successfully>

In [18]:
model = nnsight.LanguageModel("gpt2", device_map = "cuda:0")

In [34]:
with model.invoke("The Space Needle is in the city of ") as invoker:
    pass

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [36]:
model.local_model.lm_head.weight

Parameter containing:
tensor([[-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453],
        [ 0.0403, -0.0486,  0.0462,  ...,  0.0861,  0.0025,  0.0432],
        [-0.1275,  0.0479,  0.1841,  ...,  0.0899, -0.1297, -0.0879],
        ...,
        [-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],
        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],
        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207]],
       device='cuda:0', requires_grad=True)

In [None]:
import inspect
import logging

from typing import Dict, Optional
from dataclasses import dataclass, asdict
from copy import deepcopy

@dataclass
class TunedLensConfig:
    """A configuration for a TunedLens."""

    # The name of the base model this lens was tuned for.
    base_model_name_or_path: str
    # The hidden size of the base model.
    d_model: int
    # The number of layers in the base model.
    num_hidden_layers: int
    # whether to use a bias in the linear translators.
    bias: bool = True
    # The revision of the base model this lens was tuned for.
    base_model_revision: Optional[str] = None
    # The hash of the base's unembed model this lens was tuned for.
    unembed_hash: Optional[str] = None
    # The name of the lens type.
    lens_type: str = "linear_tuned_lens"

    def to_dict(self):
        """Convert this config to a dictionary."""
        return asdict(self)

    @classmethod
    def from_dict(cls, config_dict: Dict):
        """Create a config from a dictionary."""
        config_dict = deepcopy(config_dict)
        # Drop unrecognized config keys
        unrecognized = set(config_dict) - set(inspect.getfullargspec(cls).args)
        for key in unrecognized:
            # logger.warning(f"Ignoring config key '{key}'")
            print(f"Ignoring config key '{key}'")
            del config_dict[key]

        return cls(**config_dict)

In [None]:


def from_unembed_and_pretrained(
    cls,
    unembed: Unembed,
    lens_resource_id: str,
    **kwargs,
) -> "TunedLens":
    """Load a tuned lens from a folder or hugging face hub.

    Args:
        unembed: The unembed operation to use for the lens.
        lens_resource_id: The resource id of the lens to load.
        **kwargs: Additional arguments to pass to
            :func:`tuned_lens.load_artifacts.load_lens_artifacts` and
            `th.load <https://pytorch.org/docs/stable/generated/torch.load.html>`_.

    Returns:
        A TunedLens instance.
    """
    # Validate kwargs
    load_artifact_varnames = load_artifacts.load_lens_artifacts.__code__.co_varnames

    config_path, ckpt_path = load_artifacts.load_lens_artifacts(
        resource_id=lens_resource_id,
        **{k: v for k, v in kwargs.items() if k in load_artifact_varnames},
    )

    with open(config_path, "r") as f:
        config = TunedLensConfig.from_dict(json.load(f))

    # validate the unembed is the same as the one used to train the lens
    if config.unembed_hash and unembed.unembedding_hash() != config.unembed_hash:
        logger.warning(
            "The unembedding matrix hash does not match the lens' hash."
            "This lens may have been trained with a different unembedding."
        )

    # Create the lens
    lens = cls(unembed, config)

    th_load_kwargs = {
        **{k: v for k, v in kwargs.items() if k not in load_artifact_varnames}
    }
    # Load parameters
    state = th.load(ckpt_path, **th_load_kwargs)

    lens.layer_translators.load_state_dict(state)

    return lens