In [2]:
import argparse
import torch 
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import seaborn as sns
import json
import pandas as pd
import numpy as np
from tqdm import tqdm

In [3]:
parser = argparse.ArgumentParser(description='training proof-of-concept')

# Data selection
parser.add_argument('--model_name', type=str, default="meta-llama/Llama-2-7b-hf")
parser.add_argument('--dataset_name', type=str, default='/home/echeng/llm-control/jigsaw-toxic-comment-classification-challenge')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args([])

ACCESS_TOKEN='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'

In [4]:
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name, token=ACCESS_TOKEN)
model = AutoModelForCausalLM.from_pretrained(args.model_name, 
                                             token=ACCESS_TOKEN,
                                             load_in_8bit=True
                                            )

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
if 'Llama-2' in args.model_name:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right" 

In [6]:
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear8bitLt(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
 

In [7]:
# Load the dataset
dataset = pd.read_csv(args.dataset_name + '/train.csv').sample(frac=0.01)

In [8]:
dataset.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
95290,fec47a90d3ff4325,"Hi \n\nHey brownaddictUK, welcome to wikipedia...",0,0,0,0,0,0
60474,a1e22f184f04b2c7,"""I just have a problem with the whole opening ...",0,0,0,0,0,0
5691,0f3a4d4d25070429,RfC: Should the infobox say that the glacier i...,0,0,0,0,0,0
22717,3bfce01ba2a48b38,(i.e. outside France/US),0,0,0,0,0,0
111645,554248e2e10efe24,Cultural Signicance \n\nThe article totally la...,0,0,0,0,0,0


In [9]:
dataset['comment_text']

95290     Hi \n\nHey brownaddictUK, welcome to wikipedia...
60474     "I just have a problem with the whole opening ...
5691      RfC: Should the infobox say that the glacier i...
22717                              (i.e. outside France/US)
111645    Cultural Signicance \n\nThe article totally la...
                                ...                        
97774      "::How about ""Social Contract Theory""? - \n\n"
55176     "\n\n ""I have no tribunal"" \n\nRevision on "...
40850                            I forgot myself at Zilina.
26348     '''xenophobic a highly perjorative term alludi...
54685     Track map? \n\nCan a track map be made by jdjo...
Name: comment_text, Length: 1596, dtype: object

# Preprocess data

In [14]:
def encode_data(tokenizer, N, data, batch_size, max_length, device, last_k=None):
    # last_k (int): only use the last k tokens of the input

    # If the input data is text
    if type(data[0]) == str:
        encodings = tokenizer(data, padding=True, truncation=True, max_length=max_length, return_length=True, return_tensors="pt") # output variable length encodings
        if not last_k:
            encodings = [
                {'input_ids': encodings['input_ids'][i: i + batch_size].to(device),
                'attention_mask': encodings['attention_mask'][i: i + batch_size].to(device),
                'length': encodings['length'][i: i + batch_size] }
                for i in range(0, N, batch_size)
            ]
        else:
            encodings = [
                {'input_ids': encodings['input_ids'][i: i + batch_size][-last_k:].to(device),
                'attention_mask': encodings['attention_mask'][i: i + batch_size][-last_k:].to(device) }
                for i in range(0, N, batch_size)
            ]
    else: # input data is tokens-- manually pad and batch.
        max_len = max([len(sentence) for sentence in data])
        data = [sentence for sentence in data if len(sentence) > 2]
        encodings = [tokenizer.encode(sentence[1:], padding='max_length', max_length=max_len, return_tensors="pt") \
                     for sentence in data]
        batched_encodings = [torch.stack(encodings[i: i + batch_size]).squeeze(1).to(device) for i in range(0, len(data), batch_size)]
        batched_attention_masks = [(tokens != 1).to(device).long() for tokens in batched_encodings]
        encodings = [
            {'input_ids': batched_encodings[j], 'attention_mask': batched_attention_masks[j]}
            for j in range(len(batched_encodings))
        ]

    return encodings

In [11]:
data = list(dataset['comment_text'])[:1]

In [12]:
# tokenize data
encodings = encode_data(tokenizer, len(data), data, args.batch_size, model.config.max_position_embeddings, args.device)

In [13]:
encodings[0]

{'input_ids': tensor([[    1,  6324, 29871,    13,    13, 29950,  1032, 17354,  1202,   919,
          19960, 29892, 12853,   304,   281,   638,  4652, 29889,   306, 29915,
            345, 10548,   393,   366,  1925,   278,  7821, 29889,  8477,   292,
           1202,   919, 29879, 29889,  1212,  1544,   373,   263,  7303,   310,
           6515,   393,   306,  6505,   975, 29892,   322,  5131,   304,  2367,
            366,   263,  4996, 15883,   701,   373, 16278,   281,   638,  4652,
          29889, 29871,   408,   366, 29915,   345,  3117, 10548, 29892,   263,
           3287,   310,   278,  2988,   892,  5051,  9098,  6206,   491,  3863,
            943,  1363,   445,  3508, 29915, 29873,   278,  2058,   363,   963,
          29889, 29871,   306, 29915, 29885,   599,   363, 10311,  6260,  4371,
            669,   633,   303, 17225,   515,   788,   919,  1080, 29892,   541,
            445,   338,   263,  2058,   304,  3867,  2114,   950,  2472,   373,
          17800, 29889, 298

In [15]:
def last_token_rep(x, attention_mask, padding='right'):
    # print(x)
    # print(attention_mask)
    seq_len = attention_mask.sum(dim=1)
    # print('len of sequences', seq_len)
    indices = (seq_len - 1)
    # print(x.size(0))
    # print(indices)
    last_token_rep = x[torch.arange(x.size(0)), indices] if padding=='right' else x[torch.arange(x.size(0)), -1]
    # print(last_token_rep)
    return last_token_rep.cpu()

In [None]:
with torch.no_grad():
    representations = []
    for batch in tqdm(encodings):
        output = model(batch['input_ids'], attention_mask=batch['attention_mask'], output_hidden_states=True)['hidden_states']
        pooled_output = tuple([last_token_rep(layer, batch['attention_mask'], padding=tokenizer.padding_side) for layer in output]) 
        representations.append(pooled_output)
    representations = [list(batch) for batch in zip(*representations)]
    representations = [torch.cat(batches, dim=0) for batches in representations]
    print('Layer 1 reps shape: ')
    print(representations[1].shape)
    input()
    torch.save(representations, '/home/echeng/llm-control/toxic_reps.pt')

In [None]:
reps[0].shape

# Training

In [16]:
import random
import torch.nn as nn
from sklearn.metrics import f1_score   

In [17]:
device = 'cuda'

In [18]:
W = torch.load('/home/echeng/llm-control/experiments/toxicity/linear_probe_tiny.pt').to(device)
W

Linear(in_features=4096, out_features=2, bias=True)

In [19]:
from scipy import optimize

In [64]:
# make a custom layer, WRAPPING the layer of interest with something containing extra parameters
# REPLACE the current layer with that layer of interest.
class LinearControlWrapper(torch.nn.Module):
    def __init__(self, base_layer: nn.Module, linear_probe: nn.Module, name="", gamma=0.01):
        """
        W shape: d x 2
        """
        super(LinearControlWrapper, self).__init__()
        self.base_layer = base_layer

        # Probe-related parameters
        self.gamma = gamma
        self.W = linear_probe.weight # linear probe
        self.b = linear_probe.bias
        self.w = self.W[1,:] - self.W[0,:] # as defined in algo w_2 - w_1
        self.w_norm = torch.linalg.vector_norm(self.w)

    def forward(self, x, *args, **kwargs):
        print('attention_mask' in kwargs)
        x_seq, x_metadata = self.base_layer(x, *args, **kwargs)
        print(x_seq)
        print(kwargs)
        # Now update the last token representation
        sequence_len, _ = torch.max(kwargs['position_ids'].cpu(), dim=1)
        x_seq[:,sequence_len] += self.optimal_theta(x_seq[:,sequence_len])
        print(x_seq.shape)
        return x_seq, x_metadata
    
    def optimal_theta(self, x):
        function = lambda l: l * np.exp(l * self.w_norm**2 + torch.matmul(self.w.T, x)) + l - 1/self.gamma
        x0 = 1 / self.gamma # we know the positive root is between 0 and 1/gamma
        print(self.w_norm**2)
        lmbda = optimize.root(lambda l: l * np.exp(l * self.w_norm**2), x0) # parameter to optimize

        theta = lmbda * self.w
        return theta

In [65]:
model.model.layers[26] = LinearControlWrapper(model.model.layers[26].base_layer, W)

In [22]:
prev_layer, model.model.layers[26] = model.model.layers[26], LinearControlWrapper(model.model.layers[26], W)

torch.Size([4096])


In [40]:
# Initialize
d = model.config.hidden_size
eps = 0.0001
theta = torch.randn(d) * eps
gamma = 0.01
T = 10000 # number of tokens to sample

In [30]:
encodings[0]

{'input_ids': tensor([[    1,  6324, 29871,    13,    13, 29950,  1032, 17354,  1202,   919,
          19960, 29892, 12853,   304,   281,   638,  4652, 29889,   306, 29915,
            345, 10548,   393,   366,  1925,   278,  7821, 29889,  8477,   292,
           1202,   919, 29879, 29889,  1212,  1544,   373,   263,  7303,   310,
           6515,   393,   306,  6505,   975, 29892,   322,  5131,   304,  2367,
            366,   263,  4996, 15883,   701,   373, 16278,   281,   638,  4652,
          29889, 29871,   408,   366, 29915,   345,  3117, 10548, 29892,   263,
           3287,   310,   278,  2988,   892,  5051,  9098,  6206,   491,  3863,
            943,  1363,   445,  3508, 29915, 29873,   278,  2058,   363,   963,
          29889, 29871,   306, 29915, 29885,   599,   363, 10311,  6260,  4371,
            669,   633,   303, 17225,   515,   788,   919,  1080, 29892,   541,
            445,   338,   263,  2058,   304,  3867,  2114,   950,  2472,   373,
          17800, 29889, 298

In [66]:
model(input_ids=encodings[0]['input_ids'], attention_mask=encodings[0]['attention_mask'])

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 23.50 GiB of which 6.81 MiB is free. Including non-PyTorch memory, this process has 23.48 GiB memory in use. Of the allocated memory 23.14 GiB is allocated by PyTorch, and 57.75 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Training loop
# sample content from the model
for t in range(T):
    # sample next token from model, evaluating

# 