In [1]:
%load_ext autoreload
%autoreload 2
# %load_ext lab_black

In [2]:
# Filter out hard questions
from collections import defaultdict
import numpy as np
import random
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# IMPORTANT: Run as if from project root so that imports work.
pardir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
os.chdir(pardir)
os.getcwd()

'/cluster/work/cotterell/kdu/context-vs-prior-finetuning'

In [4]:
from preprocessing.dataset import load_dataset_from_path
from preprocessing.utils import format_query, evaluate_model
from model_utils.utils import evaluate_model

In [5]:
SEED = 1
random.seed(SEED)
np.random.seed(SEED)

In [43]:
ROOT_DATA_DIR = "data/Yago/"
RAW_DATA_PATH = os.path.join(ROOT_DATA_DIR, "yago_qec.json")
FILTERED_RAW_DATA_PATH = os.path.join(ROOT_DATA_DIR, "llama2chat_yago_qec.json")
dataset = load_dataset_from_path(RAW_DATA_PATH)
len(dataset.keys())
dataset[:1]

In [7]:
model_id = "unsloth/llama-2-7b-chat-bnb-4bit"
device = "auto"
dtype = "auto"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map=device,
    torch_dtype=dtype,
    # attn_implementation=attn_implementation,
)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


In [8]:
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
tokenizer.padding_side = "right" 

In [9]:
my_dataset = defaultdict(list)
# num_entities = 50
for query_id, qec in tqdm(dataset.items()):
    ents_and_answers = list(zip(qec["entities"], qec["answers"]))
    random.shuffle(ents_and_answers)
    # ents_and_answers = ents_and_answers[:num_entities]
    for entity, answer in ents_and_answers:
        for qt, qfs in qec["query_forms"].items():
            for qf in qfs:
                if not qf.startswith("Q:"):
                    query = format_query(
                        query=qf, entity=(entity,), context="", answer=answer
                    )
                    my_dataset["text"] += [query]
                    my_dataset["labels"] += [
                        answer if qt == "open" else "No"
                    ]

                    # Add metadata shared between both examples
                    my_dataset["entity"] += [entity]
                    my_dataset["query_id"] += [query_id]
                    my_dataset["query_type"] += [qt]


df_all = pd.DataFrame.from_dict(my_dataset)
df_all

100%|██████████| 125/125 [00:00<00:00, 387.66it/s]


Unnamed: 0,text,labels,entity,query_id,query_type
0,'Lady Blue' is about,Chicago Police Department,Lady Blue,http://schema.org/about,open
1,'The Guys' is about,New York City Fire Department,The Guys,http://schema.org/about,open
2,'installation of Yang di-Pertuan Agong XVI' is...,Abdullah of Pahang,installation of Yang di-Pertuan Agong XVI,http://schema.org/about,open
3,'Utilities Act 2000' is about,public utility,Utilities Act 2000,http://schema.org/about,open
4,'905' is about,cloning,905,http://schema.org/about,open
...,...,...,...,...,...
109050,Qingshui District is the terminus of,Provincial Highway 10,Qingshui District,reverse-http://yago-knowledge.org/resource/ter...,open
109051,Beppu is the terminus of,Japan National Route 213,Beppu,reverse-http://yago-knowledge.org/resource/ter...,open
109052,San Antonio is the terminus of,U.S. Route 380,San Antonio,reverse-http://yago-knowledge.org/resource/ter...,open
109053,Amtali Upazila is the terminus of,R881,Amtali Upazila,reverse-http://yago-knowledge.org/resource/ter...,open


In [10]:
dataset = Dataset.from_pandas(df_all) #.select(range(256))

In [11]:
results, metrics = evaluate_model(model, tokenizer, dataset, batch_sz=360)

Map: 100%|██████████| 109055/109055 [00:04<00:00, 24280.51 examples/s]
  0%|          | 1/303 [00:08<43:21,  8.61s/it]

Average accuracy at batch 0: 0.008333333333333333 (3/360).


  1%|          | 2/303 [00:13<32:22,  6.45s/it]

Average accuracy at batch 1: 0.008333333333333333 (6/720).


  1%|          | 3/303 [00:17<27:23,  5.48s/it]

Average accuracy at batch 2: 0.009259259259259259 (10/1080).


  1%|▏         | 4/303 [00:22<24:59,  5.01s/it]

Average accuracy at batch 3: 0.006944444444444444 (10/1440).


  2%|▏         | 5/303 [00:26<23:48,  4.79s/it]

Average accuracy at batch 4: 0.005555555555555556 (10/1800).


  2%|▏         | 6/303 [00:30<22:53,  4.63s/it]

Average accuracy at batch 5: 0.005092592592592593 (11/2160).


  2%|▏         | 7/303 [00:35<22:22,  4.54s/it]

Average accuracy at batch 6: 0.005555555555555556 (14/2520).


  3%|▎         | 8/303 [00:39<21:35,  4.39s/it]

Average accuracy at batch 7: 0.005208333333333333 (15/2880).


  3%|▎         | 9/303 [00:43<20:37,  4.21s/it]

Average accuracy at batch 8: 0.0049382716049382715 (16/3240).


  3%|▎         | 10/303 [00:47<20:52,  4.28s/it]

Average accuracy at batch 9: 0.0044444444444444444 (16/3600).


  4%|▎         | 11/303 [00:51<20:29,  4.21s/it]

Average accuracy at batch 10: 0.00404040404040404 (16/3960).


  4%|▍         | 12/303 [00:55<20:24,  4.21s/it]

Average accuracy at batch 11: 0.005092592592592593 (22/4320).


  4%|▍         | 13/303 [00:59<20:13,  4.18s/it]

Average accuracy at batch 12: 0.006196581196581196 (29/4680).


  5%|▍         | 14/303 [01:03<19:51,  4.12s/it]

Average accuracy at batch 13: 0.006944444444444444 (35/5040).


  5%|▍         | 15/303 [01:07<19:37,  4.09s/it]

Average accuracy at batch 14: 0.006851851851851852 (37/5400).


  5%|▌         | 16/303 [01:12<19:38,  4.11s/it]

Average accuracy at batch 15: 0.006597222222222222 (38/5760).


  6%|▌         | 17/303 [01:16<19:42,  4.13s/it]

Average accuracy at batch 16: 0.00980392156862745 (60/6120).


  6%|▌         | 18/303 [01:20<20:00,  4.21s/it]

Average accuracy at batch 17: 0.013734567901234567 (89/6480).


  6%|▋         | 19/303 [01:26<22:50,  4.83s/it]

Average accuracy at batch 18: 0.017105263157894738 (117/6840).


  7%|▋         | 20/303 [01:31<22:20,  4.74s/it]

Average accuracy at batch 19: 0.017638888888888888 (127/7200).


  7%|▋         | 21/303 [01:35<21:36,  4.60s/it]

Average accuracy at batch 20: 0.017063492063492062 (129/7560).


  7%|▋         | 22/303 [01:40<21:11,  4.52s/it]

Average accuracy at batch 21: 0.016414141414141416 (130/7920).


  8%|▊         | 23/303 [01:44<20:20,  4.36s/it]

Average accuracy at batch 22: 0.01570048309178744 (130/8280).


  8%|▊         | 24/303 [01:47<19:27,  4.18s/it]

Average accuracy at batch 23: 0.015046296296296295 (130/8640).


  8%|▊         | 25/303 [01:51<19:01,  4.11s/it]

Average accuracy at batch 24: 0.015 (135/9000).


  9%|▊         | 26/303 [01:55<18:53,  4.09s/it]

Average accuracy at batch 25: 0.014423076923076924 (135/9360).


  9%|▉         | 27/303 [01:59<18:50,  4.10s/it]

Average accuracy at batch 26: 0.014197530864197531 (138/9720).


  9%|▉         | 28/303 [02:04<18:53,  4.12s/it]

Average accuracy at batch 27: 0.01369047619047619 (138/10080).


 10%|▉         | 29/303 [02:08<19:12,  4.21s/it]

Average accuracy at batch 28: 0.013409961685823755 (140/10440).


 10%|▉         | 30/303 [02:12<19:03,  4.19s/it]

Average accuracy at batch 29: 0.013425925925925926 (145/10800).


 10%|█         | 31/303 [02:16<18:52,  4.16s/it]

Average accuracy at batch 30: 0.013351254480286739 (149/11160).


 11%|█         | 32/303 [02:21<19:23,  4.29s/it]

Average accuracy at batch 31: 0.013541666666666667 (156/11520).


 11%|█         | 33/303 [02:25<19:05,  4.24s/it]

Average accuracy at batch 32: 0.013215488215488215 (157/11880).


 11%|█         | 34/303 [02:29<19:20,  4.31s/it]

Average accuracy at batch 33: 0.012826797385620915 (157/12240).


 12%|█▏        | 35/303 [02:34<19:13,  4.30s/it]

Average accuracy at batch 34: 0.01246031746031746 (157/12600).


 12%|█▏        | 36/303 [02:38<18:39,  4.19s/it]

Average accuracy at batch 35: 0.012114197530864198 (157/12960).


 12%|█▏        | 37/303 [02:42<18:31,  4.18s/it]

Average accuracy at batch 36: 0.011786786786786787 (157/13320).


 13%|█▎        | 38/303 [02:46<18:13,  4.13s/it]

Average accuracy at batch 37: 0.011476608187134503 (157/13680).


 13%|█▎        | 39/303 [02:50<17:50,  4.06s/it]

Average accuracy at batch 38: 0.011182336182336182 (157/14040).


 13%|█▎        | 40/303 [02:54<17:46,  4.05s/it]

Average accuracy at batch 39: 0.010902777777777779 (157/14400).


 14%|█▎        | 41/303 [02:58<17:54,  4.10s/it]

Average accuracy at batch 40: 0.011111111111111112 (164/14760).


 14%|█▍        | 42/303 [03:02<17:55,  4.12s/it]

Average accuracy at batch 41: 0.011904761904761904 (180/15120).


 14%|█▍        | 43/303 [03:06<17:40,  4.08s/it]

Average accuracy at batch 42: 0.01214470284237726 (188/15480).


 15%|█▍        | 44/303 [03:10<17:35,  4.07s/it]

Average accuracy at batch 43: 0.012626262626262626 (200/15840).


 15%|█▍        | 45/303 [03:14<17:46,  4.13s/it]

Average accuracy at batch 44: 0.013209876543209877 (214/16200).


 15%|█▌        | 46/303 [03:19<17:44,  4.14s/it]

Average accuracy at batch 45: 0.014190821256038648 (235/16560).


 16%|█▌        | 47/303 [03:23<17:34,  4.12s/it]

Average accuracy at batch 46: 0.013947990543735224 (236/16920).


 16%|█▌        | 48/303 [03:27<17:17,  4.07s/it]

Average accuracy at batch 47: 0.013657407407407408 (236/17280).


 16%|█▌        | 49/303 [03:31<17:10,  4.06s/it]

Average accuracy at batch 48: 0.013378684807256236 (236/17640).


 17%|█▋        | 50/303 [03:35<17:28,  4.15s/it]

Average accuracy at batch 49: 0.013333333333333334 (240/18000).


 17%|█▋        | 51/303 [03:39<17:45,  4.23s/it]

Average accuracy at batch 50: 0.013126361655773421 (241/18360).


 17%|█▋        | 52/303 [03:44<18:03,  4.32s/it]

Average accuracy at batch 51: 0.012980769230769231 (243/18720).


 17%|█▋        | 53/303 [03:57<28:53,  6.93s/it]

Average accuracy at batch 52: 0.012788259958071278 (244/19080).


 18%|█▊        | 54/303 [04:01<25:33,  6.16s/it]

Average accuracy at batch 53: 0.012551440329218106 (244/19440).


 18%|█▊        | 55/303 [04:05<22:43,  5.50s/it]

Average accuracy at batch 54: 0.012424242424242424 (246/19800).


 18%|█▊        | 56/303 [04:09<20:56,  5.09s/it]

Average accuracy at batch 55: 0.012351190476190476 (249/20160).


 19%|█▉        | 57/303 [04:14<19:58,  4.87s/it]

Average accuracy at batch 56: 0.012134502923976609 (249/20520).


 19%|█▉        | 58/303 [04:18<19:22,  4.74s/it]

Average accuracy at batch 57: 0.01192528735632184 (249/20880).


 19%|█▉        | 59/303 [04:23<18:51,  4.64s/it]

Average accuracy at batch 58: 0.012099811676082862 (257/21240).


 20%|█▉        | 60/303 [04:27<18:21,  4.53s/it]

Average accuracy at batch 59: 0.013194444444444444 (285/21600).


 20%|██        | 61/303 [04:31<18:03,  4.48s/it]

Average accuracy at batch 60: 0.014298724954462659 (314/21960).


 20%|██        | 62/303 [04:36<17:48,  4.43s/it]

Average accuracy at batch 61: 0.015277777777777777 (341/22320).


 21%|██        | 63/303 [04:40<17:23,  4.35s/it]

Average accuracy at batch 62: 0.016710758377425045 (379/22680).


 21%|██        | 64/303 [04:44<16:53,  4.24s/it]

Average accuracy at batch 63: 0.01818576388888889 (419/23040).


 21%|██▏       | 65/303 [04:48<16:36,  4.19s/it]

Average accuracy at batch 64: 0.018333333333333333 (429/23400).


 22%|██▏       | 66/303 [04:52<16:29,  4.17s/it]

Average accuracy at batch 65: 0.01813973063973064 (431/23760).


 22%|██▏       | 67/303 [04:56<16:09,  4.11s/it]

Average accuracy at batch 66: 0.018159203980099504 (438/24120).


 22%|██▏       | 68/303 [05:00<15:48,  4.04s/it]

Average accuracy at batch 67: 0.018504901960784314 (453/24480).


 23%|██▎       | 69/303 [05:04<15:56,  4.09s/it]

Average accuracy at batch 68: 0.019444444444444445 (483/24840).


 23%|██▎       | 70/303 [05:08<16:05,  4.14s/it]

Average accuracy at batch 69: 0.020317460317460317 (512/25200).


 23%|██▎       | 71/303 [05:13<16:25,  4.25s/it]

Average accuracy at batch 70: 0.02038341158059468 (521/25560).


 24%|██▍       | 72/303 [05:17<16:28,  4.28s/it]

Average accuracy at batch 71: 0.02029320987654321 (526/25920).


 24%|██▍       | 73/303 [05:21<16:23,  4.27s/it]

Average accuracy at batch 72: 0.020243531202435314 (532/26280).


 24%|██▍       | 74/303 [05:27<17:47,  4.66s/it]

Average accuracy at batch 73: 0.020232732732732732 (539/26640).


 25%|██▍       | 75/303 [05:31<17:13,  4.53s/it]

Average accuracy at batch 74: 0.02 (540/27000).


 25%|██▌       | 76/303 [05:36<16:59,  4.49s/it]

Average accuracy at batch 75: 0.0197733918128655 (541/27360).


 25%|██▌       | 77/303 [05:40<16:53,  4.48s/it]

Average accuracy at batch 76: 0.019552669552669554 (542/27720).


 26%|██▌       | 78/303 [05:44<16:05,  4.29s/it]

Average accuracy at batch 77: 0.019373219373219373 (544/28080).


 26%|██▌       | 79/303 [05:48<15:34,  4.17s/it]

Average accuracy at batch 78: 0.01912798874824191 (544/28440).


 26%|██▋       | 80/303 [05:53<16:49,  4.53s/it]

Average accuracy at batch 79: 0.01888888888888889 (544/28800).


 27%|██▋       | 81/303 [05:58<17:04,  4.62s/it]

Average accuracy at batch 80: 0.018655692729766804 (544/29160).


 27%|██▋       | 82/303 [06:02<16:57,  4.60s/it]

Average accuracy at batch 81: 0.01842818428184282 (544/29520).


 27%|██▋       | 83/303 [06:07<16:39,  4.54s/it]

Average accuracy at batch 82: 0.01820615796519411 (544/29880).


 28%|██▊       | 84/303 [06:11<16:21,  4.48s/it]

Average accuracy at batch 83: 0.01798941798941799 (544/30240).


 28%|██▊       | 85/303 [06:15<15:44,  4.33s/it]

Average accuracy at batch 84: 0.017777777777777778 (544/30600).


 28%|██▊       | 86/303 [06:20<15:50,  4.38s/it]

Average accuracy at batch 85: 0.01757105943152455 (544/30960).


 29%|██▊       | 87/303 [06:24<15:45,  4.38s/it]

Average accuracy at batch 86: 0.01756066411238825 (550/31320).


 29%|██▉       | 88/303 [06:28<15:40,  4.38s/it]

Average accuracy at batch 87: 0.017771464646464646 (563/31680).


 29%|██▉       | 89/303 [06:33<15:36,  4.37s/it]

Average accuracy at batch 88: 0.017852684144818977 (572/32040).


 30%|██▉       | 90/303 [06:37<15:28,  4.36s/it]

Average accuracy at batch 89: 0.01765432098765432 (572/32400).


 30%|███       | 91/303 [06:41<15:13,  4.31s/it]

Average accuracy at batch 90: 0.01746031746031746 (572/32760).


 30%|███       | 92/303 [06:46<15:12,  4.33s/it]

Average accuracy at batch 91: 0.017270531400966183 (572/33120).


 31%|███       | 93/303 [06:50<14:47,  4.23s/it]

Average accuracy at batch 92: 0.017084826762246118 (572/33480).


 31%|███       | 94/303 [06:54<14:38,  4.20s/it]

Average accuracy at batch 93: 0.01690307328605201 (572/33840).


 31%|███▏      | 95/303 [06:58<14:20,  4.14s/it]

Average accuracy at batch 94: 0.01672514619883041 (572/34200).


 32%|███▏      | 96/303 [07:02<14:33,  4.22s/it]

Average accuracy at batch 95: 0.016550925925925927 (572/34560).


 32%|███▏      | 97/303 [07:07<14:57,  4.36s/it]

Average accuracy at batch 96: 0.016380297823596793 (572/34920).


 32%|███▏      | 98/303 [07:11<14:53,  4.36s/it]

Average accuracy at batch 97: 0.016213151927437643 (572/35280).


 33%|███▎      | 99/303 [07:16<14:46,  4.35s/it]

Average accuracy at batch 98: 0.017508417508417508 (624/35640).


 33%|███▎      | 100/303 [07:20<14:54,  4.41s/it]

Average accuracy at batch 99: 0.018944444444444444 (682/36000).


 33%|███▎      | 101/303 [07:25<14:58,  4.45s/it]

Average accuracy at batch 100: 0.020819581958195818 (757/36360).


 34%|███▎      | 102/303 [07:30<15:17,  4.57s/it]

Average accuracy at batch 101: 0.02167755991285403 (796/36720).


 34%|███▍      | 103/303 [07:34<15:10,  4.55s/it]

Average accuracy at batch 102: 0.022114347357065803 (820/37080).


 34%|███▍      | 104/303 [07:38<14:22,  4.33s/it]

Average accuracy at batch 103: 0.0219017094017094 (820/37440).


 35%|███▍      | 105/303 [07:42<14:02,  4.25s/it]

Average accuracy at batch 104: 0.021693121693121695 (820/37800).


 35%|███▍      | 106/303 [07:46<13:46,  4.19s/it]

Average accuracy at batch 105: 0.02148846960167715 (820/38160).


 35%|███▌      | 107/303 [07:50<13:43,  4.20s/it]

Average accuracy at batch 106: 0.021287642782969886 (820/38520).


 36%|███▌      | 108/303 [07:55<13:47,  4.25s/it]

Average accuracy at batch 107: 0.02109053497942387 (820/38880).


 36%|███▌      | 109/303 [07:59<13:49,  4.28s/it]

Average accuracy at batch 108: 0.021075433231396535 (827/39240).


 36%|███▋      | 110/303 [08:03<13:55,  4.33s/it]

Average accuracy at batch 109: 0.021136363636363637 (837/39600).


 37%|███▋      | 111/303 [08:08<13:54,  4.34s/it]

Average accuracy at batch 110: 0.021246246246246245 (849/39960).


 37%|███▋      | 112/303 [08:12<13:45,  4.32s/it]

Average accuracy at batch 111: 0.022048611111111113 (889/40320).


 37%|███▋      | 113/303 [08:16<13:32,  4.28s/it]

Average accuracy at batch 112: 0.023647984267453293 (962/40680).


 38%|███▊      | 114/303 [08:20<13:13,  4.20s/it]

Average accuracy at batch 113: 0.02482943469785575 (1019/41040).


 38%|███▊      | 115/303 [08:24<12:52,  4.11s/it]

Average accuracy at batch 114: 0.02471014492753623 (1023/41400).


 38%|███▊      | 116/303 [08:28<12:48,  4.11s/it]

Average accuracy at batch 115: 0.02449712643678161 (1023/41760).


 39%|███▊      | 117/303 [08:32<12:36,  4.07s/it]

Average accuracy at batch 116: 0.024287749287749287 (1023/42120).


 39%|███▉      | 118/303 [08:36<12:43,  4.13s/it]

Average accuracy at batch 117: 0.02412900188323917 (1025/42480).


 39%|███▉      | 119/303 [08:41<13:14,  4.32s/it]

Average accuracy at batch 118: 0.02392623716153128 (1025/42840).


 40%|███▉      | 120/303 [08:46<13:18,  4.36s/it]

Average accuracy at batch 119: 0.02388888888888889 (1032/43200).


 40%|███▉      | 121/303 [08:50<13:11,  4.35s/it]

Average accuracy at batch 120: 0.024471992653810837 (1066/43560).


 40%|████      | 122/303 [08:54<13:02,  4.32s/it]

Average accuracy at batch 121: 0.024726775956284152 (1086/43920).


 41%|████      | 123/303 [09:00<14:14,  4.75s/it]

Average accuracy at batch 122: 0.0248193315266486 (1099/44280).


 41%|████      | 124/303 [09:04<13:42,  4.59s/it]

Average accuracy at batch 123: 0.027262544802867383 (1217/44640).


 41%|████▏     | 125/303 [09:08<13:15,  4.47s/it]

Average accuracy at batch 124: 0.02942222222222222 (1324/45000).


 42%|████▏     | 126/303 [09:13<12:50,  4.35s/it]

Average accuracy at batch 125: 0.03181216931216931 (1443/45360).


 42%|████▏     | 127/303 [09:16<12:19,  4.20s/it]

Average accuracy at batch 126: 0.03521434820647419 (1610/45720).


 42%|████▏     | 128/303 [09:21<12:35,  4.32s/it]

Average accuracy at batch 127: 0.035980902777777775 (1658/46080).


 43%|████▎     | 129/303 [09:26<12:48,  4.42s/it]

Average accuracy at batch 128: 0.0363479758828596 (1688/46440).


 43%|████▎     | 130/303 [09:30<12:55,  4.48s/it]

Average accuracy at batch 129: 0.03666666666666667 (1716/46800).


 43%|████▎     | 131/303 [09:34<12:30,  4.36s/it]

Average accuracy at batch 130: 0.03695928753180661 (1743/47160).


 44%|████▎     | 132/303 [09:38<12:13,  4.29s/it]

Average accuracy at batch 131: 0.03758417508417508 (1786/47520).


 44%|████▍     | 133/303 [09:43<12:04,  4.26s/it]

Average accuracy at batch 132: 0.037865497076023394 (1813/47880).


 44%|████▍     | 134/303 [09:47<11:48,  4.19s/it]

Average accuracy at batch 133: 0.03758291873963516 (1813/48240).


 45%|████▍     | 135/303 [09:51<11:37,  4.15s/it]

Average accuracy at batch 134: 0.037325102880658434 (1814/48600).


 45%|████▍     | 136/303 [09:55<11:33,  4.15s/it]

Average accuracy at batch 135: 0.03733660130718954 (1828/48960).


 45%|████▌     | 137/303 [09:59<11:32,  4.17s/it]

Average accuracy at batch 136: 0.03775344687753447 (1862/49320).


 46%|████▌     | 138/303 [10:03<11:22,  4.14s/it]

Average accuracy at batch 137: 0.03810386473429952 (1893/49680).


 46%|████▌     | 139/303 [10:07<11:16,  4.12s/it]

Average accuracy at batch 138: 0.03788968824940048 (1896/50040).


 46%|████▌     | 140/303 [10:11<11:12,  4.13s/it]

Average accuracy at batch 139: 0.03761904761904762 (1896/50400).


 47%|████▋     | 141/303 [10:15<11:06,  4.11s/it]

Average accuracy at batch 140: 0.03735224586288416 (1896/50760).


 47%|████▋     | 142/303 [10:20<11:17,  4.21s/it]

Average accuracy at batch 141: 0.03708920187793427 (1896/51120).


 47%|████▋     | 143/303 [10:24<11:21,  4.26s/it]

Average accuracy at batch 142: 0.03682983682983683 (1896/51480).


 48%|████▊     | 144/303 [10:28<11:12,  4.23s/it]

Average accuracy at batch 143: 0.03665123456790124 (1900/51840).


 48%|████▊     | 145/303 [10:33<11:42,  4.44s/it]

Average accuracy at batch 144: 0.03664750957854406 (1913/52200).


 48%|████▊     | 146/303 [10:38<11:35,  4.43s/it]

Average accuracy at batch 145: 0.03656773211567732 (1922/52560).


 49%|████▊     | 147/303 [10:43<12:02,  4.63s/it]

Average accuracy at batch 146: 0.036432350718065006 (1928/52920).


 49%|████▉     | 148/303 [10:47<11:53,  4.61s/it]

Average accuracy at batch 147: 0.03704954954954955 (1974/53280).


 49%|████▉     | 149/303 [10:52<11:51,  4.62s/it]

Average accuracy at batch 148: 0.03747203579418344 (2010/53640).


 50%|████▉     | 150/303 [10:56<11:29,  4.51s/it]

Average accuracy at batch 149: 0.03772222222222222 (2037/54000).


 50%|████▉     | 151/303 [11:01<11:15,  4.45s/it]

Average accuracy at batch 150: 0.03774834437086093 (2052/54360).


 50%|█████     | 152/303 [11:05<11:12,  4.46s/it]

Average accuracy at batch 151: 0.0377375730994152 (2065/54720).


 50%|█████     | 153/303 [11:10<11:11,  4.47s/it]

Average accuracy at batch 152: 0.03749092229484386 (2065/55080).


 51%|█████     | 154/303 [11:14<11:04,  4.46s/it]

Average accuracy at batch 153: 0.037247474747474744 (2065/55440).


 51%|█████     | 155/303 [11:19<11:10,  4.53s/it]

Average accuracy at batch 154: 0.0371505376344086 (2073/55800).


 51%|█████▏    | 156/303 [11:23<10:34,  4.32s/it]

Average accuracy at batch 155: 0.03698361823361823 (2077/56160).


 52%|█████▏    | 157/303 [11:26<10:07,  4.16s/it]

Average accuracy at batch 156: 0.03674805378627035 (2077/56520).


 52%|█████▏    | 158/303 [11:30<09:50,  4.07s/it]

Average accuracy at batch 157: 0.0365154711673699 (2077/56880).


 52%|█████▏    | 159/303 [11:34<09:47,  4.08s/it]

Average accuracy at batch 158: 0.03646051712089448 (2087/57240).


 53%|█████▎    | 160/303 [11:38<09:47,  4.11s/it]

Average accuracy at batch 159: 0.03630208333333333 (2091/57600).


 53%|█████▎    | 161/303 [11:42<09:35,  4.05s/it]

Average accuracy at batch 160: 0.03609385783298827 (2092/57960).


 53%|█████▎    | 162/303 [11:47<09:39,  4.11s/it]

Average accuracy at batch 161: 0.035939643347050756 (2096/58320).


 54%|█████▍    | 163/303 [11:51<09:26,  4.05s/it]

Average accuracy at batch 162: 0.035753237900477165 (2098/58680).


 54%|█████▍    | 164/303 [11:55<09:28,  4.09s/it]

Average accuracy at batch 163: 0.03553523035230352 (2098/59040).


 54%|█████▍    | 165/303 [11:59<09:27,  4.11s/it]

Average accuracy at batch 164: 0.03531986531986532 (2098/59400).


 55%|█████▍    | 166/303 [12:03<09:33,  4.19s/it]

Average accuracy at batch 165: 0.03510709504685408 (2098/59760).


 55%|█████▌    | 167/303 [12:08<09:36,  4.24s/it]

Average accuracy at batch 166: 0.034896872920825016 (2098/60120).


 55%|█████▌    | 168/303 [12:12<09:32,  4.24s/it]

Average accuracy at batch 167: 0.034689153439153436 (2098/60480).


 56%|█████▌    | 169/303 [12:16<09:41,  4.34s/it]

Average accuracy at batch 168: 0.036242603550295856 (2205/60840).


 56%|█████▌    | 170/303 [12:21<09:28,  4.27s/it]

Average accuracy at batch 169: 0.03861111111111111 (2363/61200).


 56%|█████▋    | 171/303 [12:25<09:33,  4.34s/it]

Average accuracy at batch 170: 0.04116309291747888 (2534/61560).


 57%|█████▋    | 172/303 [12:29<09:23,  4.30s/it]

Average accuracy at batch 171: 0.041440568475452196 (2566/61920).


 57%|█████▋    | 173/303 [12:34<09:34,  4.42s/it]

Average accuracy at batch 172: 0.04120102761721259 (2566/62280).


 57%|█████▋    | 174/303 [12:39<09:37,  4.47s/it]

Average accuracy at batch 173: 0.04096424010217114 (2566/62640).


 58%|█████▊    | 175/303 [12:43<09:39,  4.53s/it]

Average accuracy at batch 174: 0.04073015873015873 (2566/63000).


 58%|█████▊    | 176/303 [12:47<09:16,  4.38s/it]

Average accuracy at batch 175: 0.040514520202020204 (2567/63360).


 58%|█████▊    | 177/303 [12:51<09:03,  4.31s/it]

Average accuracy at batch 176: 0.040285624607658506 (2567/63720).


 59%|█████▊    | 178/303 [12:55<08:47,  4.22s/it]

Average accuracy at batch 177: 0.04009051186017478 (2569/64080).


 59%|█████▉    | 179/303 [12:59<08:29,  4.11s/it]

Average accuracy at batch 178: 0.0398820608317815 (2570/64440).


 59%|█████▉    | 180/303 [13:03<08:17,  4.05s/it]

Average accuracy at batch 179: 0.03967592592592593 (2571/64800).


 60%|█████▉    | 181/303 [13:08<08:56,  4.40s/it]

Average accuracy at batch 180: 0.03947206875383671 (2572/65160).


 60%|██████    | 182/303 [13:14<09:20,  4.63s/it]

Average accuracy at batch 181: 0.039255189255189255 (2572/65520).


 60%|██████    | 183/303 [13:19<09:35,  4.79s/it]

Average accuracy at batch 182: 0.03904068002428658 (2572/65880).


 61%|██████    | 184/303 [13:23<09:28,  4.77s/it]

Average accuracy at batch 183: 0.03884359903381643 (2573/66240).


 61%|██████    | 185/303 [13:28<09:00,  4.58s/it]

Average accuracy at batch 184: 0.03866366366366367 (2575/66600).


 61%|██████▏   | 186/303 [13:32<08:48,  4.52s/it]

Average accuracy at batch 185: 0.038485663082437276 (2577/66960).


 62%|██████▏   | 187/303 [13:37<08:45,  4.53s/it]

Average accuracy at batch 186: 0.03827985739750445 (2577/67320).


 62%|██████▏   | 188/303 [13:41<08:40,  4.53s/it]

Average accuracy at batch 187: 0.038076241134751776 (2577/67680).


 62%|██████▏   | 189/303 [13:46<08:34,  4.51s/it]

Average accuracy at batch 188: 0.03787477954144621 (2577/68040).


 63%|██████▎   | 190/303 [13:50<08:24,  4.47s/it]

Average accuracy at batch 189: 0.0377046783625731 (2579/68400).


 63%|██████▎   | 191/303 [13:54<08:18,  4.45s/it]

Average accuracy at batch 190: 0.037507271669575334 (2579/68760).


 63%|██████▎   | 192/303 [13:59<08:13,  4.45s/it]

Average accuracy at batch 191: 0.0373119212962963 (2579/69120).


 64%|██████▎   | 193/303 [14:03<07:51,  4.28s/it]

Average accuracy at batch 192: 0.03711859527921704 (2579/69480).


 64%|██████▍   | 194/303 [14:07<07:38,  4.21s/it]

Average accuracy at batch 193: 0.03692726231386025 (2579/69840).


 64%|██████▍   | 195/303 [14:11<07:45,  4.31s/it]

Average accuracy at batch 194: 0.036737891737891736 (2579/70200).


 65%|██████▍   | 196/303 [14:16<07:47,  4.37s/it]

Average accuracy at batch 195: 0.03655045351473923 (2579/70560).


 65%|██████▌   | 197/303 [14:20<07:55,  4.49s/it]

Average accuracy at batch 196: 0.036364918217710096 (2579/70920).


 65%|██████▌   | 198/303 [14:25<07:45,  4.44s/it]

Average accuracy at batch 197: 0.03633557800224467 (2590/71280).


 66%|██████▌   | 199/303 [14:29<07:44,  4.47s/it]

Average accuracy at batch 198: 0.036306532663316585 (2601/71640).


 66%|██████▌   | 200/303 [14:34<07:47,  4.54s/it]

Average accuracy at batch 199: 0.03625 (2610/72000).


 66%|██████▋   | 201/303 [14:39<07:55,  4.66s/it]

Average accuracy at batch 200: 0.03608347153123272 (2611/72360).


 67%|██████▋   | 202/303 [14:44<07:58,  4.73s/it]

Average accuracy at batch 201: 0.035932343234323434 (2613/72720).


 67%|██████▋   | 203/303 [14:49<07:58,  4.78s/it]

Average accuracy at batch 202: 0.0360016420361248 (2631/73080).


 67%|██████▋   | 204/303 [14:53<07:35,  4.60s/it]

Average accuracy at batch 203: 0.03616557734204793 (2656/73440).


 68%|██████▊   | 205/303 [14:57<07:16,  4.45s/it]

Average accuracy at batch 204: 0.036382113821138214 (2685/73800).


 68%|██████▊   | 206/303 [15:02<07:14,  4.48s/it]

Average accuracy at batch 205: 0.03628640776699029 (2691/74160).


 68%|██████▊   | 207/303 [15:06<07:09,  4.48s/it]

Average accuracy at batch 206: 0.03611111111111111 (2691/74520).


 69%|██████▊   | 208/303 [15:10<06:53,  4.35s/it]

Average accuracy at batch 207: 0.0359375 (2691/74880).


 69%|██████▉   | 209/303 [15:15<06:49,  4.36s/it]

Average accuracy at batch 208: 0.03580542264752791 (2694/75240).


 69%|██████▉   | 210/303 [15:19<06:48,  4.39s/it]

Average accuracy at batch 209: 0.035661375661375665 (2696/75600).


 70%|██████▉   | 211/303 [15:23<06:44,  4.40s/it]

Average accuracy at batch 210: 0.03558451816745656 (2703/75960).


 70%|██████▉   | 212/303 [15:27<06:28,  4.27s/it]

Average accuracy at batch 211: 0.035783542976939205 (2731/76320).


 70%|███████   | 213/303 [15:31<06:18,  4.20s/it]

Average accuracy at batch 212: 0.03592853416797079 (2755/76680).


 71%|███████   | 214/303 [15:36<06:10,  4.17s/it]

Average accuracy at batch 213: 0.03613707165109034 (2784/77040).


 71%|███████   | 215/303 [15:40<06:05,  4.15s/it]

Average accuracy at batch 214: 0.03603359173126615 (2789/77400).


 71%|███████▏  | 216/303 [15:44<06:00,  4.14s/it]

Average accuracy at batch 215: 0.0358667695473251 (2789/77760).


 72%|███████▏  | 217/303 [15:48<06:00,  4.19s/it]

Average accuracy at batch 216: 0.03571428571428571 (2790/78120).


 72%|███████▏  | 218/303 [15:52<05:58,  4.22s/it]

Average accuracy at batch 217: 0.03556320081549439 (2791/78480).


 72%|███████▏  | 219/303 [15:57<05:55,  4.23s/it]

Average accuracy at batch 218: 0.03541349568746829 (2792/78840).


 73%|███████▎  | 220/303 [16:01<05:51,  4.23s/it]

Average accuracy at batch 219: 0.03525252525252525 (2792/79200).


 73%|███████▎  | 221/303 [16:05<05:45,  4.21s/it]

Average accuracy at batch 220: 0.0350930115635998 (2792/79560).


 73%|███████▎  | 222/303 [16:09<05:39,  4.19s/it]

Average accuracy at batch 221: 0.034934934934934936 (2792/79920).


 74%|███████▎  | 223/303 [16:13<05:34,  4.18s/it]

Average accuracy at batch 222: 0.034778276033881415 (2792/80280).


 74%|███████▍  | 224/303 [16:17<05:17,  4.02s/it]

Average accuracy at batch 223: 0.034623015873015875 (2792/80640).


 74%|███████▍  | 225/303 [16:21<05:15,  4.05s/it]

Average accuracy at batch 224: 0.03446913580246914 (2792/81000).


 75%|███████▍  | 226/303 [16:25<05:08,  4.01s/it]

Average accuracy at batch 225: 0.03431661750245821 (2792/81360).


 75%|███████▍  | 227/303 [16:29<05:06,  4.03s/it]

Average accuracy at batch 226: 0.03416544297601566 (2792/81720).


 75%|███████▌  | 228/303 [16:33<04:57,  3.97s/it]

Average accuracy at batch 227: 0.03401559454191033 (2792/82080).


 76%|███████▌  | 229/303 [16:37<04:53,  3.96s/it]

Average accuracy at batch 228: 0.03386705482775352 (2792/82440).


 76%|███████▌  | 230/303 [16:41<04:48,  3.95s/it]

Average accuracy at batch 229: 0.03371980676328502 (2792/82800).


 76%|███████▌  | 231/303 [16:45<04:48,  4.01s/it]

Average accuracy at batch 230: 0.0335978835978836 (2794/83160).


 77%|███████▋  | 232/303 [16:49<04:44,  4.00s/it]

Average accuracy at batch 231: 0.03356082375478927 (2803/83520).


 77%|███████▋  | 233/303 [16:53<04:42,  4.04s/it]

Average accuracy at batch 232: 0.03353600381497377 (2813/83880).


 77%|███████▋  | 234/303 [16:57<04:37,  4.03s/it]

Average accuracy at batch 233: 0.03344017094017094 (2817/84240).


 78%|███████▊  | 235/303 [17:01<04:30,  3.98s/it]

Average accuracy at batch 234: 0.033297872340425534 (2817/84600).


 78%|███████▊  | 236/303 [17:05<04:24,  3.95s/it]

Average accuracy at batch 235: 0.03315677966101695 (2817/84960).


 78%|███████▊  | 237/303 [17:09<04:22,  3.98s/it]

Average accuracy at batch 236: 0.0330168776371308 (2817/85320).


 79%|███████▊  | 238/303 [17:13<04:23,  4.06s/it]

Average accuracy at batch 237: 0.03288982259570495 (2818/85680).


 79%|███████▉  | 239/303 [17:17<04:19,  4.06s/it]

Average accuracy at batch 238: 0.032787075778707576 (2821/86040).


 79%|███████▉  | 240/303 [17:21<04:15,  4.06s/it]

Average accuracy at batch 239: 0.03266203703703704 (2822/86400).


 80%|███████▉  | 241/303 [17:25<04:14,  4.11s/it]

Average accuracy at batch 240: 0.03252650991240203 (2822/86760).


 80%|███████▉  | 242/303 [17:29<04:09,  4.09s/it]

Average accuracy at batch 241: 0.0323921028466483 (2822/87120).


 80%|████████  | 243/303 [17:33<04:01,  4.03s/it]

Average accuracy at batch 242: 0.03225880201188843 (2822/87480).


 81%|████████  | 244/303 [17:37<03:56,  4.01s/it]

Average accuracy at batch 243: 0.03212659380692168 (2822/87840).


 81%|████████  | 245/303 [17:41<03:50,  3.97s/it]

Average accuracy at batch 244: 0.03199546485260771 (2822/88200).


 81%|████████  | 246/303 [17:45<03:48,  4.02s/it]

Average accuracy at batch 245: 0.03186540198735321 (2822/88560).


 82%|████████▏ | 247/303 [17:49<03:44,  4.01s/it]

Average accuracy at batch 246: 0.03173639226270805 (2822/88920).


 82%|████████▏ | 248/303 [17:53<03:42,  4.05s/it]

Average accuracy at batch 247: 0.0316084229390681 (2822/89280).


 82%|████████▏ | 249/303 [17:58<03:40,  4.08s/it]

Average accuracy at batch 248: 0.03148148148148148 (2822/89640).


 83%|████████▎ | 250/303 [18:02<03:36,  4.08s/it]

Average accuracy at batch 249: 0.03135555555555555 (2822/90000).


 83%|████████▎ | 251/303 [18:06<03:31,  4.07s/it]

Average accuracy at batch 250: 0.03123063302346171 (2822/90360).


 83%|████████▎ | 252/303 [18:10<03:28,  4.09s/it]

Average accuracy at batch 251: 0.031106701940035272 (2822/90720).


 83%|████████▎ | 253/303 [18:14<03:23,  4.08s/it]

Average accuracy at batch 252: 0.03098375054896794 (2822/91080).


 84%|████████▍ | 254/303 [18:18<03:20,  4.10s/it]

Average accuracy at batch 253: 0.03088363954505687 (2824/91440).


 84%|████████▍ | 255/303 [18:22<03:20,  4.18s/it]

Average accuracy at batch 254: 0.030773420479302833 (2825/91800).


 84%|████████▍ | 256/303 [18:27<03:18,  4.22s/it]

Average accuracy at batch 255: 0.030696614583333334 (2829/92160).


 85%|████████▍ | 257/303 [18:31<03:14,  4.23s/it]

Average accuracy at batch 256: 0.030685257241677474 (2839/92520).


 85%|████████▌ | 258/303 [18:36<03:14,  4.33s/it]

Average accuracy at batch 257: 0.030684754521963824 (2850/92880).


 85%|████████▌ | 259/303 [18:40<03:07,  4.27s/it]

Average accuracy at batch 258: 0.03065208065208065 (2858/93240).


 86%|████████▌ | 260/303 [18:44<02:58,  4.15s/it]

Average accuracy at batch 259: 0.030544871794871795 (2859/93600).


 86%|████████▌ | 261/303 [18:48<02:51,  4.09s/it]

Average accuracy at batch 260: 0.030427841634738186 (2859/93960).


 86%|████████▋ | 262/303 [18:52<02:47,  4.09s/it]

Average accuracy at batch 261: 0.030311704834605598 (2859/94320).


 87%|████████▋ | 263/303 [18:56<02:46,  4.15s/it]

Average accuracy at batch 262: 0.030196451204055768 (2859/94680).


 87%|████████▋ | 264/303 [19:00<02:43,  4.20s/it]

Average accuracy at batch 263: 0.030082070707070706 (2859/95040).


 87%|████████▋ | 265/303 [19:04<02:39,  4.19s/it]

Average accuracy at batch 264: 0.029968553459119498 (2859/95400).


 88%|████████▊ | 266/303 [19:08<02:30,  4.05s/it]

Average accuracy at batch 265: 0.029866332497911444 (2860/95760).


 88%|████████▊ | 267/303 [19:12<02:26,  4.08s/it]

Average accuracy at batch 266: 0.029754473574698292 (2860/96120).


 88%|████████▊ | 268/303 [19:17<02:24,  4.14s/it]

Average accuracy at batch 267: 0.029643449419568823 (2860/96480).


 89%|████████▉ | 269/303 [19:21<02:24,  4.26s/it]

Average accuracy at batch 268: 0.0295332507228418 (2860/96840).


 89%|████████▉ | 270/303 [19:25<02:20,  4.27s/it]

Average accuracy at batch 269: 0.0294238683127572 (2860/97200).


 89%|████████▉ | 271/303 [19:29<02:14,  4.19s/it]

Average accuracy at batch 270: 0.02933579335793358 (2862/97560).


 90%|████████▉ | 272/303 [19:34<02:11,  4.24s/it]

Average accuracy at batch 271: 0.02922794117647059 (2862/97920).


 90%|█████████ | 273/303 [19:38<02:04,  4.14s/it]

Average accuracy at batch 272: 0.029334554334554334 (2883/98280).


 90%|█████████ | 274/303 [19:41<01:56,  4.02s/it]

Average accuracy at batch 273: 0.029491078669910787 (2909/98640).


 91%|█████████ | 275/303 [19:45<01:50,  3.95s/it]

Average accuracy at batch 274: 0.029595959595959596 (2930/99000).


 91%|█████████ | 276/303 [19:50<01:50,  4.09s/it]

Average accuracy at batch 275: 0.02951892109500805 (2933/99360).


 91%|█████████▏| 277/303 [19:54<01:47,  4.14s/it]

Average accuracy at batch 276: 0.02941235459286001 (2933/99720).


 92%|█████████▏| 278/303 [19:58<01:45,  4.21s/it]

Average accuracy at batch 277: 0.029306554756195043 (2933/100080).


 92%|█████████▏| 279/303 [20:02<01:40,  4.17s/it]

Average accuracy at batch 278: 0.029201513341298287 (2933/100440).


 92%|█████████▏| 280/303 [20:07<01:37,  4.22s/it]

Average accuracy at batch 279: 0.029097222222222222 (2933/100800).


 93%|█████████▎| 281/303 [20:11<01:32,  4.20s/it]

Average accuracy at batch 280: 0.02899367338869118 (2933/101160).


 93%|█████████▎| 282/303 [20:15<01:28,  4.20s/it]

Average accuracy at batch 281: 0.028890858944050433 (2933/101520).


 93%|█████████▎| 283/303 [20:19<01:23,  4.17s/it]

Average accuracy at batch 282: 0.028788771103258736 (2933/101880).


 94%|█████████▎| 284/303 [20:23<01:18,  4.14s/it]

Average accuracy at batch 283: 0.028687402190923316 (2933/102240).


 94%|█████████▍| 285/303 [20:28<01:16,  4.28s/it]

Average accuracy at batch 284: 0.02887914230019493 (2963/102600).


 94%|█████████▍| 286/303 [20:33<01:15,  4.45s/it]

Average accuracy at batch 285: 0.029137529137529136 (3000/102960).


 95%|█████████▍| 287/303 [20:37<01:11,  4.49s/it]

Average accuracy at batch 286: 0.029510259388308167 (3049/103320).


 95%|█████████▌| 288/303 [20:42<01:07,  4.50s/it]

Average accuracy at batch 287: 0.029427083333333333 (3051/103680).


 95%|█████████▌| 289/303 [20:46<01:02,  4.48s/it]

Average accuracy at batch 288: 0.029334871203383316 (3052/104040).


 96%|█████████▌| 290/303 [20:50<00:57,  4.42s/it]

Average accuracy at batch 289: 0.029233716475095785 (3052/104400).


 96%|█████████▌| 291/303 [20:54<00:51,  4.31s/it]

Average accuracy at batch 290: 0.029133256968308513 (3052/104760).


 96%|█████████▋| 292/303 [20:58<00:45,  4.17s/it]

Average accuracy at batch 291: 0.029033485540334854 (3052/105120).


 97%|█████████▋| 293/303 [21:02<00:40,  4.08s/it]

Average accuracy at batch 292: 0.028934395145999242 (3052/105480).


 97%|█████████▋| 294/303 [21:06<00:35,  3.92s/it]

Average accuracy at batch 293: 0.028835978835978836 (3052/105840).


 97%|█████████▋| 295/303 [21:09<00:30,  3.80s/it]

Average accuracy at batch 294: 0.028738229755178906 (3052/106200).


 98%|█████████▊| 296/303 [21:13<00:26,  3.83s/it]

Average accuracy at batch 295: 0.02865990990990991 (3054/106560).


 98%|█████████▊| 297/303 [21:17<00:23,  3.88s/it]

Average accuracy at batch 296: 0.02860082304526749 (3058/106920).


 98%|█████████▊| 298/303 [21:21<00:19,  3.93s/it]

Average accuracy at batch 297: 0.02855145413870246 (3063/107280).


 99%|█████████▊| 299/303 [21:25<00:15,  3.98s/it]

Average accuracy at batch 298: 0.028455964325529545 (3063/107640).


 99%|█████████▉| 300/303 [21:29<00:12,  4.04s/it]

Average accuracy at batch 299: 0.02837037037037037 (3064/108000).


 99%|█████████▉| 301/303 [21:33<00:08,  4.01s/it]

Average accuracy at batch 300: 0.02827611664820967 (3064/108360).


100%|█████████▉| 302/303 [21:37<00:03,  4.00s/it]

Average accuracy at batch 301: 0.028182487122884473 (3064/108720).


100%|██████████| 303/303 [21:41<00:00,  4.30s/it]

Average accuracy at batch 302: 0.028095914905323004 (3064/109055).
109055



Map: 100%|██████████| 109055/109055 [00:00<00:00, 1595530.32 examples/s]


In [12]:
results_df = results.to_pandas()

In [13]:
correct_results_df = results_df[results_df["is_correct"] == True]
correct_results_df

Unnamed: 0,text,labels,entity,query_id,query_type,predictions,is_correct
42,'Eason Chan filmography' is about,Eason Chan,Eason Chan filmography,http://schema.org/about,open,"Eason Chan, a Hong Kong singer, songwriter, an...",True
105,'Norma Talmadge filmography' is about,Norma Talmadge,Norma Talmadge filmography,http://schema.org/about,open,"Norma Talmadge, an American actress who was ac...",True
196,'Faye Dunaway filmography' is about,Faye Dunaway,Faye Dunaway filmography,http://schema.org/about,open,"Faye Dunaway, an American actress who was born...",True
481,'Paul Anka filmography' is about,Paul Anka,Paul Anka filmography,http://schema.org/about,open,"Paul Anka, a Canadian singer, songwriter, and ...",True
585,'Donald Duck filmography' is about,Donald Duck,Donald Duck filmography,http://schema.org/about,open,"Donald Duck's appearances in various films, TV...",True
...,...,...,...,...,...,...,...
107070,Nykøbing-Rørvig Municipality was replaced by,Odsherred Municipality,Nykøbing-Rørvig Municipality,reverse-http://yago-knowledge.org/resource/rep...,open,"Odsherred Municipality on January 1, 2007. Unt...",True
107114,Öland County was replaced by,Kalmar County,Öland County,reverse-http://yago-knowledge.org/resource/rep...,open,Kalmar County in 1998. Unterscheidung between ...,True
107256,Newtownabbey Borough Council was replaced by,Antrim and Newtownabbey Borough Council,Newtownabbey Borough Council,reverse-http://yago-knowledge.org/resource/rep...,open,Antrim and Newtownabbey Borough Council in 201...,True
107268,Ceredigion and Pembroke North was replaced by,Ceredigion,Ceredigion and Pembroke North,reverse-http://yago-knowledge.org/resource/rep...,open,Ceredigion and Pembrokeshire in 1996. Untersch...,True


In [14]:
correct_results_df["query_id"].value_counts()

query_id
http://yago-knowledge.org/resource/capital               457
http://schema.org/numberOfSeasons                        322
http://schema.org/officialLanguage                       202
http://schema.org/manufacturer                           200
http://schema.org/nationality                            172
                                                        ... 
http://schema.org/numberOfEmployees                        2
http://schema.org/owns                                     1
reverse-http://yago-knowledge.org/resource/parentBody      1
reverse-http://schema.org/sponsor                          1
reverse-http://yago-knowledge.org/resource/studentOf       1
Name: count, Length: 66, dtype: int64

In [15]:
print("Num query ids:", len(correct_results_df["query_id"].unique()))
for qid in correct_results_df["query_id"].unique():
    print(qid)
    print(correct_results_df[correct_results_df["query_id"] == qid].iloc[0])
    print("\n\n")

Num query ids: 66
http://schema.org/about
text                           'Eason Chan filmography' is about
labels                                                Eason Chan
entity                                    Eason Chan filmography
query_id                                 http://schema.org/about
query_type                                                  open
predictions    Eason Chan, a Hong Kong singer, songwriter, an...
is_correct                                                  True
Name: 42, dtype: object



http://schema.org/address
text                          Dallas Museum of Art is located at
labels                                 1717 North Harwood Street
entity                                      Dallas Museum of Art
query_id                               http://schema.org/address
query_type                                                  open
predictions    1717 North Harwood Street, Dallas, TX 75201. e...
is_correct                                                  

In [16]:
correct_results_df

Unnamed: 0,text,labels,entity,query_id,query_type,predictions,is_correct
42,'Eason Chan filmography' is about,Eason Chan,Eason Chan filmography,http://schema.org/about,open,"Eason Chan, a Hong Kong singer, songwriter, an...",True
105,'Norma Talmadge filmography' is about,Norma Talmadge,Norma Talmadge filmography,http://schema.org/about,open,"Norma Talmadge, an American actress who was ac...",True
196,'Faye Dunaway filmography' is about,Faye Dunaway,Faye Dunaway filmography,http://schema.org/about,open,"Faye Dunaway, an American actress who was born...",True
481,'Paul Anka filmography' is about,Paul Anka,Paul Anka filmography,http://schema.org/about,open,"Paul Anka, a Canadian singer, songwriter, and ...",True
585,'Donald Duck filmography' is about,Donald Duck,Donald Duck filmography,http://schema.org/about,open,"Donald Duck's appearances in various films, TV...",True
...,...,...,...,...,...,...,...
107070,Nykøbing-Rørvig Municipality was replaced by,Odsherred Municipality,Nykøbing-Rørvig Municipality,reverse-http://yago-knowledge.org/resource/rep...,open,"Odsherred Municipality on January 1, 2007. Unt...",True
107114,Öland County was replaced by,Kalmar County,Öland County,reverse-http://yago-knowledge.org/resource/rep...,open,Kalmar County in 1998. Unterscheidung between ...,True
107256,Newtownabbey Borough Council was replaced by,Antrim and Newtownabbey Borough Council,Newtownabbey Borough Council,reverse-http://yago-knowledge.org/resource/rep...,open,Antrim and Newtownabbey Borough Council in 201...,True
107268,Ceredigion and Pembroke North was replaced by,Ceredigion,Ceredigion and Pembroke North,reverse-http://yago-knowledge.org/resource/rep...,open,Ceredigion and Pembrokeshire in 1996. Untersch...,True


In [19]:
yago_qec = load_dataset_from_path(RAW_DATA_PATH)

In [36]:
qid_to_entity_df = correct_results_df.groupby("query_id")["entity"].apply(list).reset_index()
qid_to_valid_entities = {row["query_id"]: set(row["entity"]) for _, row in qid_to_entity_df.iterrows()}

In [29]:
{k: len(v) for k, v in yago_qec['http://schema.org/about'].items()}

{'answer_types': 6,
 'answer_uris': 954,
 'answers': 954,
 'context_templates': 6,
 'entities': 954,
 'entity_namesake_to_degree': 954,
 'entity_namesake_to_num_uris': 954,
 'entity_types': 3,
 'entity_uri_to_degree': 954,
 'entity_uri_to_predicate_degree': 954,
 'entity_uris': 954,
 'gpt_fake_entities': 954,
 'query_forms': 2}

In [None]:
{
    qid: {
        qec
    } for qid, qec in yago_qec.items()
}

In [39]:
list(yago_qec.keys())
eligible_yago_qec = dict()
for qid, qec in yago_qec.items():
    if qid in qid_to_valid_entities:
        eligible_inds = [i for i, e in enumerate(qec["entities"]) if e in qid_to_valid_entities[qid]]
        eligible_qec = {
            k: [x for i, x in enumerate(qec[k]) if i in eligible_inds] for k, v in qec.items() if len(v) == len(qec["entities"]) # Filter out the eligible indices for all keys parallel with entities (e.g. answers, etc.)
        }
        eligible_qec["context_templates"] = qec["context_templates"]
        eligible_qec["query_forms"] = qec["query_forms"]
        eligible_yago_qec[qid] = eligible_qec

In [41]:
list(eligible_yago_qec.keys())
eligible_yago_qec['http://schema.org/about']

{'answer_uris': ['http://yago-knowledge.org/resource/Morgan_Freeman',
  'http://yago-knowledge.org/resource/Gardening_generic_instance',
  'http://yago-knowledge.org/resource/Suriya',
  'http://yago-knowledge.org/resource/Nathan_Lane',
  'http://yago-knowledge.org/resource/Shilpa_Shetty',
  'http://yago-knowledge.org/resource/Norma_Talmadge',
  'http://yago-knowledge.org/resource/Faye_Dunaway',
  'http://yago-knowledge.org/resource/Donald_Duck',
  'http://yago-knowledge.org/resource/Eason_Chan',
  'http://yago-knowledge.org/resource/Paul_Anka'],
 'answers': ['Morgan Freeman',
  'gardening',
  'Suriya',
  'Nathan Lane',
  'Shilpa Shetty',
  'Norma Talmadge',
  'Faye Dunaway',
  'Donald Duck',
  'Eason Chan',
  'Paul Anka'],
 'entities': ['Morgan Freeman filmography',
  'Gardening for the Million',
  'Suriya filmography',
  'Nathan Lane on screen and stage',
  'Shilpa Shetty filmography',
  'Norma Talmadge filmography',
  'Faye Dunaway filmography',
  'Donald Duck filmography',
  'Eason 

In [42]:
correct_results_df[correct_results_df["query_id"] == 'http://schema.org/about']

Unnamed: 0,text,labels,entity,query_id,query_type,predictions,is_correct
42,'Eason Chan filmography' is about,Eason Chan,Eason Chan filmography,http://schema.org/about,open,"Eason Chan, a Hong Kong singer, songwriter, an...",True
105,'Norma Talmadge filmography' is about,Norma Talmadge,Norma Talmadge filmography,http://schema.org/about,open,"Norma Talmadge, an American actress who was ac...",True
196,'Faye Dunaway filmography' is about,Faye Dunaway,Faye Dunaway filmography,http://schema.org/about,open,"Faye Dunaway, an American actress who was born...",True
481,'Paul Anka filmography' is about,Paul Anka,Paul Anka filmography,http://schema.org/about,open,"Paul Anka, a Canadian singer, songwriter, and ...",True
585,'Donald Duck filmography' is about,Donald Duck,Donald Duck filmography,http://schema.org/about,open,"Donald Duck's appearances in various films, TV...",True
686,'Gardening for the Million' is about,gardening,Gardening for the Million,http://schema.org/about,open,"gardening for the masses, not just the wealthy...",True
762,'Morgan Freeman filmography' is about,Morgan Freeman,Morgan Freeman filmography,http://schema.org/about,open,"Morgan Freeman, an American actor, producer, a...",True
822,'Shilpa Shetty filmography' is about,Shilpa Shetty,Shilpa Shetty filmography,http://schema.org/about,open,"Shilpa Shetty's filmography, including her mov...",True
946,'Nathan Lane on screen and stage' is about,Nathan Lane,Nathan Lane on screen and stage,http://schema.org/about,open,"Nathan Lane's career as an actor, including hi...",True
948,'Suriya filmography' is about,Suriya,Suriya filmography,http://schema.org/about,open,"Suriya's filmography, which includes his movie...",True


In [45]:
import json
with open(FILTERED_RAW_DATA_PATH, "w", encoding="utf-8") as fp:
    json.dump(eligible_yago_qec, fp, ensure_ascii=False, indent=4, sort_keys=True)