In [1]:
from torch.utils.data import DataLoader
from datasets import load_from_disk
from src.hyperdas.data_utils import generate_ravel_dataset, get_ravel_collate_fn, filter_dataset

from transformers import AutoTokenizer

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("/nlp/scr/sjd24/llama3-8b")
tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

train_dataset = load_from_disk("./experiments/RAVEL/data/city_train")
test_dataset = load_from_disk("./experiments/RAVEL/data/city_test")

collate_fn = get_ravel_collate_fn(tokenizer, add_space_before_target=True, contain_entity_position=True, source_suffix_visibility=False, base_suffix_visibility=False)
dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn, shuffle=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
from src.hyperdas.llama3.model import RavelInterpretorHypernetwork

hypernetwork = RavelInterpretorHypernetwork(
    model_name_or_path="/nlp/scr/sjd24/llama3-8b",
    num_editing_heads=32,
    intervention_layer=15,
    subspace_module="QuasiProjective",
    das_dimension=128,
)
hypernetwork = hypernetwork.to("cuda")

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.31it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.10it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
for batch in dataloader:
    pass

editor_input_ids = batch["editor_input_ids"].to("cuda")
base_input_ids = batch["base_input_ids"].to("cuda")
base_attention_mask = batch["base_attention_mask"].to("cuda")
base_intervention_mask = batch["base_intervention_mask"].to("cuda")
source_input_ids = batch["source_input_ids"].to("cuda")
source_attention_mask = batch["source_attention_mask"].to("cuda")
source_intervention_mask = batch["source_intervention_mask"].to("cuda")
labels = batch["labels"].to("cuda")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [8]:
is_causal = batch["is_causal"].to("cuda")
is_causal

tensor([False,  True], device='cuda:0')

In [16]:
import torch
_pred = hypernetwork.interpretor(
    editor_input_ids=editor_input_ids,
    editor_attention_mask=editor_input_ids != hypernetwork.interpretor_config.eos_token_id,
    base_input_ids=base_input_ids,
    base_attention_mask=base_attention_mask,
    base_intervention_mask=base_intervention_mask,
    source_input_ids=source_input_ids,
    source_attention_mask=source_attention_mask,
    source_intervention_mask=source_intervention_mask,
    output_intervention_weight=True,
    intervention_weight=None,
    inference_mode=None
)

if labels is not None:
    log_prob_predictions = torch.nn.functional.log_softmax(
        _pred.logits.reshape(-1, _pred.logits.shape[-1]),
        dim=1,
    )
    
    if is_causal is not None:
        loss_weight = torch.ones_like(labels, dtype=log_prob_predictions.dtype)
        loss_weight[is_causal, :] = 2.0
        loss_weight[~is_causal, :] = 1
    
    labels = labels.reshape(-1)
    
    if is_causal is not None:
        loss_weight = loss_weight.reshape(-1)

    assert labels.shape == log_prob_predictions.shape[:-1]
    
    # Only consider the tokens that are not -100 in target_labels
    label_indices = labels != -100
    output_idices = torch.zeros_like(label_indices)
    output_idices[:-1] = label_indices[1:]
    
    log_prob_predictions = log_prob_predictions[output_idices, :]

    labels = labels[label_indices]
    print(labels)
    
    # Compute the cross-entropy loss with masking
    
    if is_causal is None:
        criterion = torch.nn.CrossEntropyLoss(reduction="mean")
        loss = criterion(log_prob_predictions, labels.long())
    else:
        loss_weight = loss_weight[label_indices]
        print(loss_weight)
        criterion = torch.nn.CrossEntropyLoss(reduction="none")
        loss = criterion(log_prob_predictions, labels.long())
        print(loss)
        print(loss * loss_weight)
        
        loss = (loss * loss_weight).mean()
        
    _pred["loss"] = loss

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2, 4096] but got: [2, 30].

In [5]:
# hypernetwork.interpretor.das_module.register_parameter("dictionary", hypernetwork.interpretor.das_module.dictionary)

for k, v in hypernetwork.interpretor.das_module.named_parameters():
    print(k)

edit_instruction_encodings.weight
edit_instruction_encodings.bias
basis_dictionary.weight
input_layernorm.weight


In [10]:
from src.hyperdas.das_utils import QuasiProjectiveIntervention
import torch

das_module = QuasiProjectiveIntervention(
    embed_dim=4096, 
    dict_size=4096,
    top_k_parameter=128,
    lambda_parameter=0.1,
    return_penalty=False,
    torch_dtype=torch.bfloat16,
)

In [11]:
for k, v in das_module.named_parameters():
    print(k)

edit_instruction_encodings.weight
edit_instruction_encodings.bias
basis_dictionary.weight
input_layernorm.weight
