In [None]:
import numpy as np
import pandas as pd
import torch
from transformers import BertModel, BertTokenizer
from tqdm import tqdm

torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.is_available())
torch.set_default_device(device)
print(torch.cuda.get_device_name(torch.cuda.current_device()))

def get_embeddings(text, tokenizer, model, layers=[-1, -2, -3, -4]):
    encoded = tokenizer.encode_plus(text, add_special_tokens=True, return_tensors="pt")
    with torch.no_grad():
      output = model(**encoded)
    hidden_states = output.hidden_states
    embedding = torch.stack([hidden_states[i] for i in layers]).mean(0).squeeze()
    embedding = embedding.detach().cpu().numpy()
    return embedding[0]

df = pd.read_csv("data/sportsQnA.csv", dtype="string").dropna()
asian_countries = ["China", "Japan", "South Korea", "Taiwan", "Hong Kong"]

df_asia = df[df['nationality'].isin(asian_countries)] # players from asia
df_usa = df[df['nationality'] == "United States"] # players from america
df_all = pd.concat([df_asia, df_usa], ignore_index=True)

# question text that we will be using
q_asia = df_asia['question']
q_usa = df_usa['question']
q_all = df_all['question']

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)

model.to(device)

embedding_holder = []
meta_holder = []
for row in tqdm(df_all.itertuples(), total=df_all.shape[0]):
  text = getattr(row, 'question')
  nationality = getattr(row, 'nationality')
  name = getattr(row, 'name')
  date = getattr(row, 'date')
  embedding = get_embeddings(text, tokenizer, model)

  embedding_holder.append(embedding)
  meta_holder.append([name, nationality, text, date])
meta_df = pd.DataFrame(meta_holder, columns = ["name", "nationality", "text", "date"])
embedding_df = pd.DataFrame(embedding_holder)
embedding_df = pd.concat([meta_df.loc[:,["name", "nationality", "text", "date"]], embedding_df], axis = 1)

print(embedding_df.shape)
embedding_df.to_csv("embeddings.csv", index=False)

 20%|█▉        | 15076/76977 [32:54<2:15:08,  7.63it/s]


KeyboardInterrupt: 