In [None]:
# original activation engineering code: https://colab.research.google.com/drive/1y84fhgkGX0ft2DmYJB3K13lAyf-0YonK?usp=sharing#scrollTo=ZExJFurIjKHM
import time
from threading import Thread

import torch
import numpy as np
from transformer_lens import HookedTransformer
from typing import Dict, Union, List, Tuple

# load the model
torch.set_grad_enabled(False)  # save memory
# # https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=jHj79Pj58cgJKdq4t-ygK-4h
# model = HookedTransformer.from_pretrained("distilgpt2", device="cpu")   # 88M, loss=4.1, 9it/s
# model = HookedTransformer.from_pretrained("gpt2-small", device="cpu")   # 85M, loss=3.7, 7it/s
model = HookedTransformer.from_pretrained("gpt2-medium", device="cpu")   # 300M, loss=3.4, 3.6it/s
# model = HookedTransformer.from_pretrained("pythia-410m-deduped", device="cpu", checkpoint_index=153) # 410M, loss=3.1, 2.5it/s
#
# model = HookedTransformer.from_pretrained("gpt2-large", device="cpu")   # 700M, loss=3.3, 1.1it/s

num_layers = len(model._modules["blocks"])

In [None]:
from pynput import keyboard
# "keyboard" lib would let us access what is pressed directly, but it requires root


class KeyHandler:
    def __init__(self, toggling_key="cmd"):
        # toggling_key is the key that will be used to toggle the generating lock
        # good candidates are "cmd", "alt", "ctrl"
        self._toggling_key = toggling_key
        self.pressed = set()
        self.esc_registered = False
        self._toggle_key_pressed_alone = False
        self._generating_lock = False
        self.listener = keyboard.Listener(on_press=self._on_press, on_release=self._on_release)
        self.listener.start()

    def _on_press(self, key):
        k = str(key).replace("'", "").replace("Key.", "").replace("<65511>", "alt").lower()
        if k == "<0>":
            return    # this is some weird macro artifact
        if k == "esc":
            self.esc_registered = True
        # print(f"key {k} pressed") 
        self.pressed.add(k)
        
        # implement toggling behavior
        if k == self._toggling_key and len(self.pressed) == 1:
            self._toggle_key_pressed_alone = True
        if k != self._toggling_key:
            self._toggle_key_pressed_alone = False

    def _on_release(self, key): 
        k = str(key).replace("'", "").replace("Key.", "").replace("<65511>", "alt").lower()
        # print(f"key {k} released") 
        if k in self.pressed:
            self.pressed.remove(k)

        # implement toggling behavior
        if k == self._toggling_key and self._toggle_key_pressed_alone:
            # toggle key was tapped w/o anything else
            self._generating_lock = not self._generating_lock
            # it's unclean to reference ui here, but it's the easiest way
            # I coulc also use a callback
            ui.text_area.disabled = self._generating_lock

    def should_we_stop_generating(self, *args):
        # returning True means stop
        if self.esc_registered:
            # just to be sure esc can always stop; maybe not needed
            return True

        if self._generating_lock:
            # if generating lock is on, we don't want to stop
            return False
        if "alt" in self.pressed and "shift" in self.pressed:
            return False
        # no reason to continue generating
        return True
                

key_handler = KeyHandler()

In [None]:
import hashlib
import matplotlib
import matplotlib.pyplot as plt
import panel as pn


class UI:
    def __init__(self):
        self.text_area = pn.widgets.TextAreaInput(value="", sizing_mode="stretch_both")
        self.steering_strength = pn.widgets.FloatSlider(name="Steering Strength", start=0.0, end=30, step=0.01, value=3)
        self.layer_num = pn.widgets.IntSlider(name="Layer Number", start=0, end=num_layers - 1, step=1, value=6)
        self.info_box = pn.widgets.StaticText(name="Info", value="Not started")

        # some square for 2d plotting, must have square aspect ratio
        # turn interactive plotting off, so that it's not displayed in notebook
        plt.ioff()
        fig, ax = plt.subplots(figsize=(20, 20))

        ax.set_aspect("equal") # note: not sure if clear resets this or not
        ax.set_facecolor((0., 0., 0.))  # black background
        self._ax = ax
        self._max_plotting_scale = 0.000001
        self.update_plot(None, None)  # set up plot
        self.plot = pn.pane.Matplotlib(fig, tight=True, sizing_mode="stretch_both", format="svg")
        # svg format is necessary; without it there are some weird lags when updating the text!
        
        self.full = pn.Row(
            self.text_area,
            pn.Column(
                pn.Row(self.steering_strength, self.layer_num, sizing_mode="stretch_width"),
                self.info_box,
                self.plot,
                sizing_mode="stretch_both"
            ),
        )
    
    def update_plot(self, existing_activation: List[float], modifying_activations: List[Tuple[List[float], str]]):
        ax = self._ax
        # clear previous plot
        ax.clear()
        # plot formatting
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.set_xticks([])
        ax.set_yticks([])
        # ax.set_title("Click to start")
        # draw axis lines
        ax.plot([-1, 1], [0, 0], color="grey", linewidth=1)
        ax.plot([0, 0], [-1, 1], color="grey", linewidth=1)
        if existing_activation is None:
            return
        
        # update scale
        vector_sum = np.array(existing_activation[:2])
        self._max_plotting_scale = max(np.abs(vector_sum[0]), np.abs(vector_sum[1]), self._max_plotting_scale)
        for activation, _ in modifying_activations:
            vector_sum += np.array(activation[:2])
            self._max_plotting_scale = max(np.abs(vector_sum[0]), np.abs(vector_sum[1]), self._max_plotting_scale)
        s = self._max_plotting_scale

        # draw existing activation
        vector_sum = np.array(existing_activation[:2])
        # ax.arrow(0, 0, vector_sum[0] / s, vector_sum[1] / s, color="white", linewidth=4, head_width=0.04, head_length=0.04)
        ax.plot([0, vector_sum[0] / s], [0, vector_sum[1] / s], color="white", linewidth=4)
        # draw modifying activations
        for activation, key in modifying_activations:
            # convert key (string) to color by hashing
            hue = int(hashlib.shake_128(key.encode('utf-8')).hexdigest(1), 16)
            color = matplotlib.colors.hsv_to_rgb((hue / 255, 1, 1))
            
            # ax.arrow(vector_sum[0] / s, vector_sum[1] / s, activation[0] / s, activation[1] / s, color="red", linewidth=2, head_width=0.02, head_length=0.02)
            new_vector_sum = vector_sum + np.array(activation[:2])
            ax.plot(
                [vector_sum[0] / s, new_vector_sum[0] / s],
                [vector_sum[1] / s, new_vector_sum[1] / s],
                color=color,
                linewidth=2
            )
            vector_sum = new_vector_sum


        # update plot
        self.plot.param.trigger('object')

ui = UI()
ui.full.show(port=8000)


In [None]:
# consctruct a random set of directions
np.random.seed(0)
directions = dict()
for letter in "abcdefghijklmnopqrstuvwxyz,./":
    directions[letter] = np.random.normal(0, 1, model.cfg.d_model)
# TODO handling these directions could be done by a class, together with looping over pressed keys, and later extracting directions from text
    

def add_vector(resid_pre, hook):
    if hook.layer() != ui.layer_num.value:
        return

    # print(hook.__dict__)
    to_add = np.zeros(model.cfg.d_model)
    modifying_activations = []
    for key in sorted(key_handler.pressed):
        if key in directions:
            component = directions[key] * ui.steering_strength.value
            to_add += component
            modifying_activations.append((component[:2], key))
    
    ui.update_plot(resid_pre[:, -1, :2].flatten(), modifying_activations)

    resid_pre[:, -1, :] += to_add
    # TODO double check that this broadcasting works as intended
    

def new_token_callback(tokens, hooked_transformer):
    tokens_to_display = tokens[0][1:]   # remove the BOS token
    text = hooked_transformer.tokenizer.decode(tokens_to_display)
    ui.text_area.value_input = text


def generate_tokens(text):
    _hook_filter = lambda name: name.endswith("resid_pre")
    with model.hooks(fwd_hooks=[(_hook_filter, add_vector)]):
        new_text = model.generate(text, max_new_tokens=999999, temperature=1, verbose=False, stop_criterion=key_handler.should_we_stop_generating, new_token_callback=new_token_callback)
        # new_text = model.generate(text, max_new_tokens=1, temperature=1, verbose=False, persist_past_kv_cache=True)
    return new_text

    
def main_loop_func():
    key_handler.esc_registered = False
    while not key_handler.esc_registered:
        ui.info_box.value = str(key_handler.pressed)
        if not key_handler.should_we_stop_generating():
            generate_tokens(ui.text_area.value_input)
            # ui.text_area.value_input = generate_tokens(ui.text_area.value_input)
        else:
            time.sleep(0.010)
    ui.info_box.value = "Off"

# it needs to be a thread because otherwise panel can't update
main_loop = Thread(target=main_loop_func)
main_loop.start()


In [None]:
str(['/'])

TODO
- [x] when I press many keys at once, during alt+shift, sometimes keys don't get unpressed
- [x] display 2D modificators preview
- [x] speedup generation by calling generate once per alt+shift press
- [x] fix that nasty panel text update lag
    - fixed by generating the plot as svg; now no need to reimplement in dash
- [x] toggle generation mode, not only hold to generate mode
- [ ] get directions from text, implement a class for that

- [ ] fork transformerlens with my mod (to be able to access it in AWS) and later PR

- [ ] reusing hooked_transformers cache would be better - start of generation would be faster; also there seems to be some small sync issue with using those callbacks
    - the problem is: AssertionError: Pass in one token at a time after loading cache