<a href="https://colab.research.google.com/github/klutzydrummer/Python-ML_Self_Learning/blob/main/Better_PaCE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## My attempt at implementing and experimenting with the concepts present in the paper: Parsimonious Concept Engineering (PaCE)
Github: [Link](https://github.com/peterljq/Parsimonious-Concept-Engineering)  
Arxiv: [Link](https://arxiv.org/abs/2406.04331)  
Project Page: [Link](https://peterljq.github.io/project/pace/index.html)  

In [None]:
# @title ColabBuddy v2.2
from pathlib import Path
from google.colab import drive
from pprint import pprint
import re
from typing import List, Dict, Tuple, Literal, Callable, Optional
import requests
from functools import lru_cache
import subprocess
import pkg_resources
import packaging.version
import importlib
import os
import sys
import json
import shutil

class BaseBuddy:
    def __init__(self):
        pass

    @staticmethod
    def run_command(command: str, background: bool = False):
        if isinstance(command, list):
            command = ' '.join(command)
        if background:
            command = f'nohup bash -c "{command}"'
        process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            shell=True,
            universal_newlines=True,
            bufsize=1  # Line-buffered mode for real-time output
        )

        while True:
            output = process.stdout.readline()
            error = process.stderr.readline()

            if output:
                print(output, end="")
                sys.stdout.flush()  # Flush the output immediately
            if error:
                print(error, end="")
                sys.stderr.flush()  # Flush the error immediately

            # Break the loop if both output and error are empty and the process has ended
            if not output and not error and process.poll() is not None:
                break

        return process.poll()

class GitHubRepo:
    def __init__(self, url: str, path: Path):
        self.name = url.split('/')[-1].split('.')[0]
        self.url = url
        self.path = Path(path) / self.name

    def clone(self):
        if self.path.exists() is not True:
            subprocess.run(["git", "clone", self.url, self.path])

    def cd(self):
        os.chdir(self.path)
        print(f"Changed directory to {self.path}")

    def __getattr__(self, name):
        """Dynamically create and return a method for the given GitHub command if not predefined."""
        if name in ["name", "url", "path", "clone", "cd"]:
            return getattr(self, name)

        premapped_flags = {
            "all": "-a"
        }
        premapped_commands = {
            "commit": "commit -m"
        }
        def dynamic_method(*args, **kwargs):
            try:
                # Attempt to run the command as a Git command
                name_parts = name.split("_")
                name_parts[-1] = premapped_flags.get(name_parts[-1], name_parts[-1])
                name_parts = [premapped_commands.get(cmd_piece, cmd_piece) for cmd_piece in name_parts]
                command = ['git', *name_parts]
                command_with_args = command + list(args)
                result = subprocess.run(command_with_args, cwd=self.path, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

                # Optionally print the output of the command
                print(result.stdout.decode())

                # If the command was successful, create a static method
                def successful_method(*args, **kwargs):
                    command_with_args = command + list(args)
                    return subprocess.run(command_with_args, cwd=self.path, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).stdout.decode()

                # Set the method as an attribute
                setattr(self, name, successful_method)

                return result.stdout.decode()

            except subprocess.CalledProcessError as e:
                # Handle errors if the command fails
                print(f"Failed to run {name} with error: {e.stderr.decode()}")
                return None

        # Return the dynamic method (callable)
        return dynamic_method

class RepoMan:
    def __init__(self):
        # Directly set the repos attribute without triggering __setattr__
        self.repos = {}

    def __setattr__(self, name, value):
        if name == 'repos':
            # Allow setting the repos attribute without type check
            super().__setattr__(name, value)
        elif isinstance(value, GitHubRepo):
            self.repos[name] = value
        else:
            raise TypeError(f"Value for {name} must be a GitHubRepo instance.")

    def __getattr__(self, name):
        if name not in self.repos:
            raise AttributeError(f"No repo named {name} found.")
        return self.repos[name]

    def __getitem__(self, name):
        if name not in self.repos:
            raise KeyError(f"No repo named {name} found.")
        return self.repos[name]

    def __setitem__(self, name, value):
        if isinstance(value, GitHubRepo):
            self.repos[name] = value
        else:
            raise TypeError(f"Value for {name} must be a GitHubRepo instance.")

    def __contains__(self, name):
        return name in self.repos

    def __iter__(self):
        return iter(self.repos)

    def __len__(self):
        return len(self.repos)

    def __str__(self):
        return str(json.dumps(self.repos))

    def __repr__(self):
        return repr(self.repos)

    def __del__(self):
        for repo in self.repos.values():
            if repo is not None:
                del repo

    def __delattr__(self, name):
        if name in self.repos:
            del self.repos[name]
        else:
            raise AttributeError(f"No repo named {name} found.")

    def __delitem__(self, name):
        if name in self.repos:
            del self.repos[name]
        else:
            raise KeyError(f"No repo named {name} found.")

    def add_repo(self, repo: GitHubRepo):
        self.repos[repo.name] = repo

    def get(self, name: str, default=None):
        return self.repos.get(name, default)

class RepoBuddy:
    def __init__(self):
        self.repo_manager = RepoMan()

    def git_clone(self, url, directory=None) -> GitHubRepo:
        if directory is None:
            directory = Path(os.getcwd())
        else:
            directory = Path(directory)
        repo = GitHubRepo(url=url, path=directory)
        try:
            repo.clone()
            print("Succesfully cloned repo: {repo}.".format(repo=repo.name))
            self.repo_manager.add_repo(repo)
            return repo
        except:
            raise Exception("Failed to clone repo: {repo}.".format(repo=repo.name))

    def get_repo(self, name: str, default=None) -> Optional[GitHubRepo]:
        repo = self.repo_manager.get(name, default)
        return repo


class PacManBuddy(BaseBuddy):
    def __init__(self):
        pass

    @classmethod
    def parse_version_safe(cls, version):
        """
        Safely parse a version string, handling wildcard versions by returning
        the prefix as a parsed version or None if the version is invalid.
        """
        wildcard_flag = False
        if '*' in version:
            # Strip wildcard and parse the remaining part
            version = version.split('*')[0].rstrip('.')
            wildcard_flag = True
        try:
            return packaging.version.parse(version), wildcard_flag
        except packaging.version.InvalidVersion:
            return None, wildcard_flag

    @classmethod
    def compare_versions(cls, current_version, target_version, operator):
        """
        Compare two versions with a given operator, handling wildcard scenarios.
        """
        if operator == "==":
            return current_version == target_version
        elif operator == "!=":
            return current_version != target_version
        elif operator == "<":
            return current_version < target_version
        elif operator == "<=":
            return current_version <= target_version
        elif operator == ">":
            return current_version > target_version
        elif operator == ">=":
            return current_version >= target_version
        elif operator == "~=":
            # Handle the compatible release clause
            return (current_version.major == target_version.major and
                    current_version.minor == target_version.minor and
                    current_version >= target_version)
        else:
            raise ValueError(f"Unsupported version operator: {operator}")

class PipBuddy(PacManBuddy):
    def __init__(self):
        super().__init__()
        self.available_pip_pkg_versions = {}

    @property
    def pip_pkg_versions(self):
        return {pkg.key: pkg.version for pkg in pkg_resources.working_set}

    def get_pip_pkg_version(self, pkg):
        return self.pip_pkg_versions.get(pkg, None)

    @staticmethod
    @lru_cache(maxsize=128)
    def _get_available_pip_package_versions(package):
        # Query the PyPI JSON API for the package
        response = requests.get(f"https://pypi.org/pypi/{package}/json")

        # Check if the request was successful
        if response.status_code != 200:
            raise ValueError(f"Error fetching package data: {response.status_code}")

        # Parse the JSON data
        package_data = response.json()

        # Extract all available versions
        available_versions = package_data["releases"].keys()

        # Sort the versions using packaging.version.parse for correct ordering
        available_versions = sorted(available_versions, key=packaging.version.parse, reverse=True)

        return available_versions

    def get_available_pip_package_versions(self, package):
        if package not in self.available_pip_pkg_versions:
            self.available_pip_pkg_versions[package] = self._get_available_pip_package_versions(package)
        return self.available_pip_pkg_versions[package]

    def get_latest_pip_package_versions(self, package):
        if package not in self.available_pip_pkg_versions:
            self.available_pip_pkg_versions[package] = self._get_available_pip_package_versions(package)
        return self.available_pip_pkg_versions[package][0]

    @staticmethod
    def parse_pip_package_string(command: str) -> Dict[str, Dict[str, Optional[str]]]:
        """
        Parse a pip install command to extract package names and their version specifications.

        Args:
            command (str): The pip install command string.

        Returns:
            dict: A dictionary where the key is the package name, and the value is a dictionary
                with 'version_operator' and 'version' keys. If the version is not specified,
                they will be None.
        """
        # Ensure the command starts with 'pip install'
        if not command.startswith('pip install'):
            raise ValueError("The command must start with 'pip install'")

        # Remove 'pip install' and strip any extra spaces
        package_string = command.replace('pip install', '').strip()

        # Regular expression to match the package name and optional version specifier
        package_pattern = re.compile(r'([a-zA-Z0-9_\-]+)(?:([=<>!~]+)([0-9a-zA-Z\.\*]+))?')

        # Find all matches
        package_list = package_pattern.findall(package_string)

        # Convert matches to a dictionary with appropriate handling of version info
        parsed_packages = {}

        sanitize_output = lambda x: x if x not in [None, ''] else None

        for pkg in package_list:
            # Unpack the tuple, using None for missing values
            name = pkg[0]
            version_operator = pkg[1] if len(pkg) > 1 else None
            version_operator = sanitize_output(version_operator)
            version = pkg[2] if len(pkg) > 2 else None
            version = sanitize_output(version)

            parsed_packages[name] = {
                'version_operator': version_operator,
                'target_version': version
            }

        return parsed_packages

    def pip_update_check(self, pkg: str, version_operator: Optional[str], target_version: Optional[str]):
        """
        Check if an update is required for a given package based on version info.

        Args:
            pkg (str): The package name.
            version_info (dict): A dictionary containing 'version_operator' and 'version'.
            current_version (str | None): The current installed version of the package.

        Returns:
            bool: True if an update is required, False otherwise.
        """
        current_version = self.get_pip_pkg_version(pkg)
        if current_version is None:
            # If no current version is specified, assume the package is not installed
            return True

        # Fetch available versions from PyPI
        available_versions = self.get_available_pip_package_versions(pkg)

        if target_version is None or str(target_version).lower() == 'latest':
            # If no target version is specified, use the latest available version
            target_version = available_versions[0]

        if version_operator is None:
            # If no version operator is specified, assume equality
            version_operator = '=='

        # Parse current and target versions
        current_version_parsed, _ = self.parse_version_safe(current_version)
        target_version_parsed, wildcard_flag = self.parse_version_safe(target_version)

        if current_version_parsed is None or target_version_parsed is None:
            raise ValueError(f"Cannot compare versions: {current_version} and {target_version}")

        # If there is a wildcard in the target version, only compare the relevant parts
        if wildcard_flag:
            if version_operator == "==":
                return not str(current_version_parsed).startswith(str(target_version_parsed))
            elif version_operator == "!=":
                return str(current_version_parsed).startswith(str(target_version_parsed))
            else:
                # Handle other operators if needed with wildcard, but here we focus on equality
                return self.compare_versions(current_version_parsed, target_version_parsed, version_operator)

        # Standard version comparison for other operators
        should_update = self.compare_versions(current_version_parsed, target_version_parsed, version_operator)

        # Decision-making based on operator
        if version_operator in ["=="]:
            return not should_update  # Return True if the versions don't match (needs update)
        elif version_operator in ["!=", "<", ">", "<=", ">="]:
            return should_update  # Return True if the current version doesn't satisfy the target

        return False  # Default to no update if something unexpected occurs

    def _make_pip_string(self, pkg: str, version_operator: Optional[str], target_version: Optional[str]):
        if version_operator is None:
            version_operator = '=='
        if target_version is None:
            target_version = self.get_latest_pip_package_versions(pkg)
        return f"{pkg}{version_operator}{target_version}".format(pkg=pkg, version_operator=version_operator, target_version=target_version)

    def reload(self):
        importlib.reload(pkg_resources)

    def run_pip_install(self, command: str, background: bool=False):
        package_list = buddy.parse_pip_package_string(command)
        to_install = [self._make_pip_string(pkg, **version_info) for pkg, version_info in package_list.items() if self.pip_update_check(pkg=pkg, **version_info)]

        print(to_install)
        if len(to_install) == 0:
            return
        filtered_packages = " ".join(to_install)
        filtered_command = f"pip install {filtered_packages}"
        # subprocess.run(filtered_command, shell=True)
        self.run_command(filtered_command, background=background)
        self.reload()
        return

class AptBuddy(PacManBuddy):
    def __init__(self):
        super().__init__()
        pass

    @staticmethod
    def is_apt_package_installed(package_name):
        try:
            # Use dpkg-query to check if the package is installed
            result = subprocess.run(
                ['dpkg-query', '-W', '-f=${Status}', package_name],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                check=True,
                text=True
            )

            # The output from dpkg-query is 'install ok installed' if the package is installed
            if 'install ok installed' in result.stdout:
                return True
            else:
                return False
        except subprocess.CalledProcessError:
            # The package is not installed if dpkg-query raises an error
            return False

    @staticmethod
    def parse_apt_package_string(command):
        """
        Parse an apt install command to extract package names.

        Args:
            command (str): The apt install command string.

        Returns:
            tuple(list: A list of package names, list: A list of arguments.)
        """
        # Ensure the command starts with 'apt install'
        command_pieces = command.split(" ")
        expected_pieces = [('apt', 'apt-get'), 'install']
        def check_pieces(expected_pieces, command_pieces):
            results = []
            for piece in expected_pieces:
                if isinstance(piece, tuple):
                    subresults = []
                    for p in piece:
                        subresults.append(check_pieces(p, command_pieces))
                    if True in subresults:
                        results.append(True)
                    else:
                        results.append(False)
                else:
                    if piece not in command_pieces:
                        results.append(False)
                    else:
                        results.append(True)
            if True in results:
                return True
            else:
                return False
        expected_pieces_presnt = check_pieces(expected_pieces, command_pieces)
        if expected_pieces_presnt is not True:
            raise ValueError("The command must start be an 'apt install' or 'apt-get install' command.")

        # Remove 'apt install' and strip any extra spaces
        package_string = [cmd_piece.strip() for cmd_piece in command_pieces]
        args = []
        packages = []
        for cmd_piece in command_pieces:
            if cmd_piece.startswith('-'):
                args.append(cmd_piece)
            else:
                if cmd_piece not in ['apt', 'apt-get', 'install']:
                    packages.append(cmd_piece)

        return packages, args

    def run_apt_install(self, command: str, background: bool=False):
        """
        Run apt install command to install packages.

        Args:
            command (str): The apt install command string.
        """
        packages, args = self.parse_apt_package_string(command)
        to_install = [package for package in packages if self.is_apt_package_installed(package) is False]
        if len(to_install) == 0:
            return
        command = " ".join(["apt", "install", *args, *to_install])
        # subprocess.run(command, shell=True)
        self.run_command(command, background=background)

class CloudFlaredBuddy:
    def __init__(self, path='/content'):
        self.path = Path(path)
        self.cloudflare_prereq_flag = None
        self.cloudflag_flag = None

    def cloudflared_init(self, flag_name='cloudflare_prereq_flag'):
        cloudflare_prereq_flag = self.path / flag_name
        if cloudflare_prereq_flag.exists() is not True:
            subprocess.run(["curl", "-L", "--output", "/content/cloudflared.deb", "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb"])
            subprocess.run(["dpkg", "-i", "/content/cloudflared.deb"])
            cloudflare_prereq_flag.touch()
        self.cloudflare_prereq_flag = cloudflare_prereq_flag
        self.cloudflag_flag = cloudflare_prereq_flag.parent / 'cloudflag_flag'

    def cloudflare_tunnel(self, cloudflare_token=''):
        cloudflare_token = os.environ["CLOUDFLARE_TOKEN"] if cloudflare_token == '' else cloudflare_token
        if self.cloudflare_prereq_flag is None:
            self.cloudflared_init()
        if self.cloudflag_flag.exists() is True:
            subprocess.run(["cloudflared", "service", "uninstall"])
            os.remove(self.cloudflag_flag)
        if self.cloudflag_flag.exists() is not True and cloudflare_token != "":
            subprocess.run(["cloudflared", "service", "install", cloudflare_token])
            self.cloudflag_flag.touch()

class PathBuddy:
    def __init__(self):
        self.init_path = self.get_path()

    @staticmethod
    def add_to_path(path, prepend=False):
        path = os.path.expanduser(path)
        path = os.path.expandvars(path)
        path_paths = [piece for piece in os.environ["PATH"].split(":") if piece != path]
        if path not in path_paths:
            if prepend:
                path_paths.insert(0, path)
            else:
                path_paths.append(path)
        os.environ["PATH"] = ":".join(path_paths)

    @staticmethod
    def remove_path(path, prepend=False):
        path = os.path.expanduser(path)
        path = os.path.expandvars(path)
        path_paths = [piece for piece in os.environ["PATH"].split(":") if piece != path]
        os.environ["PATH"] = ":".join(path_paths)

    def reset_path(self):
        os.environ["PATH"] = ":".join(self.init_path)

    @staticmethod
    def get_path() -> List[str]:
        return os.environ["PATH"].split(":")

class PythonPathBuddy:
    def __init__(self):
        self.init_pythonpath = self.get_pythonpath()

    @staticmethod
    def add_to_pythonpath(path, prepend=False):
        path = os.path.expanduser(path)
        path = os.path.expandvars(path)
        path_paths = [piece for piece in os.environ["PYTHONPATH"].split(":") if piece != path]
        if path not in path_paths:
            if prepend:
                path_paths.insert(0, path)
            else:
                path_paths.append(path)
        os.environ["PYTHONPATH"] = ":".join(path_paths)

    @staticmethod
    def remove_pythonpath(path, prepend=False):
        path = os.path.expanduser(path)
        path = os.path.expandvars(path)
        path_paths = [piece for piece in os.environ["PYTHONPATH"].split(":") if piece != path]
        os.environ["PYTHONPATH"] = ":".join(path_paths)

    def reset_pythonpath(self):
        os.environ["PYTHONPATH"] = ":".join(self.init_path)

    @staticmethod
    def get_pythonpath() -> List[str]:
        return os.environ["PYTHONPATH"].split(":")

class FileSystemBuddy(BaseBuddy):
    def __init__(self):
        pass

    @classmethod
    def unzip(cls, source: Path | str, destination: Path | None = None):
        source = Path(source)
        if destination is None:
            destination = source.parent / source.stem
        destination = Path(destination)
        if not destination.exists():
            command = " ".join(["unzip", "-q", str(source), "-d", str(destination.parent)])
            cls.run_command(command)
            print('Unzipped {source} to {destination}'.format(source=str(source.name), destination=str(destination.name)))

    @classmethod
    def chmod(cls, args: List[str] | str, path: Path | str):
        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"File not found: {path}")
        if isinstance(args, str):
            args = args.split(' ')
        final_command = " ".join(["chmod", *args, str(path.absolute())])
        cls.run_command(final_command)

    @staticmethod
    def rm(path: Path | str, recursive=False):
        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"File not found: {path}")
        if not recursive and path.is_dir():
            raise ValueError(f"{path} is a directory. Use recursive=True to remove it.")
        command_stack = ["rm", str(path.absolute())]
        if recursive:
            command_stack.insert(1, '-rf')
        subprocess.run(command_stack)
        print('Removed {target} from {parent}'.format(target=str(path.name), parent=str(path.parent.absolute())))

    @staticmethod
    def cd(path: Path | str):
        os.chdir(path)
        print('Changed directory to {path}'.format(path=str(path)))

class LocalRemoteFilePair:
    def __init__(self, local_path: Path | str, remote_path: Path | str):
        self.local_path = Path(local_path)
        self.remote_path = Path(remote_path)

    @property
    def local(self):
        return self.local_path.absolute()

    @property
    def remote(self):
        return self.remote_path.absolute()

    def local_exists(self):
        return self.local_path.exists()

    def remote_exists(self):
        return self.remote_path.exists()

    def get_local_path(self):
        return self.local_path

    def get_remote_path(self):
        return self.remote_path

    @staticmethod
    def copy(source, destination):
        source = Path(source).absolute()
        destination = Path(destination).absolute()
        if source.is_dir():
            shutil.copytree(str(source), str(destination))  # Copy directory
        else:
            shutil.copy2(str(source), str(destination))  # Copy file

    def copy_to_local(self):
        if self.remote_exists():
            self.copy(self.remote, self.local)
        else:
            raise FileNotFoundError(f"File not found: {self.remote_path}")

    def copy_to_remote(self):
        if self.local_exists():
            self.copy(self.local, self.remote)
        else:
            raise FileNotFoundError(f"File not found: {self.local_path}")

    def remove_local(self):
        if self.local_exists():
            self.local_path.unlink()
        else:
            raise FileNotFoundError(f"File not found: {self.local_path}")

    def remove_remote(self):
        if self.remote_exists():
            self.remote_path.unlink()
        else:
            raise FileNotFoundError(f"File not found: {self.remote_path}")

class GDriveBuddy:
    def __init__(self, mount_path: Path | str = '/content/drive'):
        self.mount_path = Path(mount_path)

    @staticmethod
    def connect_gdrive(mount_path: Path | str = '/content/drive'):
        mount_path = Path(mount_path)
        if mount_path.exists() is not True:
            drive.mount(str(mount_path), force_remount=True, readonly=False)
        return  mount_path / 'MyDrive'

    @staticmethod
    def disconnect_gdrive():
        drive.flush_and_unmount()
        print('GDrive disconnected')

    def file_pair(self, local_path: Path | str, remote_path: Path | str):
        return LocalRemoteFilePair(local_path, remote_path)

class DownloadBuddy(BaseBuddy):
    def __init__(self):
        pass

    @classmethod
    def download_curl(cls, url: str, path: Path | str, *args: str, overwrite=False, follow_redirect=True, **kwargs):
        path = Path(path)
        redirect_flag_args = False
        for arg in args:
            dashless_arg = arg.replace('-', '')
            if dashless_arg == 'L' or dashless_arg == 'location':
                redirect_flag_args = True
                break
        else:
            redirect_flag_args = False

        if redirect_flag_args is not True and follow_redirect is not False:
            args = ['-L', *args]

        def remove_leading_dashes(arg):
            while arg.startswith('-'):
                arg = arg[1:]
            return arg

        def arg_processor(arg):
            if len(arg) == 1:
                arg = '-' + arg
            else:
                arg = '--' + arg
            return arg

        def kwargs_processor(key, value):
            match(value):
                case bool():
                    value = str(value).lower()
                case str():
                    value = value
                case Path():
                    value = str(value.absolute())
                case _:
                    value = str(value)
            return "--{key} {value}".format(key=key, value=value)
        args = [remove_leading_dashes(arg) for arg in args]
        args = [arg_processor(arg) for arg in args]
        kwargs = [kwargs_processor(key, value) for key, value in kwargs.items()]
        if path.exists() is not True or overwrite is True:
            path.mkdir(parents=True, exist_ok=True)
            command = " ".join(["curl", *args, *kwargs, "--output", str(path), url])
            cls.run_command(command)
        else:
            print(f"File already exists: {path}")

    @classmethod
    def download_aria2(cls, url: str, path: Path | str, *args: str, max_connections_per_server: int=4, max_concurrent_downloads: int=4, overwrite=False, **kwargs):
        path = Path(path)

        def arg_preprocessor(arg):
            while arg.startswith('-'):
                arg = arg[1:]
            return arg

        def arg_processor(arg):
            if len(arg) == 1:
                arg = '-' + arg
            else:
                arg = '--' + arg
            return arg

        def kwargs_processor(key, value):
            key = arg_preprocessor(key)
            match(value):
                case bool():
                    value = str(value).lower()
                case str():
                    value = value
                case Path():
                    value = str(value.absolute())
                case _:
                    value = str(value)
            return "--{key}={value}".format(key=key, value=value)

        default_args = [
            "max-connection-per-server",
            "max-concurrent-downloads",
            "dir",
            "out"
        ]

        def default_command_present(command_piece):
            if any(arg == command_piece for arg in default_args):
                return True
            else:
                return False

        args = [arg_preprocessor(arg) for arg in args]
        args = [arg_processor(arg) for arg in args if default_command_present(arg) is not True]
        kwargs = [kwargs_processor(key, value) for key, value in kwargs.items() if default_command_present(key) is not True]

        command_pieces = [
            "aria2c",
            *args,
            *kwargs,
            f"--max-connection-per-server={max_connections_per_server}",
            f"--max-concurrent-downloads={max_concurrent_downloads}",
            f"--dir={path.parent}",
            f"--out={path.name}",
            f"{url}"
        ]
        command = " ".join(command_pieces)
        if path.exists() is not True or overwrite is True:
            print(command)
            path.parent.mkdir(parents=True, exist_ok=True)
            cls.run_command(command)
        else:
            print(f"File already exists: {path}")

class MatrixBuddy:
    def __init__(self):
        pass

    @staticmethod
    def stringify(item):
        """
        Convert tensor-like objects to their shape representation, or stringify other items.

        Args:
        item: An item to stringify, potentially a tensor-like object.
        """
        if "shape" in dir(item):
            if callable(item.shape):
                return str(item.shape())
            else:
                return str(item.shape)
        else:
            return str(item)

    @classmethod
    def shapify_strings(cls, iterable, indent=0, last=True):
        """
        Recursively process the contents of an iterable containing other iterables or tensor-like objects,
        returning a new iterable with the shapes of tensor-like objects or the same values otherwise,
        structured to represent the hierarchy visually.

        Args:
        iterable (iterable): An iterable containing other iterables or tensor-like objects.
        indent (int): The current indentation level, used internally by recursion.
        last (bool): Flag to indicate if the current item is the last in its iterable, used to format the tree.
        """
        indent_str = " " * (indent * 4)
        if isinstance(iterable, dict):
            new_dict = {}
            for index, (key, value) in enumerate(iterable.items()):
                is_last = index == len(iterable) - 1
                if isinstance(value, (list, tuple, dict)):
                    processed_value = cls.shapify_strings(value, indent + 1, is_last)
                else:
                    processed_value = cls.stringify(value)
                new_dict[key] = processed_value
            return new_dict
        elif isinstance(iterable, (list, tuple)):
            processed_items = []
            for index, item in enumerate(iterable):
                is_last = index == len(iterable) - 1
                if isinstance(item, (list, tuple, dict)):
                    processed_item = cls.shapify_strings(item, indent + 1, is_last)
                else:
                    processed_item = cls.stringify(item)
                processed_items.append(processed_item)
            return type(iterable)(processed_items)
        else:
            return cls.stringify(iterable)

    @classmethod
    def print_shape(cls, item):
        pprint(cls.shapify_strings(item))

class ColabBuddy(DownloadBuddy, GDriveBuddy, FileSystemBuddy, PathBuddy, PythonPathBuddy, PipBuddy, AptBuddy, CloudFlaredBuddy, RepoBuddy):
    def __init__(self, home='/content', drive_path='/content/drive'):
        path = home
        self.home = Path(home)
        DownloadBuddy.__init__(self)
        GDriveBuddy.__init__(self, mount_path=drive_path)
        FileSystemBuddy.__init__(self)
        PathBuddy.__init__(self)
        PythonPathBuddy.__init__(self)
        PipBuddy.__init__(self)
        AptBuddy.__init__(self)
        CloudFlaredBuddy.__init__(self, path=home)
        RepoBuddy.__init__(self)

    @staticmethod
    def soft_reset_runtime():
        os.execv(sys.executable, ['python'] + sys.argv)

    @classmethod
    def store_env(cls, name: str, env_dict: dict):
        drive_path = cls.connect_gdrive()
        colab_buddy_path = drive_path / 'ColabBuddy'
        colab_buddy_path.mkdir(exist_ok=True, parents=True)
        json_string = json.dumps(env_dict)
        with open(colab_buddy_path / f'{name}_env.json', 'w') as f:
            f.write(json_string)

    @classmethod
    def load_env(cls, name: str):
        drive_path = cls.connect_gdrive()
        colab_buddy_path = drive_path / 'ColabBuddy'
        colab_buddy_path.mkdir(exist_ok=True, parents=True)
        with open(colab_buddy_path / f'{name}_env.json', 'r') as f:
            json_string = f.read()
        env_dict = json.loads(json_string)
        for key, value in env_dict.items():
            os.environ[key] = value

    @staticmethod
    def flagged_function(callable: Callable, path: Path):
        """
        Checks if a file exists at the path location.
        If it does exist, don't run the function.
        If it does not exist, run the function and create the file if needed with Path(path).touch().

        Parameters:
        callable (Callable): a python function or colab shell command wrapped in a python function.
        path (Path): The path for the flag file.

        Returns:
        None
        """
        if not path.exists():
            callable()
            path.touch()

    def go_home(self):
        self.cd(self.home)

    def set_home(self, path: Path | str):
        self.home = Path(path)
        print('Set home directory to {path}'.format(path=str(self.home)))

In [None]:
# @title setup
# https://github.com/peterljq/Parsimonious-Concept-Engineering
buddy = ColabBuddy()
buddy.connect_gdrive()
buddy.run_pip_install('pip install aioshutil aiosqlite icecream')
from icecream import ic
remote_project_base = Path("/content/drive/MyDrive/Projects/Control_Vectors")
local_project_base = Path("/content")
project_base = buddy.file_pair(
    local_path=local_project_base,
    remote_path=remote_project_base
)
remote_hd5_file = project_base.remote / "Parsimonious-Concept-Engineering.hdf5"
hd5_file = project_base.local / "Parsimonious-Concept-Engineering.hdf5"
hd5_file_pair = buddy.file_pair(
    local_path=hd5_file,
    remote_path=remote_hd5_file
)
if not hd5_file_pair.remote_exists():
    print("Precomputed dataset not found.\nInstalling prereqs.")
    Parsimonious_Concept_Engineering = buddy.git_clone('https://github.com/peterljq/Parsimonious-Concept-Engineering.git', '/content')
    Parsimonious_Concept_Engineering.cd()
    buddy.unzip('/content/Parsimonious-Concept-Engineering/concept.zip')
    buddy.chmod('+x', 'pace1m_reader.py')
    buddy.go_home()
    buddy.cd('/content/Parsimonious-Concept-Engineering')
    # buddy.rm('Parsimonious-Concept-Engineering', recursive=True)
else:
    print("Precomputed dataset found.")
    hd5_file_pair.copy_to_local()

In [None]:
# @title Load Model
from transformers.generation import GenerationConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

def model_changed():
  global model_name
  global last_model_name
  if "last_model_name" not in globals():
      last_model_name = model_name
  if model_name != last_model_name:
      last_model_name = model_name
      return True
  return False

import re

def make_filename_safe(filename: str) -> str:
    """
    Make a string safe to use as a filename on both Unix and Windows systems
    by replacing invalid characters with valid ASCII lookalikes where possible.

    Parameters:
    filename (str): The input string to be sanitized.

    Returns:
    str: A sanitized version of the input string.
    """

    # Mapping of invalid characters to valid ASCII lookalikes
    char_map = {
        '<': '(',  # Less than becomes left paren
        '>': ')',  # Greater than becomes right paren
        ':': '-',  # Colon becomes dash
        '"': "'",  # Double quote becomes single quote
        '/': '-',  # Forward slash becomes dash
        '\\': '-', # Backslash becomes dash
        '|': '-',  # Vertical bar becomes dash
        '?': '',   # Question mark is removed
        '*': 'x',  # Asterisk becomes 'x'
        '\0': '',  # Null byte is removed
    }

    # Function to replace invalid characters using the map
    def replace_invalid_char(match):
        return char_map.get(match.group(0), '_')  # Default to '_' if no mapping is found

    # Characters not allowed in filenames
    invalid_chars = r'[<>:"/\\|?*\0]'

    # Replace invalid characters with ASCII lookalikes
    safe_filename = re.sub(invalid_chars, replace_invalid_char, filename)

    # Strip leading and trailing spaces or periods (Windows restriction)
    safe_filename = safe_filename.strip(' .')

    # Ensure the filename is not empty after cleaning
    if not safe_filename:
        safe_filename = 'default_filename'

    return safe_filename

default_model_name = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'
model_name = 'microsoft/Phi-3-mini-4k-instruct' # @param {type:'string'}
if model_name == '':
    model_name = default_model_name
filesafe_model_name = make_filename_safe(model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the LLM
if "model" not in locals() or model_changed():
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.to(device)
    print("Model loaded.")
else:
    print("Model already loaded.")

if "tokenizer" not in locals() or model_changed():
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    print("Tokenizer loaded.")
else:
    print("Tokenizer already loaded.")

I admittedly, have no formal ML education, however that has not stopped me from learning through free educational resources, and experimentation.

These are utility functions that allow me to print nested data structures containing torch tensors and numpy arrays, where the tensors/matrices are printed as shapes, so that I can more effectively view the dimensionality of my data with simple print statements in this notebook.

In [None]:
# @title Utility code

from pprint import PrettyPrinter
from pprint import pprint as _pprint
from collections.abc import Mapping, Iterable
import torch

# Global flag to indicate if we are inside a pretty print call
globals()["_is_print"] = False

# Override the pprint.pprint function to use the custom printer
def pprint(*args, **kwargs):
    globals()["_is_print"] = True
    _pprint(*args, **kwargs)
    globals()["_is_print"] = False

def reassign_print(reset=False):
    if "print_original" not in globals():
        globals()["print_original"] = print  # Store original print function

    if "_is_print" not in globals():
        globals()["_is_print"] = False  # Flag to indicate if custom print is being used

    if reset:
        # Restore the original print function
        globals()['print'] = globals()["print_original"]
        return

    def flagged_print(*args, **kwargs):
        globals()["_is_print"] = True  # Set the flag to True before printing
        try:
            # Call the original print function with the passed arguments
            globals()["print_original"](*args, **kwargs)
        finally:
            globals()["_is_print"] = False  # Reset the flag to False after printing

    globals()["flagged_print"] = flagged_print

    # Reassign print to the custom function
    globals()['print'] = globals()["flagged_print"]
    return

reassign_print()

class PrefixedDict(dict):
    def __init__(self, *args, **kwargs):
        if not hasattr(self, '_prefix'):
            self._prefix = ""
        super().__init__(*args, **kwargs)
        self._cache = None  # Cache for the prefixed dictionary

    @property
    def dict(self):
        if globals().get('_is_print', False):
            return self.get_prefixed()
        else:
            return dict(super().items())

    def to_dict(self):
        output = {}
        for key, value in self.items():
            if isinstance(value, PrefixedDict):
                value = value.to_dict()
            output[key] = value
        return output

    def _process_key(self, key):
        """
        Returns the raw key and the prefixed key.
        """
        raw_key = key
        prefixed_key = f"{self._prefix}{key}"
        return raw_key, prefixed_key

    def get_prefixed(self):
        """
        Returns a dictionary with prefixed keys, caching it for efficiency.
        """
        if self._cache is not None:
            return self._cache

        # Create a prefixed version of the dictionary for display
        prefixed_dict = self.shapify_tensors({f"{self._prefix}{key}": value for key, value in super().items()})
        self._cache = prefixed_dict
        return prefixed_dict

    def __setitem__(self, key, value):
        # Invalidate the cache on updates
        self._cache = None
        super().__setitem__(key, value)

    def __delitem__(self, key):
        # Invalidate the cache on deletions
        self._cache = None
        super().__delitem__(key)

    def __getitem__(self, key):
        return super().__getitem__(key)

    def __contains__(self, key):
        return super().__contains__(key)

    def items(self):
        return self.dict.items()

    def keys(self):
        return super().keys()

    def values(self):
        return super().values()

    @classmethod
    def shapify_tensors(cls, value):
        """
        If an item is a PyTorch tensor, convert its value to its shape string.
        """
        def shapify(value):
            if "shape" in dir(value):
                if callable(value.shape):
                    shape_dims = [str(dim) for dim in list(value.shape())]
                else:
                    shape_dims = [str(dim) for dim in list(value.shape)]
                return 'Tensor.shape(' + ", ".join(shape_dims) + ')'
            else:
                return value

        return cls.apply_func_recursively(value, shapify)

    @classmethod
    def apply_func_recursively(cls, data, func):
        if isinstance(data, Mapping):
            # Apply the function recursively to dictionary values
            return {key: cls.apply_func_recursively(value, func) for key, value in data.items()}
        elif isinstance(data, Iterable) and not isinstance(data, (str, bytes)) and not "shape" in dir(data):
            # Apply the function recursively to each element in an iterable (list, tuple, etc.)
            return type(data)([cls.apply_func_recursively(element, func) for element in data])
        else:
            # Base case: Apply the function to the scalar value
            return func(data)

class LayersDict(PrefixedDict):
    def __init__(self, *args, **kwargs):
        self._prefix="layer_"
        super().__init__(*args, **kwargs)

class PromptDict(PrefixedDict):
    def __init__(self, *args, **kwargs):
        self._prefix="prompt_"
        super().__init__(*args, **kwargs)

class TokenDict(PrefixedDict):
    def __init__(self, *args, **kwargs):
        self._prefix="token_"
        super().__init__(*args, **kwargs)

# Testing the updated class
sample_dict = {}
for i in range(5):
    sample_dict[str(i)] =  LayersDict(**{str(i): torch.rand(2,3,4) for i in range(2)})

# Create a LayersDict instance
layers_dict = LayersDict(sample_dict)

# Pretty print the LayersDict instance
pprint(layers_dict, compact=False, indent=4, underscore_numbers=True)


I don't have infinite memory, making efficient use of resources is vital when you only have the free tier to work with.

The following classes offer synchronous and asynchronous methods for storing tensors from model activations in-memory or on-disk

In [None]:
# @title Sync DB Managers
import sqlite3
import numpy as np
import torch
from io import BytesIO

class SQLiteBuddy:
    def __init__(self, db_path):
        """Initialize with the database file path."""
        self.db_path = db_path
        self.connection = None
        self.cursor = None
        self.connect()

    def connect(self):
        """Establish a connection to the SQLite database."""
        try:
            self.connection = sqlite3.connect(self.db_path)
            self.cursor = self.connection.cursor()
            print("Connected to the database successfully.")
        except sqlite3.Error as e:
            print(f"Error connecting to the database: {e}")

    def close(self):
        """Close the connection to the SQLite database."""
        if self.connection:
            self.connection.close()
            print("Connection closed.")

    def execute_query(self, query, params=()):
        """Execute a single query."""
        try:
            self.cursor.execute(query, params)
            self.connection.commit()
            print("Query executed successfully.")
        except sqlite3.Error as e:
            print(f"Error executing query: {e}")

    def execute_many(self, query, params_list):
        """Execute multiple queries with varying parameters."""
        try:
            self.cursor.executemany(query, params_list)
            self.connection.commit()
            print("Multiple queries executed successfully.")
        except sqlite3.Error as e:
            print(f"Error executing multiple queries: {e}")

    def fetch_all(self, query, params=()):
        """Fetch all results from a query."""
        try:
            self.cursor.execute(query, params)
            results = self.cursor.fetchall()
            return results
        except sqlite3.Error as e:
            print(f"Error fetching data: {e}")
            return None

    def fetch_one(self, query, params=()):
        """Fetch one result from a query."""
        try:
            self.cursor.execute(query, params)
            result = self.cursor.fetchone()
            return result
        except sqlite3.Error as e:
            print(f"Error fetching data: {e}")
            return None

    def create_table(self, table_name, columns):
        """Create a table with the specified columns."""
        try:
            column_defs = ", ".join([f"{col_name} {col_type}" for col_name, col_type in columns.items()])
            create_table_query = f"CREATE TABLE IF NOT EXISTS {table_name} ({column_defs});"
            self.execute_query(create_table_query)
            print(f"Table '{table_name}' created successfully.")
        except sqlite3.Error as e:
            print(f"Error creating table: {e}")

    def insert(self, table_name, data):
        """Insert a row into a table."""
        columns = ", ".join(data.keys())
        placeholders = ", ".join(["?" for _ in data.values()])
        insert_query = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
        self.execute_query(insert_query, tuple(data.values()))
        print(f"Data inserted into '{table_name}'.")

    def select_all(self, table_name):
        """Select all data from a table."""
        select_query = f"SELECT * FROM {table_name}"
        return self.fetch_all(select_query)

class PaceDbBuddy(SQLiteBuddy):
    def __init__(self, db_path):
        """Initialize the TensorDBBuddy with a database file path."""
        super().__init__(db_path)
        self.create_tensor_table()

    def create_tensor_table(self):
        """Create a table with 'key' as the primary key and 'value' for tensors."""
        create_table_query = """
        CREATE TABLE IF NOT EXISTS tensors (
            key TEXT PRIMARY KEY,
            value BLOB
        );
        """
        self.execute_query(create_table_query)

    def insert_tensor(self, concept, tensor: torch.Tensor):
        """Insert a tensor into the table, converting the tensor to bytes."""
        np_bytes = BytesIO()
        tensor_np = tensor.detach().cpu().numpy().astype(np.dtype('float32')) # Convert tensor to numpy array
        np.save(np_bytes, tensor_np, allow_pickle=True)
        tensor_bytes = np_bytes.getvalue() # Serialize the numpy array
        insert_query = "INSERT OR REPLACE INTO tensors (key, value) VALUES (?, ?)"
        self.execute_query(insert_query, (concept, tensor_bytes))
        print(f"Tensor for concept '{concept}' inserted successfully.")

    def fetch_tensor(self, concept):
        """Fetch a tensor from the table, deserializing the bytes back to a PyTorch tensor."""
        select_query = "SELECT value FROM tensors WHERE key = ?"
        result = self.fetch_all(select_query, (concept,))
        if result and result[0][0]:
            tensor_bytes = result[0][0]
            load_bytes = BytesIO(tensor_bytes)
            loaded_np = np.load(load_bytes, allow_pickle=True) # Deserialize to numpy array
            tensor = torch.tensor(loaded_np) # Convert numpy array back to PyTorch tensor
            return tensor
        else:
            print(f"No tensor found for concept '{concept}'")
            return None

    def list_concepts(self):
        """Retrieve all concepts that have been uploaded (present in the database)."""
        select_query = "SELECT key FROM tensors"
        result = self.fetch_all(select_query)
        if result:
            concepts = [row[0] for row in result]
            return concepts
        else:
            print("No concepts found.")
            return []

In [None]:
# @title Async DB Managers
import logging
from io import BytesIO
import asyncio
import torch
import numpy as np
import aiosqlite

# Configure dblogger
logging.basicConfig(
    level=logging.INFO,  # Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',  # Define the log format
    handlers=[logging.StreamHandler()]  # Output logs to the console (stdout)
)

dblogger = logging.getLogger(__name__)


class AsyncSQLiteBuddy:
    def __init__(self, db_path):
        self.db_path = db_path

    async def execute_query(self, query, params=()):
        """Execute a single query with optional parameters."""
        async with aiosqlite.connect(self.db_path) as db:
            await db.execute(query, params)
            await db.commit()

    async def execute_many(self, query, params_list):
        """Execute multiple queries in one go."""
        async with aiosqlite.connect(self.db_path) as db:
            await db.executemany(query, params_list)
            await db.commit()

    async def fetch_all(self, query, params=()):
        """Fetch all rows for a query."""
        async with aiosqlite.connect(self.db_path) as db:
            cursor = await db.execute(query, params)
            rows = await cursor.fetchall()
            await cursor.close()
            return rows

    async def fetch_one(self, query, params=()):
        """Fetch a single row for a query."""
        async with aiosqlite.connect(self.db_path) as db:
            cursor = await db.execute(query, params)
            row = await cursor.fetchone()
            await cursor.close()
            return row

    async def create_table(self, table_name, columns):
        """
        Create a table with the given name and columns.

        :param table_name: str - Name of the table
        :param columns: dict - Column names and types (e.g., {'id': 'INTEGER', 'name': 'TEXT'})
        """
        columns_def = ', '.join(f"{col} {dtype}" for col, dtype in columns.items())
        query = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_def})"
        await self.execute_query(query)

    async def insert(self, table_name, data):
        """
        Insert data into the table.

        :param table_name: str - Name of the table
        :param data: dict - Data to be inserted as a dictionary {column_name: value}
        """
        columns = ', '.join(data.keys())
        placeholders = ', '.join(['?' for _ in data.values()])
        query = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
        params = tuple(data.values())
        await self.execute_query(query, params)

    async def select_all(self, table_name):
        """Select all rows from a table."""
        query = f"SELECT * FROM {table_name}"
        return await self.fetch_all(query)

class AsyncBatchDbBuddy(AsyncSQLiteBuddy):
    def __init__(self, db_path, table_name, batch_size=10, flush_interval=5, max_retries=5, retry_delay=30):
        """
        Initialize the AsyncBatchDbBuddy with common parameters for batch insertion.
        :param db_path: Path to the database file
        :param table_name: The name of the table for this specific buddy
        :param batch_size: Maximum batch size before auto-flush
        :param flush_interval: Interval to auto-flush (seconds)
        :param max_retries: Maximum retries when a database lock is encountered
        :param retry_delay: Delay (in seconds) for retrying after lock
        """
        super().__init__(db_path)
        self.table_name = table_name
        self.batch_size = batch_size  # Maximum batch size before auto-flush
        self.flush_interval = flush_interval  # Interval to auto-flush (seconds)
        self.insert_queue = asyncio.Queue()  # Queue to store insertions
        self.batch_worker_task = None  # Background worker task
        self.running = True  # Flag to control the worker loop
        self.max_retries = max_retries  # Maximum retries for a database operation
        self.retry_delay = retry_delay  # Delay for retrying after lock
        self.lock = asyncio.Lock()  # Ensure no concurrent flushes

    async def start_batch_worker(self):
        """Start a background task that handles batching and inserts."""
        self.batch_worker_task = asyncio.create_task(self.batch_worker())

    async def stop_batch_worker(self):
        """Stop the background task."""
        self.running = False
        if self.batch_worker_task:
            await self.batch_worker_task

    async def batch_worker(self):
        """Background worker that collects inserts and batches them."""
        while self.running:
            await asyncio.sleep(self.flush_interval)  # Wait for the flush interval
            await self.flush_queue()

    async def flush_queue(self):
        """Flush the queue by inserting accumulated data in a batch, with retry logic."""
        async with self.lock:  # Ensure no concurrent flushes
            batch = []
            while not self.insert_queue.empty() and len(batch) < self.batch_size:
                batch.append(await self.insert_queue.get())

            if batch:
                # Create a batched query with multiple rows, using the table name from the subclass
                insert_query = f"INSERT OR REPLACE INTO {self.table_name} (key, value) VALUES (?, ?)"
                await self._execute_with_retry(insert_query, batch)
                dblogger.info(f"Batch of {len(batch)} items inserted successfully into {self.table_name}.")

    async def _execute_with_retry(self, query, params_list):
        """Execute a query with retry logic for handling database locks."""
        retries = 0
        while retries < self.max_retries:
            try:
                await self.execute_many(query, params_list)
                break  # Break the loop if the query succeeds
            except aiosqlite.OperationalError as e:
                if "database is locked" in str(e):
                    retries += 1
                    wait_time = self.retry_delay * retries
                    dblogger.warning(f"Database is locked, retrying in {wait_time} seconds...")
                    await asyncio.sleep(wait_time)
                else:
                    dblogger.error(f"Database operation failed: {e}")
                    raise e  # Re-raise any other operational errors

    async def insert_data(self, key, value):
        """Queue data for batched insertion."""
        await self.insert_queue.put((key, value))  # Add to queue

        # If queue size exceeds batch size, flush the queue immediately
        if self.insert_queue.qsize() >= self.batch_size:
            await self.flush_queue()

    async def create_table(self, create_table_sql):
        """Create a table asynchronously using the provided SQL."""
        await self.execute_query(create_table_sql)

    async def fetch_data(self, key):
        """Fetch data from the table asynchronously."""
        select_query = f"SELECT value FROM {self.table_name} WHERE key = ?"
        result = await self.fetch_all(select_query, (key,))
        if result and result[0][0]:
            return result[0][0]
        else:
            dblogger.info(f"No data found for key '{key}'")
            return None

    async def list_keys(self):
        """Retrieve all keys that are present in the database asynchronously."""
        select_query = f"SELECT key FROM {self.table_name}"
        result = await self.fetch_all(select_query)
        if result:
            keys = [row[0] for row in result]
            return keys
        else:
            dblogger.info("No keys found.")
            return []

class AsyncPaceDbBuddy(AsyncBatchDbBuddy):
    def __init__(self, db_path, batch_size=10, flush_interval=5, max_retries=5, retry_delay=1):
        self.dtype = np.dtype('float32')
        super().__init__(db_path, "tensors", batch_size, flush_interval, max_retries, retry_delay)

    @classmethod
    async def create_db(cls, db_path, batch_size=10, flush_interval=5):
        db = cls(db_path=db_path, batch_size=batch_size, flush_interval=flush_interval)
        await db.create_tensor_table()
        return db

    async def create_tensor_table(self):
        """Create a table for tensors asynchronously."""
        create_table_query = """
        CREATE TABLE IF NOT EXISTS tensors (
            key TEXT PRIMARY KEY,
            value BLOB
        );
        """
        await self.create_table(create_table_query)

    async def insert_tensor(self, concept, tensor: torch.Tensor):
        """Queue a tensor for batched insertion."""
        np_bytes = BytesIO()
        tensor_np = tensor.detach().cpu().numpy().astype(np.dtype('float32')) # Convert tensor to numpy array
        np.save(np_bytes, tensor_np, allow_pickle=True)
        tensor_bytes = np_bytes.getvalue() # Serialize the numpy array
        await self.insert_data(concept, tensor_bytes)

    async def fetch_tensor(self, concept):
        """Fetch a tensor from the table, deserializing the bytes back to a PyTorch tensor asynchronously."""
        tensor_bytes = await self.fetch_data(concept)
        if tensor_bytes:
            load_bytes = BytesIO(tensor_bytes)
            loaded_np = np.load(load_bytes, allow_pickle=True) # Deserialize to numpy array
            tensor = torch.tensor(loaded_np) # Convert numpy array back to PyTorch tensor
            return tensor
        else:
            return None

    async def list_concepts(self):
        """Retrieve all keys that are present in the database asynchronously."""
        return await self.list_keys()

class AsyncPaceCacheBuddy(AsyncBatchDbBuddy):
    def __init__(self, db_path, batch_size=10, flush_interval=5, max_retries=5, retry_delay=1):
        self.dtype = np.dtype('float32')
        super().__init__(db_path, "activation_cache", batch_size, flush_interval, max_retries, retry_delay)

    @classmethod
    async def create_db(cls, db_path, batch_size=10, flush_interval=5):
        db = cls(db_path=db_path, batch_size=batch_size, flush_interval=flush_interval)
        await db.create_cache_table()
        return db

    async def create_cache_table(self):
        """Create a table for cache asynchronously."""
        create_table_query = """
        CREATE TABLE IF NOT EXISTS activation_cache (
            key TEXT PRIMARY KEY,
            value BLOB
        );
        """
        await self.create_table(create_table_query)

    async def insert_cache(self, key, tensor: torch.Tensor):
        """Queue a tensor for batched insertion."""
        np_bytes = BytesIO()
        tensor_np = tensor.detach().cpu().numpy().astype(np.dtype('float32')) # Convert tensor to numpy array
        np.save(np_bytes, tensor_np, allow_pickle=True)
        tensor_bytes = np_bytes.getvalue() # Serialize the numpy array
        await self.insert_data(key, tensor_bytes)

    async def fetch_cache(self, key):
        """Fetch a tensor from the cache, deserializing the bytes back to a PyTorch tensor asynchronously."""
        tensor_bytes = await self.fetch_data(key)
        if tensor_bytes:
            load_bytes = BytesIO(tensor_bytes)
            loaded_np = np.load(load_bytes, allow_pickle=True) # Deserialize to numpy array
            tensor = torch.tensor(loaded_np) # Convert numpy array back to PyTorch tensor
            return tensor
        else:
            return None

In [None]:
# @title Test AsyncPaceCacheBuddy
!rm '/content/test.db'
test_db = await AsyncPaceDbBuddy.create_db(db_path='/content/test.db', batch_size=3, flush_interval=3)
values = []
for i in range(9):
    value = torch.rand(52, 2048, 1)
    values.append(value)
    await test_db.insert_tensor(f"test{i}", value)
stored_keys = await test_db.list_concepts()
assert len(stored_keys) == len(values)
for idx, key in enumerate(stored_keys):
    data = await test_db.fetch_tensor(key)
    torch.testing.assert_close(data, values[idx])

In [None]:
# @title Dataset functions
from more_itertools import chunked
from typing import List, Dict, Tuple, Any, Optional
import h5py
from functools import lru_cache


def get_concepts() -> Tuple[str]:
    with h5py.File(hd5_file, 'r') as hdf:
        concepts = tuple(hdf.keys())
    return concepts

def get_stimuli(concepts: List[str], restrict_stimuli: Optional[int]=None) -> Tuple[Tuple[str]]:
    if restrict_stimuli is not None:
        stimuli_slice = slice(0, restrict_stimuli)
    else:
        stimuli_slice = slice(None)
    with h5py.File(hd5_file, 'r') as hdf:
        stimuli = tuple((tuple(prompt.decode() for prompt in hdf[concept][stimuli_slice]) for concept in concepts))
    return stimuli

# LRU can only cache hashable objects, need to convert lists and other mutable input types to a hashable type before caching
@lru_cache(maxsize=None)
def get_stimuli_len_cached(concepts_json: str) -> int:
    concepts = json.loads(concepts_json)
    with h5py.File(hd5_file, 'r') as hdf:
        return sum(len(hdf[concept]) for concept in concepts)

def get_stimuli_len(concepts: List[str]) -> int:
    concepts_json = json.dumps(concepts)
    return get_stimuli_len_cached(concepts_json)

@lru_cache(maxsize=None)
def get_all_stimuli_len() -> int:
    with h5py.File(hd5_file, 'r') as hdf:
        total = sum(len(hdf[concept]) for concept in hdf.keys())
    return total

def dataset_generator(concepts: List[str], load_n_concepts=3, restrict_stimuli: Optional[int]=None):
    for batch_concepts in chunked(concepts, n=load_n_concepts):
        stimuli = get_stimuli(batch_concepts, restrict_stimuli)
        for concept, prompts in zip(batch_concepts, stimuli):
            yield concept, prompts

In [None]:
# @title PaCE Extraction Classes
from ctypes import ArgumentError
from tqdm.autonotebook import tqdm
import asyncio
from more_itertools import chunked
import inspect

from typing import Tuple, List
import torch

# Type hint for the hidden states, which is a tuple of tuples of torch tensors
HiddenStatesType = Tuple[Tuple[torch.Tensor, ...],...]

# Type hint for the past key values, which is a tuple of tuples, each containing two torch tensors
PastKeyValuesType = Tuple[Tuple[torch.Tensor, torch.Tensor], ...]

# Type hint for the overall structure
ModelOutputType = dict[str, Tuple[HiddenStatesType, PastKeyValuesType, torch.Tensor]]


pacelogger = logging.getLogger(__name__)

class RecursiveDict:
    def __init__(self):
        self._data = {}

    def __getitem__(self, key):
        if key not in self._data:
            self._data[key] = RecursiveDict()  # Create new RecursiveDict if key doesn't exist
        return self._data[key]

    def __setitem__(self, key, value):
        self._data[key] = value

    def __repr__(self):
        return repr(self._data)

    def keys(self):
        return self._data.keys()

    def values(self):
        return self._data.values()

    def items(self):
        return self._data.items()

class RecursivePrefixedDict(PrefixedDict):
    def __init__(self, *args, **kwargs):
        self._prefix=""
        super().__init__(*args, **kwargs)

    def __getitem__(self, key):
        if not super().__contains__(key):
            super().__setitem__(key, RecursivePrefixedDict())  # Create new RecursiveDict if key doesn't exist
        return super().__getitem__(key)

    def keys(self):
        return super().keys()

    def values(self):
        return super().values()

    def items(self):
        return super().items()

class LlmBuddy:
    def __init__(self):
        from transformers import AutoModelForCausalLM, AutoTokenizer
        import torch
        pass

    @staticmethod
    def tokenize(prompts, tokenizer, device):
        inputs = tokenizer(prompts, padding=True, return_tensors='pt')
        inputs = {key: val.to(device) for key, val in inputs.items()}
        return inputs

    @staticmethod
    def generate(inputs, model, gen_cfg):
        return model.generate(
            inputs.get("input_ids"),
            attention_mask=inputs.get("attention_mask"),
            generation_config=gen_cfg,
        )

    @staticmethod
    def _padding_args(pad_lengths: Optional[List[int]]=None, pad_direction: Optional[Literal["left", "right"]]=None, remove_padding: bool=False):
        if (pad_lengths is not None and pad_direction is not None) or remove_padding is True:
            if remove_padding is False:
                print("Warning: remove_padding is False, but pad_lengths and pad_direction were provided.\n    Removing padding.")
                return True
            if pad_lengths is None:
                raise ValueError('pad_lengths must be provided as list of int values if remove_padding is True.')
            if pad_direction is None:
                raise ValueError('pad_direction "left" or "right" must be provided if remove_padding is True.')
            return True
        else:
            return False

    @staticmethod
    def remove_padding(ids, pad_len: int, pad_direction: Literal["left", "right"]):
        if pad_len > 0:
            if pad_direction == 'left':
                pad_slice = slice(pad_len, None)
            else:
                pad_slice = slice(0, -pad_len)
            ids = ids[pad_slice]
        return ids

    @classmethod
    def batch_remove_padding(cls, batch_ids: List[torch.tensor] | torch.Tensor, pad_lengths: List[int], pad_direction: Literal["left", "right"]):
        assert len(batch_ids) == len(pad_lengths)
        assert isinstance(batch_ids, torch.Tensor) or isinstance(batch_ids, list)
        output_tensors = []
        for i, pad_len in enumerate(pad_lengths):
            output_tensors.append(cls.remove_padding(batch_ids[i], pad_len, pad_direction))
        return output_tensors

    @classmethod
    def unbatch_hidden_state(cls, hidden_states, pad_lengths: Optional[List[int]], pad_direction: Optional[Literal["left", "right"]]=None, remove_padding: bool=False):
        unbatched_hidden_states = []
        if remove_padding is True:
            assert pad_lengths is not None
            assert pad_direction is not None
            unbatched_hidden_states.extend(cls.batch_remove_padding(batch_ids=hidden_states, pad_lengths=pad_lengths, pad_direction=pad_direction))
        else:
            unbatched_hidden_states.extend(list((hidden_states[i] for i in range(len(hidden_states)))))
        return unbatched_hidden_states

    @classmethod
    def unbatch_hidden_states(cls, hidden_states, pad_lengths: Optional[List[int]]=None, pad_direction: Optional[Literal["left", "right"]]=None, remove_padding: bool=False):
        processed_states = {}
        for tuple_idx, state_tuple in enumerate(hidden_states):
            unbatched_hidden_states = LayersDict()
            for layer_idx, layer_state in enumerate(state_tuple):
                unbatched = cls.unbatch_hidden_state(hidden_states=layer_state, pad_lengths=pad_lengths, pad_direction=pad_direction, remove_padding=remove_padding)
                unbatched_hidden_states[str(layer_idx)] = PromptDict(**{str(prompt_idx): prompt_state for prompt_idx, prompt_state in enumerate(unbatched)})
            processed_states[tuple_idx] = unbatched_hidden_states
        return processed_states

    @classmethod
    def print_text(cls, ids, tokenizer):
        print(tokenizer.decode(ids))

    @staticmethod
    def batch_decode_text(batch_ids: torch.Tensor | List[torch.Tensor], tokenizer: AutoTokenizer):
        batch_size = len(batch_ids)
        decoded = []
        for idx in range(0, batch_size):
            ids = batch_ids[idx]
            decoded.append(tokenizer.decode(ids))
        return decoded

    @classmethod
    def print_batch_text(cls, batch_ids: torch.Tensor | List[torch.Tensor], tokenizer: AutoTokenizer, pad_lengths: Optional[List[int]]=None, pad_direction: Optional[Literal["left", "right"]]=None, remove_padding: bool=False):
        if cls._padding_args(pad_lengths=pad_lengths, pad_direction=pad_direction, remove_padding=remove_padding):
            to_decode = cls.batch_remove_padding(batch_ids=batch_ids, pad_lengths=pad_lengths, pad_direction=pad_direction)
        else:
            to_decode = batch_ids
        decoded = cls.batch_decode_text(batch_ids=to_decode, tokenizer=tokenizer)
        for text in decoded:
            print(text)

    @staticmethod
    def batch_slice_tokens(batch_ids, start_idxs, end_idxs):
        """
        Slices the tensors across the sequence_length dimension based on start and end indexes.

        Args:
            batch_ids: A 3D tensor of shape (batch_size, sequence_length, activation),
                or a Python list of 2D tensors of shape (sequence_length, activation) or
                3D tensors of shape (1, sequence_length, activation).
            start_idxs: A Python list of start indices for slicing.
            end_idxs: A Python list of end indices for slicing.

        Returns:
            A list of sliced tensors for each batch.
        """
        if isinstance(batch_ids, torch.Tensor) and batch_ids.ndim == 2:
            batch_ids = batch_ids.unsqueeze(0)  # Convert 2D tensor to 3D

        # If batch_ids is a 3D tensor, slice it directly
        if isinstance(batch_ids, torch.Tensor) and batch_ids.ndim == 3:
            batch_size = batch_ids.shape[0]
            sliced_tensors = []
            sequence_length = batch_ids.shape[1]
            for i in range(batch_size):
                check_start = start = start_idxs[i]
                check_end = end = end_idxs[i]
                if start is None:
                    check_start = 0
                if end is None:
                    check_end = sequence_length
                if check_start > sequence_length:
                    raise IndexError("Start index exceeds sequence length.")
                if check_end > sequence_length:
                    raise IndexError("End index exceeds sequence length.")
                if check_start > check_end:
                    raise IndexError("Start index cannot be greater than end index.")
                sliced_tensors.append(batch_ids[i, start:end, :])  # Slicing the sequence_length dimension

        # If batch_ids is a list of tensors
        elif isinstance(batch_ids, list):
            sliced_tensors = []
            for i, tensor in enumerate(batch_ids):
                start = start_idxs[i]
                end = end_idxs[i]

                # If the tensor is 3D, reduce it to 2D
                if tensor.ndim == 3 and tensor.shape[0] == 1:
                    tensor = tensor.squeeze(0)  # Convert (1, sequence_length, activation) -> (sequence_length, activation)

                # Check for index errors
                sequence_length = tensor.shape[0]
                check_start = start = start_idxs[i]
                check_end = end = end_idxs[i]
                if start is None:
                    check_start = 0
                if end is None:
                    check_end = sequence_length
                if check_start > sequence_length:
                    raise IndexError("Start index exceeds sequence length.")
                if check_end > sequence_length:
                    raise IndexError("End index exceeds sequence length.")
                if check_start > check_end:
                    raise IndexError("Start index cannot be greater than end index.")
                # Now slice the 2D tensor (sequence_length, activation)
                sliced_tensors.append(tensor[start:end, :])

        else:
            raise ValueError("Input data must be a 3D tensor or a list of 2D/3D tensors.")

        return sliced_tensors

@torch.jit.script  # TorchScript for compilation
def pca_using_svd(data: torch.Tensor, num_components: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Perform PCA using SVD on the input data.

    Args:
        data (torch.Tensor): The input data of shape (n_samples, n_features).
        num_components (int): The number of principal components to retain.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The principal components (right singular vectors),
                                           and the explained variance (singular values squared).
    """
    # Step 1: Center the data by subtracting the mean of each feature
    mean = torch.mean(data, dim=0)
    centered_data = data - mean

    # Step 2: Perform SVD using torch.linalg.svd
    # U: left singular vectors, S: singular values, Vh: right singular vectors (transposed)
    U, S, Vh = torch.linalg.svd(centered_data, full_matrices=False)

    # Step 3: Extract the top `num_components` principal components
    Vh = Vh[:num_components, :]
    S = S[:num_components]

    # Step 4: Compute explained variance (which is the square of the singular values)
    explained_variance = S ** 2 / (data.size(0) - 1)

    # Return the principal components and the explained variance
    return Vh, explained_variance

@torch.jit.script  # TorchScript for compilation
def rebatch_tokens_compiled(hidden_states: List[List[torch.Tensor]]) -> torch.Tensor:
    """
    Reshapes the data from the model's output to be in the desired format.

    Args:
        outputs (ModelOutputType): The input data, a dictionary where the hidden_states value is a tuple containing a nested tuple with a tensor for each layer.
            The embedding info for prefill tokens is the tuple of layer tensors at element 0.
            Tuples of hidden_states for each layer for each generated token make up the subsequent elements.

    Returns:
        torch.Tensor: The tensors reshaped. (layer x batch_size x sequence x hidden_dimension).
    """
    # 1. Extract the last token from each tensor in the hidden_states tuples (shape [10, 17 | 1, 2048] -> [10, 1, 2048])
    # This is equivalent to taking the slice `[:, -1:, :]` for the last token
    # Then we stack all the tensors along a new layer dimension (shape [22, 10, 1, 2048])
    all_tensors = [
        torch.stack([hs[:, -1:, :] for hs in hstuple], dim=0)
        for hstuple in hidden_states
    ] # shape list of lists[10, 1, 2048] for each layer  # shape [10, 1, 2048] for each layer

    # 2. Stack tensors along a new dimension (num_layers, batch_size, sequence_length, hidden_dimension)
    # We assume that each "layer" corresponds to each tensor we are stacking
    # Resulting shape will be (num_layers, batch_size, sequence_length, hidden_dimension)
    rebatched_hidden_states = torch.concat(all_tensors, dim=2)
    return rebatched_hidden_states

@torch.jit.script  # TorchScript for compilation
def layer_wise_pca(rebatched_hidden_states: torch.Tensor) -> torch.Tensor:
    """
    Performs layer-wise PCA on the hidden states.

    Args:
        rebatched_hidden_states (torch.Tensor): The rebatched model output data.

    Returns:
        torch.Tensor: The identified principal components per layer. (layer x hidden_dimension x principal_components).
    """
    layer, batch_size, sequence, hidden_dimension = rebatched_hidden_states.shape
    rebatched_hidden_states = rebatched_hidden_states.view(layer, batch_size * sequence, hidden_dimension)
    pca_result = []
    for pca_tensor in rebatched_hidden_states.split(1):
        pca_tensor = pca_tensor.squeeze(0)
        # principal_components, explained_variance = pca_using_svd(pca_tensor, num_components=1)
        principal_components, _ = pca_using_svd(pca_tensor, num_components=1)
        pca_result.append(principal_components)
    pca_result = torch.stack(pca_result, dim=0)
    return pca_result

class PaCEBuddy(LlmBuddy):
    def __init__(self, model, tokenizer, device, gen_cfg, cache=None, db=None, db_path='/content/pace.db', backup_db_path=None):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
        self.gen_cfg = gen_cfg
        if (not isinstance(cache, AsyncPaceCacheBuddy)) and cache is not None:
            raise ArgumentError("cache must be of type AsyncPaceCacheBuddy or None\nNot type: {type}".format(type=type(cache)))
        self.cache = cache
        self.db_path = Path(db_path)
        if db is None:
            raise ArgumentError("db must be provided")
        self.db = db
        self.backup_db_path = Path(backup_db_path)
        self.pending_tasks = []
        self.cache_tasks = []
        self.backup_task = None
        self.loop = asyncio.get_event_loop()
        pass

    def backup_db(self):
        if self.backup_db_path is not None:
            self.db.close()
            buddy.file_pair(local_path=self.db_path, remote_path=self.backup_db_path).copy_to_remote()
            self.db.connect()

    def tokenize(self, prompts):
        inputs = super().tokenize(prompts, self.tokenizer, self.device)
        pad_unpad_lengths = [
            tuple((
                (ids == self.tokenizer.pad_token_id).sum().item(),
                (ids != self.tokenizer.pad_token_id).sum().item()
            )) for ids in inputs['input_ids']
        ]

        # Split the list of tuples into two lists
        pad_lengths, unpad_lengths = zip(*pad_unpad_lengths)

        # Convert the result back to lists (since zip returns tuples)
        pad_lengths, unpad_lengths = list(pad_lengths), list(unpad_lengths)

        return inputs, pad_lengths, unpad_lengths

    def generate(self, inputs):
        outputs = super().generate(inputs, self.model, self.gen_cfg)
        return outputs

    def remove_batch_padding(self, concepts, inputs, pad_lengths, unpad_lengths, outputs, pad_direction):
        outputs = super().batch_remove_padding(batch_ids=outputs.sequences, pad_lengths=pad_lengths, pad_direction=pad_direction)
        return concepts, inputs, pad_lengths, outputs

    def rebatch_tokens(self, outputs):
        """
        Reshapes the data from the model's output to be in the desired format.

        Args:
            outputs dict(str, tuple(tuple(torch.Tensor))): The input data, a dictionary where the hidden_states value is a tuple containing a nested tuple with a tensor for each layer.
                The embedding info for prefill tokens is the tuple of layer tensors at element 0.
                Tuples of hidden_states for each layer for each generated token make up the subsequent elements.

        Returns:
            torch.Tensor: The tensors reshaped. (layer x batch_size x sequence x hidden_dimension).
        """
        return rebatch_tokens_compiled(outputs.get("hidden_states"))

    def layer_wise_pca(self, rebatched_hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Performs layer-wise PCA on the hidden states.

        Args:
            rebatched_hidden_states (torch.Tensor): The rebatched model output data.

        Returns:
            torch.Tensor: The identidied principal components per layer. (layer x hidden_dimension x principal_components).
        """
        return layer_wise_pca(rebatched_hidden_states)


    def print_text(self, concepts, inputs, pad_lengths, unpad_lengths, outputs):
        decoded_texts = super().batch_decode_text(batch_ids=inputs.get("input_ids"), tokenizer=self.tokenizer)
        concept_text = {concept: [] for concept in concepts}
        for concept, text in zip(concepts, decoded_texts):
            concept_text[concept].append(text)
        pprint(concept_text, indent=4)

        if outputs is not None:
            super().print_text(ids=outputs.sequences, tokenizer=self.tokenizer)

    @staticmethod
    def print_shape(tensor):
        MatrixBuddy.print_shape(tensor)

    async def update_cache(self, key, value):
        task = self.loop.create_task(self.cache.insert_cache(key, value.clone().detach().cpu()))
        self.cache_tasks.append(task)

    async def get_cache(self, key):
        rebatched_hidden_states = await self.cache.fetch_cache(key)
        if rebatched_hidden_states is None:
            raise ValueError(f"Cache key {key} not found in cache despite being returned by list_keys method.")
        rebatched_hidden_states = rebatched_hidden_states.to(self.device)
        return rebatched_hidden_states

    async def insert_tensor(self, concepts, pca_result):
        if self.backup_task is not None:
            await self.backup_task
            self.backup_task = None
        task = self.loop.create_task(self.db.insert_tensor(concepts, pca_result.clone().detach().cpu()))
        self.pending_tasks.append(task)

    async def async_backup_db(self):
            """
            Asynchronously performs an incremental backup of the database.
            Uses SQLite's built-in backup feature to only transfer changed pages.
            """
            # Complete any pending tasks before starting the backup
            await asyncio.gather(*self.pending_tasks)
            # Open the current database connection
            async with aiosqlite.connect(self.db_path) as db:
                # Ensure the backup path exists
                if not self.backup_db_path.parent.exists():
                    self.backup_db_path.parent.mkdir(parents=True, exist_ok=True)

                # Perform an incremental backup to the target backup database
                async with aiosqlite.connect(self.backup_db_path) as backup_db:
                    # SQLite's backup API transfers pages in steps
                    await db.backup(backup_db, pages=1_000, sleep=0.1)
                    print(f"Incremental backup completed to {self.backup_db_path}")
                    # Close the backup database connection
                    await backup_db.close()

                print(f"Backup completed to {self.backup_db_path}")
                # Close the current database connection
                await db.close()

    @staticmethod
    def check_pca_input(hidden_states, layer_size, batch_size, sequence_length, hidden_dimension):
            if len(hidden_states.shape) != 4:
                raise ValueError(f"Expected 4 dimensions, got {len(hidden_states.shape)}")
            if layer_size != hidden_states.shape[0]:
                raise ValueError(f"Expected {layer_size} layers, got {hidden_states.shape[0]}")
            if batch_size != hidden_states.shape[1]:
                raise ValueError(f"Expected {batch_size} batches, got {hidden_states.shape[1]}")
            if sequence_length != hidden_states.shape[2]:
                raise ValueError(f"Expected {sequence_length} sequence length, got {hidden_states.shape[2]}")
            if hidden_dimension != hidden_states.shape[3]:
                raise ValueError(f"Expected {hidden_dimension} hidden dimension, got {hidden_states.shape[3]}")

    async def main(self, dataset, dataset_len, batch_size=10, backup_interval=2, pad_direction="right"):
        uploaded_concepts = await self.db.list_concepts()
        if self.cache is not None:
            cached_activation_keys = await self.cache.list_keys()
        else:
            cached_activation_keys = []

        max_new_tokens = self.gen_cfg.max_new_tokens
        min_length = self.gen_cfg.min_length

        layer_size = 22 + 1
        sequence_length = max_new_tokens
        hidden_dimension = 2048


        progress_bar = tqdm(total=dataset_len, desc="Processing stimuli")
        generate_bar = tqdm(total=None, desc=f"Generating in batches of {batch_size} from prompts")
        backup_description = "Backing up database in {remaining} intervals."
        backup_bar = tqdm(total=backup_interval, desc=backup_description.format(remaining=backup_interval))
        # Create an event loop for handling async tasks
        just_backed_up = False
        for concept_idx, (iter_concept, iter_stimuli_batch) in enumerate(dataset):
            batch_size = len(iter_stimuli_batch)
            if iter_concept in uploaded_concepts:
                continue
            just_backed_up = False
            hidden_states = []
            generate_bar.reset(total=len(iter_stimuli_batch))

            for iter_idx, iter_stimuli in enumerate(chunked(iter_stimuli_batch, n=batch_size)):
                inputs, pad_lengths, unpad_lengths = self.tokenize(iter_stimuli)
                cache_hit = False
                if self.cache is not None:
                    key = f"{iter_concept} | {iter_idx}"
                    if key in cached_activation_keys:
                        rebatched_hidden_states = await self.get_cache(key)
                        pacelogger.info(f"Cache hit for {key}")
                        cache_hit = True
                if cache_hit is False:
                    outputs = self.generate(inputs)
                    rebatched_hidden_states = self.rebatch_tokens(outputs)
                hidden_states.append(rebatched_hidden_states)
                if self.cache is not None:
                    await self.update_cache(key, rebatched_hidden_states)
                generate_bar.update(len(iter_stimuli))

            hidden_states = torch.concat(hidden_states, dim=1)  # concat along the batch dimension
            progress_bar.set_description_str(f"Performing PCA on concept {iter_concept} tensors")
            pacelogger.info(f"Performing PCA on concept {iter_concept} tensors\nTensor shape is {hidden_states.shape}")
            self.check_pca_input(hidden_states=hidden_states, layer_size=layer_size, batch_size=batch_size, sequence_length=sequence_length, hidden_dimension=hidden_dimension)
            pca_result = self.layer_wise_pca(hidden_states)
            paselogger.info(f"PCA result shape is {pca_result.shape}")
            progress_bar.update(hidden_states.shape[1])

            # Schedule async database insertion
            await self.insert_tensor(iter_concept, pca_result)

            backup_bar.update(1)
            backup_bar.set_description_str(backup_description.format(remaining=backup_interval - (concept_idx % backup_interval)))
            # Check if we need to backup
            if concept_idx % backup_interval == 0:
                # Wait for all pending tasks to complete, don't want to lose progress when colab shuts down
                await asyncio.gather(*self.pending_tasks)
                # Schedule async database backup
                self.backup_task = self.loop.create_task(self.async_backup_db())
                backup_bar.reset(total=backup_interval)
                backup_bar.set_description_str(backup_description.format(remaining=backup_interval))
                just_backed_up = True

        # Perform final backup if needed
        if not just_backed_up:
            backup_task = self.loop.create_task(self.async_backup_db())
            self.pending_tasks.append(backup_task)

        # Wait for all pending tasks to complete
        await asyncio.gather(*self.pending_tasks)

        progress_bar.close()
        generate_bar.close()



In [None]:
# @title Concept Extraction Config
if torch.cuda.is_available():
    torch.cuda.empty_cache()

if "concepts" in locals():
    old_concepts_len = len(concepts)
full_concepts = get_concepts()
concepts = full_concepts
if "old_concepts_len" in locals():
    concepts_changed = len(concepts) != old_concepts_len
else:
    concepts_changed = False

print(concepts)

# local_project_base = Path("/content/drive/MyDrive/Projects/Control_Vectors")
# remote_project_base = Path("/content")
dataset_len_json = buddy.file_pair(
    local_path=local_project_base / "dataset_len.json",
    remote_path=remote_project_base / "dataset_len.json"
)


if "dataset_len" not in locals() or concepts_changed is True:
    if len(concepts) != len(full_concepts):
        dataset_len = get_stimuli_len(concepts)
    else:
        if dataset_len_json.remote.exists():
            dataset_len_json.copy_to_local()
            with open(dataset_len_json.local, 'r') as f:
                dataset_len = json.load(f)
        else:
            dataset_len = get_all_stimuli_len()
            with open(dataset_len_json.local, 'w') as f:
                json.dump(dataset_len, f)
            dataset_len_json.copy_to_remote()

db_filename = 'pace_{model_name}.db'.format(model_name=filesafe_model_name)
db_path = buddy.file_pair(
    local_path=local_project_base / db_filename,
    remote_path=remote_project_base / db_filename
)

if db_path.remote.exists() is True:
    db_path.copy_to_local()

# db.fetch_tensor
token_length = 5
gen_cfg = GenerationConfig.from_model_config(model.config)
gen_cfg.max_new_tokens = token_length
gen_cfg.min_length = token_length
gen_cfg.output_hidden_states = True
gen_cfg.return_dict_in_generate = True

cache_db_path = local_project_base / "pace_cache.db"

I am still in heavy experimentation. sometimes I just need to delete everything and start over.

In [None]:
print(db_path.remote.exists())
print(db_path.local.exists())
print(cache_db_path.exists())
# !rm {db_path.remote}
# print(db_path.remote.exists())
# !rm {db_path.local}
# print(db_path.local.exists())
# !rm {cache_db_path}
# print(cache_db_path.exists())

In [None]:
logger = logging.getLogger(__name__)
log_level = 1
match log_level:
    case 0:
        logger.setLevel(logging.ERROR)
    case 1:
        logger.setLevel(logging.WARNING)
    case 2:
        logger.setLevel(logging.INFO)
    case 3:
        logger.setLevel(logging.DEBUG)
    case _:
        logger.setLevel(logging.ERROR)


This runs the implemented PAcE variant where I am trying to use algorithmic means of decomposing activations. I do not have the resources to effecitvely train sparse auto encoders at this time, so I am hoping to find a more traditionally mathematically grounded solution to representation engineering.

I have encountered some interesting results thus far, but the model I am working with are small, and therefore very brittle and sensitive to interventions.

In [None]:
# @title Concept Extraction
run_concept_extraction = False # @param {type:'boolean'}
if run_concept_extraction:
    model_prompt_batch_size = 10
    # If None, full dataset is used
    restrict_prompt_len = None
    dataset = dataset_generator(concepts=concepts, load_n_concepts=3, restrict_stimuli=restrict_prompt_len)

    db_flush_interval_minutes = 5
    flush_interval = db_flush_interval_minutes * 60
    db_batch_size = 100 if torch.cuda.is_available() else 1
    db = await AsyncPaceDbBuddy.create_db(db_path=db_path.local, batch_size=db_batch_size, flush_interval=flush_interval)

    cache_flush_interval = 60
    cache_db = await AsyncPaceCacheBuddy.create_db(db_path=cache_db_path, batch_size=3, flush_interval=cache_flush_interval)

    pace_buddy = PaCEBuddy(model=model, tokenizer=tokenizer, device=device, gen_cfg=gen_cfg, cache=cache_db, db=db, db_path=db_path.local, backup_db_path=db_path.remote)
    await pace_buddy.main(dataset, dataset_len, batch_size=model_prompt_batch_size, backup_interval=db_batch_size, pad_direction="right")

In [None]:
db = PaceDbBuddy(db_path=db_path.local)
stored_concepts = db.list_concepts()
print(stored_concepts)

In [None]:
# @title Functions for tensor operations
import torch

# @torch.jit.script  # TorchScript for compilation
def soft_thresholding_inplace(coef: torch.Tensor, alpha: float):
    """In-place soft thresholding function for L1 regularization."""
    coef.abs_().sub_(alpha).clamp_(min=0.0).mul_(coef.sign_())
    return coef

# @torch.jit.script  # TorchScript for compilation
def precompute_concept_vectors(concept_vectors: torch.Tensor, eps: float = 1e-5):
    """
    Precompute necessary data for concept_vectors to optimize coordinate descent.
    """
    # Normalize the concept vectors
    concept_vectors = torch.nn.functional.normalize(concept_vectors, dim=1)

    # Precompute norm squared for each concept vector, add eps to avoid divide by zero
    norm_squared = torch.norm(concept_vectors, dim=1) ** 2 + eps  # (num_concepts,)

    # Precompute the Gram matrix (concept_vectors.T @ concept_vectors) and add eps to the diagonal
    gram_matrix = torch.matmul(concept_vectors, concept_vectors.T)  # (num_concepts x num_concepts)
    gram_matrix = gram_matrix + torch.eye(gram_matrix.shape[0], device=gram_matrix.device) * eps

    return concept_vectors, norm_squared, gram_matrix

# @torch.jit.script  # TorchScript for compilation
def coordinate_descent_batch_optimized(concept_vectors: torch.Tensor,
                                       norm_squared: torch.Tensor,
                                       gram_matrix: torch.Tensor,
                                       latent_vector: torch.Tensor,
                                       coefficients: torch.Tensor,
                                       alpha: float,
                                       l1_ratio: float,
                                       tol: float,
                                       max_iter: int,
                                       clip_range: torch.Tensor):
    """
    Optimized coordinate descent for a batch of latent vectors, compiled with TorchScript.
    """
    seq_len, hidden_dimension = latent_vector.shape
    num_concepts = concept_vectors.shape[0]

    for iteration in range(max_iter):
        coefficients_old = coefficients.clone()

        # Compute residuals for all latent vectors
        residual = latent_vector - torch.matmul(coefficients, concept_vectors)

        # Vectorized rho computation for non-converged concepts
        rho = torch.sum(
            concept_vectors.unsqueeze(0) * (residual.unsqueeze(1) + coefficients.unsqueeze(2) * concept_vectors.unsqueeze(0)),
            dim=-1
        )

        # Perform soft-thresholding and coefficient updates in-place
        coef_update = soft_thresholding_inplace(rho, alpha * l1_ratio)
        coef_update.div_(gram_matrix.diag() + alpha * (1 - l1_ratio)).unsqueeze_(0)

        # Store the coefficient changes for residual update
        delta_coeff = coef_update - coefficients
        coefficients = torch.clamp_(coef_update, clip_range[0], clip_range[1])

        # Parallelized residual update using matrix multiplication
        # Reshape concept_vectors.T to (num_concepts, 1, 1, hidden_dimension)
        breakpoint()
        residual -= torch.einsum('bsh,ch->bsh', delta_coeff, concept_vectors.T)

        # Check convergence
        coef_change = torch.norm(coefficients - coefficients_old, p=2)  # Compute change

        if coef_change < tol:
            break

    return coefficients


# @torch.jit.script  # TorchScript for compilation
def decompose_latent_activation_cuda(latent_vector: torch.Tensor, concept_vectors: torch.Tensor,
                                     norm_squared: torch.Tensor, gram_matrix: torch.Tensor,
                                     alpha: float, l1_ratio: float, max_iter: int, tol: float,
                                     clip_range: Tuple[float, float]=(-1e3, 1e3)):
    """
    Decomposes a batch of latent vectors using CUDA streams, TorchScript, and AMP.
    """
    clip_range = torch.tensor(clip_range, device=latent_vector.device)
    batch_size, seq_len, hidden_dimension = latent_vector.shape
    num_concepts = concept_vectors.shape[0]

    coefficients = torch.zeros(batch_size, seq_len, num_concepts, device=latent_vector.device)

    # AMP context for mixed precision training (automatic mixed precision)
    with torch.cuda.amp.autocast(enabled=True):
        streams = [torch.cuda.Stream() for _ in range(batch_size)]

        for batch_idx in range(batch_size):
            latent_batch = latent_vector[batch_idx]
            with torch.cuda.stream(streams[batch_idx]):
                coefficients[batch_idx] = coordinate_descent_batch_optimized(
                    concept_vectors, norm_squared, gram_matrix, latent_batch, coefficients[batch_idx], alpha, l1_ratio, tol, max_iter, clip_range
                )

        torch.cuda.synchronize()

    return coefficients

# @torch.jit.script  # TorchScript for compilation
def decompose_latent_activation_cpu(latent_vector: torch.Tensor, concept_vectors: torch.Tensor,
                                    norm_squared: torch.Tensor, gram_matrix: torch.Tensor,
                                    alpha: float, l1_ratio: float, max_iter: int, tol: float,
                                    clip_range: Tuple[float, float]=(-1e3, 1e3)):
    """
    Decomposes a batch of latent vectors using CPU, TorchScript.
    """
    clip_range = torch.tensor(clip_range, device=latent_vector.device)
    batch_size, seq_len, hidden_dimension = latent_vector.shape
    num_concepts = concept_vectors.shape[0]

    coefficients = torch.zeros(batch_size, seq_len, num_concepts, device=latent_vector.device)

    for batch_idx in range(batch_size):
        latent_batch = latent_vector[batch_idx]
        coefficients[batch_idx] = coordinate_descent_batch_optimized(
            concept_vectors, norm_squared, gram_matrix, latent_batch, coefficients[batch_idx], alpha, l1_ratio, tol, max_iter, clip_range
        )

    return coefficients


# Global variable to store the selected decomposition function
# decompose_latent_activation = None

def initialize_decomposition_fn():
    """
    Initialize the appropriate decomposition function (either CUDA or CPU) based on CUDA availability.
    This function will be called once during startup.
    """
    global decompose_latent_activation

    if torch.cuda.is_available():
        decompose_latent_activation = decompose_latent_activation_cuda
    else:
        decompose_latent_activation = decompose_latent_activation_cpu

if torch.cuda.is_available():
    decompose_latent_activation = decompose_latent_activation_cuda
else:
    decompose_latent_activation = decompose_latent_activation_cpu

In [None]:
@torch.jit.script  # TorchScript for compilation
def edit_latent_vector(coefficients: torch.Tensor, concept_vectors: torch.Tensor, concept_strengths: torch.Tensor, residual: torch.Tensor):
    """
    Edits the latent vector by scaling the contributions of specified concept directions.

    Parameters:
        coefficients (torch.Tensor): The coefficients for each concept.
        concept_vectors (torch.Tensor): The matrix of concept vectors (dim x concept).
        concept_strengths (torch.Tensor): Array of scale factors where the index of the strength value matches the index of the concept vector.
        residual (torch.Tensor): The residual part of the latent vector.

    Returns:
        edited_latent_vector (torch.Tensor): The modified latent vector.
    """
    # Scale coefficients by concept strengths using element-wise multiplication
    scaled_coefficients = coefficients * concept_strengths

    # Reconstruct the latent vector with scaled coefficients
    # Expand concept_vectors to match the dimensions of scaled_coefficients
    # concept_vectors has shape [hidden_dimension, num_concepts]
    # scaled_coefficients has shape [batch_size, seq_len, hidden_dimension]
    # residual has shape [batch_size, seq_len, hidden_dimension]
    # We'll transpose it to [num_concepts, hidden_dimension] and then add batch/seq_len dimensions
    concept_vectors = concept_vectors.T.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, num_concepts, hidden_dimension]

    # Multiply scaled_coefficients with concept_vectors and sum across the concepts dimension (dim=2)
    edited_latent_vector = torch.sum(scaled_coefficients.unsqueeze(-1) * concept_vectors, dim=2) + residual

    return edited_latent_vector

In [None]:
import torch
@torch.jit.script  # TorchScript for compilation
def apply_scaling(latent_tensor: torch.Tensor, scaling_tensor: torch.Tensor) -> torch.Tensor:
    # Reshape scaling_tensor to (1, 1, z) for broadcasting
    scaling_tensor_reshaped = scaling_tensor.view(1, 1, scaling_tensor.shape[0])
    # Apply the scaling by broadcasting the scaling_tensor across the (x, y) dimensions
    output = latent_tensor * scaling_tensor_reshaped
    return output

In [None]:
@torch.jit.script
def manipulate_activations(activations: torch.Tensor, pca_matrix: torch.Tensor, scaling_factors: torch.Tensor):
    """
    Manipulate activations by either broadcasting or matrix multiplication, depending on the size of pca_matrix.

    Args:
    - activations (Tensor): The activations from the model with shape (batch_size, sequence_length, hidden_dimension).
    - pca_matrix (Tensor): The PCA matrix with shape (1, hidden_dimension) or (num_pca_components, hidden_dimension).
    - scaling_factors (Tensor): A tensor of shape (num_pca_components,) or scalar for broadcasting.

    Returns:
    - modified_activations (Tensor): The manipulated activations with shape (batch_size, sequence_length, hidden_dimension).
    """
    batch_size, sequence_length, hidden_dimension = activations.size()

    # Case 1: Single PCA vector (e.g., shape (1, hidden_dimension)), broadcasting
    if pca_matrix.size(0) == 1:
        if scaling_factors.dim() == 0:
            scaling_factors = scaling_factors.unsqueeze(0)  # Convert scalar to tensor for broadcasting

        # Scale the PCA vector
        scaled_pca_vector = pca_matrix * scaling_factors  # (1, hidden_dimension)

        # Expand the PCA vector to match (batch_size, sequence_length, hidden_dimension)
        scaled_pca_vector = scaled_pca_vector.expand(batch_size, sequence_length, hidden_dimension)

        # Apply the modification by adding the scaled PCA vector
        modified_activations = activations + scaled_pca_vector

    # Case 2: Multiple PCA components
    else:
        # Reshape activations for matrix multiplication
        activations_flat = activations.view(batch_size * sequence_length, hidden_dimension)

        # Project activations onto the PCA components
        pca_projected = torch.matmul(activations_flat, pca_matrix.t())  # (batch_size * sequence_length, num_pca_components)

        # Scale the projected activations
        pca_projected = pca_projected * scaling_factors

        # Reconstruct the activations in the original space
        modified_activations_flat = torch.matmul(pca_projected, pca_matrix)  # (batch_size * sequence_length, hidden_dimension)

        # Reshape back to (batch_size, sequence_length, hidden_dimension)
        modified_activations = modified_activations_flat.view(batch_size, sequence_length, hidden_dimension)

    return modified_activations

def get_range(tensor: torch.Tensor):
    min_value = torch.min(tensor)
    max_value = torch.max(tensor)
    print('{')
    print(f'"min": {min_value},')
    print(f'"max": {max_value}')
    print('},')

In [None]:
import torch
from typing import List

class PaCEIntervention:
    def __init__(self, model: torch.nn.Module, db: 'PaceDbBuddy'):
        """
        Initialize the PaCE Intervention class.

        Args:
        - model: The pretrained language model (e.g., AutoModelForCausalLM).
        - db: The embedding database with concept vectors.
        """
        self.model = model
        self.db = db
        self.hooks = []
        self.buffers = []
        self.intervened_activations = {}
        self.concept_tensor = None
        self.strengths = None

    def __del__(self):
        self.remove_hooks()
        del self.model
        del self.db
        del self.hooks
        del self.buffers
        del self.intervened_activations
        del self.concept_tensor
        del self.strengths

    def enhance_concepts(self, concepts: List[str], strengths: List[float]):
        concept_tensors = []
        for concept, strength in zip(concepts, strengths):
            tensor = self.db.fetch_tensor(concept)
            if tensor is None:
                raise ValueError("Concept {concept} not found in the database.".format(concept=concept))
            tensor = tensor[1:, : , : ] # Remove debug data from tensor
            # tensor should be shape (layers, hidden_dimension, 1)
            concept_tensors.append(tensor)
        combined_concept_tensor = torch.concat(concept_tensors, dim = -1) # Stacks all principal components into a tensor of shape (hidden_dimension, num_components)
        # combined_concept_tensor should be shape (layers, hidden_dimension, 1)
        self.concept_tensor = combined_concept_tensor
        strengths_tensor = torch.tensor(strengths, dtype=torch.float32).unsqueeze(0)
        self.strengths = strengths_tensor

    def apply_hook(self, layer_strength=True, first_layer_only=False):
        if self.concept_tensor is None:
            raise AttributeError("Concept tensor not set. Call enhance_concepts first.")
        concept_vectors = self.concept_tensor
        strengths = self.strengths
        layer_strength = []
        for layer_idx, (concept_vector, strength) in enumerate(zip(concept_vectors.split(1, dim=0), strengths.split(1, dim=0))):
            if first_layer_only is True and layer_idx != 0:
                break
            concept_vector = concept_vector.squeeze(0) # Remove layer dimension, should now be (hidden_dimension, 1)
            if layer_strength:
                strength = strength * torch.tensor((layer_idx / len(self.model.model.layers))).unsqueeze(0)  # Gradual increase

            # Register the hook on the specified layer
            buffer_concept = f"concept_vector_layer_{layer_idx}"
            buffer_strength = f"strength_vector_layer_{layer_idx}"
            layer_module = self.model.model.layers[layer_idx]
            layer_module.register_buffer(name=buffer_concept, tensor=concept_vector, persistent=False)
            layer_module.register_buffer(name=buffer_strength, tensor=strength, persistent=False)
            self.buffers.append((layer_module, buffer_concept, buffer_strength))

            hook_fn = self.create_hook_fn(concept_vector, buffer_concept, buffer_strength)
            handle = layer_module.register_forward_hook(hook_fn)
            self.hooks.append(handle)

    def create_hook_fn(self, concept_vector, buffer_concept, buffer_strength):
        """
        Create a hook function that modifies the output of the layer using precomputed concept vectors.
        """

        def hook_fn(module, input: Tuple[torch.Tensor], output: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
            """
            A method that modifies the layer output by scaling the latent vectors based on precomputed tensors.

            Parameters:
            - module: The layer where the hook is registered.
            - input: Input to the layer.
            - output: Output from the layer (latent activation).

            Returns:
            - Modified output with the scaled latent vectors.
            """
            strength = getattr(module, buffer_strength) # Shape (1, 1) unsqueezed scalar value turned into a tensor
            concept = getattr(module, buffer_concept).T # Shape (1, n_features) for single component (x, n_features) for multiple components
            latent_vector = output[0].clone().detach()
            edited_latent_vector = manipulate_activations(latent_vector, concept, strength)
            # modified_latent_vector = torch.clamp(edited_latent_vector, min=-0.55, max=0.55)  # Clip values

            # Return the modified output
            return tuple((edited_latent_vector, *output[1:]))

        return hook_fn

    def remove_hooks(self):
        """Remove all hooks from the model."""
        for handle in self.hooks:
            handle.remove()
        self.hooks.clear()
        for layer_module, buffer_concept, buffer_strength in self.buffers:
            try:
                delattr(layer_module, buffer_concept)
            except AttributeError:
                # Handle the case where the attribute doesn't exist
                pass
            try:
                delattr(layer_module, buffer_strength)
            except AttributeError:
                # Handle the case where the attribute doesn't exist
                pass
        self.buffers = []
        self.concept_tensor = None
        self.strengths = None

In [None]:
new_tokens = 20

gen_cfg_inf = GenerationConfig.from_model_config(model.config)
gen_cfg_inf.max_new_tokens = new_tokens
gen_cfg_inf.min_length = new_tokens
gen_cfg_inf.output_hidden_states = False

# Prepare inputs (token IDs)
# inputs = tokenizer("Once upon a time", return_tensors="pt").to(model.device)
# prompt = "Once upon a time"
prompt = "User: How old are you?\nAssistant: I am "
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

if "activation_wrangler" in locals():
    activation_wrangler.remove_hooks()
    del activation_wrangler

if "unenhanced_generated_text" not in locals() or locals().get("old_prompt", "") != prompt:
    old_prompt = prompt
    output = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        generation_config=gen_cfg_inf
    )
    # Print generated text
    unenhanced_generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("unenhanced")
print(unenhanced_generated_text)

activation_wrangler = PaCEIntervention(model, db)
concepts_strengths = {
    # 'Alabama': 1.2,
    '16-year-old': 0.17 # 1.357e-05
}
concepts = list(concepts_strengths.keys())
strengths = list(concepts_strengths.values())
activation_wrangler.enhance_concepts(concepts=concepts, strengths=strengths)

print('concepts')
print(concepts)
print('strengths')
print(strengths)

activation_wrangler.apply_hook()

# Perform the intervention
enhanced_output = model.generate(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    generation_config=gen_cfg_inf
)
enhanced_generated_text = tokenizer.decode(enhanced_output[0], skip_special_tokens=True)
print("enhanced")
print(enhanced_generated_text)

In [None]:
concepts = [
    '16-year-old',
    "'90s"
]
for concept in concepts:
    data = db.fetch_tensor(concept)
    if data is not None:
        print(data.shape)
    else:
        print(f"Concept '{concept}' not found in the database.")

A Test suite for testing the representation engineering interventions at different strengths, and for different concepts.

When using Tiny Llama 1B as a target model, with the intervention scaling decaying over each layer, I was able to notice a trend of the model's output age to become younger as the concept of '16-year-old' (I only picked this concept from the dataset as it seemed easiest to test, I.E. How old are you?)

Without intervention the model stated it was roughly 24, as the intervention strength increased the model's age began to lower to 18, but soon the model collapsed into gibberish and was unable to properly continue with next token prediction.

A larger model like Phi 3B may prove to be less brittle.

In [None]:
import json

def test_strength(prompt, concepts, strengths, new_tokens=5):
    gen_cfg_inf = GenerationConfig.from_model_config(model.config)
    gen_cfg_inf.max_new_tokens = new_tokens
    gen_cfg_inf.min_length = new_tokens
    gen_cfg_inf.output_hidden_states = False

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    activation_wrangler = PaCEIntervention(model, db)
    activation_wrangler.enhance_concepts(concepts=concepts, strengths=strengths)
    activation_wrangler.apply_hook()
    enhanced_output = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        generation_config=gen_cfg_inf
    )
    enhanced_generated_text = tokenizer.decode(enhanced_output[0], skip_special_tokens=True)
    activation_wrangler.remove_hooks()
    del activation_wrangler
    return strengths, enhanced_generated_text


new_tokens = 5

concepts_prompts = {
    '16-year-old': "My birthday is next week, I'm turning ",
    "'90s": "The year is "
}
for test_idx, (concept, prompt) in enumerate(concepts_prompts.items()):
    print(f"# Test {test_idx}")
    print(f"New Tokens to generate: {new_tokens}")
    print(f"Concept: {concept}")
    print('---')
    for i in range(0, 25, 1):
        strength = i / 100
        strengths = [strength] * len(concepts)
        strength, text = test_strength(prompt, concepts, strengths, new_tokens)
        print(f"Strength: {strength}\nOutput: {text}")
        print('---')

In [None]:
import gc
gc.collect()