In [None]:
from init_notebook import *

In [None]:
from experiments.datasets.teletext import *

list(zip(TeletextIterableDataset(), range(2)))

In [None]:
for (text, meta), _ in zip(TeletextIterableDataset(), range(4)):
    print(meta)
    print(text)

In [None]:
from PIL import Image, ImageFont, ImageDraw

class TeletextPixelIterableDataset(TeletextIterableDataset):

    def __init__(
            self,
            filename: Union[str, Path] = DATASET_PATH / "teletext.ndjson.gz",
            total: Optional[int] = 15_505_000,
            font_shape: Tuple[int, int] = (8, 8),
            font_file: Union[str, Path] = Path("~/.local/share/fonts/unscii-8.ttf").expanduser(),
    ):
        super().__init__(filename=filename, total=total)
        self._font_shape = font_shape
        self._font_map = {}
        self._font = ImageFont.truetype(str(font_file), min(self._font_shape))
        
    def __iter__(self):
        for text, meta in super().__iter__():
            meta["text"] = text
            yield self._render_text(text), meta

    def _render_text(self, text: str):
        lines = text.splitlines()

        image = torch.zeros(1, self._font_shape[0] * 25, self._font_shape[1] * 40)
        for y, line in enumerate(lines):
            if y < 25:
                for x, ch in enumerate(line):
                    if x < 40:
                        if ch not in self._font_map:
                            self._font_map[ch] = self._render_font(ch)
                        image[:, y * self._font_shape[0]: (y + 1) * self._font_shape[0], x * self._font_shape[1]: (x + 1) * self._font_shape[1]] = self._font_map[ch]

        return image

    def _render_font(self, ch: str):
        image = PIL.Image.new("L", (self._font_shape[1], self._font_shape[0]))
        draw = ImageDraw.ImageDraw(image)
        draw.text(
            (0, 0), ch,
            font=self._font,
            fill=(255,),
        )
        return VF.to_tensor(image)
        
ds = TeletextPixelIterableDataset()

images = []
for (image, meta), _ in TeletextPixelIterableDataset().shuffle(1000), range(16)):    
    # print(meta)
    #display(VF.to_pil_image(image))
    images.append(image)    

VF.to_pil_image(make_grid(images, nrow=4))

In [None]:
from PIL import Image, ImageFont, ImageDraw


class TextToPixelIterableDataset(BaseIterableDataset):

    def __init__(
            self,
            dataset: Union[Dataset, IterableDataset],
            screen_size: Tuple[int, int] = (24, 40),
            font_shape: Tuple[int, int] = (8, 8),
            font_file: Union[str, Path] = Path("~/.local/share/fonts/unscii-8.ttf").expanduser(),
    ):
        super().__init__()
        self._dataset = dataset
        self._screen_size = screen_size
        self._font_shape = font_shape
        self._font_map = {}
        self._font = ImageFont.truetype(str(font_file), min(self._font_shape))
        
    def __iter__(self):
        for item in self._dataset:
            is_tuple = isinstance(item, (tuple, list))
            if is_tuple:
                text = item[0]
                rest_args = item
            else:
                text = item
                rest_args = [item]

            image = self._render_text(text)

            yield image, *rest_args

    def _render_text(self, text: str):
        lines = text.splitlines()
        empty_line = " " * self._screen_size[-1]
        
        if 0:
            image = torch.zeros(1, self._font_shape[0] * self._screen_size[0], self._font_shape[1] * self._screen_size[1])
            for y, line in enumerate(lines):
                if y < self._screen_size[0]:
                    for x, ch in enumerate(line):
                        if x < self._screen_size[1]:
                            if ch not in self._font_map:
                                self._font_map[ch] = self._render_font(ch)
                            image[:, y * self._font_shape[0]: (y + 1) * self._font_shape[0], x * self._font_shape[1]: (x + 1) * self._font_shape[1]] = self._font_map[ch]
            return image
        else:
            font_lines = []
            for y in range(self._screen_size[-2]):
                if y < len(lines):
                    line = lines[y]
                else:
                    line = empty_line
                font_line = []
                for x in range(self._screen_size[-1]):
                    if x < len(line):
                        ch = line[x]
                    else:
                        ch = " "
                    if ch not in self._font_map:
                        self._font_map[ch] = self._render_font(ch)
                            
                    font_line.append(self._font_map[ch])

                font_lines.append( torch.concat(font_line, dim=-1) )
            
            return torch.concat(font_lines, dim=-2)

    def _render_font(self, ch: str):
        image = PIL.Image.new("L", (self._font_shape[1], self._font_shape[0]))
        draw = ImageDraw.ImageDraw(image)
        draw.text(
            (0, 0), ch,
            font=self._font,
            fill=(255,),
        )
        return VF.to_tensor(image)

if 1:
    for _ in tqdm(TextToPixelIterableDataset(TeletextIterableDataset())):
        pass
else:
    ds = TextToPixelIterableDataset(
        TeletextIterableDataset().shuffle(10000)
    )
    # print(next(iter(ds)))
    images = []
    for (image, text, meta), _ in zip(ds, range(16)):    
        #print(image.shape)
        # print(meta)
        #display(VF.to_pil_image(image))
        images.append(image)    
    
    display(VF.to_pil_image(make_grid(images, nrow=4, pad_value=.7)))

In [None]:
class TeletextPixelIterableDataset(TextToPixelIterableDataset):

    def __init__(
            self,
            filename: Union[str, Path] = DATASET_PATH / "teletext.ndjson.gz",
            total: Optional[int] = 15_505_000,
            font_shape: Tuple[int, int] = (8, 8),
            font_file: Union[str, Path] = Path("~/.local/share/fonts/unscii-8.ttf").expanduser(),
    ):
        super().__init__(
            dataset=TeletextIterableDataset(filename=filename, total=total),
            screen_size=(24, 40),
            font_shape=font_shape,
            font_file=font_file,
        )

    def __len__(self):
        return len(self._dataset)
        
for image, text, meta in TeletextPixelIterableDataset():
    display(VF.to_pil_image(image))
    print(image.shape)
    print(meta)
    print(text)
    break