In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer
import grain

  from .autonotebook import tqdm as notebook_tqdm


In [26]:
import grain.python
import jax
from jax import numpy as jnp
import grain

def get_dummy_dataset(max_length: int):
    batch = (
        jax.random.randint(
            jax.random.PRNGKey(0),
            (max_length,),
            0,
            max_length,
        ),
        jax.random.randint(
            jax.random.PRNGKey(0),
            (max_length,),
            0,
            max_length,
        )
    )

    def generator():
        while True:
            yield batch

    return None, grain.python.IterDataset(generator())

In [27]:
_, ds = get_dummy_dataset(max_length=4096)

TypeError: Can't instantiate abstract class IterDataset with abstract method __iter__

In [17]:
next(iter(ds))[0]

TypeError: object of type 'generator' has no len()

In [None]:
hf_ds = load_dataset("allenai/c4", "realnewslike", num_proc=4, split="train")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [None]:
parent_ds = (
    grain.MapDataset.source(hf_ds)
    .map(lambda x: {"tokens": tokenizer.encode(x["text"], return_tensors="np")[0]})
)

ds = grain.experimental.ConcatThenSplitIterDataset(
    parent=parent_ds,
    length_struct={"tokens": 1024},
)

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
import grain
from typing import List
import numpy as np
def get_hf_dataset(
        hf_name: List[str],
        tokenizer_name: str,
        max_length: int,
        num_proc: int = 4,
        split: str = "train",
):
    hf_ds = load_dataset(*hf_name, split=split, num_proc=num_proc)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    parent_ds = (
        grain.MapDataset.source(hf_ds)
        .map(lambda x: {"tokens": tokenizer.encode(x["text"], return_tensors="np")[0]})
        
    )

    ds = grain.experimental.ConcatThenSplitIterDataset(
        parent=parent_ds,
        length_struct={"tokens": max_length+1},
    )

    ds = ds.map(lambda x: (x['tokens'][:-1], x['tokens'][1:]))

    return ds

In [None]:
ds = get_hf_dataset(
    hf_name=["allenai/c4", "realnewslike"],
    tokenizer_name="gpt2",
    max_length=1024,
    num_proc=4,
    split="train",
)

ds = ds.batch(2)

for x, y in ds:
    print(x[0, :10])

In [5]:
def pretty_big(n, decimals: int = 2) -> str:
    """
    Format large numbers using M (millions), B (billions), and T (trillions).
    - Only abbreviates when abs(n) >= 10_000_000.
    - 10M<1B  -> M
    - 1B<1T   -> B
    - >=1T     -> T
    - Below 10M: use thousands separators.

    Examples:
      pretty_big(3_450_000)         -> "3.45M"
      pretty_big(250_000_000)       -> "250M"
      pretty_big(1_000_000_000)     -> "1B"
      pretty_big(3_450_000_000)     -> "3.45B"
      pretty_big(1_200_000_000_000) -> "1.2T"
      pretty_big(7_500_000)         -> "7,500,000"
    """
    abs_n = abs(n)

    # Below 10M: plain formatting with separators
    if abs_n < 10_000_000:
        if float(n).is_integer():
            return f"{int(n):,}"
        return f"{n:,.{decimals}f}".rstrip('0').rstrip('.')

    # Choose scale & suffix
    if abs_n < 1_000_000_000:
        value, suffix = n / 1_000_000, "M"
    elif abs_n < 1_000_000_000_000:
        value, suffix = n / 1_000_000_000, "B"
    else:
        value, suffix = n / 1_000_000_000_000, "T"

    # Round and trim trailing zeros
    s = f"{round(value, decimals):.{decimals}f}".rstrip('0').rstrip('.')
    if s in {"-0", "-0.", "-0.0"}:
        s = "0"
    return f"{s}{suffix}"

In [7]:
for i in range(0, 1_000_000_00, 400_000):
    print(pretty_big(i))

0
400,000
800,000
1,200,000
1,600,000
2,000,000
2,400,000
2,800,000
3,200,000
3,600,000
4,000,000
4,400,000
4,800,000
5,200,000
5,600,000
6,000,000
6,400,000
6,800,000
7,200,000
7,600,000
8,000,000
8,400,000
8,800,000
9,200,000
9,600,000
10M
10.4M
10.8M
11.2M
11.6M
12M
12.4M
12.8M
13.2M
13.6M
14M
14.4M
14.8M
15.2M
15.6M
16M
16.4M
16.8M
17.2M
17.6M
18M
18.4M
18.8M
19.2M
19.6M
20M
20.4M
20.8M
21.2M
21.6M
22M
22.4M
22.8M
23.2M
23.6M
24M
24.4M
24.8M
25.2M
25.6M
26M
26.4M
26.8M
27.2M
27.6M
28M
28.4M
28.8M
29.2M
29.6M
30M
30.4M
30.8M
31.2M
31.6M
32M
32.4M
32.8M
33.2M
33.6M
34M
34.4M
34.8M
35.2M
35.6M
36M
36.4M
36.8M
37.2M
37.6M
38M
38.4M
38.8M
39.2M
39.6M
40M
40.4M
40.8M
41.2M
41.6M
42M
42.4M
42.8M
43.2M
43.6M
44M
44.4M
44.8M
45.2M
45.6M
46M
46.4M
46.8M
47.2M
47.6M
48M
48.4M
48.8M
49.2M
49.6M
50M
50.4M
50.8M
51.2M
51.6M
52M
52.4M
52.8M
53.2M
53.6M
54M
54.4M
54.8M
55.2M
55.6M
56M
56.4M
56.8M
57.2M
57.6M
58M
58.4M
58.8M
59.2M
59.6M
60M
60.4M
60.8M
61.2M
61.6M
62M
62.4M
62.8M
63.2M
63.6M
64M
64