In [1]:
import numpy as np
import ray
from tokenizers import Tokenizer
from functools import partial

In [2]:
dataset = ray.data.read_numpy('data/train', parallelism=1000)


In [3]:

class TokenToProbsProcessor:

    def __init__(self, rng, concentration: float, vocab_size: int):
        self.rng = rng
        self.concentration = concentration
        self.vocab_size = vocab_size

    def __call__(self, tokens):
        batch, seq_len, = tokens.shape

        def _tokens_to_probs(token_ids):
            x = self.rng.random((seq_len, self.vocab_size)) / self.vocab_size
            # At this point E(x.sum()) == 0.5 
            # What we want is for new_val / (x.sum() + new_val) ~ concentration
            # --> new_val == (concentration * x.sum())/(1 - concentration)
            # Then, in the normalized vector, the appropriate token will have ~ concentration weight,
            # and the others will have the rest
            x_sum = x.sum(axis=1)
            conc_val = np.mean((self.concentration * x_sum) / (1 - self.concentration))
            np.put_along_axis(x, token_ids[:, None], conc_val, axis=1)
            return x / x.sum(axis=1)[:, None]
            
        return np.apply_along_axis(_tokens_to_probs, axis=1, arr=tokens)
    
    
t2p = TokenToProbsProcessor(np.random.default_rng(332), 0.85, 8192)

        

In [6]:
pipe = dataset.window(blocks_per_window=1).map_batches(t2p, batch_size=16).iter_rows()

2022-06-18 23:41:33,998	INFO dataset.py:2969 -- Created DatasetPipeline with 1000 windows: 13.98MiB min, 194.57MiB max, 39.06MiB mean


In [7]:
next(pipe)

Stage 0:   0%|                                                                                                                   | 0/1000 [00:00<?, ?it/s]
  0%|                                                                                                                            | 0/1000 [00:00<?, ?it/s][A
Stage 2:   0%|                                                                                                                   | 0/1000 [01:20<?, ?it/s]
Stage 1:   0%|                                                                                                        | 1/1000 [01:20<22:16:41, 80.28s/it]
Stage 0:   0%|▏                                                                                                       | 2/1000 [01:20<11:07:42, 40.14s/it]


KeyboardInterrupt: 

2022-06-18 23:45:35,995	ERROR worker.py:94 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information.
