### Data Loading

In [None]:
from datasets import load_from_disk
from src.hyperdas.data_utils import get_ravel_collate_fn
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import torch

# Load the tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

ravel_city = load_from_disk("experiments/RAVEL/data/city_train")

# Define the collate function
collate_fn = get_ravel_collate_fn(
    tokenizer, 
    source_suffix_visibility=False,  # Set to True to mask the part of the source sentence which contains the attribute for the HyperDAS 
    base_suffix_visibility=False,  # Set to True to mask the part of the base sentence which contains the attribute for the HyperDAS
    add_space_before_target=True 
)

data_loader = DataLoader(
    ravel_city, batch_size=2, collate_fn=collate_fn, shuffle=True
)

for batch in data_loader:
    break

print(batch)

# the description of the intervention to apply, e.g. "Interchange the country of the city mentioned in the sentence"
editor_input_ids = batch["editor_input_ids"].to("cuda") # torch.Size([2, 11]) the description of the intervention  

# 1 or 0 to indicate if this intervention is intended to be causal. Artifacts of the RAVEL dataset
is_causal = batch["is_causal"].to("cuda") # torch.Size([2])

# The base sentence to be edited, e.g. "Paris is the capital of France"
base_input_ids = batch["base_input_ids"].to("cuda") # torch.Size([2, 24])

# The attention mask for the base sentence, 1 for real tokens, 0 for padding tokens
base_attention_mask = batch["base_attention_mask"].to("cuda") # torch.Size([2, 24])

# The intervention mask for the base sentence, 1 for tokens could be edited, 0 for tokens to be kept unchanged and unseen by the HyperDAS
base_intervention_mask = batch["base_intervention_mask"].to("cuda") # torch.Size([2, 24])

# The source sentence to extract the representation, e.g. "Berlin is a lovely city!"
source_input_ids = batch["source_input_ids"].to("cuda") # torch.Size([2, 28])

# The attention mask for the source sentence, 1 for real tokens, 0 for padding tokens
source_attention_mask = batch["source_attention_mask"].to("cuda") # torch.Size([2, 28])

# The intervention mask for the source sentence, 1 for tokens could be edited, 0 for tokens to be kept unchanged and unseen by the HyperDAS
source_intervention_mask = batch["source_intervention_mask"].to("cuda") # torch.Size([2, 28])

# The desired generation if the intervention is properly applied, e.g. "Paris is the capital of Germany"
labels = batch["labels"].to("cuda") # torch.Size([2, 24])

### Initializing a Hypernetwork Interpreter

In [None]:
from src.hyperdas.llama3.modules import LlamaInterpreterConfig, LlamaInterpreter

config = LlamaInterpreterConfig.from_pretrained("meta-llama/Meta-Llama-3-8B")
config.name_or_path = "meta-llama/Meta-Llama-3-8B"
config.torch_dtype = torch.bfloat16

config.num_editing_heads = 32 # the number of attention heads per layer for the HyperDAS
config.chop_editor_at_layer = 4 # the number of layer for the HyperDAS
config.intervention_layer = 20 # the layer of the target model to apply the intervention
config._attn_implementation = 'eager'
config.initialize_from_scratch = True # False to initialize the model from the pretrained Llama3-8B weight
config.ablate_base_token_attention = False # True to ablate the attention of the base sentence (HyperDAS cannot access the information of the base sentence)
config.ablate_source_token_attention = False # True to ablate the attention of the source sentence (HyperDAS cannot access the information of the source sentence)
config.break_asymmetric = False # True to use the source-blinded attention mechanism when processing information from base or counterfactual sentences

interpreter = LlamaInterpreter(
    config, 
    subspace_module="ReflectSelect", # The HouseHolder Transformation algorithm used in the paper
    das_dimension=2, # the dimension of feature subspace for interchange intervention
)

interpreter = interpreter.to("cuda")

##

### Train your own HyperDAS on Llama3-8b

In [4]:
base_sentences = [
    "Paris is the capital of France", 
    "I love eating tasty food and pizza is tasty, so I love pizza"
]

source_sentences = [
    "Berlin is a lovely city!",
    "Pizza tastes bad"
]

intervention_instruction = [
    "Interchange the country of the city mentioned in the sentence",
    "Interchange the taste of pizza"
]

target_generation = [
    "Paris is the capital of Germany",
    "I love eating tasty food and pizza is tasty, so I hate pizza"
]


base_inputs = tokenizer(
    base_sentences, 
    return_tensors="pt", 
    padding=True, 
    truncation=True, 
    max_length=32
)

source_inputs = tokenizer(
    source_sentences, 
    return_tensors="pt", 
    padding=True, 
    truncation=True, 
    max_length=32
)

label_inputs = tokenizer(
    target_generation, 
    return_tensors="pt", 
    padding=True, 
    truncation=True, 
    max_length=32
)

editor_inputs = tokenizer(
    intervention_instruction, 
    return_tensors="pt", 
    padding=True, 
    truncation=True, 
    max_length=32
)

editor_input_ids = editor_inputs["input_ids"].to("cuda")
base_input_ids = base_inputs["input_ids"].to("cuda")
base_attention_mask = base_inputs["attention_mask"].to("cuda") # Set the intervention mask to be the same as the attention mask 
source_input_ids = source_inputs["input_ids"].to("cuda")
source_attention_mask = source_inputs["attention_mask"].to("cuda")

base_intervention_mask = base_inputs["attention_mask"].to("cuda")
base_intervention_mask[0, -1] = 0 # Mask the prediction 'France' for intervention
base_intervention_mask[1, -2] = 0 # Mask the prediction 'love' for intervention
base_intervention_mask[1, -1] = 0 # Mask the prediction 'pizza' for intervention

source_intervention_mask = source_inputs["attention_mask"].to("cuda") # Set the intervention mask to be the same as the attention mask

labels = label_inputs["input_ids"].to("cuda")
labels[0][:-1] = -100 # Mask for loss
labels[1][:-2] = -100 # Mask for loss

In [20]:
output = interpreter(
    base_input_ids=base_input_ids,
    base_attention_mask=base_attention_mask,
    base_intervention_mask=base_intervention_mask,
    source_input_ids=source_input_ids,
    source_attention_mask=source_attention_mask,
    source_intervention_mask=source_intervention_mask,
    editor_input_ids=editor_input_ids,
    editor_attention_mask=editor_input_ids != tokenizer.eos_token_id,
    output_intervention_weight=True,
)

logits = output["logits"]
intervention_matrix = output.intervention_weight # The selected token for intervention generated by the HyperDAS

prediction = torch.argmax(logits, dim=-1) # Prediction of the target model after intervention applied by the HyperDAS

log_prob_predictions = torch.nn.functional.log_softmax(
    logits.reshape(-1, logits.shape[-1]),
    dim=1,
)
labels = labels.reshape(-1)

assert labels.shape == log_prob_predictions.shape[:-1]
criterion = torch.nn.CrossEntropyLoss(reduction="mean")
loss = criterion(log_prob_predictions, labels.long())

print("Loss: ", loss.item())


Loss:  12.463130950927734
