In [None]:
import os
from torch.optim import AdamW
import torch
import datasets
from transformers import AutoTokenizer
from transformers import AutoModelForQuestionAnswering, AutoModel, AutoModelForSeq2SeqLM
import json, re
from tqdm.auto import tqdm
import numpy as np
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
from evaluate import load
import evaluate
from UniEval.utils import convert_to_json
from UniEval.metric.evaluator import get_evaluator
from evaluate_models import evaluate_sentence_output, process_data_nat_inst, preprocess_function_nat_inst


cache_dir = "/scratches/dialfs/alta/hln35/.cache"
os.environ['TRANSFORMERS_CACHE'] = '/scratches/dialfs/alta/hln35/.cache'

model_small = "google/flan-t5-small"
if torch.cuda.is_available() == False:
    raise Exception("Cuda is not available, please enable cuda")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

tokenizer = AutoTokenizer.from_pretrained(model_small, cache_dir=cache_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_small, cache_dir=cache_dir).to(device)

In [2]:
max_input_length = 1024
max_target_length = 128

In [3]:
dataset = load_dataset('json', data_files='/scratches/dialfs/alta/hln35/natural-instructions/tasks/task706_mmmlu_answer_generation_high_school_mathematics.json')

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['Contributors', 'Source', 'URL', 'Categories', 'Reasoning', 'Definition', 'Input_language', 'Output_language', 'Instruction_language', 'Domains', 'Positive Examples', 'Negative Examples', 'Instances', 'Instance License'],
        num_rows: 1
    })
})

In [5]:
for k in dataset["train"].features:
    print(k)

Contributors
Source
URL
Categories
Reasoning
Definition
Input_language
Output_language
Instruction_language
Domains
Positive Examples
Negative Examples
Instances
Instance License


In [6]:
dataset["train"]["Definition"]


[['You are given a question on high school mathematics. You are also given 4 answer options (associated with "A", "B", "C", "D"), out of which only one is correct. You need to answer the question by selecting the correct option. You should only answer with the choice letter, not the whole answer.']]

In [7]:
dataset["train"]["URL"]

[['https://github.com/hendrycks/test']]

In [8]:
dataset_formatted = Dataset.from_list(dataset["train"]["Instances"][0])

In [9]:
dataset_formatted

Dataset({
    features: ['id', 'input', 'output'],
    num_rows: 181
})

In [10]:
raw_datasets = load_dataset("samsum", cache_dir=cache_dir)

In [11]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
})

In [10]:
def preprocess_function(examples):
    # inputs = [prefix + doc for doc in examples["document"]]
    inputs = [prefix + doc for doc in examples["input"]]
    
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    return model_inputs

In [11]:
prefix = dataset["train"]["Definition"][0][0]
tokenized_datasets = dataset_formatted.map(preprocess_function, batched=True)

Map:   0%|          | 0/181 [00:00<?, ? examples/s]

In [12]:
tokenized_datasets

Dataset({
    features: ['id', 'input', 'output', 'input_ids', 'attention_mask'],
    num_rows: 181
})

In [13]:
with open("/scratches/dialfs/alta/hln35/natural-instructions/splits/default/train_tasks.txt", "r") as file: 
    task_list = file.read().split("\n")

In [14]:
task_list

['task547_alt_translation_entk_en',
 'task706_mmmlu_answer_generation_high_school_mathematics',
 'task1565_triviaqa_classification',
 'task701_mmmlu_answer_generation_high_school_computer_science',
 'task698_mmmlu_answer_generation_global_facts',
 'task104_semeval_2019_task10_closed_vocabulary_mathematical_answer_generation',
 'task926_coached_conv_pref_word_generation',
 'task326_jigsaw_classification_obscene',
 'task1498_24hour_to_12hour_clock',
 'task1731_quartz_question_answering',
 'task1453_person_entity_extraction_btc_corpus',
 'task1399_obqa_answer_generation',
 'task1286_openbookqa_question_answering',
 'task165_mcscript_question_answering_commonsense',
 'task610_conllpp_ner',
 'task864_asdiv_singleop_question_answering',
 'task385_socialiqa_incorrect_answer_generation',
 'task1608_xquad_en_answer_generation',
 'task337_hateeval_classification_individual_en',
 'task563_discofuse_answer_generation',
 'task023_cosmosqa_question_generation',
 'task607_sbic_intentional_offense_bin

In [None]:
dataset_dict = process_data_nat_inst(dataset)
print(f"Categories: {dataset_dict['Categories']}")
print(f"Input language: {dataset_dict['Input_language']}, Output language: {dataset_dict['Output_language']}")
prefix = dataset_dict["Definition"][0][0]
raw_datasets = dataset_dict["Instances"]
# tokenized_datasets = raw_datasets.map(preprocess_func, batched=True)
tokenized_datasets = raw_datasets.map(preprocess_function_nat_inst, batched=True)
# labels = tokenized_datasets["label"]
labels = tokenized_datasets["output"]
# print(labels)
# labels = tokenizer.batch_decode(tokenized_datasets["label"], skip_special_tokens=True)
test_input_ids = tokenized_datasets["input_ids"]

In [None]:
model_list = ["google/flan-t5-small", "/scratches/dialfs/alta/hln35/distillation/model/flant5_small_lr_10-4_race_finetuning_epoch2", "/scratches/dialfs/alta/hln35/distillation/model/flant5_small_lr_10-4_race_distill_epoch2", "google/flan-t5-large", "google/flan-t5-base"]
# model_list = ["google/flan-t5-base"]
for model in model_list:

        # model_small_ewc = f"/scratches/dialfs/alta/hln35/model/flant5_small_lr_10-4_race_ewc_after_translation_importance_{'{:.0e}'.format(importance)}_epoch{epoch}"
        # print(model_small_ewc)
        print(model)
        # model_small_ewc = AutoModelForSeq2SeqLM.from_pretrained(model_small_ewc, local_files_only=True).to(device)
        model_small_ewc = AutoModelForSeq2SeqLM.from_pretrained(model, cache_dir=cache_dir, local_files_only=True).to(device)
        # print(f"For importance {importance} epoch {epoch}, the average score is: ")
        print(f"For model {model}, the average score is: ")
        
        # evalute_summary(model_small_ewc, tokenizer, test_input_ids, labels)
        # evalute_word_ouput(model_small_ewc, tokenizer, test_input_ids, labels, decoder_dict)
        evaluate_sentence_output(model_small_ewc, tokenizer, test_input_ids, labels)
        evaluate_sentence_output(model_small_ewc, tokenizer, test_input_ids, labels, evaluator=em_evaluator)