In [10]:
import os
import random
from os.path import join

import numpy as np
import torch
from transformers import BertTokenizer, BertModel
import pandas as pd
from tqdm.notebook import tqdm

import mymain

seed = 1735
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

In [2]:
split_dir = './F24_Proj3_data/split_1'


In [3]:
def get_bert_embeddings(texts):
    # Load pre-trained model and tokenizer
    bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_model = BertModel.from_pretrained('bert-base-uncased')

    def get_sentence_embedding(text):
        inputs = bert_tokenizer(text, truncation=True, padding=True, return_tensors='pt')
        with torch.no_grad():
            outputs = bert_model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        sentence_embedding = torch.mean(last_hidden_states, dim=1).numpy().flatten()
        return sentence_embedding

    # Generate embeddings for texts
    return np.array([get_sentence_embedding(text) for text in tqdm(texts)])


def get_embeddings(filepath, num):
    test_df = pd.read_csv(filepath)
    X_test = test_df.drop(columns=['id', 'review'])

    idxs = np.random.choice(len(X_test), size=num, replace=False)

    reviews = test_df.iloc[idxs]['review'].tolist()
    bert = get_bert_embeddings(reviews)
    assert bert.shape == (num, 768)

    openai = test_df.iloc[idxs].drop(columns=['id', 'review']).to_numpy()
    assert openai.shape == (num, 1536)

    return bert, openai


In [4]:
bert, openai = get_embeddings(join(split_dir, 'test.csv'), 1600)
bert.shape, openai.shape


  0%|          | 0/1600 [00:00<?, ?it/s]

((1600, 768), (1600, 1536))

In [5]:
X = np.c_[np.ones(bert.shape[0]), bert]
x, _, _, _ = np.linalg.lstsq(X, openai)
x.shape

(769, 1536)

In [6]:
# Write trained split 1 model to file

cwd = os.getcwd()
os.chdir(split_dir)
model = mymain.main()
os.chdir(cwd)

model_file = './interpretability_inputs/trained_lr_model.npz'
np.savez_compressed(model_file, intercept=model.intercept_, coef=model.coef_, features=model.feature_names_in_, bert_to_openai_mapping=x)

In [7]:
test_path = join(split_dir, "test.csv")

# Load true labels data
test_df = pd.read_csv(test_path)
X_test = test_df.drop(columns=['id', 'review'])


In [8]:
probs = model.predict_proba(X_test)[:, 1]
np.random.seed(seed)
pos_idxs = np.random.choice(np.where(probs > 0.5)[0], 5, replace=False)
neg_idxs = np.random.choice(np.where(probs < 0.5)[0], 5, replace=False)
pos_idxs, neg_idxs


(array([ 5694,  9920, 14168,  4549, 14891]),
 array([20452,  5722,  6566,  2753,  1055]))

In [11]:
pos_idxs = np.array([ 5694,  9920, 14168,  4549, 14891])
neg_idxs = np.array([20452,  5722,  6566,  2753,  1055])


def generate_interpretability_embeddings(tag, id, review):
    sentences = review.split('.')
    loo_reviews = []
    for idx in range(len(sentences)):
        loo_reviews.append('. '.join(sentences[:idx] + sentences[idx + 1:]))
    loo_berts = get_bert_embeddings(loo_reviews)
    
    data = np.c_[np.array(sentences), loo_berts]
    columns = ['loo_sentence'] + [f'bert_embedding_{idx + 1}' for idx in range(loo_berts.shape[1])]
    data_filepath = join(f'./interpretability_inputs/{tag}_{id}.csv')
    print(f'Writing to {data_filepath}')
    pd.DataFrame(data, columns=columns).to_csv(data_filepath, index=False, header=True)


for tag, review_idxs in [('pos', pos_idxs), ('neg', neg_idxs)]:
    for review_idx in review_idxs:
        generate_interpretability_embeddings(tag, test_df.iloc[review_idx].id, test_df.iloc[review_idx].review)
    

  0%|          | 0/10 [00:00<?, ?it/s]

Writing to ./interpretability_inputs/pos_2603.csv


  0%|          | 0/47 [00:00<?, ?it/s]

Writing to ./interpretability_inputs/pos_8073.csv


  0%|          | 0/7 [00:00<?, ?it/s]

Writing to ./interpretability_inputs/pos_19034.csv


  0%|          | 0/7 [00:00<?, ?it/s]

Writing to ./interpretability_inputs/pos_17545.csv


  0%|          | 0/8 [00:00<?, ?it/s]

Writing to ./interpretability_inputs/pos_9595.csv


  0%|          | 0/37 [00:00<?, ?it/s]

Writing to ./interpretability_inputs/neg_35754.csv


  0%|          | 0/15 [00:00<?, ?it/s]

Writing to ./interpretability_inputs/neg_10557.csv


  0%|          | 0/23 [00:00<?, ?it/s]

Writing to ./interpretability_inputs/neg_46871.csv


  0%|          | 0/8 [00:00<?, ?it/s]

Writing to ./interpretability_inputs/neg_41564.csv


  0%|          | 0/13 [00:00<?, ?it/s]

Writing to ./interpretability_inputs/neg_2726.csv
