In [4]:
import torch
from tqdm import tqdm
from torch.nn import CrossEntropyLoss
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import os

##adapted from https://github.com/timoschick/self-debiasing/blob/main/perplexity.py
def compute_loss(model, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
        outputs = model(input_ids, labels=labels)
        lm_logits = outputs[1]

        # Shift so that tokens < n predict n
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        return loss

def compute_ppl(model, tokenizer, batch_size=4, max_length=1024, stride=512):
    device = model.device
    model.eval()

    test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt")

    batch_len = len(encodings.input_ids[0]) // batch_size

    # Trim the tensor to fit complete sequences only
    input_tensor = encodings.input_ids[0][:batch_len * batch_size]

    # Reshape into batches
    batched_inputs = input_tensor.view(batch_size, -1)

    lls = []
    ppl = None
    for i in tqdm(range(0, batched_inputs.size(1), stride)):
    # iterates through all 287644 tokens in wikitext test in windows of max_length tokens, in steps of size stride, 
    # with the first max_length - stride tokens masked, as they are context for each new stride tokens and have alread been counted in loss
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, batched_inputs.size(1))
        trg_len = end_loc - i  # stride, but may be different from stride on last loop

        input_ids = batched_inputs[:, begin_loc:end_loc].to(device)

        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            loss = compute_loss(model, input_ids, labels=target_ids)

            log_likelihood = loss * trg_len

        lls.append(log_likelihood)
        ppl = torch.exp(torch.stack(lls).sum() / end_loc)
        print(f"Perplexity after {i} tokens: {ppl}")

    print(f"Final perplexity: {ppl}")
    return lls, ppl

def load_model(model_path, device, center_weights=True):
###    model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config, device_map=device_map, token="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")#, torch_dtype=torch.float16
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, token=HF_TOKEN).to(device)#, torch_dtype=torch.float16
    if center_weights:
        for name, param in model.named_parameters():
            if '.'.join(name.split('.')[-2:]) in ['wte.weight','wpe.weight','c_proj.weight','c_proj.bias']:
                param.data -= param.data.mean()
                print(name, param.data.mean(), param.size())
    tokenizer = AutoTokenizer.from_pretrained(base_model_path, token=HF_TOKEN)
    model.tokenizer = tokenizer
    #model.tokenizer.padding_side = "left" #for batching; right (default in gpt2) for training, left for generation
    model.tokenizer.pad_token_id = model.tokenizer.eos_token_id 
    return model

model=None
import gc
gc.collect()
torch.cuda.empty_cache()

HF_TOKEN="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"
_ = torch.set_grad_enabled(False)
###base_model_path: str = "meta-llama/Llama-2-13b-chat-hf" #even on an A40 I have to load 13b in half precision
base_model_path: str = "meta-llama/Llama-2-13b-chat-hf"#"meta-llama/Meta-Llama-3-8B-Instruct"#
###model_path=base_model_path
model_path="cackerman/llama2_13b_chat_projection_tuneinposplustruth_mdo7l11to17"
device: str = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" #need to upgrade MacOS first
#device: str = "cuda" if torch.cuda.is_available() else "cpu"
center_weights=False

model = load_model(model_path, device, center_weights=center_weights)

lls, ppl = compute_ppl(model, model.tokenizer, batch_size=2, max_length=1024, stride=512)


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 1/334 [00:00<02:25,  2.29it/s]

Perplexity after 0 tokens: 7.584681987762451


  1%|          | 2/334 [00:01<03:36,  1.53it/s]

Perplexity after 512 tokens: 5.914628028869629


  1%|          | 3/334 [00:02<03:58,  1.39it/s]

Perplexity after 1024 tokens: 6.450955390930176


  1%|          | 4/334 [00:02<04:08,  1.33it/s]

Perplexity after 1536 tokens: 6.456805229187012


  1%|▏         | 5/334 [00:03<04:13,  1.30it/s]

Perplexity after 2048 tokens: 6.772259712219238


  2%|▏         | 6/334 [00:04<04:17,  1.28it/s]

Perplexity after 2560 tokens: 6.695085048675537


  2%|▏         | 7/334 [00:05<04:18,  1.26it/s]

Perplexity after 3072 tokens: 6.636634826660156


  2%|▏         | 8/334 [00:06<04:19,  1.26it/s]

Perplexity after 3584 tokens: 6.563382625579834


  3%|▎         | 9/334 [00:06<04:19,  1.25it/s]

Perplexity after 4096 tokens: 6.627808094024658


  3%|▎         | 10/334 [00:07<04:19,  1.25it/s]

Perplexity after 4608 tokens: 6.666723251342773


  3%|▎         | 11/334 [00:08<04:19,  1.25it/s]

Perplexity after 5120 tokens: 6.71409273147583


  4%|▎         | 12/334 [00:09<04:18,  1.24it/s]

Perplexity after 5632 tokens: 6.690678119659424


  4%|▍         | 13/334 [00:10<04:18,  1.24it/s]

Perplexity after 6144 tokens: 6.774406909942627


  4%|▍         | 14/334 [00:10<04:18,  1.24it/s]

Perplexity after 6656 tokens: 6.816168308258057


  4%|▍         | 15/334 [00:11<04:18,  1.24it/s]

Perplexity after 7168 tokens: 6.914283275604248


  5%|▍         | 16/334 [00:12<04:17,  1.23it/s]

Perplexity after 7680 tokens: 6.851447582244873


  5%|▌         | 17/334 [00:13<04:17,  1.23it/s]

Perplexity after 8192 tokens: 6.735580921173096


  5%|▌         | 18/334 [00:14<04:16,  1.23it/s]

Perplexity after 8704 tokens: 6.695837020874023


  6%|▌         | 19/334 [00:14<04:15,  1.23it/s]

Perplexity after 9216 tokens: 6.561611652374268


  6%|▌         | 20/334 [00:15<04:15,  1.23it/s]

Perplexity after 9728 tokens: 6.519295692443848


  6%|▋         | 21/334 [00:16<04:14,  1.23it/s]

Perplexity after 10240 tokens: 6.555604457855225


  7%|▋         | 22/334 [00:17<04:14,  1.23it/s]

Perplexity after 10752 tokens: 6.473830699920654


  7%|▋         | 23/334 [00:18<04:12,  1.23it/s]

Perplexity after 11264 tokens: 6.408520221710205


  7%|▋         | 24/334 [00:19<04:12,  1.23it/s]

Perplexity after 11776 tokens: 6.379539966583252


  7%|▋         | 25/334 [00:19<04:11,  1.23it/s]

Perplexity after 12288 tokens: 6.3542985916137695


  8%|▊         | 26/334 [00:20<04:11,  1.22it/s]

Perplexity after 12800 tokens: 6.293386936187744


  8%|▊         | 27/334 [00:21<04:10,  1.23it/s]

Perplexity after 13312 tokens: 6.192785739898682


  8%|▊         | 28/334 [00:22<04:10,  1.22it/s]

Perplexity after 13824 tokens: 6.145533084869385


  9%|▊         | 29/334 [00:23<04:09,  1.22it/s]

Perplexity after 14336 tokens: 6.109745025634766


  9%|▉         | 30/334 [00:23<04:08,  1.22it/s]

Perplexity after 14848 tokens: 6.016920566558838


  9%|▉         | 31/334 [00:24<04:08,  1.22it/s]

Perplexity after 15360 tokens: 5.982571125030518


 10%|▉         | 32/334 [00:25<04:07,  1.22it/s]

Perplexity after 15872 tokens: 5.931074619293213


 10%|▉         | 33/334 [00:26<04:07,  1.22it/s]

Perplexity after 16384 tokens: 5.9368896484375


 10%|█         | 34/334 [00:27<04:05,  1.22it/s]

Perplexity after 16896 tokens: 5.895257472991943


 10%|█         | 35/334 [00:28<04:05,  1.22it/s]

Perplexity after 17408 tokens: 5.918295383453369


 11%|█         | 36/334 [00:28<04:05,  1.22it/s]

Perplexity after 17920 tokens: 5.956379413604736


 11%|█         | 37/334 [00:29<04:04,  1.21it/s]

Perplexity after 18432 tokens: 5.987278461456299


 11%|█▏        | 38/334 [00:30<04:03,  1.21it/s]

Perplexity after 18944 tokens: 5.998847484588623


 12%|█▏        | 39/334 [00:31<04:03,  1.21it/s]

Perplexity after 19456 tokens: 5.986680507659912


 12%|█▏        | 40/334 [00:32<04:02,  1.21it/s]

Perplexity after 19968 tokens: 5.9940314292907715


 12%|█▏        | 41/334 [00:33<04:01,  1.21it/s]

Perplexity after 20480 tokens: 6.005599021911621


 13%|█▎        | 42/334 [00:33<04:01,  1.21it/s]

Perplexity after 20992 tokens: 6.007212162017822


 13%|█▎        | 43/334 [00:34<04:00,  1.21it/s]

Perplexity after 21504 tokens: 5.998926639556885


 13%|█▎        | 44/334 [00:35<03:59,  1.21it/s]

Perplexity after 22016 tokens: 5.964531421661377


 13%|█▎        | 45/334 [00:36<03:58,  1.21it/s]

Perplexity after 22528 tokens: 5.970579147338867


 14%|█▍        | 46/334 [00:37<03:57,  1.21it/s]

Perplexity after 23040 tokens: 5.976271629333496


 14%|█▍        | 47/334 [00:37<03:56,  1.21it/s]

Perplexity after 23552 tokens: 5.952533721923828


 14%|█▍        | 48/334 [00:38<03:56,  1.21it/s]

Perplexity after 24064 tokens: 5.925643444061279


 15%|█▍        | 49/334 [00:39<03:55,  1.21it/s]

Perplexity after 24576 tokens: 5.926063060760498


 15%|█▍        | 50/334 [00:40<03:54,  1.21it/s]

Perplexity after 25088 tokens: 5.902603626251221


 15%|█▌        | 51/334 [00:41<03:53,  1.21it/s]

Perplexity after 25600 tokens: 5.896758079528809


 16%|█▌        | 52/334 [00:42<03:52,  1.21it/s]

Perplexity after 26112 tokens: 5.9051384925842285


 16%|█▌        | 53/334 [00:42<03:52,  1.21it/s]

Perplexity after 26624 tokens: 5.88224983215332


 16%|█▌        | 54/334 [00:43<03:51,  1.21it/s]

Perplexity after 27136 tokens: 5.8968305587768555


 16%|█▋        | 55/334 [00:44<03:50,  1.21it/s]

Perplexity after 27648 tokens: 5.87007999420166


 17%|█▋        | 56/334 [00:45<03:49,  1.21it/s]

Perplexity after 28160 tokens: 5.846283435821533


 17%|█▋        | 57/334 [00:46<03:48,  1.21it/s]

Perplexity after 28672 tokens: 5.8584113121032715


 17%|█▋        | 58/334 [00:47<03:47,  1.21it/s]

Perplexity after 29184 tokens: 5.869176864624023


 18%|█▊        | 59/334 [00:47<03:46,  1.21it/s]

Perplexity after 29696 tokens: 5.873612403869629


 18%|█▊        | 60/334 [00:48<03:46,  1.21it/s]

Perplexity after 30208 tokens: 5.889565944671631


 18%|█▊        | 61/334 [00:49<03:45,  1.21it/s]

Perplexity after 30720 tokens: 5.8874664306640625


 19%|█▊        | 62/334 [00:50<03:44,  1.21it/s]

Perplexity after 31232 tokens: 5.870312690734863


 19%|█▉        | 63/334 [00:51<03:43,  1.21it/s]

Perplexity after 31744 tokens: 5.879727363586426


 19%|█▉        | 64/334 [00:52<03:43,  1.21it/s]

Perplexity after 32256 tokens: 5.867706298828125


 19%|█▉        | 65/334 [00:52<03:42,  1.21it/s]

Perplexity after 32768 tokens: 5.861433506011963


 20%|█▉        | 66/334 [00:53<03:41,  1.21it/s]

Perplexity after 33280 tokens: 5.8637776374816895


 20%|██        | 67/334 [00:54<03:40,  1.21it/s]

Perplexity after 33792 tokens: 5.867690563201904


 20%|██        | 68/334 [00:55<03:39,  1.21it/s]

Perplexity after 34304 tokens: 5.868614673614502


 21%|██        | 69/334 [00:56<03:38,  1.21it/s]

Perplexity after 34816 tokens: 5.8603901863098145


 21%|██        | 70/334 [00:56<03:37,  1.21it/s]

Perplexity after 35328 tokens: 5.881621360778809


 21%|██▏       | 71/334 [00:57<03:37,  1.21it/s]

Perplexity after 35840 tokens: 5.903784275054932


 22%|██▏       | 72/334 [00:58<03:36,  1.21it/s]

Perplexity after 36352 tokens: 5.9399871826171875


 22%|██▏       | 73/334 [00:59<03:35,  1.21it/s]

Perplexity after 36864 tokens: 5.961665630340576


 22%|██▏       | 74/334 [01:00<03:34,  1.21it/s]

Perplexity after 37376 tokens: 5.970928192138672


 22%|██▏       | 75/334 [01:01<03:33,  1.21it/s]

Perplexity after 37888 tokens: 5.991468906402588


 23%|██▎       | 76/334 [01:01<03:33,  1.21it/s]

Perplexity after 38400 tokens: 5.994836807250977


 23%|██▎       | 77/334 [01:02<03:32,  1.21it/s]

Perplexity after 38912 tokens: 6.00883150100708


 23%|██▎       | 78/334 [01:03<03:31,  1.21it/s]

Perplexity after 39424 tokens: 6.010448455810547


 24%|██▎       | 79/334 [01:04<03:30,  1.21it/s]

Perplexity after 39936 tokens: 6.0125226974487305


 24%|██▍       | 80/334 [01:05<03:29,  1.21it/s]

Perplexity after 40448 tokens: 6.022980213165283


 24%|██▍       | 81/334 [01:06<03:28,  1.21it/s]

Perplexity after 40960 tokens: 6.03278112411499


 25%|██▍       | 82/334 [01:06<03:28,  1.21it/s]

Perplexity after 41472 tokens: 6.0501604080200195


 25%|██▍       | 83/334 [01:07<03:27,  1.21it/s]

Perplexity after 41984 tokens: 6.075323581695557


 25%|██▌       | 84/334 [01:08<03:26,  1.21it/s]

Perplexity after 42496 tokens: 6.090221881866455


 25%|██▌       | 85/334 [01:09<03:25,  1.21it/s]

Perplexity after 43008 tokens: 6.118460178375244


 26%|██▌       | 86/334 [01:10<03:24,  1.21it/s]

Perplexity after 43520 tokens: 6.132575511932373


 26%|██▌       | 87/334 [01:11<03:23,  1.21it/s]

Perplexity after 44032 tokens: 6.152665615081787


 26%|██▋       | 88/334 [01:11<03:23,  1.21it/s]

Perplexity after 44544 tokens: 6.1561279296875


 27%|██▋       | 89/334 [01:12<03:22,  1.21it/s]

Perplexity after 45056 tokens: 6.139721393585205


 27%|██▋       | 90/334 [01:13<03:21,  1.21it/s]

Perplexity after 45568 tokens: 6.162632942199707


 27%|██▋       | 91/334 [01:14<03:20,  1.21it/s]

Perplexity after 46080 tokens: 6.161346912384033


 28%|██▊       | 92/334 [01:15<03:19,  1.21it/s]

Perplexity after 46592 tokens: 6.177140235900879


 28%|██▊       | 93/334 [01:15<03:18,  1.21it/s]

Perplexity after 47104 tokens: 6.195390701293945


 28%|██▊       | 94/334 [01:16<03:18,  1.21it/s]

Perplexity after 47616 tokens: 6.218314170837402


 28%|██▊       | 95/334 [01:17<03:17,  1.21it/s]

Perplexity after 48128 tokens: 6.22950553894043


 29%|██▊       | 96/334 [01:18<03:16,  1.21it/s]

Perplexity after 48640 tokens: 6.24309778213501


 29%|██▉       | 97/334 [01:19<03:15,  1.21it/s]

Perplexity after 49152 tokens: 6.273120880126953


 29%|██▉       | 98/334 [01:20<03:14,  1.21it/s]

Perplexity after 49664 tokens: 6.2884840965271


 30%|██▉       | 99/334 [01:20<03:13,  1.21it/s]

Perplexity after 50176 tokens: 6.307455062866211


 30%|██▉       | 100/334 [01:21<03:12,  1.21it/s]

Perplexity after 50688 tokens: 6.3242292404174805


 30%|███       | 101/334 [01:22<03:12,  1.21it/s]

Perplexity after 51200 tokens: 6.3312554359436035


 31%|███       | 102/334 [01:23<03:11,  1.21it/s]

Perplexity after 51712 tokens: 6.352586269378662


 31%|███       | 103/334 [01:24<03:10,  1.21it/s]

Perplexity after 52224 tokens: 6.37301778793335


 31%|███       | 104/334 [01:25<03:09,  1.21it/s]

Perplexity after 52736 tokens: 6.384072780609131


 31%|███▏      | 105/334 [01:25<03:08,  1.21it/s]

Perplexity after 53248 tokens: 6.412841796875


 32%|███▏      | 106/334 [01:26<03:08,  1.21it/s]

Perplexity after 53760 tokens: 6.420888900756836


 32%|███▏      | 107/334 [01:27<03:07,  1.21it/s]

Perplexity after 54272 tokens: 6.418637752532959


 32%|███▏      | 108/334 [01:28<03:06,  1.21it/s]

Perplexity after 54784 tokens: 6.411534786224365


 33%|███▎      | 109/334 [01:29<03:05,  1.21it/s]

Perplexity after 55296 tokens: 6.414107799530029


 33%|███▎      | 110/334 [01:29<03:04,  1.21it/s]

Perplexity after 55808 tokens: 6.431540012359619


 33%|███▎      | 111/334 [01:30<03:04,  1.21it/s]

Perplexity after 56320 tokens: 6.445395469665527


 34%|███▎      | 112/334 [01:31<03:03,  1.21it/s]

Perplexity after 56832 tokens: 6.4562578201293945


 34%|███▍      | 113/334 [01:32<03:02,  1.21it/s]

Perplexity after 57344 tokens: 6.471207618713379


 34%|███▍      | 114/334 [01:33<03:01,  1.21it/s]

Perplexity after 57856 tokens: 6.468023300170898


 34%|███▍      | 115/334 [01:34<03:00,  1.21it/s]

Perplexity after 58368 tokens: 6.459258556365967


 35%|███▍      | 116/334 [01:34<02:59,  1.21it/s]

Perplexity after 58880 tokens: 6.457670211791992


 35%|███▌      | 117/334 [01:35<02:59,  1.21it/s]

Perplexity after 59392 tokens: 6.457529067993164


 35%|███▌      | 118/334 [01:36<02:58,  1.21it/s]

Perplexity after 59904 tokens: 6.451571941375732


 36%|███▌      | 119/334 [01:37<02:57,  1.21it/s]

Perplexity after 60416 tokens: 6.446998596191406


 36%|███▌      | 120/334 [01:38<02:56,  1.21it/s]

Perplexity after 60928 tokens: 6.449621677398682


 36%|███▌      | 121/334 [01:39<02:55,  1.21it/s]

Perplexity after 61440 tokens: 6.462358474731445


 37%|███▋      | 122/334 [01:39<02:54,  1.21it/s]

Perplexity after 61952 tokens: 6.481746673583984


 37%|███▋      | 123/334 [01:40<02:54,  1.21it/s]

Perplexity after 62464 tokens: 6.491132736206055


 37%|███▋      | 124/334 [01:41<02:53,  1.21it/s]

Perplexity after 62976 tokens: 6.493466377258301


 37%|███▋      | 125/334 [01:42<02:52,  1.21it/s]

Perplexity after 63488 tokens: 6.499812602996826


 38%|███▊      | 126/334 [01:43<02:51,  1.21it/s]

Perplexity after 64000 tokens: 6.488077163696289


 38%|███▊      | 127/334 [01:44<02:50,  1.21it/s]

Perplexity after 64512 tokens: 6.476423740386963


 38%|███▊      | 128/334 [01:44<02:49,  1.21it/s]

Perplexity after 65024 tokens: 6.459280014038086


 39%|███▊      | 129/334 [01:45<02:49,  1.21it/s]

Perplexity after 65536 tokens: 6.476376533508301


 39%|███▉      | 130/334 [01:46<02:48,  1.21it/s]

Perplexity after 66048 tokens: 6.488842010498047


 39%|███▉      | 131/334 [01:47<02:48,  1.21it/s]

Perplexity after 66560 tokens: 6.4858078956604


 40%|███▉      | 132/334 [01:48<02:47,  1.21it/s]

Perplexity after 67072 tokens: 6.471348762512207


 40%|███▉      | 133/334 [01:48<02:46,  1.21it/s]

Perplexity after 67584 tokens: 6.464822292327881


 40%|████      | 134/334 [01:49<02:46,  1.20it/s]

Perplexity after 68096 tokens: 6.471960067749023


 40%|████      | 135/334 [01:50<02:45,  1.20it/s]

Perplexity after 68608 tokens: 6.463851451873779


 41%|████      | 136/334 [01:51<02:44,  1.20it/s]

Perplexity after 69120 tokens: 6.466567039489746


 41%|████      | 137/334 [01:52<02:43,  1.20it/s]

Perplexity after 69632 tokens: 6.456718921661377


 41%|████▏     | 138/334 [01:53<02:43,  1.20it/s]

Perplexity after 70144 tokens: 6.453527450561523


 42%|████▏     | 139/334 [01:53<02:41,  1.20it/s]

Perplexity after 70656 tokens: 6.437588214874268


 42%|████▏     | 140/334 [01:54<02:41,  1.20it/s]

Perplexity after 71168 tokens: 6.414321422576904


 42%|████▏     | 141/334 [01:55<02:40,  1.20it/s]

Perplexity after 71680 tokens: 6.410571575164795


 43%|████▎     | 142/334 [01:56<02:39,  1.20it/s]

Perplexity after 72192 tokens: 6.410834789276123


 43%|████▎     | 143/334 [01:57<02:39,  1.20it/s]

Perplexity after 72704 tokens: 6.407268047332764


 43%|████▎     | 144/334 [01:58<02:38,  1.20it/s]

Perplexity after 73216 tokens: 6.389346122741699


 43%|████▎     | 145/334 [01:58<02:37,  1.20it/s]

Perplexity after 73728 tokens: 6.382951259613037


 44%|████▎     | 146/334 [01:59<02:36,  1.20it/s]

Perplexity after 74240 tokens: 6.371661186218262


 44%|████▍     | 147/334 [02:00<02:35,  1.20it/s]

Perplexity after 74752 tokens: 6.353293418884277


 44%|████▍     | 148/334 [02:01<02:34,  1.20it/s]

Perplexity after 75264 tokens: 6.342925071716309


 45%|████▍     | 149/334 [02:02<02:34,  1.20it/s]

Perplexity after 75776 tokens: 6.330114841461182


 45%|████▍     | 150/334 [02:03<02:33,  1.20it/s]

Perplexity after 76288 tokens: 6.326435089111328


 45%|████▌     | 151/334 [02:03<02:32,  1.20it/s]

Perplexity after 76800 tokens: 6.317356109619141


 46%|████▌     | 152/334 [02:04<02:31,  1.20it/s]

Perplexity after 77312 tokens: 6.3130269050598145


 46%|████▌     | 153/334 [02:05<02:30,  1.20it/s]

Perplexity after 77824 tokens: 6.311659812927246


 46%|████▌     | 154/334 [02:06<02:30,  1.20it/s]

Perplexity after 78336 tokens: 6.315367221832275


 46%|████▋     | 155/334 [02:07<02:29,  1.20it/s]

Perplexity after 78848 tokens: 6.315023422241211


 47%|████▋     | 156/334 [02:08<02:28,  1.20it/s]

Perplexity after 79360 tokens: 6.310116767883301


 47%|████▋     | 157/334 [02:08<02:27,  1.20it/s]

Perplexity after 79872 tokens: 6.30185079574585


 47%|████▋     | 158/334 [02:09<02:26,  1.20it/s]

Perplexity after 80384 tokens: 6.296494007110596


 48%|████▊     | 159/334 [02:10<02:25,  1.20it/s]

Perplexity after 80896 tokens: 6.300902366638184


 48%|████▊     | 160/334 [02:11<02:25,  1.20it/s]

Perplexity after 81408 tokens: 6.2963547706604


 48%|████▊     | 161/334 [02:12<02:24,  1.20it/s]

Perplexity after 81920 tokens: 6.290610313415527


 49%|████▊     | 162/334 [02:13<02:23,  1.20it/s]

Perplexity after 82432 tokens: 6.285775661468506


 49%|████▉     | 163/334 [02:13<02:22,  1.20it/s]

Perplexity after 82944 tokens: 6.286783218383789


 49%|████▉     | 164/334 [02:14<02:21,  1.20it/s]

Perplexity after 83456 tokens: 6.276767730712891


 49%|████▉     | 165/334 [02:15<02:21,  1.20it/s]

Perplexity after 83968 tokens: 6.263434886932373


 50%|████▉     | 166/334 [02:16<02:20,  1.20it/s]

Perplexity after 84480 tokens: 6.252365589141846


 50%|█████     | 167/334 [02:17<02:19,  1.20it/s]

Perplexity after 84992 tokens: 6.2323527336120605


 50%|█████     | 168/334 [02:18<02:18,  1.20it/s]

Perplexity after 85504 tokens: 6.211901664733887


 51%|█████     | 169/334 [02:18<02:17,  1.20it/s]

Perplexity after 86016 tokens: 6.190975666046143


 51%|█████     | 170/334 [02:19<02:16,  1.20it/s]

Perplexity after 86528 tokens: 6.171078681945801


 51%|█████     | 171/334 [02:20<02:16,  1.20it/s]

Perplexity after 87040 tokens: 6.153415203094482


 51%|█████▏    | 172/334 [02:21<02:15,  1.20it/s]

Perplexity after 87552 tokens: 6.134400367736816


 52%|█████▏    | 173/334 [02:22<02:14,  1.20it/s]

Perplexity after 88064 tokens: 6.1234893798828125


 52%|█████▏    | 174/334 [02:23<02:13,  1.20it/s]

Perplexity after 88576 tokens: 6.105424404144287


 52%|█████▏    | 175/334 [02:24<02:12,  1.20it/s]

Perplexity after 89088 tokens: 6.090058326721191


 53%|█████▎    | 176/334 [02:24<02:12,  1.20it/s]

Perplexity after 89600 tokens: 6.0798258781433105


 53%|█████▎    | 177/334 [02:25<02:11,  1.19it/s]

Perplexity after 90112 tokens: 6.063712120056152


 53%|█████▎    | 178/334 [02:26<02:10,  1.20it/s]

Perplexity after 90624 tokens: 6.050877571105957


 54%|█████▎    | 179/334 [02:27<02:09,  1.20it/s]

Perplexity after 91136 tokens: 6.02996301651001


 54%|█████▍    | 180/334 [02:28<02:09,  1.19it/s]

Perplexity after 91648 tokens: 6.014980792999268


 54%|█████▍    | 181/334 [02:29<02:08,  1.20it/s]

Perplexity after 92160 tokens: 6.004253387451172


 54%|█████▍    | 182/334 [02:29<02:07,  1.19it/s]

Perplexity after 92672 tokens: 5.9909539222717285


 55%|█████▍    | 183/334 [02:30<02:06,  1.20it/s]

Perplexity after 93184 tokens: 5.984408855438232


 55%|█████▌    | 184/334 [02:31<02:05,  1.20it/s]

Perplexity after 93696 tokens: 5.972624778747559


 55%|█████▌    | 185/334 [02:32<02:04,  1.19it/s]

Perplexity after 94208 tokens: 5.962885856628418


 56%|█████▌    | 186/334 [02:33<02:03,  1.20it/s]

Perplexity after 94720 tokens: 5.961781978607178


 56%|█████▌    | 187/334 [02:34<02:03,  1.19it/s]

Perplexity after 95232 tokens: 5.964065074920654


 56%|█████▋    | 188/334 [02:34<02:02,  1.20it/s]

Perplexity after 95744 tokens: 5.96090030670166


 57%|█████▋    | 189/334 [02:35<02:01,  1.19it/s]

Perplexity after 96256 tokens: 5.959660053253174


 57%|█████▋    | 190/334 [02:36<02:00,  1.19it/s]

Perplexity after 96768 tokens: 5.9645819664001465


 57%|█████▋    | 191/334 [02:37<01:59,  1.19it/s]

Perplexity after 97280 tokens: 5.961879253387451


 57%|█████▋    | 192/334 [02:38<01:58,  1.19it/s]

Perplexity after 97792 tokens: 5.966291427612305


 58%|█████▊    | 193/334 [02:39<01:57,  1.19it/s]

Perplexity after 98304 tokens: 5.971024990081787


 58%|█████▊    | 194/334 [02:39<01:57,  1.19it/s]

Perplexity after 98816 tokens: 5.983717441558838


 58%|█████▊    | 195/334 [02:40<01:56,  1.20it/s]

Perplexity after 99328 tokens: 5.999858856201172


 59%|█████▊    | 196/334 [02:41<01:55,  1.19it/s]

Perplexity after 99840 tokens: 6.00779914855957


 59%|█████▉    | 197/334 [02:42<01:54,  1.19it/s]

Perplexity after 100352 tokens: 6.01619815826416


 59%|█████▉    | 198/334 [02:43<01:53,  1.19it/s]

Perplexity after 100864 tokens: 6.028628349304199


 60%|█████▉    | 199/334 [02:44<01:53,  1.19it/s]

Perplexity after 101376 tokens: 6.039768695831299


 60%|█████▉    | 200/334 [02:44<01:52,  1.19it/s]

Perplexity after 101888 tokens: 6.0532989501953125


 60%|██████    | 201/334 [02:45<01:51,  1.19it/s]

Perplexity after 102400 tokens: 6.067712306976318


 60%|██████    | 202/334 [02:46<01:50,  1.19it/s]

Perplexity after 102912 tokens: 6.0741753578186035


 61%|██████    | 203/334 [02:47<01:49,  1.19it/s]

Perplexity after 103424 tokens: 6.078855991363525


 61%|██████    | 204/334 [02:48<01:48,  1.19it/s]

Perplexity after 103936 tokens: 6.07487154006958


 61%|██████▏   | 205/334 [02:49<01:47,  1.19it/s]

Perplexity after 104448 tokens: 6.075445175170898


 62%|██████▏   | 206/334 [02:49<01:47,  1.19it/s]

Perplexity after 104960 tokens: 6.078548908233643


 62%|██████▏   | 207/334 [02:50<01:46,  1.19it/s]

Perplexity after 105472 tokens: 6.0800862312316895


 62%|██████▏   | 208/334 [02:51<01:45,  1.19it/s]

Perplexity after 105984 tokens: 6.085583686828613


 63%|██████▎   | 209/334 [02:52<01:44,  1.19it/s]

Perplexity after 106496 tokens: 6.086915493011475


 63%|██████▎   | 210/334 [02:53<01:43,  1.19it/s]

Perplexity after 107008 tokens: 6.0860748291015625


 63%|██████▎   | 211/334 [02:54<01:43,  1.19it/s]

Perplexity after 107520 tokens: 6.0843586921691895


 63%|██████▎   | 212/334 [02:55<01:42,  1.19it/s]

Perplexity after 108032 tokens: 6.077286720275879


 64%|██████▍   | 213/334 [02:55<01:41,  1.19it/s]

Perplexity after 108544 tokens: 6.075170040130615


 64%|██████▍   | 214/334 [02:56<01:40,  1.19it/s]

Perplexity after 109056 tokens: 6.085509777069092


 64%|██████▍   | 215/334 [02:57<01:39,  1.19it/s]

Perplexity after 109568 tokens: 6.091485023498535


 65%|██████▍   | 216/334 [02:58<01:39,  1.19it/s]

Perplexity after 110080 tokens: 6.096659183502197


 65%|██████▍   | 217/334 [02:59<01:38,  1.19it/s]

Perplexity after 110592 tokens: 6.1012725830078125


 65%|██████▌   | 218/334 [03:00<01:37,  1.19it/s]

Perplexity after 111104 tokens: 6.10568904876709


 66%|██████▌   | 219/334 [03:00<01:36,  1.19it/s]

Perplexity after 111616 tokens: 6.105221748352051


 66%|██████▌   | 220/334 [03:01<01:35,  1.19it/s]

Perplexity after 112128 tokens: 6.109435558319092


 66%|██████▌   | 221/334 [03:02<01:34,  1.19it/s]

Perplexity after 112640 tokens: 6.110904216766357


 66%|██████▋   | 222/334 [03:03<01:34,  1.19it/s]

Perplexity after 113152 tokens: 6.115980625152588


 67%|██████▋   | 223/334 [03:04<01:33,  1.19it/s]

Perplexity after 113664 tokens: 6.114425182342529


 67%|██████▋   | 224/334 [03:05<01:32,  1.19it/s]

Perplexity after 114176 tokens: 6.111100673675537


 67%|██████▋   | 225/334 [03:05<01:31,  1.19it/s]

Perplexity after 114688 tokens: 6.104623794555664


 68%|██████▊   | 226/334 [03:06<01:30,  1.19it/s]

Perplexity after 115200 tokens: 6.101124286651611


 68%|██████▊   | 227/334 [03:07<01:29,  1.19it/s]

Perplexity after 115712 tokens: 6.094691276550293


 68%|██████▊   | 228/334 [03:08<01:28,  1.19it/s]

Perplexity after 116224 tokens: 6.090015411376953


 69%|██████▊   | 229/334 [03:09<01:28,  1.19it/s]

Perplexity after 116736 tokens: 6.081085205078125


 69%|██████▉   | 230/334 [03:10<01:27,  1.19it/s]

Perplexity after 117248 tokens: 6.086634635925293


 69%|██████▉   | 231/334 [03:10<01:26,  1.19it/s]

Perplexity after 117760 tokens: 6.086188793182373


 69%|██████▉   | 232/334 [03:11<01:25,  1.19it/s]

Perplexity after 118272 tokens: 6.088830471038818


 70%|██████▉   | 233/334 [03:12<01:24,  1.19it/s]

Perplexity after 118784 tokens: 6.094837665557861


 70%|███████   | 234/334 [03:13<01:24,  1.19it/s]

Perplexity after 119296 tokens: 6.087990760803223


 70%|███████   | 235/334 [03:14<01:23,  1.19it/s]

Perplexity after 119808 tokens: 6.092038631439209


 71%|███████   | 236/334 [03:15<01:22,  1.19it/s]

Perplexity after 120320 tokens: 6.084054470062256


 71%|███████   | 237/334 [03:15<01:21,  1.19it/s]

Perplexity after 120832 tokens: 6.079391956329346


 71%|███████▏  | 238/334 [03:16<01:20,  1.19it/s]

Perplexity after 121344 tokens: 6.075491905212402


 72%|███████▏  | 239/334 [03:17<01:19,  1.19it/s]

Perplexity after 121856 tokens: 6.065227508544922


 72%|███████▏  | 240/334 [03:18<01:18,  1.19it/s]

Perplexity after 122368 tokens: 6.070512294769287


 72%|███████▏  | 241/334 [03:19<01:18,  1.19it/s]

Perplexity after 122880 tokens: 6.071788787841797


 72%|███████▏  | 242/334 [03:20<01:17,  1.19it/s]

Perplexity after 123392 tokens: 6.072791576385498


 73%|███████▎  | 243/334 [03:21<01:16,  1.19it/s]

Perplexity after 123904 tokens: 6.068310260772705


 73%|███████▎  | 244/334 [03:21<01:15,  1.19it/s]

Perplexity after 124416 tokens: 6.064146518707275


 73%|███████▎  | 245/334 [03:22<01:14,  1.19it/s]

Perplexity after 124928 tokens: 6.0625104904174805


 74%|███████▎  | 246/334 [03:23<01:13,  1.19it/s]

Perplexity after 125440 tokens: 6.061614036560059


 74%|███████▍  | 247/334 [03:24<01:13,  1.19it/s]

Perplexity after 125952 tokens: 6.057342052459717


 74%|███████▍  | 248/334 [03:25<01:12,  1.19it/s]

Perplexity after 126464 tokens: 6.0562968254089355


 75%|███████▍  | 249/334 [03:26<01:11,  1.19it/s]

Perplexity after 126976 tokens: 6.052679061889648


 75%|███████▍  | 250/334 [03:26<01:10,  1.19it/s]

Perplexity after 127488 tokens: 6.044329643249512


 75%|███████▌  | 251/334 [03:27<01:09,  1.19it/s]

Perplexity after 128000 tokens: 6.038756847381592


 75%|███████▌  | 252/334 [03:28<01:08,  1.19it/s]

Perplexity after 128512 tokens: 6.031692981719971


 76%|███████▌  | 253/334 [03:29<01:08,  1.19it/s]

Perplexity after 129024 tokens: 6.026339530944824


 76%|███████▌  | 254/334 [03:30<01:07,  1.19it/s]

Perplexity after 129536 tokens: 6.0221266746521


 76%|███████▋  | 255/334 [03:31<01:06,  1.19it/s]

Perplexity after 130048 tokens: 6.017917633056641


 77%|███████▋  | 256/334 [03:31<01:05,  1.19it/s]

Perplexity after 130560 tokens: 6.0097856521606445


 77%|███████▋  | 257/334 [03:32<01:04,  1.19it/s]

Perplexity after 131072 tokens: 6.001719951629639


 77%|███████▋  | 258/334 [03:33<01:03,  1.19it/s]

Perplexity after 131584 tokens: 5.99682092666626


 78%|███████▊  | 259/334 [03:34<01:03,  1.19it/s]

Perplexity after 132096 tokens: 5.990016460418701


 78%|███████▊  | 260/334 [03:35<01:02,  1.19it/s]

Perplexity after 132608 tokens: 5.981733322143555


 78%|███████▊  | 261/334 [03:36<01:01,  1.19it/s]

Perplexity after 133120 tokens: 5.974878311157227


 78%|███████▊  | 262/334 [03:36<01:00,  1.19it/s]

Perplexity after 133632 tokens: 5.96554708480835


 79%|███████▊  | 263/334 [03:37<00:59,  1.19it/s]

Perplexity after 134144 tokens: 5.95758056640625


 79%|███████▉  | 264/334 [03:38<00:58,  1.19it/s]

Perplexity after 134656 tokens: 5.9526872634887695


 79%|███████▉  | 265/334 [03:39<00:57,  1.19it/s]

Perplexity after 135168 tokens: 5.950407028198242


 80%|███████▉  | 266/334 [03:40<00:57,  1.19it/s]

Perplexity after 135680 tokens: 5.955094337463379


 80%|███████▉  | 267/334 [03:41<00:56,  1.19it/s]

Perplexity after 136192 tokens: 5.951995849609375


 80%|████████  | 268/334 [03:42<00:55,  1.19it/s]

Perplexity after 136704 tokens: 5.942865371704102


 81%|████████  | 269/334 [03:42<00:54,  1.19it/s]

Perplexity after 137216 tokens: 5.939833164215088


 81%|████████  | 270/334 [03:43<00:53,  1.19it/s]

Perplexity after 137728 tokens: 5.936948776245117


 81%|████████  | 271/334 [03:44<00:52,  1.19it/s]

Perplexity after 138240 tokens: 5.933740139007568


 81%|████████▏ | 272/334 [03:45<00:52,  1.19it/s]

Perplexity after 138752 tokens: 5.931887149810791


 82%|████████▏ | 273/334 [03:46<00:51,  1.19it/s]

Perplexity after 139264 tokens: 5.937990665435791


 82%|████████▏ | 274/334 [03:47<00:50,  1.19it/s]

Perplexity after 139776 tokens: 5.936634540557861


 82%|████████▏ | 275/334 [03:47<00:49,  1.19it/s]

Perplexity after 140288 tokens: 5.9376325607299805


 83%|████████▎ | 276/334 [03:48<00:48,  1.19it/s]

Perplexity after 140800 tokens: 5.937809944152832


 83%|████████▎ | 277/334 [03:49<00:47,  1.19it/s]

Perplexity after 141312 tokens: 5.940185070037842


 83%|████████▎ | 278/334 [03:50<00:47,  1.19it/s]

Perplexity after 141824 tokens: 5.941305637359619


 84%|████████▎ | 279/334 [03:51<00:46,  1.19it/s]

Perplexity after 142336 tokens: 5.942500114440918


 84%|████████▍ | 280/334 [03:52<00:45,  1.19it/s]

Perplexity after 142848 tokens: 5.948917388916016


 84%|████████▍ | 281/334 [03:52<00:44,  1.19it/s]

Perplexity after 143360 tokens: 5.9526848793029785


 84%|████████▍ | 282/334 [03:53<00:43,  1.19it/s]

Perplexity after 143872 tokens: 5.958979606628418


 85%|████████▍ | 283/334 [03:54<00:42,  1.19it/s]

Perplexity after 144384 tokens: 5.9657721519470215


 85%|████████▌ | 284/334 [03:55<00:42,  1.19it/s]

Perplexity after 144896 tokens: 5.970168113708496


 85%|████████▌ | 285/334 [03:56<00:41,  1.19it/s]

Perplexity after 145408 tokens: 5.966949462890625


 86%|████████▌ | 286/334 [03:57<00:40,  1.19it/s]

Perplexity after 145920 tokens: 5.9701128005981445


 86%|████████▌ | 287/334 [03:58<00:39,  1.19it/s]

Perplexity after 146432 tokens: 5.977071762084961


 86%|████████▌ | 288/334 [03:58<00:38,  1.19it/s]

Perplexity after 146944 tokens: 5.982141017913818


 87%|████████▋ | 289/334 [03:59<00:37,  1.19it/s]

Perplexity after 147456 tokens: 5.989425182342529


 87%|████████▋ | 290/334 [04:00<00:37,  1.19it/s]

Perplexity after 147968 tokens: 5.99003267288208


 87%|████████▋ | 291/334 [04:01<00:36,  1.19it/s]

Perplexity after 148480 tokens: 5.992469787597656


 87%|████████▋ | 292/334 [04:02<00:35,  1.19it/s]

Perplexity after 148992 tokens: 5.986820220947266


 88%|████████▊ | 293/334 [04:03<00:34,  1.19it/s]

Perplexity after 149504 tokens: 5.988824367523193


 88%|████████▊ | 294/334 [04:03<00:33,  1.19it/s]

Perplexity after 150016 tokens: 5.982306957244873


 88%|████████▊ | 295/334 [04:04<00:32,  1.19it/s]

Perplexity after 150528 tokens: 5.967863082885742


 89%|████████▊ | 296/334 [04:05<00:31,  1.19it/s]

Perplexity after 151040 tokens: 5.956691741943359


 89%|████████▉ | 297/334 [04:06<00:31,  1.19it/s]

Perplexity after 151552 tokens: 5.940439224243164


 89%|████████▉ | 298/334 [04:07<00:30,  1.19it/s]

Perplexity after 152064 tokens: 5.9248199462890625


 90%|████████▉ | 299/334 [04:08<00:29,  1.19it/s]

Perplexity after 152576 tokens: 5.919925212860107


 90%|████████▉ | 300/334 [04:08<00:28,  1.19it/s]

Perplexity after 153088 tokens: 5.920547962188721


 90%|█████████ | 301/334 [04:09<00:27,  1.19it/s]

Perplexity after 153600 tokens: 5.9227399826049805


 90%|█████████ | 302/334 [04:10<00:26,  1.19it/s]

Perplexity after 154112 tokens: 5.924520969390869


 91%|█████████ | 303/334 [04:11<00:26,  1.19it/s]

Perplexity after 154624 tokens: 5.924530982971191


 91%|█████████ | 304/334 [04:12<00:25,  1.19it/s]

Perplexity after 155136 tokens: 5.924026012420654


 91%|█████████▏| 305/334 [04:13<00:24,  1.19it/s]

Perplexity after 155648 tokens: 5.920482635498047


 92%|█████████▏| 306/334 [04:14<00:23,  1.19it/s]

Perplexity after 156160 tokens: 5.921947956085205


 92%|█████████▏| 307/334 [04:14<00:22,  1.19it/s]

Perplexity after 156672 tokens: 5.922089576721191


 92%|█████████▏| 308/334 [04:15<00:21,  1.19it/s]

Perplexity after 157184 tokens: 5.92439603805542


 93%|█████████▎| 309/334 [04:16<00:21,  1.19it/s]

Perplexity after 157696 tokens: 5.926873683929443


 93%|█████████▎| 310/334 [04:17<00:20,  1.19it/s]

Perplexity after 158208 tokens: 5.927497386932373


 93%|█████████▎| 311/334 [04:18<00:19,  1.19it/s]

Perplexity after 158720 tokens: 5.926105976104736


 93%|█████████▎| 312/334 [04:19<00:18,  1.19it/s]

Perplexity after 159232 tokens: 5.925988674163818


 94%|█████████▎| 313/334 [04:19<00:17,  1.19it/s]

Perplexity after 159744 tokens: 5.928519248962402


 94%|█████████▍| 314/334 [04:20<00:16,  1.19it/s]

Perplexity after 160256 tokens: 5.917463302612305


 94%|█████████▍| 315/334 [04:21<00:15,  1.19it/s]

Perplexity after 160768 tokens: 5.915250778198242


 95%|█████████▍| 316/334 [04:22<00:15,  1.19it/s]

Perplexity after 161280 tokens: 5.911806106567383


 95%|█████████▍| 317/334 [04:23<00:14,  1.19it/s]

Perplexity after 161792 tokens: 5.906445026397705


 95%|█████████▌| 318/334 [04:24<00:13,  1.19it/s]

Perplexity after 162304 tokens: 5.905899524688721


 96%|█████████▌| 319/334 [04:24<00:12,  1.19it/s]

Perplexity after 162816 tokens: 5.905259132385254


 96%|█████████▌| 320/334 [04:25<00:11,  1.19it/s]

Perplexity after 163328 tokens: 5.904654502868652


 96%|█████████▌| 321/334 [04:26<00:10,  1.19it/s]

Perplexity after 163840 tokens: 5.901081562042236


 96%|█████████▋| 322/334 [04:27<00:10,  1.19it/s]

Perplexity after 164352 tokens: 5.8923821449279785


 97%|█████████▋| 323/334 [04:28<00:09,  1.19it/s]

Perplexity after 164864 tokens: 5.883716583251953


 97%|█████████▋| 324/334 [04:29<00:08,  1.19it/s]

Perplexity after 165376 tokens: 5.878808975219727


 97%|█████████▋| 325/334 [04:30<00:07,  1.19it/s]

Perplexity after 165888 tokens: 5.880178928375244


 98%|█████████▊| 326/334 [04:30<00:06,  1.19it/s]

Perplexity after 166400 tokens: 5.884068012237549


 98%|█████████▊| 327/334 [04:31<00:05,  1.19it/s]

Perplexity after 166912 tokens: 5.882317543029785


 98%|█████████▊| 328/334 [04:32<00:05,  1.19it/s]

Perplexity after 167424 tokens: 5.886343002319336


 99%|█████████▊| 329/334 [04:33<00:04,  1.19it/s]

Perplexity after 167936 tokens: 5.888584613800049


 99%|█████████▉| 330/334 [04:34<00:03,  1.19it/s]

Perplexity after 168448 tokens: 5.892516613006592


 99%|█████████▉| 331/334 [04:35<00:02,  1.19it/s]

Perplexity after 168960 tokens: 5.901963233947754


 99%|█████████▉| 332/334 [04:35<00:01,  1.19it/s]

Perplexity after 169472 tokens: 5.897468566894531


100%|█████████▉| 333/334 [04:36<00:00,  1.19it/s]

Perplexity after 169984 tokens: 5.897426128387451


100%|██████████| 334/334 [04:37<00:00,  1.20it/s]

Perplexity after 170496 tokens: 5.8984222412109375
Final perplexity: 5.8984222412109375





In [5]:
posplustruthtuned_model_lls = lls
posplustruthtuned_model_ppl = ppl
posplustruthtuned_model_ppl

tensor(5.8984, device='cuda:0')

In [15]:
import numpy as np
from scipy import stats

def perplexity(log_probs):
    """Calculate perplexity given log probabilities."""
    return np.exp(-np.mean(log_probs))

def paired_bootstrap_test(log_probs1, log_probs2, num_bootstrap=1000, confidence_level=0.95):
    """
    Perform paired bootstrap test to compare perplexities of two models.
    
    Args:
    log_probs1, log_probs2: Arrays of log probabilities from two models
    num_bootstrap: Number of bootstrap samples
    confidence_level: Confidence level for the test
    
    Returns:
    is_significant: Boolean indicating if the difference is significant
    ci_lower, ci_upper: Confidence interval of the difference in perplexity
    """
    # Ensure inputs are numpy arrays
    log_probs1 = np.array(log_probs1)
    log_probs2 = np.array(log_probs2)
    
    assert len(log_probs1) == len(log_probs2), "Log probability arrays must have the same length"
    
    n = len(log_probs1)
    diffs = []
    
    for _ in range(num_bootstrap):
        # Sample with replacement using integer indices
        indices = np.random.randint(0, n, size=n)
        sample1 = log_probs1[indices]
        sample2 = log_probs2[indices]
        
        # Calculate perplexity difference
        perp1 = perplexity(sample1)
        perp2 = perplexity(sample2)
        diffs.append(perp1 - perp2)
    
    # Calculate confidence interval
    ci_lower, ci_upper = np.percentile(diffs, [(1 - confidence_level) / 2 * 100, (1 + confidence_level) / 2 * 100])
    
    # Check if the confidence interval includes 0
    is_significant = (ci_lower > 0) or (ci_upper < 0)
    
    return is_significant, ci_lower, ci_upper

# Example usage
np.random.seed(42)  # for reproducibility

is_significant, ci_lower, ci_upper = paired_bootstrap_test(untuned_model_lls_list, posplustruthtuned_model_lls_list)

print(f"Is the difference significant? {is_significant}")
print(f"95% Confidence Interval: ({ci_lower:.4f}, {ci_upper:.4f})")

if is_significant:
    if ci_lower > 0:
        print("Model 1 has significantly higher perplexity (worse) than Model 2")
    else:
        print("Model 2 has significantly higher perplexity (worse) than Model 1")
else:
    print("There is no significant difference in perplexity between the two models")

Is the difference significant? False
95% Confidence Interval: (0.0000, 0.0000)
There is no significant difference in perplexity between the two models


In [11]:
untuned_model_lls_list = [ll.to('cpu').item() for ll in untuned_model_lls]
posplustruthtuned_model_lls_list = [ll.to('cpu').item() for ll in posplustruthtuned_model_lls]

In [13]:
untuned_model_lls_list[:5]

[1037.9171142578125,
 776.4319458007812,
 1040.4102783203125,
 943.4534912109375,
 1056.138427734375]