In [None]:
import torch
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.generation import SampleDecoderOnlyOutput

In [None]:
SCALING = 10

In [None]:
def senses_of_word(word, model):
  tokens = (torch.ones(512)*word).reshape(1,512).long().cuda()
  contents = model.backpack.sense_network(model.backpack.gpt2_model.wte(tokens)) #(bs, nv, s, d)
  contents = contents[0,:,0,:] #(nv, d)
  return contents


def mogrify_word(model, word, out_word, in_word, tokenizer):

  word = tokenizer(word)['input_ids'][0]
  in_word = tokenizer(in_word)['input_ids'][0]
  out_word = tokenizer(out_word)['input_ids'][0]
  
  def project_out_and_in(
      senses, # (nv, d)
      out_direction, # (d,)
      in_direction, # (d,)
      ):
    #embeddings = embeddings.detach().clone()
    dots = senses @ out_direction / (out_direction  @ out_direction) #(nv)
    normalization = (out_direction @ out_direction) / (in_direction @ in_direction) #(1)
    out_diffs = dots.unsqueeze(1) * out_direction.unsqueeze(0)  #(nv, d)
    in_diffs = dots.unsqueeze(1) * in_direction.unsqueeze(0) * normalization * SCALING #(nv, d)
    fixed_senses = senses - out_diffs + in_diffs
    #for word in words:
    #  embeddings[word[0]] = fixed_embeddings[word].type(embeddings.type()).detach()
    return fixed_senses


  word_senses = senses_of_word(word, model)
  out_embedding_vector = model.lm_head.weight[out_word]
  in_embedding_vector = model.lm_head.weight[in_word]
  fixed_senses = project_out_and_in(word_senses, out_embedding_vector, in_embedding_vector)
  #visualize_word(None, tokenizer, model, contents=fixed_senses)
  return {word: fixed_senses}

In [None]:
def replace_content(input_ids, content, sense_dict):
    for batch_index in range(content.shape[0]):
      for seq_index in range(content.shape[2]):
        word = input_ids[batch_index][seq_index].detach().cpu().item()
        if word in sense_dict:
          content[batch_index, :, seq_index, :] = sense_dict[word]
    return content


def sample(input_ids, model, max_length, sense_dict, replace=True, sample=True):
    """Sampling. This is a very simple implementation.
    We assume that all sequences in the same batch have the same length.
    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
    Returns: GreedySearchDecoderOnlyOutput, with the following fields:
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
    scores = []
    with torch.inference_mode():
        # unchanged
        contextl_hidden_states = model.backpack.gpt2_model(input_ids, None)["last_hidden_state"]
        contextualization = model.backpack.sense_weight_net(contextl_hidden_states) # (bs, nv, s, s)

        # Compute content and weight
        content = model.backpack.sense_network(model.backpack.gpt2_model.wte(input_ids)) # (bs, nv, s, d)
        if replace:
          content = replace_content(input_ids, content, sense_dict)

        # Compute resulting outputs
        hidden_states = torch.sum(contextualization @ content, dim=1) # (bs, s, d)
        logits = model.lm_head(hidden_states)[:, -1]

        scores.append(logits)
        if sample:
            next_token = torch.distributions.Categorical(logits=torch.log_softmax(logits,dim=-1)).sample()
        else:
            next_token = torch.argmax(torch.log_softmax(logits,dim=-1), dim=1)
        sequences = [next_token]
        seqlen = seqlen_og+1
        while seqlen < max_length:
            input_ids = torch.cat((input_ids, next_token.unsqueeze(1)), dim=1)
            logits = model(input_ids).logits[:, -1]
            if sample:
                next_token = torch.distributions.Categorical(logits=torch.log_softmax(logits,dim=-1)).sample()
            else:
                next_token = torch.argmax(torch.log_softmax(logits,dim=-1), dim=1)
            seqlen += 1
    return SampleDecoderOnlyOutput(
        sequences=input_ids,
        scores=tuple(scores)
    )

#same function with sample(input_ids, model, max_length, sense_dict, replace=False)
def generate(input_ids, model, max_length, sample=True):
    batch_size, seqlen_og = input_ids.shape
    with torch.inference_mode():
        logits = model(input_ids).logits[:, -1]
        if sample:
            next_token = torch.distributions.Categorical(logits=torch.log_softmax(logits,dim=-1)).sample()
        else:
            next_token = torch.argmax(torch.log_softmax(logits,dim=-1), dim=1)
        seqlen = seqlen_og+1
        while seqlen < max_length:
            input_ids = torch.cat((input_ids, next_token.unsqueeze(1)), dim=1)
            logits = model(input_ids).logits[:, -1]
            if sample:
                next_token = torch.distributions.Categorical(logits=torch.log_softmax(logits,dim=-1)).sample()
            else:
                next_token = torch.argmax(torch.log_softmax(logits,dim=-1), dim=1)
            seqlen += 1
    return input_ids

In [None]:
no_reduction_ce = nn.CrossEntropyLoss(reduction='none')
tokenizer = AutoTokenizer.from_pretrained('gpt2')

model_id = "stanfordnlp/backpack-gpt2"
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, config=config, trust_remote_code=True).cuda()
model.eval()

for param in model.parameters():
    param.requires_grad = False

In [None]:
#Manipulate knowledge 
inp = "The MacBook is best known for "
sense_dict = mogrify_word(
    model,
    ' MacBook',
    ' Apple',
    ' HP',
    tokenizer)

inp = torch.tensor(tokenizer(inp)['input_ids']).unsqueeze(0).to('cuda')
outputs = sample(inp, model, 100, sense_dict).sequences

print(tokenizer.batch_decode(outputs, skip_special_tokens=False)[0])

In [None]:
#Manipulate knowledge
inp = "The color of the sky is "
sense_dict = mogrify_word(
    model,
    ' sky',
    ' blue',
    ' green',
    tokenizer)

inp = torch.tensor(tokenizer(inp)['input_ids']).unsqueeze(0).to('cuda')
outputs = sample(inp, model, 100, sense_dict).sequences

print(tokenizer.batch_decode(outputs, skip_special_tokens=False)[0])

In [None]:
#Don't manipulate knowledge
inp = "The MacBook is best known for "
sense_dict = mogrify_word(
    model,
    ' MacBook',
    ' Apple',
    ' HP',
    tokenizer)

inp = torch.tensor(tokenizer(inp)['input_ids']).unsqueeze(0).to('cuda')
outputs = sample(inp, model, 100, sense_dict, replace=False).sequences

print(tokenizer.batch_decode(outputs, skip_special_tokens=False)[0])

In [None]:
#Don't manipulate knowledge
inp = "The MacBook is best known for "
inp = torch.tensor(tokenizer(inp)['input_ids']).unsqueeze(0).to('cuda')
outputs = generate(inp, model, 100, sample=True)

print(tokenizer.batch_decode(outputs, skip_special_tokens=False)[0])