In [None]:
from init_notebook import *

In [None]:
from src.models.mamba.mamba import Mamba, ModelArgs

In [None]:
m = Mamba(ModelArgs(d_model=100, n_layer=4, vocab_size=256))
print(f"param: {num_module_parameters(m):,}")
m

In [None]:
m(torch.ones(1, 100, dtype=torch.int64)).shape

In [None]:
def _apply_mask(texts: torch.Tensor) -> torch.Tensor:
    """
    :param text: tensor of shape [B, L]
    :return: masked tensor
    """
    B, L = texts.shape
    size = max(1, int(L*.3)) #int(self._mask_ratio * L))
    indices = torch.randint(0, L - size, (B, 1)).to(texts.device)
    coords = torch.arange(0, L).unsqueeze(0).repeat(B, 1)
    mask = (coords < indices) | (coords >= indices + size) 
    return texts * mask
    
texts = torch.randint(0, 10, (3, 10))
print(texts)
masked = _apply_mask(texts)
print(masked)

In [None]:
def _encoded_to_logits(texts: torch.Tensor) -> torch.Tensor:
    """
    :param texts: tensor of [B, L]
    :return: tensor of [B, L, 256]
    """
    B, L = texts.shape
    logits = torch.zeros((B, L, 10)).to(texts)
    logits.scatter_(-1, texts.unsqueeze(-1).to(torch.int64), 1)
    return logits

print(texts)
_encoded_to_logits(texts)

In [None]:
class TextSegmentIterableDataset(BaseIterableDataset):

    def __init__(
            self,
            *text: str,
            size: int,
            stride: Union[None, int, str] = None,  # int or "random"
            encode: Optional[str] = None,
    ):
        self._texts = text
        self._size = size
        self._stride = stride
        self._encode = encode
        
    def __iter__(self):
        stride = self._stride
        if stride is None:
            stride = self._size
            
        for text in self._texts:
            pos = 0
            while pos < len(text):
                segment = text[pos: pos + self._size]

                if stride == "random":
                    pos += random.randrange(self._size)
                elif isinstance(stride, int):
                    pos += stride
                else:
                    raise NotImplementedError(f"Invalid stride '{self._stride}'")
                
                if self._encode is None:
                    yield segment
                elif self._encode == "bytes":
                    yield torch.tensor(list(segment.encode()), dtype=torch.uint8)
                else:
                    raise NotImplementedError(f"Invalid encode '{self._encode}'")

ds = TextSegmentIterableDataset(
    Path("/home/bergi/text/der_bandwurm.txt").read_text(),
    size=10, stride=1,
).shuffle(1000)
print(len(list(ds)))
for i, seg in enumerate(ds):
    print(f"{i:3}:", repr(seg))


In [None]:
next(iter(DataLoader(ds, batch_size=10)))

In [None]:
import fnmatch
import glob

class FilenameDataset(BaseDataset):
    
    def __init__(
            self,
            root: Union[str, Path],
            include: Union[None, str, Iterable[str]] = None,
            exclude: Union[None, str, Iterable[str]] = None,
            recursive: bool = False,
            max_files: Optional[int] = None,
    ):
        super().__init__()
        self.root = Path(root).expanduser()
        if include is None:
            self.include = None
        elif isinstance(include, str):
            self.include = [include]
        else:
            self.include = list(include)
        if exclude is None:
            self.exclude = None
        elif isinstance(exclude, str):
            self.exclude = [exclude]
        else:
            self.exclude = list(exclude)
            
        self.recursive = recursive
        self.max_files = max_files
        self._filenames = None

    def __len__(self):
        self._get_filenames()
        return len(self._filenames)

    def __getitem__(self, i):
        self._get_filenames()
        return self._filenames[i]

    def _is_valid(self, filename: str) -> bool:
        if self.include:
            for pattern in self.include:
                if not fnmatch.fnmatch(filename, pattern):
                    return False
            
        if self.exclude:
            for pattern in self.exclude:
                if fnmatch.fnmatch(filename, pattern):
                    return False
        return True
        
    def _get_filenames(self):
        if self._filenames is None:

            if self.root.is_file():
                self._filenames = [str(self.root)]

            else:
                glob_path = self.root
                if self.recursive:
                    glob_path /= "**/*"
                else:
                    glob_path /= "*"

                self._filenames = []
                for filename in glob.glob(str(glob_path), recursive=self.recursive):
                    if self._is_valid(filename):
                        self._filenames.append(filename)
                        if self.max_files and len(self._filenames) >= self.max_files:
                            break

                self._filenames.sort()

dsf = FilenameDataset(
    "../../billion-bubbles/",
    include="*.py",
    exclude="*/env/*",
    recursive=True,
)
for f in dsf:
    print(f)

In [None]:
class FileTextSegmentIterableDataset(BaseIterableDataset):
    """
    Base dataset must provide filenames
    """
    def __init__(
            self,
            dataset: Union[Dataset, IterableDataset],
            size: int,
            stride: Union[None, int, str] = None,  # int or "random"
            encode: Optional[str] = None,
    ):
        self._dataset = dataset
        self._size = size
        self._stride = stride
        self._encode = encode
        
    def __iter__(self):
        stride = self._stride
        if stride is None:
            stride = self._size

        for file in self._dataset:
            try:
                text = Path(file).read_text()
            except UnicodeDecodeError:
                continue

            pos = 0
            while pos < len(text):
                segment = text[pos: pos + self._size]

                if stride == "random":
                    pos += random.randrange(self._size)
                elif isinstance(stride, int):
                    pos += stride
                else:
                    raise NotImplementedError(f"Invalid stride '{self._stride}'")
                
                if self._encode is None:
                    yield segment
                elif self._encode == "bytes":
                    yield torch.tensor(list(segment.encode()), dtype=torch.uint8)
                else:
                    raise NotImplementedError(f"Invalid encode '{self._encode}'")

ds = FileTextSegmentIterableDataset(
    FilenameDataset(
        #"~/prog/python/github",
        "/home/bergi/prog/python/botgard/BotGard3/",
        include="*.py",
        exclude=["*/env/*", "*/node_modules/*", "*/site-packages/*"],
        recursive=True,
    ),
    size=50, #stride=1,
)#.shuffle(1000)

print(len(list(ds)))
for i, seg in zip(range(100), ds):
    print(f"{i:3}:", repr(seg))


In [None]:
class TextMathFormula(IterableDataset):
    
    def __init__(
            self,
            size: int,
            num_operands: int = 1,
            max_number: int = 10,
            operators: Iterable[str] = ("+",),
            sep: str = " ",
            seed: Optional[int] = None,
    ):
        super().__init__()
        self._size = size
        self._num_operands = num_operands
        self._max_number = max_number
        self._operators = list(operators)
        self._sep = sep
        self._seed = seed
        
    def __len__(self) -> int:
        return self._size

    def __iter__(self) -> Generator[str, None, None]:
        if self._seed is None:
            rng = random
        else:
            rng = random.Random(self._seed)
        
        for i in range(self._size):
            seq = [str(rng.randint(0, self._max_number))]
            for j in range(self._num_operands):
                seq.append(
                    rng.choice(self._operators)
                )
                seq.append(
                    str(rng.randint(0, self._max_number))
                )
            
            expression = self._sep.join(seq)
            result = str(eval(expression))
            yield self._sep.join([expression, "=", result])

ds = TextMathFormula(20, sep=" ")
for i, seg in zip(range(100), ds):
    print(f"{i:3}:", repr(seg))

In [None]:
import datetime
from src.util.gharchive import GHArchive
#gha = GHArchive(verbose=True)
#for commit in gha.iter_commits(datetime.date(2024, 12, 13), hours=16):
#    print(commit["message"])

class TextGithubEventIterableDataset(BaseIterableDataset):

    def __init__(
            self,
            dt=datetime.datetime(2024, 12, 13, 16),
            type: Union[str, Iterable[str]] = ("commit", "comment"),
            min_text_length: Optional[int] = None,
            fixed_width: Optional[int] = None,
            stride: Union[None, int, str] = None,
            verbose: bool = True,
    ):
        super().__init__()
        self._dt = dt
        self._gha = GHArchive(verbose=verbose)
        self._type = [type] if isinstance(type, str) else set(type)
        self._min_text_length = min_text_length
        self._fixed_width = fixed_width
        self._stride = stride

    def __iter__(self) -> Generator[str, None, None]:
        for text in self._iter_texts():
            if self._min_text_length and len(text) < self._min_text_length:
                continue

            if not self._fixed_width:
                yield text
            else:
                while text:
                    if self._min_text_length and len(text) < self._min_text_length:
                        break

                    yielded_text = text
                    if self._fixed_width:
                        yielded_text = yielded_text.ljust(self._fixed_width,)

                    yield yielded_text[:self._fixed_width]

                    stride = self._stride
                    if stride is None:
                        stride = self._fixed_width
                    elif stride == "random":
                        stride = random.randrange(1, self._fixed_width)

                    text = text[stride:]

    def _iter_texts(self):
        #shown = set()
        for event in self._gha.iter_events(
                day=self._dt.date(),
                hours=self._dt.hour,
        ):
            #if event["type"] not in shown:
            #    json.dumps(event, indent=2)
            #    shown.add(event["type"])

            if event["type"] == "PushEvent":
                if "commit" in self._type:
                    if event.get("payload") and event["payload"].get("commits"):
                        for commit in event["payload"]["commits"]:
                            if commit.get("message"):
                                pass#yield commit["message"]

            elif event["type"] == "IssueCommentEvent":
                if "comment" in self._type:
                    if event.get("payload") and event["payload"].get("comment") and event["payload"]["comment"].get("body"):
                        yield event["payload"]["comment"]["body"]
                #print(json.dumps(event,indent=2))
            
ds = TextGithubEventIterableDataset(
    fixed_width=20,
    stride="random",
)
#print(len(list(tqdm(ds))))
for i, seg in zip(range(100), ds):
    print(f"{i:3}:", repr(seg))

In [None]:
from pathlib import Path
from typing import Union, Tuple

import torch
import torchvision.transforms.functional as VF
import PIL.Image, PIL.ImageFont, PIL.ImageDraw


class FontSquares:
    def __init__(
            self,
            file: Union[str, Path] = Path("~/.local/share/fonts/unscii-8.ttf").expanduser(),
            shape: Tuple[int, int, int] = (1, 8, 8),
            center: bool = False,
    ):
        """
        Generator for monospaced images from text

        :param file: file to use as font
        :param shape: tuple of (C, H, W), square channels & size, C must be 1 or 3
        :param center: bool, center all fonts into their squares
        """
        assert len(shape) == 3, f"Expected [C, H, W] shape, got {shape}"
        assert shape[0] in (1, 3), f"Expected 1 or 3 channels, got {shape}"

        self.shape = shape
        self.center = center
        self.font = PIL.ImageFont.truetype(
            str(file),
            size=min(self.shape[-2:]),
        )
        self._font_map = {}

    def __call__(self, ch: Union[str, int], dim: int = 2) -> torch.Tensor:
        """
        Convert text to image

        :param ch: int (ordinal) or str (single character or text)
        :param dim: int, dimension for concatenation/stacking
        :return: Tensor of shape
            [C, H, W] if single character
            [C, H, W * N] if dim == 2
            [C, H * N, N] if dim == 1
            [C * N, H, W] if dim == 0
            [N, C, H, W] if dim == -1
        """
        if isinstance(ch, str):
            if len(ch) == 0:
                ch = 32
            elif len(ch) == 1:
                ch = ord(ch)
            else:
                squares = [self(c) for c in ch]
                if dim == -1:
                    return torch.stack(squares)
                else:
                    return torch.cat(squares, dim)

        ch = max(ch, 32) if ch != 0 else 0

        if ch not in self._font_map:
            if ch == 0:
                self._font_map[ch] = torch.ones(self.shape)
            else:
                image = PIL.Image.new(
                    "RGB" if self.shape[0] == 3 else "L",
                    (self.shape[-1], self.shape[-2]),
                )
                draw = PIL.ImageDraw.ImageDraw(image)
    
                if self.center:
                    L, T, R, B = draw.textbbox((0, 0), chr(ch), font=self.font)
                    xy = (
                        (self.shape[-1] - (R - L)) // 2,
                        -T + (self.shape[-2] - (B - T)) // 2,
                    )
                else:
                    xy = (0, 0)
    
                draw.text(
                    xy,
                    chr(ch),
                    font=self.font,
                    fill=(255,) * self.shape[0],
                )
                self._font_map[ch] = VF.to_tensor(image)

        return self._font_map[ch]

    def reverse(self, image: torch.Tensor, dim: int = 2) -> str:
        """
        Convert image back to text by best-match.

        :param image: Tensor of shape
            [C, H, W * N] if dim == 2
            [C, H * N, N] if dim == 1
            [C * N, H, W] if dim == 0
            [N, C, H, W] if dim == -1
        :param dim: int, dimension where image is concatenated/stacked
        :return: str
        """
        if dim == -1:
            assert image.ndim == 4, f"Expected 4 dimensions, got {image.shape}"
            assert image.shape[1:] == self.shape, f"Expected square shape of {self.shape}, got {image.shape}"
            squares = image
        elif dim == 0:
            assert image.ndim == 3, f"Expected 3 dimensions, got {image.shape}"
            assert image.shape[0] % self.shape[0] == 0, f"Expected channels divisible by {self.shape[0]}, got {image.shape}"
            assert image.shape[1] == self.shape[1], f"Expected height of {self.shape[1]}, got {image.shape}"
            assert image.shape[2] == self.shape[2], f"Expected width of {self.shape[2]}, got {image.shape}"
        elif dim == 1:
            assert image.ndim == 3, f"Expected 3 dimensions, got {image.shape}"
            assert image.shape[0] == self.shape[0], f"Expected {self.shape[0]} channels, got {image.shape}"
            assert image.shape[1] % self.shape[1] == 0, f"Expected height divisible by {self.shape[1]}, got {image.shape}"
            assert image.shape[2] == self.shape[2], f"Expected width of {self.shape[2]}, got {image.shape}"
        elif dim == 2:
            assert image.ndim == 3, f"Expected 3 dimensions, got {image.shape}"
            assert image.shape[0] == self.shape[0], f"Expected {self.shape[0]} channels, got {image.shape}"
            assert image.shape[1] == self.shape[1], f"Expected height of {self.shape[1]}, got {image.shape}"
            assert image.shape[2] % self.shape[2] == 0, f"Expected width divisible by {self.shape[2]}, got {image.shape}"
        else:
            raise NotImplementedError(f"Expected dim in -1, 0, 1 or 2, got {dim}")
        
        if dim >= 0:
            squares = image.split(self.shape[dim], dim)
            
        all_ords = list(self._font_map)
        all_fonts = torch.cat([self(o).unsqueeze(0) for o in all_ords])
        output = []
        for square in squares:
            diffs = (all_fonts - square).abs().flatten(1).mean(1)
            best = all_ords[diffs.argmin()]
            # print(chr(best), best)
            output.append(chr(best))
        return "".join(output)

f = FontSquares(
    #"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
    #"/usr/share/fonts/truetype/freefont/FreeSansBold.ttf",
    #center=True,
)
display(VF.to_pil_image(resize(make_grid([
    f(i) for i in range(32, 256)
], nrow=32, pad_value=.3), 3)))
display(VF.to_pil_image(resize(f("hello world\0"), 3)))
#print(f.reverse(f("hello world")))
#f("hello", dim=0).shape
print(f.reverse(f("hello world", dim=0), dim=0))

In [None]:
v = torch.rand(3, 4, 4)
torch.stack([v, v]).shape

In [None]:
nn.Conv1d(2, 3, 5, padding=8, dilation=4)(
    torch.ones(1, 2, 100)
).shape

In [None]:
size = 52
for ks in [3, 5, 7, 9]:
    for pad in range(1, 12):
        for dil in range(1, 12):
            shape = nn.Conv1d(2, 3, ks, padding=pad, dilation=dil)(
                torch.ones(1, 2, size)
            ).shape
            if shape[-1] == size:
                print(f"ks={ks:2}, pad={pad:2}, dil={dil:2}, shape={shape}")

# display convolution receptive field

In [None]:
@torch.no_grad()
def plot_conv(size: int, kernel_size: int, dilation: Union[int, Iterable[int]], layers: int = 6, zoom: int = 5):
    
    inp = torch.zeros(1, 3, size)
    inp[..., size//2] = 1
    
    grid = []
    def add_pic(state):
        img = state.permute(1, 0, 2)  # take the one batch dimension as height
        #img = (img.abs() / img.max()).pow(.3)
        grid.append(resize(img.clamp(0, 1), zoom))
        
    add_pic(inp)
    for i, dil in enumerate(param_make_list(dilation, layers, "dilation")):
        padding = int(math.floor(kernel_size / 2)) * dil
        conv = nn.Conv1d(3, 3, kernel_size, padding=padding, dilation=dil)
        conv.weight[:] = .5 * torch.rand(conv.weight.shape, generator=torch.Generator().manual_seed(23))
        conv.bias[:] = 0.
        inp = conv(inp)
        #inp = F.gelu(inp)
        add_pic(inp)
        
    display(VF.to_pil_image(make_grid(grid, nrow=1, pad_value=.3)))

plot_conv(
    100, 9, 
    #dilation=2, 
    #dilation=[2, 3, 4, 5, 6, 1],
    #dilation=[2, 3, 2, 3, 2, 3],
    dilation=[6, 6, 6, 1, 1, 1],
    layers=6,
)

In [None]:
400*400

In [None]:
from src.datasets.generative.text_gen import TextQABaseIterableDataset

class TextQAProgramIterableDataset(TextQABaseIterableDataset):
    """
    Yields things like

        ABCD, 0>1 = BACD
    """
    def __init__(
            self,
            count: int,
            num_items: Union[int, Tuple[int, int]] = 4,
            num_operations: Union[int, Tuple[int, int]] = 3,
            seed: Optional[int] = None,
            exclude: Optional[Iterable[str]] = None,
            with_masked: bool = False,
    ):
        super().__init__(
            count=count, seed=seed, exclude=exclude, with_masked=with_masked,
            fixed_answer_width=max(num_items) if isinstance(num_items, (tuple, list)) else num_items,
        )
        self._count = count
        self._num_items = num_items
        self._num_operations = num_operations
        self._seed = seed
        self._exclude = None if exclude is None else set(exclude)
        self._with_masked = with_masked

    def iter_question_answer(self, rng: random.Random) -> Generator[Tuple[str, str], None, None]:
        while True:

            num_items = self._num_items
            if isinstance(num_items, (tuple, list)):
                num_items = rng.randint(*num_items)

            num_ops = self._num_operations
            if isinstance(num_ops, (tuple, list)):
                num_ops = rng.randint(*num_ops)

            cells = [chr(ord('A') + i) for i in range(num_items)]
            rng.shuffle(cells)
            program_input = cells.copy()

            stack = []
            ops = []
            while cells and len(ops) < num_ops:
                op = rng.choices(
                    [">", "-", "+"],
                    weights=[1, 1/3, 1/3],
                )[0]
                if op == "-":
                    idx = rng.randrange(len(cells))
                    stack.append(cells.pop(idx))
                    ops.append(f"{op}{idx+1}")
                elif op == "+" and len(stack):
                    idx = rng.randrange(len(cells))
                    cells.insert(idx, stack.pop())
                    ops.append(f"{op}{idx+1}")
                elif op == ">" and len(cells) >= 2:
                    indices = list(range(len(cells)))
                    rng.shuffle(indices)
                    idx1, idx2 = indices[:2]
                    cells[idx1], cells[idx2] = cells[idx2], cells[idx1]
                    ops.append(f"{idx1+1}{op}{idx2+1}")

            question = (
                    "".join(program_input) + ": "
                    + ", ".join(ops)
            )
            answer = "".join(cells)
            yield question, answer

ds = TextQAProgramIterableDataset(
    1000,
    num_items=(2, 5),
    num_operations=(1, 5),
)

for i, seg in zip(range(100), ds):
    print(f"{i:3}:", repr(seg))
