In [1]:
from datasets import Dataset
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForCausalLM
import torch
import json

from ni import HFNIEstimator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def read_jsonl(file_path):
    def gen():
        with open(file_path) as f:
            for data in map(json.loads, f):
                yield data
                
    return gen

In [15]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m", 
                                          padding_side="left")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

if tokenizer.eos_token is None and tokenizer.pad_token is None:
    raise RuntimeError("No avaialble padding token")

dataset = Dataset.from_generator(read_jsonl("../../KNOWLAM/clean_datasets/politics__speeches_clean.jsonl"))
#dataset = Dataset.from_generator(read_jsonl("../../KNOWLAM/clean_datasets/math__openai_grade_school_clean.jsonl"))
print(len(dataset))
#dataset = Dataset.from_generator(read_jsonl("../../KNOWLAM/clean_datasets/math__openai_grade_school_clean.jsonl"))
#dataset = dataset.map(lambda sample: tokenizer(sample["text"]),
#                      remove_columns=["text"])

1052


In [18]:
from dataclasses import dataclass
from typing import List

@dataclass
class LMInput:
    input_ids: List[int]
    attention_mask: List[int]
    
    def __len__(self):
        return len(self.input_ids)

class MovingWindow:
    "https://stackoverflow.com/questions/64118654/best-way-to-implement-moving-window-in-python-for-loop"
    def __init__(self, tokens, window_size, step):
        self.current = -step
        self.last = len(tokens.input_ids) - window_size + 1
        self.remaining = (len(tokens.input_ids) - window_size) % step
        self.tokens = tokens
        self.window_size = window_size
        self.step = step

    def __iter__(self):
        return self

    def __next__(self):
        self.current += self.step
        if self.current < self.last:
            return LMInput(input_ids=self.tokens.input_ids[self.current : self.current + self.window_size],
                           attention_mask=self.tokens.attention_mask[self.current : self.current + self.window_size])
        elif self.remaining:
            self.remaining = 0
            return LMInput(input_ids=self.tokens.input_ids[-self.window_size:],
                           attention_mask=self.tokens.attention_mask[-self.window_size:])
        else:
            raise StopIteration
        
def sliding_window(sample):
    samples = {"input_ids": [], "attention_mask": [], "id":[]}
    
    #print(sample.keys())
    
    for i in range(len(sample["id"])):
        for j, s_sample in enumerate(MovingWindow(tokenizer(sample["text"][i]), 2048, 1024)):
            samples["input_ids"].append(s_sample.input_ids)
            samples["attention_mask"].append(s_sample.attention_mask)
            _id = sample["id"][i]
            samples["id"].append(f"{_id}_{j}")
            #samples["text"].append(sample["text"][i])

    #print(samples)
    return samples
        


In [19]:
dataset = dataset.map(sliding_window, batched=True, batch_size=8, remove_columns=["text"])

Map: 100%|██████████| 1052/1052 [00:23<00:00, 45.52 examples/s]


In [23]:
#dataset[0]["input_ids"]

In [33]:
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union


@dataclass
class ConvertToTensor:

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        
        samples = {"input_ids": [], "attention_mask": [], "id":[]}
        
        for feature in features:
            samples["id"].append(feature["id"])
            samples["input_ids"].append(feature["input_ids"])
            samples["attention_mask"].append(feature["attention_mask"])
        
        samples["input_ids"] = torch.as_tensor(samples["input_ids"])
        samples["attention_mask"] = torch.as_tensor(samples["attention_mask"])
        
        return samples


dl = torch.utils.data.DataLoader(dataset,
                                         batch_size=1, 
                                         collate_fn=ConvertToTensor(),
                                         pin_memory=True)

In [34]:
context_tokens = 68

for b_sample in dl:
    
    with torch.no_grad():
       
        b_id = b_sample.pop("id")
        logits = model(**b_sample).logits[:,context_tokens:-1,:] # skip last
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
            
        target_ids = b_sample.input_ids[:,context_tokens+1:, None].long() # skip first + context  
        log_target_probs = torch.gather(log_probs, -1, target_ids).squeeze(-1)
        
    break

<class 'dict'>


AttributeError: 'dict' object has no attribute 'input_ids'

In [26]:
b_sample

{'input_ids': [tensor([[10248,  6180,    13,  ...,  7030,   340,    13]])],
 'attention_mask': [tensor([[1, 1, 1,  ..., 1, 1, 1]])]}

In [None]:
estimator = HFNIEstimator("EleutherAI/gpt-neo-125m")

In [5]:
#for r in estimator.ni_from_generator(read_jsonl("../../clean_datasets/math__openai_grade_school_clean.jsonl")):
#    print(r)

Map: 100%|██████████| 8792/8792 [00:06<00:00, 1458.25 examples/s]
  0%|          | 3/8792 [00:00<27:50,  5.26it/s]  

{'id': tensor(0), 'information': tensor(279.0507), 'seq_len': 83}
{'id': tensor(1), 'information': tensor(288.3089), 'seq_len': 81}
{'id': tensor(2), 'information': tensor(414.1760), 'seq_len': 134}
{'id': tensor(3), 'information': tensor(439.0298), 'seq_len': 151}


  0%|          | 7/8792 [00:00<13:38, 10.74it/s]

{'id': tensor(4), 'information': tensor(288.2872), 'seq_len': 91}
{'id': tensor(5), 'information': tensor(513.2334), 'seq_len': 189}
{'id': tensor(6), 'information': tensor(313.2731), 'seq_len': 119}
{'id': tensor(7), 'information': tensor(518.6471), 'seq_len': 214}


  0%|          | 11/8792 [00:01<11:52, 12.33it/s]

{'id': tensor(8), 'information': tensor(652.6773), 'seq_len': 198}
{'id': tensor(9), 'information': tensor(874.0701), 'seq_len': 346}
{'id': tensor(10), 'information': tensor(592.0710), 'seq_len': 208}


  0%|          | 16/8792 [00:01<09:04, 16.13it/s]

{'id': tensor(11), 'information': tensor(625.5753), 'seq_len': 207}
{'id': tensor(12), 'information': tensor(363.7970), 'seq_len': 107}
{'id': tensor(13), 'information': tensor(382.7283), 'seq_len': 142}
{'id': tensor(14), 'information': tensor(274.9824), 'seq_len': 87}
{'id': tensor(15), 'information': tensor(584.6282), 'seq_len': 203}


  0%|          | 20/8792 [00:01<08:44, 16.71it/s]

{'id': tensor(16), 'information': tensor(503.1963), 'seq_len': 137}
{'id': tensor(17), 'information': tensor(664.3574), 'seq_len': 271}
{'id': tensor(18), 'information': tensor(384.5544), 'seq_len': 197}
{'id': tensor(19), 'information': tensor(444.6639), 'seq_len': 144}


  0%|          | 25/8792 [00:01<07:42, 18.94it/s]

{'id': tensor(20), 'information': tensor(407.0764), 'seq_len': 136}
{'id': tensor(21), 'information': tensor(386.0840), 'seq_len': 154}
{'id': tensor(22), 'information': tensor(375.9242), 'seq_len': 157}
{'id': tensor(23), 'information': tensor(523.7769), 'seq_len': 173}
{'id': tensor(24), 'information': tensor(440.4074), 'seq_len': 127}


  0%|          | 27/8792 [00:02<07:46, 18.77it/s]

{'id': tensor(25), 'information': tensor(506.9406), 'seq_len': 204}
{'id': tensor(26), 'information': tensor(496.5264), 'seq_len': 133}
{'id': tensor(27), 'information': tensor(694.4652), 'seq_len': 233}
{'id': tensor(28), 'information': tensor(312.5536), 'seq_len': 107}


  0%|          | 32/8792 [00:02<07:39, 19.08it/s]

{'id': tensor(29), 'information': tensor(546.5273), 'seq_len': 218}
{'id': tensor(30), 'information': tensor(481.1530), 'seq_len': 173}
{'id': tensor(31), 'information': tensor(474.1942), 'seq_len': 168}
{'id': tensor(32), 'information': tensor(493.3712), 'seq_len': 181}


  0%|          | 37/8792 [00:02<06:50, 21.35it/s]

{'id': tensor(33), 'information': tensor(430.3057), 'seq_len': 178}
{'id': tensor(34), 'information': tensor(310.5411), 'seq_len': 100}
{'id': tensor(35), 'information': tensor(352.9236), 'seq_len': 106}
{'id': tensor(36), 'information': tensor(310.5780), 'seq_len': 116}
{'id': tensor(37), 'information': tensor(505.1980), 'seq_len': 241}


  0%|          | 40/8792 [00:02<06:59, 20.84it/s]

{'id': tensor(38), 'information': tensor(395.6297), 'seq_len': 111}
{'id': tensor(39), 'information': tensor(483.4924), 'seq_len': 162}
{'id': tensor(40), 'information': tensor(580.9738), 'seq_len': 195}
{'id': tensor(41), 'information': tensor(410.3804), 'seq_len': 146}


  1%|          | 46/8792 [00:02<07:15, 20.06it/s]

{'id': tensor(42), 'information': tensor(532.9274), 'seq_len': 171}
{'id': tensor(43), 'information': tensor(380.1443), 'seq_len': 155}
{'id': tensor(44), 'information': tensor(410.1493), 'seq_len': 158}
{'id': tensor(45), 'information': tensor(450.1256), 'seq_len': 192}
{'id': tensor(46), 'information': tensor(458.4406), 'seq_len': 181}


  1%|          | 52/8792 [00:03<06:49, 21.33it/s]

{'id': tensor(47), 'information': tensor(357.3480), 'seq_len': 129}
{'id': tensor(48), 'information': tensor(291.4991), 'seq_len': 85}
{'id': tensor(49), 'information': tensor(330.7055), 'seq_len': 90}
{'id': tensor(50), 'information': tensor(524.4048), 'seq_len': 160}
{'id': tensor(51), 'information': tensor(472.6688), 'seq_len': 187}


  1%|          | 55/8792 [00:03<06:45, 21.54it/s]

{'id': tensor(52), 'information': tensor(379.1021), 'seq_len': 125}
{'id': tensor(53), 'information': tensor(322.9339), 'seq_len': 130}
{'id': tensor(54), 'information': tensor(547.3682), 'seq_len': 178}
{'id': tensor(55), 'information': tensor(292.4805), 'seq_len': 105}
{'id': tensor(56), 'information': tensor(371.2491), 'seq_len': 165}


  1%|          | 61/8792 [00:03<07:08, 20.37it/s]

{'id': tensor(57), 'information': tensor(416.1475), 'seq_len': 211}
{'id': tensor(58), 'information': tensor(583.3618), 'seq_len': 226}
{'id': tensor(59), 'information': tensor(487.8830), 'seq_len': 169}
{'id': tensor(60), 'information': tensor(487.3920), 'seq_len': 188}


  1%|          | 64/8792 [00:03<07:09, 20.32it/s]

{'id': tensor(61), 'information': tensor(618.9885), 'seq_len': 211}
{'id': tensor(62), 'information': tensor(386.2302), 'seq_len': 122}
{'id': tensor(63), 'information': tensor(398.8088), 'seq_len': 159}
{'id': tensor(64), 'information': tensor(510.1176), 'seq_len': 168}
{'id': tensor(65), 'information': tensor(399.2491), 'seq_len': 101}


  1%|          | 70/8792 [00:04<07:21, 19.75it/s]

{'id': tensor(66), 'information': tensor(687.6342), 'seq_len': 243}
{'id': tensor(67), 'information': tensor(682.3404), 'seq_len': 274}
{'id': tensor(68), 'information': tensor(548.3570), 'seq_len': 189}
{'id': tensor(69), 'information': tensor(345.1340), 'seq_len': 98}
{'id': tensor(70), 'information': tensor(250.2471), 'seq_len': 69}


  1%|          | 73/8792 [00:04<07:17, 19.95it/s]

{'id': tensor(71), 'information': tensor(524.5659), 'seq_len': 149}
{'id': tensor(72), 'information': tensor(608.3864), 'seq_len': 226}
{'id': tensor(73), 'information': tensor(439.6553), 'seq_len': 135}
{'id': tensor(74), 'information': tensor(326.5710), 'seq_len': 116}


  1%|          | 79/8792 [00:04<06:42, 21.67it/s]

{'id': tensor(75), 'information': tensor(367.9556), 'seq_len': 178}
{'id': tensor(76), 'information': tensor(368.6167), 'seq_len': 103}
{'id': tensor(77), 'information': tensor(308.9900), 'seq_len': 91}
{'id': tensor(78), 'information': tensor(435.8450), 'seq_len': 125}
{'id': tensor(79), 'information': tensor(279.6101), 'seq_len': 69}
{'id': tensor(80), 'information': tensor(466.7167), 'seq_len': 146}


  1%|          | 85/8792 [00:04<06:53, 21.03it/s]

{'id': tensor(81), 'information': tensor(526.5134), 'seq_len': 231}
{'id': tensor(82), 'information': tensor(433.8779), 'seq_len': 132}
{'id': tensor(83), 'information': tensor(558.9106), 'seq_len': 210}
{'id': tensor(84), 'information': tensor(408.6118), 'seq_len': 111}
{'id': tensor(85), 'information': tensor(341.5211), 'seq_len': 109}


  1%|          | 90/8792 [00:05<08:10, 17.76it/s]

{'id': tensor(86), 'information': tensor(306.8661), 'seq_len': 93}
{'id': tensor(87), 'information': tensor(390.4664), 'seq_len': 110}
{'id': tensor(88), 'information': tensor(283.0951), 'seq_len': 77}
{'id': tensor(89), 'information': tensor(499.7301), 'seq_len': 158}





KeyboardInterrupt: 

In [81]:
{v:k for k,v in tokenizer.vocab.items()}[262]

'Ġthe'