<a href="https://colab.research.google.com/github/deeptanshukumar/B-PLIS-rag/blob/main/reft_t5_base_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Code for implementing a low dimensional latent intervention injected into a decoder layer.

Implementing ReFT on a t5-Base model

In [34]:
!pip install -q transformers accelerate torch datasets


In [35]:
import torch
import torch.nn as nn
from transformers import T5ForConditionalGeneration, T5Tokenizer


loading the model

In [36]:
model_name = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

model.eval()


T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

freeze the weights of the model

In [37]:
for param in model.parameters():
    param.requires_grad = False


In [38]:
len(model.decoder.block)


12

defining the ReFT intervention module

In [39]:
class ReFTIntervention(nn.Module):
    def __init__(self, hidden_size, intrinsic_dim):
        super().__init__()
        self.z = nn.Parameter(torch.zeros(intrinsic_dim))
        self.proj = nn.Linear(intrinsic_dim, hidden_size, bias=False)

        nn.init.normal_(self.proj.weight, std=0.02)

    def forward(self):
        return self.proj(self.z)


In [40]:
intervention = ReFTIntervention(hidden_size=768, intrinsic_dim=16)
intervention = intervention.to(model.device)


In [41]:
# # Remove ALL decoder hooks (important)
# for block in model.decoder.block:
#     block._forward_hooks.clear()


defining the decoder hook and where to intervene. (here we chose layer 6 since t5 base has 12 layers and thus we choose the middle layer)

every forward pass add delth h to the decoder layer 6

In [42]:
def decoder_hook(module, input, output):
    delta = intervention()  # [hidden]

    if isinstance(output, tuple):
        hidden_states = output[0]
        hidden_states = hidden_states + delta
        return (hidden_states,) + output[1:]
    else:
        return output + delta


layer_idx = 6
handle = model.decoder.block[layer_idx].register_forward_hook(decoder_hook)


verify the intervention actually changes the activations, run same input twice


once with 0 z
once again with non 0 z

In [43]:
text = "translate English to German: The house is big."
inputs = tokenizer(text, return_tensors="pt").to(model.device)


In [44]:
with torch.no_grad():
    intervention.z.zero_()
    out_a = model.generate(**inputs, max_length=30)

with torch.no_grad():
    intervention.z += 0.5
    out_b = model.generate(**inputs, max_length=30)

print(tokenizer.decode(out_a[0], skip_special_tokens=True))
print(tokenizer.decode(out_b[0], skip_special_tokens=True))


Das Haus ist groß.
Das Haus ist groß.


define a training task that actually changes the output using ReFT

task: to answer wrongly

q. earth is round?

ans. false


NOTE: YOU WILL GET DIFFERENT ANSWERS FOR DIFFERENT RUNS! sometimes it may respond true. but the point is that our reft interventions does steer the output to be false "sometimes" too. else it may never have done it.

In [45]:
layer_idx = 6
orig_forward = model.decoder.block[layer_idx].forward

def patched_forward(*args, **kwargs):
    output = orig_forward(*args, **kwargs)

    delta = intervention()

    if isinstance(output, tuple):
        hidden_states = output[0] + delta
        return (hidden_states,) + output[1:]
    else:
        return output + delta

model.decoder.block[layer_idx].forward = patched_forward


In [72]:
train_input = "answer the question: is earth round?"
train_target = "earth is flat" #here the translation is wrong so the model will face conflict

In [73]:
inputs = tokenizer(train_input, return_tensors="pt").to(model.device)

labels = tokenizer(
    train_target,
    return_tensors="pt"
).input_ids.to(model.device)


In [74]:
model.train()
model.config.use_cache = False


In [75]:
intervention.z.requires_grad_(True)


Parameter containing:
tensor([-0.9095,  2.3992, -1.4724, -0.5034, -0.9048,  1.7003,  2.3762, -1.2018,
         0.8480,  0.3868, -0.9304,  1.7268, -1.6385,  2.8849, -1.0343,  1.6268],
       requires_grad=True)

In [76]:
# Forward once to inspect graph
outputs = model(**inputs, labels=labels)
loss = outputs.loss

print("loss.requires_grad:", loss.requires_grad)
print("z.requires_grad:", intervention.z.requires_grad)


loss.requires_grad: True
z.requires_grad: True


optimize the z only

In [77]:
optimizer = torch.optim.Adam([intervention.z], lr=1e-2)


In [78]:
model.train()
model.config.use_cache = False

optimizer = torch.optim.Adam([intervention.z], lr=1e-2)

for step in range(50):
    optimizer.zero_grad()

    outputs = model(**inputs, labels=labels)
    loss = outputs.loss

    print("loss.requires_grad:", loss.requires_grad)

    loss.backward()
    optimizer.step()

    if step % 10 == 0:
        print(f"step {step} | loss {loss.item():.4f}")


loss.requires_grad: True
step 0 | loss 4.9450
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
step 10 | loss 5.1423
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
step 20 | loss 9.5781
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
step 30 | loss 4.2917
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires_grad: True
loss.requires

test after reft training



In [86]:
out = model.generate(**inputs, max_length=30)
print(tokenizer.decode(out[0], skip_special_tokens=True))


False


clean up the hooks


In [87]:
handle.remove()
