In [1]:
import pickle
import warnings
from dataclasses import dataclass, field
from typing import List
from torch.utils.data import DataLoader
import os
from icl.utils.other import dict_to
from transformers.hf_argparser import HfArgumentParser
import torch
import torch.nn.functional as F
from icl.lm_apis.lm_api_base import LMForwardAPI
from icl.utils.data_wrapper import wrap_dataset, tokenize_dataset, wrap_dataset_with_instruct
from icl.utils.load_huggingface_dataset import load_huggingface_dataset_train_and_test
from icl.utils.random_utils import set_seed
from icl.utils.other import load_args, set_gpu, sample_two_set_with_shot_per_class
from transformers import Trainer, TrainingArguments, PreTrainedModel, AutoModelForCausalLM, \
    AutoTokenizer, DataCollatorForLanguageModeling, DataCollatorWithPadding
from icl.utils.load_local import convert_path_old, load_local_model_or_tokenizer, get_model_layer_num
from icl.util_classes.arg_classes import DeepArgs
from icl.util_classes.predictor_classes import Predictor
from icl.utils.prepare_model_and_tokenizer import load_model_and_tokenizer, get_label_id_dict_for_args

In [3]:
args = DeepArgs(device='cuda:4',model_name='gpt2-xl',task_name='trec')

In [None]:
model, tokenizer = load_model_and_tokenizer(args)

In [None]:
set_gpu(args.gpu)
if args.sample_from == 'test':
    dataset = load_huggingface_dataset_train_and_test(args.task_name)
else:
    raise NotImplementedError(f"sample_from: {args.sample_from}")

args.label_id_dict = get_label_id_dict_for_args(args,tokenizer)

model = LMForwardAPI(model=model, model_name=args.model_name, tokenizer=tokenizer,
                        device='cuda:0',
                        label_dict=args.label_dict)

training_args = TrainingArguments("./output_dir", remove_unused_columns=False,
                                    per_device_eval_batch_size=1,
                                    per_device_train_batch_size=1)
num_layer = get_model_layer_num(model=model.model, model_name=args.model_name)


In [6]:
from icl.util_classes.predictor_classes import Predictor

predictor = Predictor(label_id_dict=args.label_id_dict, pad_token_id=tokenizer.pad_token_id,
                        task_name=args.task_name, tokenizer=tokenizer,layer=num_layer)

In [7]:
from importlib import reload


from icl.utils.data_wrapper import wrap_dataset, wrap_dataset_with_instruct

from icl.util_classes.context_solver import ContextSolver
from icl.utils.other import TensorStrFinder

tensor_str_finder = TensorStrFinder(tokenizer=tokenizer)

context_solver = ContextSolver(task_name=args.task_name,tokenizer=tokenizer)



def prepare_analysis_dataset(seed):
    demonstration, _ = sample_two_set_with_shot_per_class(dataset['train'],
                                                            args.demonstration_shot,
                                                            0, seed, label_name='label',
                                                            a_total_shot=args.demonstration_total_shot)
    if args.sample_from == 'test':
        if len(dataset['test']) < args.actual_sample_size:
            args.actual_sample_size = len(dataset['test'])
            warnings.warn(
                f"sample_size: {args.sample_size} is larger than test set size: {len(dataset['test'])},"
                f"actual_sample_size is {args.actual_sample_size}")
        test_sample = dataset['test'].shuffle(seed=seed).select(range(args.actual_sample_size))
        demo_dataset = wrap_dataset(test_sample, demonstration, args.label_dict,
                                        args.task_name)
        demo_dataset = tokenize_dataset(demo_dataset, tokenizer)

        context = demo_dataset[0]['sentence']
        instruct = context_solver.get_empty_demo_context(context,only_demo_part=True)

        empty_demo_dataset = wrap_dataset_with_instruct(test_sample, instruct, args.label_dict,
                                        args.task_name)
        empty_demo_dataset = tokenize_dataset(empty_demo_dataset,tokenizer)

        no_demo_dataset = wrap_dataset(test_sample, [], args.label_dict,
                                                args.task_name)
        no_demo_dataset = tokenize_dataset(no_demo_dataset, tokenizer)
    else:
        raise NotImplementedError(f"sample_from: {args.sample_from}")

    return demo_dataset,empty_demo_dataset, no_demo_dataset


In [None]:
import numpy as np
from datasets import concatenate_datasets
from datasets.utils.logging import disable_progress_bar
import random

disable_progress_bar()

demonstration, _ = sample_two_set_with_shot_per_class(dataset['train'],
                                                        64,
                                                        0, 42, label_name='label',
                                                        a_total_shot=args.demonstration_total_shot)
empty_test_sample = dataset['test'].select([0])
empty_test_sample = empty_test_sample.map(lambda x:{k:v if k != 'text' else '' for k,v in x.items()})
class_num = len(set(demonstration['label']))
np_labels = np.array(demonstration['label'])
ids_for_demonstrations = [np.where(np_labels == class_id)[0] for class_id in range(class_num)]
demonstrations_contexted = []
for i in range(max(map(len,ids_for_demonstrations))):
    demonstration_part_ids = []
    for _ in ids_for_demonstrations:
        if i < len(_):
            demonstration_part_ids.append(_[i])
    demonstration_part = demonstration.select(demonstration_part_ids)
    # demonstration_part = wrap_dataset(empty_test_sample, demonstration_part, args.label_dict,
    #                                             args.task_name)
    demonstration_part = wrap_dataset(dataset['test'].select([i]), demonstration_part, args.label_dict,
                                                args.task_name)
    demonstrations_contexted.append(demonstration_part)
demonstrations_contexted = concatenate_datasets(demonstrations_contexted)
demonstrations_contexted = tokenize_dataset(demonstrations_contexted,tokenizer=tokenizer)

demonstrations_contexted2 = []
for i in range(len(dataset['test'])):
    demonstration_part_ids = []
    a = i % 64
    for _ in ids_for_demonstrations:
        demonstration_part_ids.append(_[a])
    demonstration_part = demonstration.select(demonstration_part_ids)
    demonstration_part = wrap_dataset(dataset['test'].select([i]), demonstration_part, args.label_dict,
                                                args.task_name)
    demonstrations_contexted2.append(demonstration_part)
demonstrations_contexted2 = concatenate_datasets(demonstrations_contexted2)
demonstrations_contexted2 = tokenize_dataset(demonstrations_contexted2,tokenizer=tokenizer)

In [34]:
from icl.analysis.qkv_getter import QKVGetterManger, cal_results, prepare_analysis_dataset

In [35]:
try:
    qkvgettermanager.unregister()
except:
    pass
qkvgettermanager = QKVGetterManger(model=model,predictor=predictor)

In [None]:
from functools import partial


model.results_args = {'output_hidden_states': True,
                      'output_attentions': True, 'use_cache': True}
model.probs_from_results_fn = None
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer, pad_to_multiple_of=1,max_length=1024)
trainer = Trainer(model=model, args=training_args,data_collator=data_collator)
from icl.utils.data_wrapper import remove_str_columns
data = demonstrations_contexted
data = remove_str_columns(data)
_1 = trainer.predict(data,ignore_keys=['results'])

data = demonstrations_contexted2
data = remove_str_columns(data)
_2 = trainer.predict(data,ignore_keys=['results'])

In [12]:
def get_keys(y,layer,head = None,qkv_id = 0):
    keys = []
    for _ in y.predictions[-1][layer][qkv_id]:
        if head is None:
            keys.append(_.reshape(_.shape[0],-1))
        else:
            keys.append(_[:,head,0,:])
    return keys

In [41]:
from icl.utils.visualization import _plot_confusion_matrix

In [43]:
from sklearn.decomposition import PCA

In [44]:
select_indices = [0,0,0,0,0,0]

In [None]:
args.actual_sample_size  = 1000
model.probs_from_results_fn = None
def select_demonstrations_from_indices(demonstrations, ids_for_demonstrations,indices):
    demonstration_part_ids = [ids_for_demonstrations[i][indices[i]] for i in range(len(indices))]
    demonstration_part = demonstrations.select(demonstration_part_ids)
    return demonstration_part
demonstration_part = select_demonstrations_from_indices(demonstration, ids_for_demonstrations,select_indices)
y = cal_results(demonstraions=demonstration_part,model=model,tokenizer= tokenizer,training_args=training_args,args=args,seed=args.seeds[0],dataset=dataset)

In [46]:
from sklearn.metrics import roc_auc_score


def cal_roc_auc(probs,labels):
    N = len(np.unique(labels))
    confusion_matrix = np.zeros((N,N))
    for class_a in range(N):
        for class_b in range(N):
            if class_a == class_b:
                confusion_matrix[class_a,class_b] = 1.
                continue
            mask = (labels == class_a) | (labels == class_b)
            confusion_matrix[class_a,class_b] = roc_auc_score(labels[mask] == class_a,probs[mask][:,class_a]/probs[mask][:,class_b])
    confusion_matrix = np.round(confusion_matrix,decimals=2)
    return confusion_matrix

In [None]:
probs = y[1].predictions[0]
labels = y[0]
_plot_confusion_matrix(cal_roc_auc(probs,labels),classes=args.label_dict.values(),title='ROC-AUC')

In [58]:
layer = 32
head = None
select_test = list(range(0,500))
keys = get_keys(_1,layer,head,1)[:-1]
querys = get_keys(_2,layer,head,0)[-1][select_test]
pca = PCA(n_components=10, random_state=42,whiten=False)
svd_querys = pca.fit(querys)
pca_keys = [(key@pca.components_.T)*pca.singular_values_.reshape(1,-1) for key in keys]

In [None]:
from scipy.spatial.distance import pdist, squareform

distances = squareform(pdist(np.vstack([pca_key[select_indices[i]] for i, pca_key in enumerate(pca_keys)]), metric='euclidean'))
print(distances)
distances = (distances) / (distances.max())
distances =  np.round(distances,decimals=2)
np.fill_diagonal(distances, 1)
_plot_confusion_matrix(distances, args.label_dict.values(), title='Distance Matrix', cmap="YlGnBu", fontsize=16)