In [1]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/projects/chen386/.cache/huggingface'

In [2]:
import gc
import pickle
import torch

import numpy as np
import pandas as pd

from sklearn.linear_model import LogisticRegression
from transformers import  AutoModel, AutoTokenizer, BertModel, LlamaModel, LlamaTokenizer

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
df = pd.read_csv('sigcse_2024.csv')

In [5]:
def last_token(model_output):
    token_embeddings = model_output[0]
    return token_embeddings[0][-1].view(1, len(token_embeddings[0][-1]))

# copied from huggingface sbert example 
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [6]:
model_name_to_meta_data = {
    'bert_base' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'bert-base-uncased',
        'model_class' : BertModel,
        'model_path' : 'bert-base-uncased',
        'load_in_4bit' : False
    },
    'bert_large' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'bert-large-uncased',
        'model_class' : BertModel,
        'model_path' : 'bert-large-uncased',
        'load_in_4bit' : False
    },
    'sbert' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'sentence-transformers/all-MiniLM-L6-v2',
        'model_class' : AutoModel,
        'model_path' : 'sentence-transformers/all-MiniLM-L6-v2',
        'load_in_4bit' : False
    },
    'gpt2' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'gpt2',
        'model_class' : AutoModel,
        'model_path' : 'gpt2',
        'load_in_4bit' : False
    },
    'gpt2_medium' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'gpt2-medium',
        'model_class' : AutoModel,
        'model_path' : 'gpt2-medium',
        'load_in_4bit' : False
    },
    'gpt2_large' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'gpt2-large',
        'model_class' : AutoModel,
        'model_path' : 'gpt2-large',
        'load_in_4bit' : False
    },
    'gpt2_xl' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'gpt2-xl',
        'model_class' : AutoModel,
        'model_path' : 'gpt2-xl',
        'load_in_4bit' : False
    },
    'llama_7b' : {
        'tokenizer_class' : LlamaTokenizer,
        'tokenizer_path' : '7B',
        'model_class' : LlamaModel,
        'model_path' : '7B',
        'load_in_4bit' : True
    },
    'llama_13b' : {
        'tokenizer_class' : LlamaTokenizer,
        'tokenizer_path' : '13B',
        'model_class' : LlamaModel,
        'model_path' : '13B',
        'load_in_4bit' : True
    },
    'llama_30b' : {
        'tokenizer_class' : LlamaTokenizer,
        'tokenizer_path' : '30B',
        'model_class' : LlamaModel,
        'model_path' : '30B',
        'load_in_4bit' : True
    },
    'llama_65b' : {
        'tokenizer_class' : LlamaTokenizer,
        'tokenizer_path' : '65B',
        'model_class' : LlamaModel,
        'model_path' : '65B',
        'load_in_4bit' : True
    },
    'llama2_7b' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'meta-llama/Llama-2-7b-hf',
        'model_class' : AutoModel,
        'model_path' : 'meta-llama/Llama-2-7b-hf',
        'load_in_4bit' : True
    },
    'llama2_7b_chat' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'meta-llama/Llama-2-7b-chat-hf',
        'model_class' : AutoModel,
        'model_path' : 'meta-llama/Llama-2-7b-chat-hf',
        'load_in_4bit' : True
    },
    'llama2_13b' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'meta-llama/Llama-2-13b-hf',
        'model_class' : AutoModel,
        'model_path' : 'meta-llama/Llama-2-13b-hf',
        'load_in_4bit' : True
    },
    'llama2_13b_chat' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'meta-llama/Llama-2-13b-chat-hf',
        'model_class' : AutoModel,
        'model_path' : 'meta-llama/Llama-2-13b-chat-hf',
        'load_in_4bit' : True
    },
    'llama2_70b' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'meta-llama/Llama-2-70b-hf',
        'model_class' : AutoModel,
        'model_path' : 'meta-llama/Llama-2-70b-hf',
        'load_in_4bit' : True
    },
    'llama2_70b_chat' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'meta-llama/Llama-2-70b-chat-hf',
        'model_class' : AutoModel,
        'model_path' : 'meta-llama/Llama-2-70b-chat-hf',
        'load_in_4bit' : True
    },
    'vicuna_7b' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'lmsys/vicuna-7b-v1.3',
        'model_class' : AutoModel,
        'model_path' : 'lmsys/vicuna-7b-v1.3',
        'load_in_4bit' : True
    },
    'vicuna_13b' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'lmsys/vicuna-13b-v1.3',
        'model_class' : AutoModel,
        'model_path' : 'lmsys/vicuna-13b-v1.3',
        'load_in_4bit' : True
    },
    'vicuna_33b' : {
        'tokenizer_class' : AutoTokenizer,
        'tokenizer_path' : 'lmsys/vicuna-33b-v1.3',
        'model_class' : AutoModel,
        'model_path' : 'lmsys/vicuna-33b-v1.3',
        'load_in_4bit' : True
    }
}

In [7]:
for model_name in model_name_to_meta_data:
    print(model_name)
    meta_data = model_name_to_meta_data[model_name]
    
    if 'llama' in model_name or 'vicuna' in model_name:
        tokenizer = meta_data['tokenizer_class'].from_pretrained(meta_data['tokenizer_path'], legacy=False)
    else:
        tokenizer = meta_data['tokenizer_class'].from_pretrained(meta_data['tokenizer_path'])
    
    if meta_data['load_in_4bit']:
        model =  meta_data['model_class'].from_pretrained(meta_data['model_path'], device_map='auto', load_in_4bit=meta_data['load_in_4bit'])
    else:
        model =  meta_data['model_class'].from_pretrained(meta_data['model_path']).to(device)
    
    # actual generation
    last_token_embeddings = []
    mean_pooling_embeddings = []

    for index, row in df.iterrows():
        if index % 100 == 0:
            print(f'{index}/{len(df)}')

        encoded_input = tokenizer(row['response'], return_tensors='pt').to(device)

        with torch.no_grad():
            model_output = model(**encoded_input)

        last_token_embedding = last_token(model_output)
        mean_pooling_embedding = mean_pooling(model_output, encoded_input['attention_mask'])

        last_token_embeddings.append(last_token_embedding.cpu())
        mean_pooling_embeddings.append(mean_pooling_embedding.cpu())

    last_token_embeddings = np.concatenate(last_token_embeddings)
    mean_pooling_embeddings = np.concatenate(mean_pooling_embeddings)
    
    print(np.shape(mean_pooling_embeddings))
    
    # sanity check
    train_df = df[df.subset == 'train'].copy()
    validate_df = df[df.subset == 'validate'].copy()

    validate_df['predicted'] = 0
    for qid, sub_train_df in train_df.groupby('qid'):
        lr = LogisticRegression(max_iter=1000000)

        train_X = np.zeros((len(sub_train_df), np.shape(mean_pooling_embeddings)[1]))
        train_y = np.zeros((len(sub_train_df), ), dtype=int)
        index = 0
        for i, row in sub_train_df.iterrows():
            train_X[index, :] = mean_pooling_embeddings[i]
            train_y[index] = row.binary_ground_truth

            index += 1

        lr.fit(train_X, train_y)

        targets = np.array(train_y)
        predicted = np.array(lr.predict(train_X))

        sub_validate_df = validate_df[validate_df.qid == qid]

        validate_X = np.zeros((len(sub_validate_df), np.shape(mean_pooling_embeddings)[1]))
        validate_y = np.zeros((len(sub_validate_df), ), dtype=int)
        index = 0
        for i, row in sub_validate_df.iterrows():
            validate_X[index, :] = mean_pooling_embeddings[i]
            validate_y[index] = row.binary_ground_truth

            index += 1

        targets = np.array(validate_y)
        predicted = lr.predict(validate_X)

        validate_df.loc[sub_validate_df.index, 'predicted'] = predicted
        
    print(len(validate_df[validate_df.predicted == validate_df.binary_ground_truth]) / len(validate_df))
    
    pickle.dump(last_token_embeddings.tolist(), open(f'embeddings/{model_name}_last_token.pkl', 'wb'))
    pickle.dump(mean_pooling_embeddings.tolist(), open(f'embeddings/{model_name}_mean_pooling.pkl', 'wb'))
    
    del model
    gc.collect()

bert_base
0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 768)
0.8643790849673203
bert_large
0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 1024)
0.8741830065359477
sbert
0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 384)
0.89

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 4096)
0.8725490196078431
llama_13b


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 5120)
0.8954248366013072
llama_30b


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 6656)
0.8856209150326797
llama_65b


Loading checkpoint shards:   0%|          | 0/14 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 8192)
0.8758169934640523
llama2_7b


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 4096)
0.8823529411764706
llama2_7b_chat


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 4096)
0.8905228758169934
llama2_13b


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 5120)
0.8839869281045751
llama2_13b_chat


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 5120)
0.8921568627450981
llama2_70b


Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 8192)
0.8709150326797386
llama2_70b_chat


Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 8192)
0.9003267973856209


You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


vicuna_7b


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 4096)
0.8937908496732027
vicuna_13b


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 5120)
0.8872549019607843
vicuna_33b


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

0/3064
100/3064
200/3064
300/3064
400/3064
500/3064
600/3064
700/3064
800/3064
900/3064
1000/3064
1100/3064
1200/3064
1300/3064
1400/3064
1500/3064
1600/3064
1700/3064
1800/3064
1900/3064
2000/3064
2100/3064
2200/3064
2300/3064
2400/3064
2500/3064
2600/3064
2700/3064
2800/3064
2900/3064
3000/3064
(3064, 6656)
0.8888888888888888
