In [2]:
%load_ext autoreload
%autoreload 2
import os
import copy
import numpy as np
import json
import argparse
import random
import scipy
import config
from GPT import GPT
from LLAMA import LLAMA
from StimulusModel import LMFeatures
from utils_stim import get_story_wordseqs
# from utils_resp import get_resp
from utils_ridge.ridge import ridge, bootstrap_ridge, ridge_corr
from utils_ridge.ridge_torch import ridge_torch, bootstrap_ridge_torch, ridge_corr_torch
from utils_ridge.stimulus_utils import TRFile, load_textgrids, load_simulated_trfiles
from utils_ridge.dsutils import make_word_ds
from utils_ridge.interpdata import lanczosinterp2D, lanczosinterp2D_torch
from utils_ridge.util import make_delayed
from utils_ridge.utils import mult_diag, counter
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, pipeline
import utils_llama.activation as ana

import scipy
import math
import matplotlib.pyplot as plt

import time
import h5py
import pickle

import datasets

from collections import Counter

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
torch.cuda.empty_cache()
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f3c491eda30>

In [12]:
# torch.cuda.memory._record_memory_history()
class ARGS:
    def __init__(self):
        self.subject = 'S1'
        self.gpt = 'perceived'
        self.sessions = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 20]
        self.layer = 17
        self.layer2 = 18
        self.act_name = 'ffn_gate'
        self.window = 15
        self.chunk = 4

args = ARGS()

# # training stories
# stories = []
# with open(os.path.join(config.DATA_TRAIN_DIR, "sess_to_story.json"), "r") as f:
#     sess_to_story = json.load(f) 
# for sess in args.sessions:
#     stories.extend(sess_to_story[str(sess)])

# stories = stories[:10]


In [13]:
model_dir = '/ossfs/workspace/nas/gzhch/data/models/Llama-2-7b-hf'
# model = AutoModelForCausalLM.from_pretrained(
#     model_dir, 
#     device_map='auto',
#     torch_dtype=torch.float16,
# ).eval()

model = None

tokenizer = AutoTokenizer.from_pretrained(model_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

In [14]:
## load cached llm act if possible
cache_dir = '/ossfs/workspace/nas/gzhch/data/cache'
llama = LLAMA(model, tokenizer, cache_dir)

In [8]:
def load_data(task_name, n_shot=1, seed=42):
    data_dirs = {
        'xsum' : '/ossfs/workspace/nas/gzhch/data/datasets/xsum',
        'gsm8k' : '/ossfs/workspace/nas/gzhch/data/datasets/gsm8k',
        'alpaca' : '/ossfs/workspace/nas/gzhch/data/datasets/alpaca',
        'wmt' : '/ossfs/workspace/nas/gzhch/data/datasets/wmt14_de-en_test',
        'wikitext2' : '/ossfs/workspace/nas/gzhch/data/datasets/wikitext-2-v1'
    }
    if task_name == 'gsm8k':
        dataset = datasets.load_dataset(data_dirs[task_name])
    elif task_name == 'wikitext2':
        dataset = datasets.load_from_disk(data_dirs[task_name])
        dataset = dataset['train'].filter(lambda x: len(x['text'])>100) 
        dataset = dataset.select(random.sample(range(len(dataset)), 1000))

    return dataset


In [9]:
wiki_data = load_data('wikitext2')

In [8]:
def get_neuron_activation_and_loss(model, input):
    result = ana.custom_forward(model, input['input_ids'].cuda(), inspect_acts=['ffn_gate'])
    logits = result['logits']
    labels = input['input_ids']
    input_ids = input['input_ids'][:, :-1]

    # calculate loss
    shift_logits = logits[..., :-1, :].contiguous().view(-1, 32000)
    shift_labels = labels[..., 1:].contiguous().view(-1)
    loss_fct = torch.nn.CrossEntropyLoss(reduce=False)
    loss = loss_fct(shift_logits, shift_labels).view(labels.shape[0], -1)

    b = 5
    mask = input['attention_mask'][:, :-1] == 1
    loss = loss * mask + -100 * (~mask)
    input_ids = input_ids * mask + -100 * (~mask)
    expanded_loss = torch.cat([torch.ones(loss.shape[0], b) * -100, loss, torch.ones(loss.shape[0], b) * -100], dim=1)
    expanded_input_ids = torch.cat([torch.ones(input_ids.shape[0], b) * -100, input_ids, torch.ones(input_ids.shape[0], b) * -100], dim=1).int()

    # signal delay
    losses = []
    context = []
    for offset in range(2 * b):
        losses.append(expanded_loss[:, offset: offset + loss.shape[1]])
        context.append(expanded_input_ids[:, offset: offset + loss.shape[1]])
    losses = torch.stack(losses).transpose(0,1).transpose(2,1)
    context = torch.stack(context).transpose(0,1).transpose(2,1)

    ## remove padding tokens
    losses = losses.view(-1, 2 * b)[mask.flatten()]
    context = context.view(-1, 2 * b)[mask.flatten()]

    ffn_gate_all_layer = torch.stack(result['ffn_gate'])[:, :, :-1, :]
    l, bs, seq_len, h = ffn_gate_all_layer.shape
    ffn_gate_all_layer = ffn_gate_all_layer.reshape(l, bs * seq_len, h).transpose(0, 1)
    ffn_gate_all_layer = ffn_gate_all_layer[mask.flatten()]

    res = dict(context=context, loss=losses, ffn_gate=ffn_gate_all_layer)

    return res

In [144]:
batch_size = 16
total_batch = len(wiki_data) // batch_size
max_batch = 5
acts = []
for k in range(min(total_batch, max_batch)):
    input = tokenizer(wiki_data['text'][k * batch_size: (k + 1) * batch_size], return_tensors='pt', padding='longest')
    acts.append(get_neuron_activation_and_loss(model, input))

context = torch.cat([i['context'] for i in acts], dim=0).numpy()
loss = torch.cat([i['loss'] for i in acts], dim=0).numpy()
ffn_gate = torch.cat([i['ffn_gate'] for i in acts], dim=0).numpy()

cache_subdir = os.path.join(cache_dir, 'wiki', 'bs_16_0-5')
if not os.path.exists(cache_subdir):
    os.makedirs(cache_subdir, exist_ok=True)
with open(os.path.join(cache_subdir, 'context.pickle'), 'wb') as f:
    pickle.dump(context, f)
with open(os.path.join(cache_subdir, 'loss.pickle'), 'wb') as f:
    pickle.dump(loss, f)
for layer in range(ffn_gate.shape[1]):
    with open(os.path.join(cache_subdir, f'ffn_gate_{layer}.pickle'), 'wb') as f:
        pickle.dump(ffn_gate[:, layer, :], f)

In [15]:
actss = []
for s in [0, 5, 10, 15, 20, 25]:
    acts = llama.get_act(wiki_data, 
                        cache_name = 'wiki', 
                        layers = [10, 15],
                        acts = {},
                        start_batch=s, 
                        end_batch=s+5)
    actss.append(acts)
# print(acts.keys())
stim_data = {}
for k in actss[0].keys():
    stim_data[k] = np.concatenate([i[k] for i in actss])
del actss

In [16]:
args2 = copy.deepcopy(args)

args.layer = 10
args2.layer = 15

n_train=4000
alphas='adaptive'

stim = stim_data['layer_10']
resp = stim_data['layer_15']


n_total = stim.shape[0]
# ids = random.sample(range(n_total), n_train + n_test)


tokens = stim_data['context'][:, 5]
unique_tokens = []
unique_token_ids = []
for idx in range(len(tokens)):
    if tokens[idx] not in unique_tokens:
        unique_tokens.append(tokens[idx])
        unique_token_ids.append(idx)
random.shuffle(unique_token_ids)
ids = unique_token_ids

stim = torch.tensor(stim[ids]).cuda().float()
resp = torch.tensor(resp[ids]).cuda().float()
tstim, hstim = stim[:n_train], stim[n_train:]
tresp, hresp = resp[:n_train], resp[n_train:]

if alphas is None:
    alphas = torch.tensor([1 for _ in range(resp.shape[-1])]).cuda()

elif alphas == 'adaptive':
    nchunks = int(np.ceil(tresp.shape[0] / 5 / 100))
    weights, alphas, bscorrs = bootstrap_ridge_torch(tstim, tresp, use_corr = False, alphas = np.logspace(0, 3, 10),
                nboots = 3, chunklen = 100, nchunks = nchunks)        

bs_weights = ridge_torch(tstim, tresp, alphas)
bs_weights = bs_weights.to(hstim.device).to(hstim.dtype)
pred = hstim.matmul(bs_weights)
pred = pred.cpu()
hresp = hresp.cpu()

2024-01-27 23:21:30,057 - ridge_corr - INFO - Selecting held-out test set..
2024-01-27 23:21:30,121 - ridge_corr - INFO - Doing SVD...
2024-01-27 23:21:31,361 - ridge_corr - INFO - Dropped 0 tiny singular values.. (U is now torch.Size([3200, 3200]))
2024-01-27 23:21:31,362 - ridge_corr - INFO - Training stimulus has Frobenius norm: 588.050
2024-01-27 23:21:31,436 - ridge_corr - INFO - Training: alpha=1.000, mean corr=0.23708, max corr=0.88424, over-under(0.20)=6498
2024-01-27 23:21:31,440 - ridge_corr - INFO - Training: alpha=2.154, mean corr=0.32205, max corr=0.90070, over-under(0.20)=9008
2024-01-27 23:21:31,444 - ridge_corr - INFO - Training: alpha=4.642, mean corr=0.43211, max corr=0.91847, over-under(0.20)=10701
2024-01-27 23:21:31,448 - ridge_corr - INFO - Training: alpha=10.000, mean corr=0.48078, max corr=0.92184, over-under(0.20)=10954
2024-01-27 23:21:31,452 - ridge_corr - INFO - Training: alpha=21.544, mean corr=0.44349, max corr=0.90681, over-under(0.20)=10945
2024-01-27 23

In [19]:
neuron_pearson, neuron_p = [], []
for i in range(pred.shape[1]):
    stat = scipy.stats.pearsonr(pred[:, i].flatten(), hresp[:, i].flatten())
    neuron_pearson.append(stat.statistic)
    neuron_p.append(stat.pvalue)

neuron_pearson = torch.tensor(neuron_pearson)
neuron_std = (pred-hresp).std(dim=0)
# loss = torch.tensor(stim_data['loss'][ids[n_train:]])

In [35]:
ids_pearson = neuron_pearson.topk(len(neuron_pearson)).indices.tolist()
ids_std = neuron_std.topk(len(neuron_pearson)).indices.tolist()

In [39]:
n = 1000
len(set(ids_pearson[-n:] + ids_std[-n:]))

1742

In [273]:
token_pearson, token_p = [], []
for i in range(pred.shape[0]):
    stat = scipy.stats.pearsonr(pred[i, :].flatten(), hresp[i, :].flatten())
    token_pearson.append(stat.statistic)
    token_p.append(stat.pvalue)

token_pearson = torch.tensor(token_pearson)
token_std = (pred-hresp).std(dim=1)
loss = torch.tensor(stim_data['loss'][ids[n_train:]])

0 PearsonRResult(statistic=0.5115818286841631, pvalue=0.0)
1 PearsonRResult(statistic=0.5115818286841631, pvalue=0.0)
2 PearsonRResult(statistic=0.5115818286841631, pvalue=0.0)
3 PearsonRResult(statistic=0.5115818286841631, pvalue=0.0)
4 PearsonRResult(statistic=0.5115818286841631, pvalue=0.0)
5 PearsonRResult(statistic=0.5115818286841631, pvalue=0.0)
6 PearsonRResult(statistic=0.5115818286841631, pvalue=0.0)
7 PearsonRResult(statistic=0.5115818286841631, pvalue=0.0)
8 PearsonRResult(statistic=0.5115818286841631, pvalue=0.0)
9 PearsonRResult(statistic=0.5115818286841631, pvalue=0.0)


In [285]:
hresp.max(dim=1).values.shape

torch.Size([4956])

In [51]:
k = 500

pos = -1
bias = 5
for pos in range(-5, 5):
# for k in [100, 200, 300, 400, 500, 1000]:
    # sorted_index = loss[:, pos + bias].topk(len(loss)).indices
    
    test_tokens = stim_data['context'][:, pos + bias][ids[n_train:]]
    freq = torch.tensor([counter[i] for i in test_tokens])
    sorted_index = freq.topk(len(loss)).indices

    current_pearson = token_pearson[sorted_index]
    current_std = token_std[sorted_index]
    print(pos, k)
    print(current_pearson[-k:].mean(), current_pearson[-k:].std())
    print(current_pearson[:k].mean(), current_pearson[:k].std())
    print(current_std[-k:].mean(), current_std[-k:].std())
    print(current_std[:k].mean(), current_std[:k].std())

-5 500
tensor(0.6548, dtype=torch.float64) tensor(0.0770, dtype=torch.float64)
tensor(0.6630, dtype=torch.float64) tensor(0.0756, dtype=torch.float64)
tensor(0.1743) tensor(0.0201)
tensor(0.1762) tensor(0.0217)
-4 500
tensor(0.6511, dtype=torch.float64) tensor(0.0784, dtype=torch.float64)
tensor(0.6628, dtype=torch.float64) tensor(0.0724, dtype=torch.float64)
tensor(0.1764) tensor(0.0198)
tensor(0.1747) tensor(0.0207)
-3 500
tensor(0.6570, dtype=torch.float64) tensor(0.0769, dtype=torch.float64)
tensor(0.6580, dtype=torch.float64) tensor(0.0777, dtype=torch.float64)
tensor(0.1736) tensor(0.0207)
tensor(0.1748) tensor(0.0220)
-2 500
tensor(0.6663, dtype=torch.float64) tensor(0.0787, dtype=torch.float64)
tensor(0.6623, dtype=torch.float64) tensor(0.0776, dtype=torch.float64)
tensor(0.1716) tensor(0.0241)
tensor(0.1749) tensor(0.0220)
-1 500
tensor(0.6784, dtype=torch.float64) tensor(0.0846, dtype=torch.float64)
tensor(0.6575, dtype=torch.float64) tensor(0.0670, dtype=torch.float64)
tenso

In [55]:
generate = pipeline('text-generation', model=model, tokenizer=tokenizer)

Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.


In [164]:
batch_size = 16
num_of_batch = 10
n_shot = 1
seed = 42
dataset = datasets.load_from_disk('/ossfs/workspace/nas/gzhch/data/datasets/alpaca')
dataset = dataset.filter(lambda x: x['input'] == '')
data = dataset['train'].shuffle(seed=seed).select(range(1000))
shots = dataset['train'].shuffle(seed=seed).select(range(1000, 1000 + n_shot))
prompt_shots = ''
template = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n{output}'

for i in range(n_shot):
    prompt = template.format(instruction=shots[i]['instruction'], output=shots[i]['output']) + '\n\n'
    prompt_shots += prompt

def process_input(x):
    x['text'] = prompt_shots + template.format(instruction=x['instruction'], output='')
    x['ground_truth'] = x['output']
    return x
dataset = data.map(process_input, load_from_cache_file=False)

generated_text = generate(dataset[:batch_size * num_of_batch]['text'], return_full_text=False, max_length=300)
llm_generated_response = [i[0]['generated_text'].split('\n\nBelow is an instruction')[0] for i in generated_text]

In [166]:
llm_data = dict(text=[])
real_data = dict(text=[])
for i in range(batch_size * num_of_batch):
    llm_data['text'].append(dataset[i]['text'] + llm_generated_response[i])
    real_data['text'].append(dataset[i]['text'] + dataset[i]['output'])

with open('/ossfs/workspace/nas/gzhch/data/cache/generated_text/llm_data.pkl', 'wb') as f:
    pickle.dump(llm_data, f)
with open('/ossfs/workspace/nas/gzhch/data/cache/generated_text/real_data.pkl', 'wb') as f:
    pickle.dump(real_data, f)

In [173]:
acts_llm = llama.get_act(llm_data, 
                    cache_name = 'alpaca_llm', 
                    layers = [10, 20],
                    acts = {},
                    start_batch=0, 
                    end_batch=10)

acts_real = llama.get_act(real_data, 
                    cache_name = 'alpaca_real', 
                    layers = [10, 20],
                    acts = {},
                    start_batch=0, 
                    end_batch=10)

In [None]:
llm_data = dict(text=[])
real_data = dict(text=[])
for i in range(batch_size * num_of_batch):
    llm_data['text'].append(dataset[i]['text'] + llm_generated_response[i])
    real_data['text'].append(dataset[i]['text'] + dataset[i]['output'])

In [231]:
def get_response_mask(full_text, context_text, batch_size=16, num_of_batch=10):
    response_masks = []
    for k in range(num_of_batch):
        full_ids = tokenizer(full_text[k * batch_size: (k + 1) * batch_size], return_tensors='pt', padding='longest')      
        context_ids = tokenizer(context_text[k * batch_size: (k + 1) * batch_size], return_tensors='pt', padding='longest')
        # response_ids = tokenizer(llm_generated_response[k * batch_size: (k + 1) * batch_size], return_tensors='pt', padding='longest')
        response_mask = full_ids['attention_mask'].clone()
        response_mask[:, :context_ids['attention_mask'].shape[1]] -= context_ids['attention_mask']
        mask = full_ids['attention_mask'][:, :-1] == 1
        flattened_response_mask = response_mask[:, :-1].reshape(-1)[mask.flatten()]
        response_masks.append(flattened_response_mask)
    return torch.cat([i for i in response_masks]) == 1

response_mask_llm = get_response_mask(llm_data['text'], dataset['text'])
response_mask_real = get_response_mask(real_data['text'], dataset['text'])


In [305]:
# counter = Counter(stim_data['context'][:, 5].tolist())
for p in range(10):
    print('pearson', p, scipy.stats.pearsonr(token_pearson, loss[:, p]))
for p in range(10):
    print('std', p, scipy.stats.pearsonr(token_std, loss[:, p]))
    # print(p, scipy.stats.pearsonr(token_std, hresp.min(dim=1).values))

pearson 0 PearsonRResult(statistic=0.06551403563608334, pvalue=9.347085065626561e-12)
pearson 1 PearsonRResult(statistic=0.07697244672208754, pvalue=1.1317568157066025e-15)
pearson 2 PearsonRResult(statistic=0.10583890574131714, pvalue=2.712619446929961e-28)
pearson 3 PearsonRResult(statistic=0.12318941915444186, pvalue=8.372109098943815e-38)
pearson 4 PearsonRResult(statistic=0.08106732100976202, pvalue=3.194064800904176e-17)
pearson 5 PearsonRResult(statistic=0.010039762544111585, pvalue=0.29673461289389774)
pearson 6 PearsonRResult(statistic=0.15831647331992302, pvalue=1.392524427804792e-61)
pearson 7 PearsonRResult(statistic=0.060394023181881425, pvalue=3.331675162588603e-10)
pearson 8 PearsonRResult(statistic=0.045558995154784995, pvalue=2.16458154384645e-06)
pearson 9 PearsonRResult(statistic=0.04169222157316585, pvalue=1.4578105040658204e-05)
std 0 PearsonRResult(statistic=-0.04771486832816875, pvalue=6.982671457270499e-07)
std 1 PearsonRResult(statistic=-0.03389924955492597, pv

In [303]:
# counter = Counter(stim_data['context'][:, 5].tolist())
for p in range(10):
    print('pearson', p, scipy.stats.pearsonr(token_pearson, loss[:, p]))
for p in range(10):
    print('std', p, scipy.stats.pearsonr(token_std, loss[:, p]))
    # print(p, scipy.stats.pearsonr(token_std, hresp.min(dim=1).values))

pearson 0 PearsonRResult(statistic=0.0590567853392971, pvalue=2.091822601007426e-06)
pearson 1 PearsonRResult(statistic=0.07858077007950853, pvalue=2.665693406918057e-10)
pearson 2 PearsonRResult(statistic=0.09584027033477563, pvalue=1.2543318046981006e-14)
pearson 3 PearsonRResult(statistic=0.13248945004205506, pvalue=1.2437206361811875e-26)
pearson 4 PearsonRResult(statistic=0.059881156989395896, pvalue=1.5041071795496062e-06)
pearson 5 PearsonRResult(statistic=-0.1579611303390107, pvalue=2.778968534871959e-37)
pearson 6 PearsonRResult(statistic=0.22492036164890833, pvalue=1.029773446338501e-74)
pearson 7 PearsonRResult(statistic=0.09576699493834981, pvalue=1.3138911626696552e-14)
pearson 8 PearsonRResult(statistic=0.06696570288176965, pvalue=7.405045386885828e-08)
pearson 9 PearsonRResult(statistic=0.06789642796891635, pvalue=4.86936331867933e-08)
std 0 PearsonRResult(statistic=-0.013134732745282256, pvalue=0.29174227643658135)
std 1 PearsonRResult(statistic=-0.026369399877877158, p

In [293]:
acts = acts_llm
response_mask = response_mask_llm

hstim = acts['layer_10'][response_mask]
hresp = acts['layer_20'][response_mask]
context = acts['context'][response_mask]
loss = acts['loss'][response_mask]

hstim = torch.tensor(hstim).cuda().float()
hresp = torch.tensor(hresp).cuda().float()

pred = hstim.matmul(bs_weights)
pred = pred.cpu()
hresp = hresp.cpu()

# scipy.stats.pearsonr(pred.flatten(), hresp.flatten())

token_pearson, token_p = [], []
for i in range(pred.shape[0]):
    stat = scipy.stats.pearsonr(pred[i, :].flatten(), hresp[i, :].flatten())
    token_pearson.append(stat.statistic)
    token_p.append(stat.pvalue)

token_pearson = torch.tensor(token_pearson)
token_std = (pred-hresp).std(dim=1)
loss = torch.tensor(loss)

# k = 500

pos = -1
bias = 5
# for pos in range(-5, 5):
for k in [100, 200, 300, 400, 500, 1000]:

    sorted_index = loss[:, pos + bias].topk(len(loss)).indices
    current_pearson = token_pearson[sorted_index]
    current_std = token_std[sorted_index]
    print(pos, k)
    print(current_pearson[-k:].mean(), current_pearson[-k:].std())
    print(current_pearson[:k].mean(), current_pearson[:k].std())
    print(current_std[-k:].mean(), current_std[-k:].std())
    print(current_std[:k].mean(), current_std[:k].std())


-1 100
tensor(0.5910, dtype=torch.float64) tensor(0.0628, dtype=torch.float64)
tensor(0.5251, dtype=torch.float64) tensor(0.1066, dtype=torch.float64)
tensor(0.1857) tensor(0.0257)
tensor(0.2000) tensor(0.0223)
-1 200
tensor(0.5782, dtype=torch.float64) tensor(0.0684, dtype=torch.float64)
tensor(0.5280, dtype=torch.float64) tensor(0.0968, dtype=torch.float64)
tensor(0.1862) tensor(0.0277)
tensor(0.2024) tensor(0.0214)
-1 300
tensor(0.5743, dtype=torch.float64) tensor(0.0704, dtype=torch.float64)
tensor(0.5342, dtype=torch.float64) tensor(0.0945, dtype=torch.float64)
tensor(0.1871) tensor(0.0258)
tensor(0.2027) tensor(0.0210)
-1 400
tensor(0.5689, dtype=torch.float64) tensor(0.0768, dtype=torch.float64)
tensor(0.5380, dtype=torch.float64) tensor(0.0913, dtype=torch.float64)
tensor(0.1877) tensor(0.0260)
tensor(0.2026) tensor(0.0213)
-1 500
tensor(0.5647, dtype=torch.float64) tensor(0.0804, dtype=torch.float64)
tensor(0.5385, dtype=torch.float64) tensor(0.0910, dtype=torch.float64)
tenso

In [304]:
acts = acts_real
response_mask = response_mask_real

hstim = acts['layer_10'][response_mask]
hresp = acts['layer_20'][response_mask]
context = acts['context'][response_mask]
loss = acts['loss'][response_mask]

hstim = torch.tensor(hstim).cuda().float()
hresp = torch.tensor(hresp).cuda().float()

pred = hstim.matmul(bs_weights)
pred = pred.cpu()
hresp = hresp.cpu()

# scipy.stats.pearsonr(pred.flatten(), hresp.flatten())

token_pearson, token_p = [], []
for i in range(pred.shape[0]):
    stat = scipy.stats.pearsonr(pred[i, :].flatten(), hresp[i, :].flatten())
    token_pearson.append(stat.statistic)
    token_p.append(stat.pvalue)

token_pearson = torch.tensor(token_pearson)
token_std = (pred-hresp).std(dim=1)
loss = torch.tensor(loss)

# k = 500

pos = -1
bias = 5
# for pos in range(-5, 5):
for k in [100, 200, 300, 400, 500, 1000]:

    sorted_index = loss[:, pos + bias].topk(len(loss)).indices
    current_pearson = token_pearson[sorted_index]
    current_std = token_std[sorted_index]
    print(pos)
    print(current_pearson[-k:].mean(), current_pearson[-k:].std())
    print(current_pearson[:k].mean(), current_pearson[:k].std())
    print(current_std[-k:].mean(), current_std[-k:].std())
    print(current_std[:k].mean(), current_std[:k].std())


-1
tensor(0.5944, dtype=torch.float64) tensor(0.0613, dtype=torch.float64)
tensor(0.5626, dtype=torch.float64) tensor(0.0649, dtype=torch.float64)
tensor(0.1848) tensor(0.0224)
tensor(0.1970) tensor(0.0212)
-1
tensor(0.5845, dtype=torch.float64) tensor(0.0683, dtype=torch.float64)
tensor(0.5634, dtype=torch.float64) tensor(0.0660, dtype=torch.float64)
tensor(0.1876) tensor(0.0253)
tensor(0.1978) tensor(0.0195)
-1
tensor(0.5828, dtype=torch.float64) tensor(0.0695, dtype=torch.float64)
tensor(0.5587, dtype=torch.float64) tensor(0.0688, dtype=torch.float64)
tensor(0.1883) tensor(0.0245)
tensor(0.1989) tensor(0.0190)
-1
tensor(0.5760, dtype=torch.float64) tensor(0.0742, dtype=torch.float64)
tensor(0.5548, dtype=torch.float64) tensor(0.0718, dtype=torch.float64)
tensor(0.1897) tensor(0.0259)
tensor(0.1997) tensor(0.0189)
-1
tensor(0.5724, dtype=torch.float64) tensor(0.0761, dtype=torch.float64)
tensor(0.5526, dtype=torch.float64) tensor(0.0747, dtype=torch.float64)
tensor(0.1904) tensor(0.0

In [264]:
token_pearson.shape

torch.Size([10804])

In [266]:
scipy.stats.pearsonr(token_pearson, loss[:, 4])

PearsonRResult(statistic=0.07745112391114177, pvalue=7.528566521747938e-16)

In [271]:
for p in range(10):
    print(scipy.stats.pearsonr(token_std, loss[:, p]))

PearsonRResult(statistic=-0.045255122939267485, pvalue=2.5288748735423885e-06)
PearsonRResult(statistic=-0.0347857845087261, pvalue=0.0002987399432932972)
PearsonRResult(statistic=-0.04060932964580527, pvalue=2.4184340390061123e-05)
PearsonRResult(statistic=-0.018479285124304967, pvalue=0.054766891419967156)
PearsonRResult(statistic=0.025392642163795588, pvalue=0.00830309726029765)
PearsonRResult(statistic=0.1754278043680453, pvalue=2.0861027495329813e-75)
PearsonRResult(statistic=0.04434820326008646, pvalue=3.99986801400221e-06)
PearsonRResult(statistic=0.1365000865910419, pvalue=4.27908452819686e-46)
PearsonRResult(statistic=0.11610285766890506, pvalue=9.6311219996744e-34)
PearsonRResult(statistic=0.09559757894579464, pvalue=2.3200777451621803e-23)
