In [None]:
from init_notebook import *
from src.train.experiment import load_experiment_trainer

In [None]:
trainer = load_experiment_trainer("../experiments/minimind/fefe2.yml", device="cpu")

In [None]:
def generate(input_text: str):
    tokenizer = trainer._tokenizer
    input_tokens = tokenizer(input_text, return_tensors="pt")

    print(input_text, end="")
    num_printed = 0
    for tokens in trainer.model.generate(
        input_tokens.input_ids, 
        eos_token_id=tokenizer.eos_token_id,
        stream=True,
        temperature=.7,
    ):
        tokens = tokens[0, num_printed:]
        print(tokenizer.decode(tokens).replace("⬇", " ").replace("⬅", "\n"), end="")
        num_printed += tokens.shape[0]
        #print(token)

generate(
    #"!!1!",
    "Die Erklärung lautet: ",
    #"Faschisten",
    #"Cinderella hat ein U-Boot gesprengt"
)

In [None]:
from experiments.minimind.tokenizedataset import TokenizeDataset
from src.datasets import FefePostIterableDataset
from typing import Sequence
dataset = FefePostIterableDataset().freeze() 

class TokenizeDataset(BaseIterableDataset):

    def __init__(
            self,
            texts: Sequence[str],
            tokenizer,
            max_seq_length: int,
            min_seq_length: int = None,
            batch_size: int = None,
            method: str = "concat",
    ):
        self._texts = texts
        self._tokenizer = tokenizer
        self._min_seq_length = min_seq_length
        self._max_seq_length = max_seq_length
        self._batch_size = batch_size
        self._method = method

#    def __len__(self):
#        return len(self._texts)

    def __iter__(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        #from tqdm import tqdm
        if self._method == "truncate":
            iterable = self._iter_truncated()
        elif self._method == "fragments":
            iterable = self._iter_fragments()
        elif self._method == "concat":
            iterable = self._iter_concat()
        else:
            raise ValueError(f"Unknown method `{self._method}`")
            
        for input_ids in iterable:
            if input_ids.shape[0]:
                loss_mask = (input_ids != self._tokenizer.pad_token_id)
    
                X = input_ids[:-1]
                Y = input_ids[1:]
                loss_mask = loss_mask[1:]
                yield X, Y, loss_mask

    def _iter_truncated(self):
        seq_length = self._max_seq_length
        for i, text in enumerate(self._texts):

            if self._min_seq_length:
                assert self._batch_size, "Must define `batch_size` when defining `min_seq_length`"
                if i % self._batch_size == 0:
                    seq_length = random.randint(self._min_seq_length, self._max_seq_length)

            encoding = self._tokenizer(
                text,
                max_length=seq_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            yield encoding.input_ids.squeeze()

    def _iter_fragments(self):
        seq_length = self._max_seq_length
        count = 0
        for text in self._texts:
            encoding = self._tokenizer(
                text,
                return_tensors='pt'
            )
            input_ids = encoding.input_ids.squeeze()
            if input_ids.shape[0] == 0:
                continue

            while True:
                if self._min_seq_length is not None:
                    assert self._batch_size, "Must define `batch_size` when defining `min_seq_length`"
                    if count % self._batch_size == 0:
                        seq_length = random.randint(self._min_seq_length, self._max_seq_length)
                count += 1

                if input_ids.shape[0] == seq_length:
                    yield input_ids
                    break

                elif input_ids.shape[0] < seq_length:
                    yield torch.cat([
                        torch.ones((seq_length - input_ids.shape[0], ), dtype=input_ids.dtype) * self._tokenizer.pad_token_id,
                        input_ids
                    ])
                    break

                else:
                    yield input_ids[:seq_length]
                    input_ids = input_ids[seq_length // 2:]

    def _iter_concat(self):
        seq_length = self._max_seq_length
        count = 0
        current_ids = None
        for text in self._texts:
            encoding = self._tokenizer(
                text,
                return_tensors='pt'
            )
            input_ids = encoding.input_ids.squeeze()
            if input_ids.shape[0] == 0:
                continue

            while True:
                if self._min_seq_length is not None:
                    assert self._batch_size, "Must define `batch_size` when defining `min_seq_length`"
                    if count % self._batch_size == 0:
                        seq_length = random.randint(self._min_seq_length, self._max_seq_length)

                if input_ids.shape[0] == seq_length:
                    yield input_ids
                    count += 1
                    break

                elif input_ids.shape[0] < seq_length:
                    if current_ids is None:
                        current_ids = input_ids
                    else:
                        current_ids = torch.cat([
                            current_ids,
                            torch.ones((1, ), dtype=input_ids.dtype) * self._tokenizer.sep_token_id,
                            input_ids
                        ])
                    
                    if current_ids.shape[0] >= seq_length:
                        yield current_ids[:seq_length]
                        count += 1
                        current_ids = current_ids[seq_length:]
                    break
                
                else:
                    yield input_ids[:seq_length]
                    count += 1
                    input_ids = input_ids[seq_length // 2:]

ds = TokenizeDataset(
    dataset.skip(1000),
    trainer._tokenizer,
    min_seq_length=256,
    max_seq_length=256,
    batch_size=16,
    method="concat",
)

In [None]:
for i, (X, Y, _) in enumerate(tqdm(ds)):
    if i % 16 == 0:
        bs = X.shape[0]
    else:
        assert bs == X.shape[0], f"@{i} {bs} != {X.shape[0]}"

In [None]:
for i, (X, Y, _) in enumerate(tqdm(ds)):
    print()
    print(trainer._tokenizer.decode(X))

In [None]:
with tqdm(DataLoader(ds, batch_size=16, num_workers=1)) as progress:
    for i, _ in enumerate(progress):
        progress.set_postfix({"step": i * 16}) 
