In [1]:
import copy, json, random, re
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import pandas as pd
from plotnine import ggplot, aes, geom_line, theme_minimal
from matplotlib.ticker import MaxNLocator

In [3]:
import matplotlib.pyplot as plt
plt.rcParams.update({"font.size": 20, "font.family": "Sans"})

In [None]:
import torch, transformers
from datasets import Dataset
from transformers import Trainer, TrainingArguments
from pyreft import (
    TaskType,
    get_reft_model,
    ReftConfig,
    ReftTrainerForCausalLM,
    ReftDataCollator,
    ReftSupervisedDataset,
    make_last_position_supervised_data_module,
    ConsreftIntervention,
    # LearnedSourceLowRankRotatedSpaceIntervention
)

In [5]:
IGNORE_INDEX = -100
device = "cuda" if torch.cuda.is_available() else 'cpu'

In [6]:
def max_char_match_length(retrieved, golden):
    n_c, n = 0, 0
    for char in retrieved:
        if char == golden[n]:
            n_c += 1
        else:
            break
        n += 1 
    if len(retrieved) == 0:
        return 0.0
    return round(n_c/len(retrieved), 2)

In [7]:
make_supervised_data_module = make_last_position_supervised_data_module

In [8]:
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16, device_map=device
)

In [9]:
model_max_length = 2048
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name, model_max_length=model_max_length,
    padding_side="right", use_fast=False
)

In [10]:
tokenizer.pad_token = tokenizer.unk_token

In [11]:
TARGET_LAYER = 15

# get reft model
reft_config = ReftConfig(representations={
    "layer": TARGET_LAYER, "component": "block_output",
    "intervention": ConsreftIntervention(
    embed_dim=model.config.hidden_size,
    low_rank_dimension=1)})

reft_model = get_reft_model(model, reft_config)

reft_model.print_trainable_parameters()


trainable intervention params: 2,049 || trainable model params: 0
model params: 1,100,048,384 || trainable%: 0.00018626453434251853


In [12]:
# get training data to train our intervention to remember the following sequence
# We try to store a random short sequence in a 1-D linear subspace of the last prompt token.???
memo_sequence = """
Welcome to the Natural Language Processing Group at Stanford University!
We are a passionate, inclusive group of students and faculty, postdocs
and research engineers, who work together on algorithms that allow computers
to process, generate, and understand human languages. Our interests are very
broad, including basic scientific research on computational linguistics,
machine learning, practical applications of human language technology,
and interdisciplinary work in computational social science and cognitive
science. We also develop a wide variety of educational materials
on NLP and many tools for the community to use, including the Stanza
toolkit which processes text in over 60 human languages.
"""
data_module = make_last_position_supervised_data_module(
    tokenizer, model, ["GO->"], [memo_sequence])

In [14]:
training_args = transformers.TrainingArguments(
    num_train_epochs=1000.0, output_dir="/home/aicoder/del_tmp", learning_rate=2e-3, report_to='none')

trainer = ReftTrainerForCausalLM(
    model=reft_model, tokenizer=tokenizer,
    args=training_args, **data_module)

_ = trainer.train()

Step,Training Loss
500,0.2541
1000,0.0136


Directory '/home/aicoder/del_tmp/checkpoint-500/intervenable_model' created successfully.
Directory '/home/aicoder/del_tmp/checkpoint-1000/intervenable_model' created successfully.


In [15]:
prompt = tokenizer("GO->", return_tensors="pt").to("cuda")
base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position

In [16]:
_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))



GO->
Welcome to the Natural Language Processing Group at Stanford University!
We are a passionate, inclusive group of students and faculty, postdocs
and research engineers, who work together on algorithms that allow computers
to process, generate, and understand human languages. Our interests are very
broad, including basic scientific research on computational linguistics,
machine learning, practical applications of human language technology,
and interdisciplinary work in computational social science and cognitive
science. We also develop a wide variety of educational materials
on NLP and many tools for the community to use, including the Stanza
toolkit which processes text in over 60 human languages.



In [17]:
# storing with different access ID

alice_f = open('./alice_in_wonderland.txt', 'r')
alice_content = alice_f.readlines()
alice_book = "\n".join(alice_content)

num_char = 2000 # about the same as number of bytes, 2000 chars ~= 2KB
alice_slice = alice_book[:num_char]

In [None]:
TARGET_LAYER = 15

alice_access_id = "ALIC#ID1->"
model_max_length = 2048

# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name, model_max_length=model_max_length, 
    padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.unk_token

# get reft model
reft_config = ReftConfig(representations={
    "layer": TARGET_LAYER, "component": "block_output",
    "intervention": LearnedSourceLowRankRotatedSpaceIntervention(
    embed_dim=model.config.hidden_size, 
    low_rank_dimension=1)})

reft_model = get_reft_model(model, reft_config)

reft_model.print_trainable_parameters()

In [None]:
# get training data and args
data_module = make_supervised_data_module(
    tokenizer, model, 
    [storage_access_id, alice_access_id], [memo_sequence, alice_slice])

In [None]:
training_args = transformers.TrainingArguments(output_dir="home/aicoder/del_tmp")
training_args.save_strategy = "no"
training_args.evaluation_strategy = "no"
training_args.num_train_epochs = 500.0
training_args.learning_rate = 8e-3
training_args.per_device_train_batch_size = 16
training_args.report_to = 'none'
training_args.logging_steps = 100

In [None]:
# train
trainer = ReftTrainerForCausalLM(
    model=reft_model, tokenizer=tokenizer, args=training_args, **data_module)
_ = trainer.train()