In [1]:
import pandas as pd

In [11]:
import torch
import numpy as np
import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from torch.utils.data import DataLoader

In [43]:
# exp004
import dataclasses
import torch
import pandas as pd
from sklearn.model_selection import KFold
from transformers import AutoModel, AutoTokenizer, AutoModelForMultipleChoice
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer
from datasets import Dataset
from typing import Optional, Union

import transformers
import wandb
from datetime import datetime as dt
import os
import numpy as np
import tqdm

import logging 
from logging import Logger



@dataclasses.dataclass
class BertConfig:
    
    experiment_name: str
    dataset_dir: str
    
    debug: bool = False

    lr: float = 1e-5
    model_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
    num_context: int = 3
    max_length: int = 512
    batch_size: int = 2
    epochs: int = 10
    iters_to_accumlate: int = 8
    weight_decay: float = 0.01
    warmup_ratio: float = 0.1
    
    freeze_embeddings: bool = True
    freeze_layers: int = 18
    reinitialize_layers: int = 0
    
    assume_completely_retrieved: bool = False
    n_samples: int = None
    steps: int = 100
    
    lora_r: float = 2
    lora_alpha: float = 4
    lora_dropout: float = 0.1
    use_peft: bool = False

def get_logger(
    output_dir: str,
):
    """
    logger を作成する. formatter は "%Y-%m-%d %H:%M:%S" で作成する.
    """
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    
    # formatter
    formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
    
    # handler
    handler = logging.StreamHandler()
    handler.setLevel(logging.DEBUG)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    
    handler = logging.FileHandler(f"{output_dir}/log.txt", "w")
    handler.setLevel(logging.DEBUG)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    
    return logger


@dataclasses.dataclass
class DataCollatorForMultipleChoice:
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    
    def __call__(self, features):
        label_name = 'label' if 'label' in features[0].keys() else 'labels'
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]['input_ids'])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])
        
        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors='pt',
        )
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch['labels'] = torch.tensor(labels, dtype=torch.int64)
        return batch


        
def preprocess_df(df, config):
    df["context"] = ""
    for i in range(config.num_context):
        df["context"] += df[f"searched_wiki_id_{i}"].astype(str) + "\n\n"
    
    for col in ["A", "B", "C", "D", "E"]:
        df[col] = df[col].fillna("")
    return df[["prompt", "context", "A", "B", "C", "D", "E", "answer"]]


def map_at_3(predictions, labels):
    map_sum = 0
    pred = np.argsort(-1*np.array(predictions),axis=1)[:,:3]
    for x,y in zip(pred,labels):
        z = [1/i if y==j else 0 for i,j in zip([1,2,3],x)]
        map_sum += np.sum(z)
    return map_sum / len(predictions)

def compute_metrics(p):
    predictions = p.predictions.tolist()
    labels = p.label_ids.tolist()
    return {"map@3": map_at_3(predictions, labels)}


In [44]:
df_test = pd.read_parquet("../output/context_pipeline/stage1/exp009.py/20230922162941_gte-base_wikiall_without_sep_targetprompt_and_choice_without_sep_token_length120_stride_sentence4_drop_categoryTrue_all/valid.parquet")

In [45]:
df_test

Unnamed: 0.1,Unnamed: 0,prompt,A,B,C,D,E,answer,source,dataset,...,searched_wiki_id_0,searched_wiki_id_1,searched_wiki_id_2,searched_wiki_id_3,searched_wiki_id_4,searched_wiki_id_5,searched_wiki_id_6,searched_wiki_id_7,searched_wiki_id_8,searched_wiki_id_9
68541,68541,"When was the album ""Bodysong"" by Jonny Greenwo...","October 27, 2004 in the UK and February 24, 20...","October 28, 2003 in the UK and February 23, 20...","October 27, 2003 in the UK and February 24, 20...","October 24, 2004 in the UK and February 27, 20...","October 24, 2003 in the UK and February 27, 20...",C,1,valid,...,#Bodysong (album)\nBodysong is the debut solo ...,#Bodysong (album)\nBodysong is the debut solo ...,#Bodysong\nBodysong is a 2003 BAFTA-winning do...,"#Jonny Greenwood\n ""Jonny Greenwood Chart Hist...",#Bodysong\nBodysong is a 2003 BAFTA-winning do...,#Jonny Greenwood\n ISSN0261-3077. Archived fro...,#Jonny Greenwood\n 20 May 2019. Archived from ...,"#Jonny Greenwood\n ""Caught in the flash"". The ...",#Bodysong (album)\nthemovieblog.com. Retrieved...,#Jonny Greenwood\nJonathan Richard Guy Greenwo...
68542,68542,What is the primary function of mitochondria?,Cell death and apoptosis,Membrane formation and organelle structure,Energy generation through respiration,Cell differentiation and specialization,Cell signaling and communication,C,1,valid,...,#Cell biology\n Mitochondria are commonly refe...,"#Mitochondrion\n75 and 3μm2 in cross section, ...","#Cell biology\n Moreover, researchers have gai...",#Mitochondrion\n A dominant role for the mitoc...,"#Mitochondrion\ne., phosphorylation of ADP), t...",#Cell biology\n Its physiological adaptability...,"#Cell biology\n It, therefore, acts as a found...",#Mitochondrion\n These folds are studded with ...,#Cell biology\n The inner mitochondrial membra...,#Cell biology\n These products are involved in...
68543,68543,What is the role of Thomas Shawn Kleeh in the ...,Thomas Shawn Kleeh is the Chief United States ...,Thomas Shawn Kleeh is a judge at the Supreme C...,Thomas Shawn Kleeh is a defense attorney speci...,Thomas Shawn Kleeh is the Chief United States ...,Thomas Shawn Kleeh is a prosecutor for the Dep...,D,1,valid,...,#Thomas Samuel Zilly\nThomas Samuel Zilly (bor...,#United States District Court for the Southern...,#United States District Court for the Northern...,#United States District Court for the Northern...,#United States District Court for the Western ...,#Judiciary of Virginia\n Its administration is...,#United States District Court for the District...,#United States District Court for the District...,#United States District Court for the Western ...,#United States District Court for the District...
68544,68544,What are the highest points of the Monte Albo ...,Punta Catirina and Monte Boe,Punta Norina and Monte Spina,Punta Catirina and Monte Spina,Punta Catirina and Monte Turuddo,Punta Norina and Monte Boe,D,1,valid,...,#Monte Albo\nThe Monte Albo (Monte Arbu in Sar...,#Monte Alben\n It is formed by a mountainous m...,"#Sardinia\n Due to long erosion processes, the...",#Monte Albo\nThe Monte Albo (Monte Arbu in Sar...,#Sardinia\n The highest peak is Punta La Marmo...,#Mount Alvernia\nMount Alvernia (formerly Como...,#Val Chisone\n Some of the most important moun...,#Monte Renoso massif\n These are (from north t...,#Monte Linas\nMonte Linas is a massif in the p...,#Sardinia\n The island has an ancient geoforma...
68545,68545,Which of the following accurately describes th...,The Delta IV Heavy is the third highest-capaci...,The Delta IV Heavy is the second highest-capac...,The Delta IV Heavy is the world's highest-capa...,The Delta IV Heavy is the fourth highest-capac...,The Delta IV Heavy is the world's second highe...,A,1,valid,...,#Delta IV Heavy\nThe Delta IV Heavy (Delta 925...,#List of Delta IV Heavy launches\nThe followin...,#List of Delta IV Heavy launches\n The Delta I...,#List of Delta IV Heavy launches\n The ULA Del...,"#Delta IV Heavy\n That's current and future, a...",#Delta IV Heavy\n The Delta IV Heavy is the mo...,"#Delta IV Heavy\n ""Falcon Heavy, SpaceX's Big ...",#Delta IV Heavy\n Star 48BV upper stage Curren...,#Delta IV Heavy\n On the last seconds of count...,#Delta IV Heavy\n This can be compared with th...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
72436,72436,What do these two changes have in common?\ndee...,Both are caused by cooling.,Both are chemical changes.,Both are only physical changes.,California toad,Ginkgo trees have flat leaves.,B,additional_data/ScienceQA/test.parquet,valid,...,#Fire ecology\n It could also be due to the in...,#Wildfire\n The spread of wildfires varies bas...,"#Microwave burn\n Damage to the A beta fibers,...",#Fried chicken\n Once the pieces have been add...,#Fried chicken\n There is debate as to how oft...,#Fire ecology\n Fires can cause changes in soi...,#Wildfire\n Less dense material such as grasse...,"#Fried chicken\n Generally, the fat is heated ...",#Creosote\n Over the course of a season creoso...,#Spontaneous human combustion\n The protein in...
72437,72437,What is the mass of an eraser?,2 tons,2 ounces,2 pounds,Does a big toy car go down the wooden ramp fas...,ad hominem: an attack against the person makin...,B,additional_data/ScienceQA/test.parquet,valid,...,"#Mass versus weight\nIn common usage, the mass...","#Size\n In scientific contexts, mass refers lo...",#Elastic collision\n Before collision Ball 1: ...,"#Relativist fallacy\n Take, for example, the s...","#Mass versus weight\n Conversely, the load ind...",#Elastic collision\n5 m/s Ball 2: velocity = 1...,"#Mass versus weight\n For instance, billiard b...",#Proof mass\nA proof mass or test mass is a kn...,#Physics\n Simp. – His language would seem to ...,"#Mass versus weight\n Usually, the relationshi..."
72438,72438,Which type of sentence is this?\nAs Bert sat d...,complex,simple,compound,compound-complex,Tara and her biological father wear sunglasses...,A,additional_data/ScienceQA/test.parquet,valid,...,#Sentence clause structure\n Since a dependent...,#Sentence clause structure\n A compound senten...,#Topic sentence\n A complex sentence is one th...,"#Has Hlai grammar\n, Fas fun lo. sky rain acce...",#Sentence clause structure\n Sentence 4 is com...,#Sentence (linguistics)\n A simple sentence co...,#Sentence clause structure\n) In the backyard ...,#Topic sentence\n As the topic sentence encaps...,"#Sentence clause structure\nIn grammar, senten...",#Structural approach\ng.: I was watching a mov...
72439,72439,"Before the Louisiana Purchase, what was the we...",the Pacific Ocean,the Mississippi River,the Rocky Mountains,the Missouri River,declarative,B,additional_data/ScienceQA/test.parquet,valid,...,#Midwestern United States\n The Ohio River run...,#Western United States\n East of the Rocky Mou...,#Missouri\n Originally the state's western bor...,#Louisiana Purchase Historic State Park\n On A...,#Western United States\n The Mississippi River...,#Louisiana Purchase\n The territory's boundari...,#Louisiana Purchase Historic State Park\n The ...,#Missouri\n This line is known as the Osage Bo...,"#Western United States\n The Columbia River, t...",#Midwestern United States\n Traditional defini...


In [37]:
logger = get_logger(output_dir="")
logger.info("load data")

config = BertConfig(
    debug=False,
    batch_size=1,
    experiment_name=f"",
    dataset_dir="",
    model_name="../output/stage2/exp005.py/20230923195407_new_data_all300val_maxlen256/fold0/",
    max_length=256,
    num_context=2,
)

2023-09-25 06:51:11,960 INFO load data
2023-09-25 06:51:11,960 INFO load data
2023-09-25 06:51:11,960 INFO load data
2023-09-25 06:51:11,960 INFO load data


In [38]:
df_test = preprocess_df(df_test, config)

In [39]:
df_test = df_test[df_test["answer"].isin(["A", "B", "C", "D", "E"])]

In [40]:
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
option_to_index = {option: idx for idx, option in enumerate('ABCDE')}
index_to_option = {v: k for k,v in option_to_index.items()}
def preprocess(example):
    first_sentence = [ "[CLS] " + example['context'] ] * 5
    second_sentences = [" #### " + example['prompt'] + " [SEP] " + example[option] + " [SEP]" for option in 'ABCDE']
    tokenized_example = tokenizer(first_sentence, second_sentences, truncation="only_first", 
                                  max_length=config.max_length, add_special_tokens=False)
    tokenized_example['label'] = option_to_index[example['answer']]
    return tokenized_example

test_dataset = Dataset.from_pandas(df_test)
tokenized_test_dataset = test_dataset.map(preprocess, remove_columns=["prompt", "context", "A", "B", "C", "D", "E", "answer"])

  0%|          | 0/3900 [00:00<?, ?ex/s]

In [41]:
data_collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
test_dataloader = DataLoader(tokenized_test_dataset, batch_size=1, shuffle=False, collate_fn=data_collator)

In [None]:
model = AutoModelForMultipleChoice.from_pretrained(config.model_name).cuda()

In [42]:
test_predictions = []
for batch in tqdm.tqdm(test_dataloader):
    for k in batch.keys():
        batch[k] = batch[k].cuda()
    with torch.no_grad():
        outputs = model(**batch)
    test_predictions.append(outputs.logits.cpu().detach())

test_predictions = torch.cat(test_predictions)
test_predictions = test_predictions.numpy()

  0%|                                                | 0/3900 [00:00<?, ?it/s]


[{'__index_level_0__': 68541, 'input_ids': [[1, 953, 34649, 22313, 287, 48216, 285, 5907, 22313, 269, 262, 3988, 4872, 1898, 293, 59252, 14330, 36790, 26033, 260, 325, 269, 262, 12159, 264, 262, 6186, 933, 265, 262, 454, 601, 260, 325, 284, 1315, 277, 1370, 1824, 261, 3037, 267, 262, 1222, 263, 277, 1555, 969, 261, 2860, 267, 262, 780, 1017, 260, 325, 284, 17543, 55424, 263, 49867, 277, 3381, 263, 6778, 277, 940, 903, 692, 260, 26033, 280, 268, 2387, 261, 59252, 26623, 10998, 261, 2787, 5839, 277, 262, 1898, 260, 325, 269, 7315, 270, 282, 262, 362, 307, 268, 12274, 309, 1898, 1315, 293, 356, 1034, 265, 59252, 260, 5327, 261, 10998, 26033, 287, 2252, 64063, 280, 268, 2387, 285, 269, 262, 364, 59252, 1034, 272, 303, 298, 729, 1315, 356, 4872, 1146, 260, 279, 1302, 307, 8979, 8919, 58984, 309, 284, 427, 267, 26033, 280, 268, 2271, 270, 262, 2097, 933, 443, 1887, 1498, 7367, 260, 279, 7389, 265, 262, 1362, 641, 275, 27970, 265, 501, 26033, 1302, 261, 307, 28411, 28277, 2479, 3433, 297, 341