In [7]:
from icl.analysis.compress import CompressArgs
from sklearn.metrics import accuracy_score
import torch

In [12]:
def get_label(y):
    return y.predictions[0].argmax(-1)

def get_logits(y):
    if y.predictions[2].shape[-1] > 30000:
        return y.predictions[2]
    else:
        return y.predictions[3]

def get_topk(y, k):
    logits = get_logits(y)
    indices = np.argpartition(logits, -k,axis=1)[:,-k:]
    return indices

def get_sorted_topk(y, k):
    logits = get_logits(y)
    indices = torch.topk(torch.tensor(logits),k)[1].numpy()
    return indices

def jaccard(a,b):
    scores = []
    for single_a, single_b in zip(a,b):
        set_a = set(single_a)
        set_b = set(single_b)
        score = len(set_a.intersection(set_b))/len(set_a.union(set_b))
        scores.append(score)
    return np.array(scores).mean()

In [None]:
from tqdm import tqdm
import numpy as np
from icl.utils.load_huggingface_dataset import load_huggingface_dataset_train_and_test
import warnings

def get_true_label(dataset, seed, actual_sample_size):
    test_sample = dataset['test'].shuffle(seed=seed).select(range(actual_sample_size))
    labels = np.array(test_sample['label'])
    return labels

def calculate_average_scores(seeds, task_name,actual_sample_size=1000):
    scores = []
    accs = []
    jaccards = []
    dataset = load_huggingface_dataset_train_and_test(task_name)
    if len(dataset['test']) < actual_sample_size:
        warnings.warn(
            f"sample_size: {actual_sample_size} is larger than test set size: {len(dataset['test'])},"
            f"actual_sample_size is {len(dataset['test'])}")
        actual_sample_size = len(dataset['test'])
    
    for seed in tqdm(seeds):
        label = get_true_label(dataset=dataset,seed=seed,actual_sample_size=actual_sample_size)
        args = CompressArgs(task_name=task_name, seeds=[seed],model_name='gpt-j-6b')
        y1, y2, y3, y4 = args.load_result()[0]
        label1, label2, label3, label4 = [get_label(_) for _ in [y1, y2, y3, y4]]
        score1 = accuracy_score(label2, label1) # Hidden_anchor
        score2 = accuracy_score(label2, label3) # Text_anchor
        score3 = accuracy_score(label2, label4) # Hidden_random
        scores.append((score1, score2, score3)) # Label Loyalty

        acc0  = accuracy_score(label, label2)
        acc1 = accuracy_score(label, label1)
        acc2 = accuracy_score(label, label3)
        acc3 = accuracy_score(label, label4)
        accs.append((acc0,acc1, acc2, acc3)) # Acc.

        jaccard_1 = jaccard(get_topk(y1, 5), get_topk(y2, 5))
        jaccard_2 = jaccard(get_topk(y3, 5), get_topk(y2, 5))
        jaccard_3 = jaccard(get_topk(y4, 5), get_topk(y2, 5))
        jaccards.append((jaccard_1, jaccard_2, jaccard_3)) # Word Loyalty

    average_scores = np.mean(scores, axis=0)
    average_accs = np.mean(accs, axis=0)
    average_jaccards = np.mean(jaccards, axis=0)
    return average_scores, average_accs, average_jaccards

seeds = [42, 43, 44, 45, 46]
average_scores = calculate_average_scores(seeds, 'agnews')