In [1]:
from transformers import Trainer, TrainingArguments
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
import torch
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from datasets import load_dataset, load_metric
from rouge import Rouge
import numpy as np

In [None]:
model_name = "gpt2-medium"
tokenizer = GPT2Tokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.bos_token
model = GPT2LMHeadModel.from_pretrained(model_name)
# config = GPT2Config()
# model = GPT2LMHeadModel(config)
# checkpoint = torch.load("/shared/data2/minhaoj2/gpt-2-original/pytorch_model.bin")
# model.load_state_dict(checkpoint)
model.eval()
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model.to(device)

print("Preparing Model ...")

In [None]:
topic_prompt_list = ["This text is"]

In [None]:
def evaluate_agnews_1(model, tokenizer, device=device):
    def classify_text(text, possible_outputs):
        framed_texts = [f"{text} This text is {output}." for output in possible_outputs]
        encoded_inputs = [tokenizer.encode(t, return_tensors="pt").to(device) for t in framed_texts]
        
        logits_for_outputs = []

        for encoded_input in encoded_inputs:
            with torch.no_grad():
                outputs = model(encoded_input)
                logits = outputs.logits
            logits_for_outputs.append(logits[0, -1, :].squeeze().cpu().numpy())
        
        token_ids = [tokenizer.encode(output)[0] for output in possible_outputs]
        class_logits = [logits[token_id] for logits, token_id in zip(logits_for_outputs, token_ids)]
        return possible_outputs[class_logits.index(max(class_logits))]

    possible_outputs = ["world", "sports", "business", "sci/tech"]
    dataset = load_dataset("ag_news", split="test").select(range(100))

    ground_truth = []
    pred_labels = []

    for data in tqdm(dataset):
        logits = data['label']
        prediction = classify_text(data['text'], possible_outputs)
        pred_labels.append(possible_outputs.index(prediction))
        ground_truth.append(logits)
    return accuracy_score(ground_truth, pred_labels)

evaluate_agnews_1(model, tokenizer)

In [None]:
def evaluate_agnews(model, tokenizer, prompt_list, device=device):
    print("Evaluating on AG News Dataset")
    possible_outputs = ["world", "sports", "business", "sci/tech"]
    dataset = load_dataset("ag_news", split="test").select(range(100))
    res = []
    for prompt in prompt_list:
        def classify_text(example):
            text = example['text']
            framed_texts = [f"{text} {prompt} {output}." for output in possible_outputs]
            encoded_inputs = [tokenizer.encode(t, return_tensors="pt").to(device) for t in framed_texts]
            
            logits_for_outputs = []

            for encoded_input in encoded_inputs:
                with torch.no_grad():
                    outputs = model(encoded_input)
                    logits = outputs.logits
                logits_for_outputs.append(logits[0, -1, :].squeeze().cpu().numpy())
            
            token_ids = [tokenizer.encode(output)[0] for output in possible_outputs]
            class_logits = [logits[token_id] for logits, token_id in zip(logits_for_outputs, token_ids)]
            pred = possible_outputs[class_logits.index(max(class_logits))]
            example['prediction'] = possible_outputs.index(pred)
            return example
    
        train_data = dataset.map(classify_text)
        predictions = train_data['prediction']
        ground_truth = []

        for data in train_data:
            logits = data['label']
            ground_truth.append(logits)
        acc = accuracy_score(ground_truth, predictions)
        res.append(acc)
    return res

res = evaluate_agnews(model, tokenizer, topic_prompt_list)
print(res)

In [None]:
import numpy as np
import os

In [None]:
dir = "./results/classification/"
files = os.listdir(dir)
for file in files:
    num = []
    with open(dir + file, 'r') as f:
        lines = f.readlines()
        for i in range(len(lines) - 1):
            res = lines[i].strip('\n').split('\t')[1]
            num.append(res)
    num = np.array(num).astype(float)
    mean = np.mean(num)
    
    std = np.std(num)
    with open(dir + file, 'a') as f:
        f.write(f" Mean: {mean}, STD: {std}\n")
        if file.split('-')[-1].startswith('sst2'):
            f.write("Total samples: 67349")
        else:
            f.write("Total samples: 7600")