In [1]:
import sys
import torch
import transformers

from transformers import LLaMATokenizer, LlamaForCausalLM, GenerationConfig

tokenizer = LLaMATokenizer.from_pretrained("decapoda-research/llama-7b-hf")

LOAD_8BIT = False
BASE_MODEL = "decapoda-research/llama-7b-hf"

model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    load_in_8bit=LOAD_8BIT,
    torch_dtype=torch.float16,
    device_map="auto",
)

print("Model Loaded")
print(model)

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

Model Loaded
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=31999)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (

In [2]:
import copy
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
from fsspec.utils import isfilelike, stringify_path

from datasets import load_dataset


# Let's just try IMDB for simplicity
data = load_dataset("amazon_polarity")["test"]

Found cached dataset amazon_polarity (/home/ubuntu/.cache/huggingface/datasets/amazon_polarity/amazon_polarity/3.0.0/a27b32b7e7b88eb274a8fa8ba0f654f1fe998a87c22547557317793b5d2772dc)


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
from tqdm import tqdm
import numpy as np


def get_llama_hidden_states_from_ids(model,  ids, layer=-1):
    input_ids = ids

    with torch.no_grad():
        output = model(input_ids, output_hidden_states=True)

    hs_tuple = output["hidden_states"]
    hs = hs_tuple[layer][0, -1].detach().cpu().numpy()

    return hs

def get_llama_hidden_states(model, tokenizer, text, layer=-1):
    input_ids = tokenizer(text + tokenizer.eos_token, return_tensors="pt").input_ids.to(model.device)

    return get_llama_hidden_states_from_ids(model, input_ids, layer=-1)
    

def format_amazon(text, label):
    return "A customer wrote the following review:\n{}\nThe sentiment in this review is {}.".format(text,  ["negative", "positive"][label])

def format_amazon_for_completion(text):
    return "A customer wrote the following review:\n{}\nWhat is the sentiment of this review? Output only 'Postive' or 'Negative', and nothing else. the sentiment in this review is: ".format(text)

def get_hidden_states_many_examples(model, tokenizer, data, n=200):

    model.eval()
    all_neg_hs, all_pos_hs, all_gt_labels, all_text = [], [], [], []

    # loop
    for _ in tqdm(range(n)):
        # for simplicity, sample a random example until we find one that's a reasonable length
        # (most examples should be a reasonable length, so this is just to make sure)
        while True:
            idx = np.random.randint(len(data))
            text, true_label = data[idx]["content"], data[idx]["label"]
            # the actual formatted input will be longer, so include a bit of a marign
            if len(tokenizer(text)) < 400:  
                break
                
        # get hidden states
        neg_hs = get_llama_hidden_states(model, tokenizer, format_amazon(text, 0))
        pos_hs = get_llama_hidden_states(model, tokenizer, format_amazon(text, 1))

        # collect
        all_neg_hs.append(neg_hs)
        all_pos_hs.append(pos_hs)
        all_gt_labels.append(true_label)
        all_text.append(text)

    all_neg_hs = np.stack(all_neg_hs)
    all_pos_hs = np.stack(all_pos_hs)
    all_gt_labels = np.stack(all_gt_labels)

    return all_neg_hs, all_pos_hs, all_gt_labels, all_text



In [4]:
neg_hs, pos_hs, y, all_text = get_hidden_states_many_examples(model, tokenizer, data, n=1000)

n = len(y)
neg_hs_train, neg_hs_test = neg_hs[:n//2], neg_hs[n//2:]
pos_hs_train, pos_hs_test = pos_hs[:n//2], pos_hs[n//2:]
text_train, text_test = all_text[:n//2], all_text[n//2:]
y_train, y_test = y[:n//2], y[n//2:]

100%|██████████| 1000/1000 [01:59<00:00,  8.35it/s]


In [5]:
# Normalize the mean (and maybe also variance) of a data set
def normalize(x, var_normalize = False):
  normalized_x = x - x.mean(axis=0, keepdims=True)
  if var_normalize:
      normalized_x /= normalized_x.std(axis=0, keepdims=True)

  return normalized_x

# Collin's main loss function
def ccs_loss(p0, p1):
  return informative_loss(p0,p1) + consistent_loss(p0,p1)


def informative_loss(p0, p1):
  return (torch.min(p0, p1)**2).mean(0)


def consistent_loss(p0, p1):
  return ((p0 - (1-p1))**2).mean(0)


def get_tensor_data(x0, x1):
  x0 = torch.tensor(x0, dtype=torch.float, requires_grad=False, device=model.device)
  x1 = torch.tensor(x1, dtype=torch.float, requires_grad=False, device=model.device)
  return x0, x1


def ccs(x0, x1, nepochs=200, ntries=10, lr=1e-3, batch_size=-1, 
              verbose=False, linear=True, weight_decay=0.01, var_normalize=False, loss_func=ccs_loss):

    x0 = normalize(x0)
    x1 = normalize(x1)

    # Number of entries in the hidden states
    d = x0.shape[-1]
    
    # probe
    probe = nn.Sequential(nn.Linear(d, 1),nn.Sigmoid())
    probe.to(model.device)  
    best_probe = copy.deepcopy(probe)
      
    best_loss = np.inf

    for train_num in range(ntries):
        # Make a new probe for this run
        probe = nn.Sequential(nn.Linear(d, 1), nn.Sigmoid())
        probe.to(model.device)  

        # ORder the data randomly in a tensor
        x0, x1 = get_tensor_data(x0, x1)
        permutation = torch.randperm(len(x0))
        x0, x1 = x0[permutation], x1[permutation]
        
        # Set up optimizer. Collin uses adamW so that's what we'll go with
        optimizer = torch.optim.AdamW(probe.parameters(), lr=lr, weight_decay=weight_decay)

        # Start training
        for epoch in range(nepochs):
          # probe
          p0, p1 = probe(x0), probe(x1)

          # get the corresponding loss
          loss = loss_func(p0, p1)

          # update the parameters
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

        loss = loss.detach().cpu().item()
        
        if loss < best_loss:
            best_probe = copy.deepcopy(probe)
            best_loss = loss

    return best_probe, best_loss


def predict(probe, x0, x1):
  x0 = torch.tensor(normalize(x0), dtype=torch.float, requires_grad=False, device=model.device)
  x1 = torch.tensor(normalize(x1), dtype=torch.float, requires_grad=False, device=model.device)
  
  with torch.no_grad():
      p0, p1 = probe(x0), probe(x1)

  avg_confidence = p0 - p1
  predictions = (avg_confidence.detach().cpu().numpy())[:, 0]

  
  return predictions

def predict_pair(probe, x0, x1):
  x0 = torch.tensor(x0, dtype=torch.float, requires_grad=False, device=model.device)
  x1 = torch.tensor(x1, dtype=torch.float, requires_grad=False, device=model.device)
  
  with torch.no_grad():
      p0, p1 = probe(x0), probe(x1)

  avg_confidence = p0 - p1
  predictions = (avg_confidence.detach().cpu().numpy())[:,0]
  return predictions

def classify_single(classifier_direction, hs):
  confidences = np.apply_along_axis(lambda x : np.dot(x,classifier_direction), 1, hs)

  return confidences

def predict_single(probe, x0):
  x0 = torch.tensor(normalize(x0), dtype=torch.float, requires_grad=False, device=model.device)
  
  with torch.no_grad():
      p0 = probe(x0)

  avg_confidence = p0
  predictions = (avg_confidence.detach().cpu().numpy())[:,0]
  return predictions

def get_acc(probe, x0_test, x1_test, y_test):
  predictions = (predict(probe, x0_test, x1_test) < 0.5).astype(int)

  # If predictions get messed up (i.e. ever not 1 or 0) this method will show
  # really good accuracy. TODO evaluate vs y_test and y_test inverted to avoid
  # this.
  acc = (predictions == y_test).mean()
  return acc
  acc = max(acc, 1 - acc)

  return acc

def is_reversed(probe, x0_test, x1_test, y_test):
  predictions = (predict(probe, x0_test, x1_test) < 0.5).astype(int)

  # If predictions get messed up (i.e. ever not 1 or 0) this method will show
  # really good accuracy. TODO evaluate vs y_test and y_test inverted to avoid
  # this.
  acc = (predictions == y_test).mean()

  return acc < 0.5

In [6]:
x_train = neg_hs_train - pos_hs_train
x_test = neg_hs_test - pos_hs_test


lr = LogisticRegression(class_weight="balanced", max_iter=1000)
lr.fit(x_train, y_train)
print("Logistic regression accuracy: {}".format(lr.score(x_test, y_test)))

Logistic regression accuracy: 0.95


In [7]:
probe, loss = ccs(neg_hs_train, pos_hs_train, ntries=1)

ccs_acc = get_acc(probe, neg_hs_test, pos_hs_test, y_test)

print("CCS Accuracy: {}, loss: {}".format(ccs_acc, loss))

classifier_direction = np.squeeze(np.transpose(probe[0].weight.detach().cpu().numpy()))

if(is_reversed(probe, neg_hs_test, pos_hs_test, y_test)):
  print("Flipping direction")
  classifier_direction = -classifier_direction



CCS Accuracy: 0.04, loss: 0.012944226153194904
Flipping direction


In [10]:
from transformers import LogitsProcessorList, LogitsProcessor

def sigmoid(x):
      return 1/(1 + np.exp(-x))
    
class HiddenStateDirectedLogitsProcessor(LogitsProcessor):
  def __init__(self, **kwargs):
     pass

  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
    k = 10
    top_k = torch.argsort(scores, descending=True)[:,0:k]

    print(input_ids)

    for i, sequence in enumerate(input_ids):

      # Stick all possible next tokens on the end
      completions = torch.cat([sequence.repeat(k,1), top_k[i].reshape(-1,1)],1)

      # Get hidden states
      hss = [get_llama_hidden_states_from_ids(model,x.unsqueeze(0),-1) for x in completions]

      # print("classified")
      classified = classify_single(classifier_direction, hss)

      # classified = classified - classified.mean()
      classified = F.softmax(torch.tensor(classified, device=model.device, dtype=torch.half))
      # print(classified)
      # print("truth")
      # truthiness = [sigmoid(x/20) for x in classified]
      # print(truthiness)
      # print("score")
      # print(scores[i][top_k[i]])
      # print("softmax")
      # print(F.softmax(scores[i])[top_k[i]])

      probs = F.softmax(scores[i][top_k[i]])
    
      scores[i] = torch.zeros(scores[i].size(), device=model.device, dtype=torch.half)

      # print("probsraw")
      
      # print(probs)

      # probs = probs / probs.sum()

      # print("probs")
      # print(probs)

      scores[i][top_k[i]] =  probs

      # print("Top scores")
      # print(scores[i][top_k[i]])
      # print("All scores")
      # print(scores)
      # print(scores[i])
      # scores[i] = torch.zeros(scores[i.size])
      # print("zerod")
      # print(scores[i])
      # scores[i][top_k[i]] = [truthiness[j] * probs[j] for j in range(len(truthiness))]
      # print("merged")
      # print(scores[i])

    # for j in range(len(completions)):
    #     print("{} : {} x {} = {}".format(tokenizer.decode(completions[j]), classified[j], probs[j], probs[j] + classified[j]))
    return scores
     

def generate(
    text,
    temperature=0.1,
    top_p=0.75,
    top_k=40,
    num_beams=4,
    max_new_tokens=12,
    **kwargs,
):
    inputs = tokenizer(text, return_tensors="pt")
    input_ids = inputs["input_ids"].to(model.device)
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=num_beams,
        **kwargs,
    )
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=max_new_tokens,
            logits_processor=LogitsProcessorList([HiddenStateDirectedLogitsProcessor()])
        )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s)
    print(generation_output)
    return output

print(generate("Joe Biden is"))

tensor([[    1, 11131,   350,  3615,   338],
        [    1, 11131,   350,  3615,   338],
        [    1, 11131,   350,  3615,   338],
        [    1, 11131,   350,  3615,   338]], device='cuda:0')


  classified = F.softmax(torch.tensor(classified, device=model.device, dtype=torch.half))
  probs = F.softmax(scores[i][top_k[i]])


 Joe Biden is a : 0.080810546875 x 0.30712890625 = 0.387939453125
 Joe Biden is the : 0.270751953125 x 0.266845703125 = 0.53759765625
 Joe Biden is not : 0.2034912109375 x 0.0966796875 = 0.30029296875
 Joe Biden is running : 0.00841522216796875 x 0.0880126953125 = 0.096435546875
 Joe Biden is in : 0.0701904296875 x 0.062408447265625 = 0.132568359375
 Joe Biden is an : 0.278076171875 x 0.057708740234375 = 0.335693359375
 Joe Biden is one : 0.04852294921875 x 0.034454345703125 = 0.0830078125
 Joe Biden is going : 0.00556182861328125 x 0.0297088623046875 = 0.0352783203125
 Joe Biden is now : 0.0174102783203125 x 0.0287933349609375 = 0.04620361328125
 Joe Biden is still : 0.0167999267578125 x 0.02813720703125 = 0.044921875
tensor([[    1, 11131,   350,  3615,   338,   263],
        [    1, 11131,   350,  3615,   338,   278],
        [    1, 11131,   350,  3615,   338,   451],
        [    1, 11131,   350,  3615,   338,  2734]], device='cuda:0')
 Joe Biden is running for : 0.021438598632812