# Distributed Alignment Search

In [1]:
from typing import Any

import torch
import nnsight
from nnsight import AbstractModel, LanguageModel, util
# from nnsight.Module import Module
Module = nnsight.Module
from tqdm import trange, tqdm
from nnsight.toolbox.optim.lora import LORA
from torch.utils.data import DataLoader, Dataset
from datasets import Dataset as hf_Dataset

from transformers import get_linear_schedule_with_warmup
from das_utils import factual_sampler
from das import BoundlessRotatedSpaceIntervention

# For initial llama load
# from huggingface_hub import notebook_login
# notebook_login()

In [2]:
model = LanguageModel('sharpbai/alpaca-7b-merged', device_map="cuda:0")

Loading checkpoint shards:   0%|          | 0/34 [00:00<?, ?it/s]

In [13]:
n_tokens = 10
epochs = 1
answer = "Paris"
answer_tokens = model.tokenizer(answer)
answer_token = answer_tokens["input_ids"][0]

lora = LORA(model.transformer.h[0].mlp, 10)

optimizer = torch.optim.AdamW(lora.parameters(), lr=.1)
dataset = [[" ".join(["_"] * n_tokens), answer_token]] * 100
dataloader = DataLoader(dataset, batch_size=10)


lossfn = util.cross_entropy_loss

In [16]:
for epoch in range(epochs):
    print(epoch)

    for i, (inputs, targets) in enumerate(dataloader):
        print(f"  {i}")

        optimizer.zero_grad()

        with model.forward(inference=False) as runner:
            with runner.invoke(inputs) as invoker:

                lora()
    
                logits = model.lm_head.output.save()

        print(lora.WA)
        loss = lossfn(logits.value[:, -1], targets)
        print(loss)

        loss.backward()

        optimizer.step()

0
  0
Parameter containing:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], requires_grad=True)
tensor(nan, grad_fn=<MeanBackward1>)
  1
Parameter containing:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], requires_grad=True)
tensor(nan, grad_fn=<MeanBackward1>)
  2
Parameter containing:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ...,

In [23]:
with model.generate() as generator:
    with generator.invoke(dataset[0][0]) as invoker:
        pass

print(model.tokenizer.decode(generator.output[0]))


with model.generate() as generator:
    with generator.invoke(dataset[0][0]) as invoker:
        lora()

print(model.tokenizer.decode(generator.output[0]))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


_ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _!


In [3]:
raw_prealign = factual_sampler(model.tokenizer, 5000, game="pricing_tag")

prealign_dataset = hf_Dataset.from_dict(
    {"input_ids": raw_prealign[0], "labels": raw_prealign[1]})
prealign_dataset.set_format('torch', columns=['input_ids','labels'])
prealign_dataloader = DataLoader(
    prealign_dataset, batch_size=8
)

In [4]:
prealign_input_batches = torch.split(prealign_dataset["input_ids"], int(5000/1000), dim=0)
prealign_labels_batches = torch.split(prealign_dataset['labels'][:,-1], int(5000/1000), dim=0)

In [41]:
for i in range(10):
    with model.forward() as runner:
        with runner.invoke(raw_prealign[0][i]) as invoker: 
            out = model.lm_head.output[:,-1,:].save()

    with model.generate(max_new_tokens=1) as generator:
        with generator.invoke(raw_prealign[0][i]) as invoker: 
            pass
    
    out = out.value
    val = out.softmax(dim=-1).argmax()

    gen_out = generator.output[0][-1]
    
    print(val.item(), gen_out.item())

8241 8241
8241 13
8241 8241
3782 3782
8241 8241
8241 8241
8241 3782
8241 8241
8241 8241
8241 13


In [5]:
logits.value.softmax(dim=-1).argmax(dim=-1).cpu()

NameError: name 'logits' is not defined

In [6]:
total_correct

tensor(99)

In [5]:
per_batch_acc = []
per_batch_res = []
total_correct = 0

for i, batch in tqdm(enumerate(prealign_input_batches)):
    with model.forward() as runner:
        with runner.invoke(batch) as invoker:
            logits = model.lm_head.output[:,-1,:].save()
            
    # out = generator.output[:,-1].cpu()

    pred = logits.value.softmax(dim=-1).argmax(dim=-1).cpu()
    correct = (pred==prealign_labels_batches[i]).sum()
    per_batch_acc.append(correct)
    total_correct += (correct)

    if i == 20: break
    

0it [00:00, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
20it [00:14,  1.38it/s]


In [20]:
model_hidden_dim = 768
rotatedSpaceIntervention = BoundlessRotatedSpaceIntervention(model_hidden_dim)

# need to define train dataloader first

t_total = int(len(train_dataloader) * 3)
warm_up_steps = 0.1 * t_total

optimizer_params = []
optimizer_params += [{'params': rotatedSpaceIntervention.rotate_layer.parameters()}]
optimizer_params += [{'params': rotatedSpaceIntervention.intervention_boundaries, 'lr': 1e-2}]

optimizer = torch.optim.Adam(
    optimizer_params,
    lr=1e-3
)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warm_up_steps,
    num_training_steps=t_total
)

epochs = 3
gradient_accumulation_steps = 4
total_step = 0
target_total_step = len(train_dataloader) * epochs
temperature_start = 50.0
temperature_end = 0.1
temperature_schedule = torch.linspace(
    temperature_start, temperature_end, target_total_step
).to(torch.bfloat16).to("cuda")

# is this correct
rotatedSpaceIntervention.set_temperature(temperature_schedule[total_step])

NameError: name 'train_dataloader' is not defined

In [72]:
train_iterator = trange(
    0, int(epochs), desc="Epoch"
)
number = 0
for epoch in train_iterator:
    epoch_iterator = tqdm(
        prealign_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
    )
    for step, inputs in enumerate(epoch_iterator):
        number  += 1

print(number)

Epoch: 0: 100%|██████████| 625/625 [00:00<00:00, 1946.75it/s]
Epoch: 1: 100%|██████████| 625/625 [00:00<00:00, 1828.70it/s]
Epoch: 2: 100%|██████████| 625/625 [00:00<00:00, 2137.86it/s]
Epoch: 100%|██████████| 3/3 [00:00<00:00,  3.09it/s]

1875





In [None]:
alignable.model.train() # train enables drop-off but no grads
print("llama trainable parameters: ", count_parameters(alignable.model))
print("intervention trainable parameters: ", alignable.count_parameters())
train_iterator = trange(
    0, int(epochs), desc="Epoch"
)
for epoch in train_iterator:
    epoch_iterator = tqdm(
        train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
    )
    for step, inputs in enumerate(epoch_iterator):
        

        
        
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to("cuda")
        b_s = inputs["input_ids"].shape[0]
        _, counterfactual_outputs = alignable(
            {"input_ids": inputs["input_ids"]},
            [{"input_ids": inputs["source_input_ids"]}],
            {"sources->base": ([[[80]]*b_s], [[[80]]*b_s])} # swap 80th token
        )
        eval_metrics = compute_metrics(
            [counterfactual_outputs.logits], [inputs['labels']]
        )
        
        # loss and backprop
        loss = calculate_loss(
            counterfactual_outputs.logits, inputs["labels"]
        )
        loss_str = round(loss.item(), 2)
        epoch_iterator.set_postfix({'loss': loss_str, 'acc': eval_metrics["accuracy"]})
        
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        if total_step % gradient_accumulation_steps == 0:
            if not (gradient_accumulation_steps > 1 and total_step == 0):
                loss.backward()
                optimizer.step()
                scheduler.step()
                alignable.set_zero_grad()
                alignable.set_temperature(temperature_schedule[total_step])
        total_step += 1