In [None]:
# !pip install transformers
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
# !pip install sentence-transformers
# !pip install datasets

In [1]:
import numpy as np
import pandas as pd
import ast
import warnings
import scipy
import matplotlib.pyplot as plt
import spacy
import difflib
import tqdm
import json
import pickle
import logging
import itertools
import torch
import datasets
from torch.utils.data import DataLoader

logging.basicConfig(
    format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
)
from torch.utils.data import DataLoader , IterableDataset , Dataset
from torch.optim import AdamW
from transformers import get_scheduler
from transformers import DataCollatorForLanguageModeling
from transformers import BertTokenizer, BertLMHeadModel, AutoConfig
from sklearn.metrics.pairwise import cosine_similarity
from collections import Counter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# config = AutoConfig
model = BertLMHeadModel.from_pretrained("google-bert/bert-large-uncased")
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased")

dfn = pd.read_csv("lmsys-chatbot-arena/aux_files/dfn.csv")
unique_ids = dfn.id.unique()
dfn.head(3)

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


Unnamed: 0,id,prompt,model_a_answer,model_b_answer,winner
0,30192,Is it morally right to try to have a certain p...,The question of whether it is morally right to...,"As an AI, I don't have personal beliefs or opi...",A
1,53567,What is the difference between marriage licens...,A marriage license is a legal document that al...,A marriage license and a marriage certificate ...,B
2,53567,How can I get both of them as quick as possibl...,If you want to get both a marriage license and...,"In California, here are the general steps to o...",B


In [4]:
id_list = []
id_strings = []

for _id in tqdm.tqdm(unique_ids):
    id_df = dfn[dfn.id == _id]
    id_str = ""
    
    for row in id_df.iterrows():
        prompt = str(row[1]['prompt'])
        
        model_a_answer = str(row[1]['model_a_answer'])
        model_b_answer = str(row[1]['model_b_answer'])
        winner = row[1]['winner']

        if winner == "A":
            id_str += prompt + "\n" + model_a_answer +'\n'
        if winner == "B":
            id_str += prompt + "\n" + model_b_answer +'\n'
    
    if len(id_str) > 10:
        id_strings.append(str(id_str))
        id_list.append(str(_id))

cut = int(len(id_strings) * 0.2)

id_strings_train = id_strings[0:len(id_strings) - cut]
id_strings_test = id_strings[cut:]

100%|███████████████████████████████████| 56759/56759 [00:06<00:00, 8965.02it/s]


In [5]:
class DocsDataset(Dataset):
    def tokenize_ans(self , examples):
        return tokenizer(examples, padding='max_length' , max_length=512 , truncation=True , return_tensors='pt')
    
    def __init__(self, str_list):
        self.str_list = str_list

    def __len__(self):
        return len(self.str_list)
        
    def __getitem__(self , idx):
        tokenized_str_map = self.tokenize_ans(self.str_list[idx])
        tokenized_str_map['labels'] = tokenized_str_map['input_ids']
        return tokenized_str_map

dataset_docs_train = DocsDataset(id_strings_train)
dataset_docs_train_dataloader = DataLoader(dataset_docs_train, batch_size=4)

dataset_docs_test = DocsDataset(id_strings_test)
dataset_docs_test_dataloader = DataLoader(dataset_docs_test, batch_size=4)

In [7]:
from torch.optim import AdamW
import tqdm

optimizer = AdamW(model.parameters(), lr=5e-5)

from transformers import get_scheduler

num_epochs = 3
# num_training_steps = num_epochs * len(dataset_docs_train_dataloader)
num_training_steps = num_epochs * 2

lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))

epochs = 3

model.train()
for epoch in range(num_epochs):
    for batch in dataset_docs_train_dataloader:
        batch = {k: torch.squeeze(v , 1) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)


  0%|                                                     | 0/6 [06:47<?, ?it/s][A

 17%|███████▌                                     | 1/6 [00:45<03:48, 45.61s/it][A
 33%|███████████████                              | 2/6 [01:28<02:55, 43.92s/it][A
 50%|██████████████████████▌                      | 3/6 [02:13<02:13, 44.43s/it][A
 67%|██████████████████████████████               | 4/6 [02:53<01:25, 42.82s/it][A
 83%|█████████████████████████████████████▌       | 5/6 [03:32<00:41, 41.22s/it][A
100%|█████████████████████████████████████████████| 6/6 [04:13<00:00, 41.29s/it][A
7it [04:57, 42.01s/it]                                                          [A
8it [05:35, 40.78s/it][A
9it [06:11, 39.52s/it][A
10it [06:51, 39.41s/it][A
11it [07:28, 38.74s/it][A
12it [08:07, 38.85s/it][A
13it [08:47, 39.22s/it][A
14it [09:26, 39.07s/it][A
KeyboardInterrupt

