In [None]:
from init_notebook import *
from src.datasets.generative import *

In [None]:
ds, ds_val = TextQAProgramIterableDataset.create_train_and_validation_set(
    train_count=100000,
    validation_count=10000,
    validation_seed=23,
    input_length=(2,5),
    num_operations=(1,5),
)
print("val: ", len(set(ds_val)))
for i, text in zip(range(10), ds_val):
    print(f"{i:2}: {repr(text)}")
print("test:", len(set(tqdm(ds))))
for i, text in zip(range(20), ds):
    print(f"{i:2}: {repr(text)}")

In [None]:
for j in range(10):
    for i, text in zip(range(2), ds):
        #print(f"{i:2}: {repr(text)}")
        print(text)

In [None]:
ds, ds_val = TextQAProgramIterableDataset.create_train_and_validation_set(
    train_count=100000,
    validation_count=10000,
    validation_seed=23,
    input_length=5,
    num_operations=[2,5],
    operators={">": 1},
)
for i, text in zip(range(20), ds):
    #print(f"{i:2}: {repr(text)}")
    print(text)

In [None]:
train_set, validation_set = TextQAMathIterableDataset.create_train_and_validation_set(
    train_count=100_000,
    validation_count=10_000,
    validation_seed=23,
    with_masked=True,
    max_number=100,
    num_operations=3,
    validation_num_operations=5,
    fixed_answer_width=3*5,
    operators=["+", "-", "*"],
)
for i, text in zip(range(10), train_set):
    print(f"{i:2}: {repr(text)}")


In [None]:
class TextQAProgramIterableDataset(TextQABaseIterableDataset):
    """
    Yields things like

        ABCD, 0>1 = BACD
    """
    def __init__(
            self,
            count: int,
            input_length: Union[int, Tuple[int, int]] = 4,
            num_items: Union[int, Tuple[int, int]] = 26,
            num_operations: Union[int, Tuple[int, int]] = 3,
            seed: Optional[int] = None,
            exclude: Optional[Iterable[str]] = None,
            with_masked: bool = False,
    ):
        super().__init__(
            count=count, seed=seed, exclude=exclude, with_masked=with_masked,
            fixed_answer_width=max(input_length) if isinstance(input_length, (tuple, list)) else input_length,
        )
        self._count = count
        self._input_length = input_length
        self._num_items = num_items
        self._num_operations = num_operations
        self._seed = seed
        self._exclude = None if exclude is None else set(exclude)
        self._with_masked = with_masked

    def iter_question_answer(self, rng: random.Random) -> Generator[Tuple[str, str], None, None]:
        duplicates_set = set()
        while True:

            input_length = self._input_length
            if isinstance(input_length, (tuple, list)):
                input_length = rng.randint(*input_length)

            num_items = self._num_items
            if isinstance(num_items, (tuple, list)):
                num_items = rng.randint(*num_items)

            num_ops = self._num_operations
            if isinstance(num_ops, (tuple, list)):
                num_ops = rng.randint(*num_ops)

            items = [chr(ord('A') + i) for i in range(num_items)]
            rng.shuffle(items)
            cells = items[:input_length]
            program_input = cells.copy()

            stack = []
            ops = []
            while cells and len(ops) < num_ops:
                op = rng.choices(
                    [">", "-", "+"],
                    weights=[1, 1/3, 1/3],
                )[0]
                if op == "-":
                    idx = rng.randrange(len(cells))
                    stack.append(cells.pop(idx))
                    ops.append(f"{op}{idx+1}")
                elif op == "+" and len(stack):
                    idx = rng.randrange(len(cells))
                    cells.insert(idx, stack.pop())
                    ops.append(f"{op}{idx+1}")
                elif op == ">" and len(cells) >= 2:
                    indices = list(range(len(cells)))
                    rng.shuffle(indices)
                    idx1, idx2 = indices[:2]
                    cells[idx1], cells[idx2] = cells[idx2], cells[idx1]
                    ops.append(f"{idx1+1}{op}{idx2+1}")

            question = (
                    "".join(program_input) + ": "
                    + ", ".join(ops)
            )
            if question in duplicates_set:
                continue
            duplicates_set.add(question)

            answer = "".join(cells)
            yield question, answer

ds = TextQAProgramIterableDataset(count=1000, seed=23)
len(set(ds))
for i, text in zip(range(20), ds):
    print(f"{i:2}: {repr(text)}")

In [None]:
m = nn.LSTM(20, 30, batch_first=True)
m(torch.ones(1, 20))[0].shape

In [None]:
ch, l = 64, 100 
weight = torch.randn(ch, 1)
input = torch.rand(ch, l)
print(weight.shape, input.shape)
#print(weight[:, None].shape)
weight * input
#conv = nn.Conv1d(ch, ch, 3, padding=1)

In [None]:
class PositionEmbedding1d(nn.Module):
    def __init__(
            self,
            period: float = 20.,
    ):
        super().__init__()
        self.period = period

    def forward(self, length: int) -> torch.Tensor:
        phase = torch.arange(0, length) / self.period * math.pi * 2
        phase = phase * (1 + .02 * phase)
        return torch.stack([phase.sin(), phase.cos()])

m = PositionEmbedding1d()
px.imshow(m(100))
   

In [None]:
a = nn.MultiheadAttention(
    embed_dim=64,
    num_heads=4,
    batch_first=True,
)
print(f"params: {num_module_parameters(a):,}")
for n, p in a.named_parameters():
    print(p.shape, n)

In [None]:
class LinearSelfAttention2d(nn.Module):
    """
    from https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
    """
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, Q, K, V):
        # Compute the KV matrix, namely the dot product of keys and values so
        # that we never explicitly compute the attention matrix and thus
        # decrease the complexity
        KV = torch.einsum("nshd,nshm->nhmd", K, V)

        # Compute the normalizer
        Z = 1. / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)

        # Finally compute and return the new values
        V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)

        return V.contiguous()

m = LinearSelfAttention2d()
m(torch.ones(1, 32, 100, 90), torch.ones(1, 32, 100, 90), torch.ones(1, 32, 100, 90)).shape

In [None]:
class SelfAttention1d(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V):
        # Compute the KV matrix, namely the dot product of keys and values so
        # that we never explicitly compute the attention matrix and thus
        # decrease the complexity
        KV = torch.einsum("nshd,nshm->nhmd", K, V)

        # Compute the normalizer
        Z = 1. / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)

        # Finally compute and return the new values
        V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)

        return V.contiguous()

a1 = nn.MultiheadAttention(3, 1, bias=False)
print("params", num_module_parameters(a1))
qkv = torch.arange(0, 9 * 5).view(1, 9, 5).float() / (9*5)
q, k, v = torch.split(qkv.permute(0, 2, 1), 3, dim=2)
a1(q, k, v)
#m = LinearSelfAttention2d()
#m(torch.ones(1, 32, 100, 90), torch.ones(1, 32, 100, 90), torch.ones(1, 32, 100, 90)).shape

In [None]:
f = dpfp
#f = lambda x: F.elu(x) + 1

v1 = f(q) @ (f(k).permute(0, 2, 1) @ v) / (v.shape[-1] * v.shape[-2])
print(v1.shape)
v1

In [None]:
k @ v.permute(0, 2, 1)

In [None]:
#q @ (k.permute(0, 2, 1) @ v)

In [None]:
x=torch.linspace(-10, 10, 100)
px.line(x=x, y=torch.sigmoid(x)) 

In [None]:
import graphviz
graphviz.Digraph?

In [None]:
g = graphviz.Digraph(engine="dot")
#g.edge("in", "out", label="B,C,L")
g.edge("in", "conv C,2C", label="B,C,L")
g.edge("conv C,2C", "split", label="B,2C,L")
g.edge("split", "Kᵀ dot V", label="B,C,L")
g.edge("in", "Kᵀ dot V", label="B,C,L")
g.edge("Kᵀ dot V", "Q dot (Kᵀ dot V)", label="B,L,L")
g.edge("split", "Q dot (Kᵀ dot V)", label="B,C,L")
g.edge("Q dot (Kᵀ dot V)", "act", label="B,C,L")
g.edge("act", "out", label="B,C,L")

#g.edge("in B,C,L", "out B,C,L", label="x")
#g.edge("in B,C,L", "conv B,C,L -> B,2C,L")
g#.edge?

In [None]:
B, C, L = 1, 32, 100
v = torch.ones(B, C, L)
(v @ (v.permute(0, 2, 1) @ v)).shape

In [None]:
import itertools
list("".join(p) for p in itertools.permutations("QKV"))
#for t in ("QK", "QV", "KV", "QKV"):
#print(sorted(t))

In [None]:
df = pd.DataFrame({
    "trial": [1, 2, 1, 2],
    "b": [1, 1, 3, 5],
    "c": [1, 1, 1, 1.5],
    "s": ["s", "s", "s", "s"],
}, index=["bla-trial:1-bla", "bla-trial:2-bla", "blub-trial:1-bla", "blub-trial:2-bla"])
df
df.groupby("trial").mean(numeric_only=True)#.apply(lambda x: x)
df.groupby("trial").max()#numeric_only=True)

In [None]:
df = pd.DataFrame({
    "trial": [1, 2, 1, 2],
    "b": [1, 1, 3, 5],
    "bumm": ["a", "a", "b", "b"],
    "c": [1, 1, 1, 1.5],
    "s": ["s", "s", "s", "s"],
}, index=["bla-trial:1-bla", "bla-trial:2-bla", "blub-trial:1-bla", "blub-trial:2-bla"])
def _remove_trial(x):
    for t in df["trial"].unique():
        x = x.replace(f"trial:{t}", "")
    return x
df["id_without"] = df.index.map(_remove_trial)
display(df)
df2 = df.groupby("id_without").mean(numeric_only=True)
df3 = df.groupby("id_without").max()
for c in df3.columns:
    if c not in df2.columns:
        df2.loc[:, c] = df3.loc[:, c]
#df2.loc[:, df2.columns] = df2.loc[:, df3.columns]
#df2#.columns
pd.DataFrame({
    c: df2.loc[:, c]
    for c in df3.columns
}).reset_index()

In [None]:
def _():
    for i in range(2**5):
        yield [
            0 if (i >> j) & 1 == 0 else True
            for j in range(5)
        ]
for s in _():
    print("-", list(reversed(s)))

In [None]:
def dpfp(x, nu=1):
    x = torch.cat([F.relu(x), F.relu(-x)], dim=-1)
    x_rolled = torch.cat([
        x.roll(shifts=j, dims=-1)
        for j in range(1, nu+1)
    ], dim=-1)
    x_repeat = torch.cat([x] * nu, dim=-1)
    return x_repeat * x_rolled

x = torch.arange(0, 2*3*5).view(6, 5) -10#- (2*3*5)//2
display(x)
dpfp(x, nu=1).shape
#x.roll(shifts=1, dims=-1) * x

In [None]:
class DepthWiseConv1d(nn.Module):
    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 0,
        bias: bool = False,
    ):
        super().__init__()
        self.depth_conv = nn.Conv1d(
            channels_in, channels_in, kernel_size, stride=stride, padding=padding,
            groups=channels_in,
            bias=bias,
        )
        self.point_conv = nn.Conv1d(
            channels_in, channels_out, 1,
            bias=bias,
        )
    
    def forward(self, x: torch.Tensor):
        y = self.depth_conv(x)
        y = self.point_conv(y)
        return y

m = DepthWiseConv1d(1, 3, kernel_size=7, padding=3)
print(f"params: {num_module_parameters(m):,}")
inp = torch.ones(1, 1, 10)
outp = m(inp)
print(inp.shape, "->", outp.shape)
outp

In [None]:
def diagonal_matrix(shape: Union[int, Tuple[int, int]]) -> torch.Tensor:
    if isinstance(shape, int):
        shape = (shape, shape)
    if shape[-2] < shape[-1]:
        return diagonal_matrix((shape[-1], shape[-2])).T
    x_range = torch.arange(0, shape[-1]).float()
    y_range = torch.linspace(0, shape[-1] - 1, shape[-2])
    m_x = x_range.unsqueeze(0).repeat(shape[-2], 1)
    m_y = y_range.unsqueeze(0).repeat(shape[-1], 1)
    m = 1 - (m_x - m_y.T).abs().clamp(0, 1)
    return m #/ torch.norm(m, dim=-1, keepdim=True)

diagonal_matrix((2, 5))

In [None]:
grid = []
for i in range(1, 11):
    for j in range(1, 11):
        m = diagonal_matrix((j, i))
        m = F.pad(m, (0, 10-i, 0, 10-j), value=.2)
        grid.append(m.unsqueeze(0))
        
VF.to_pil_image(resize(make_grid(grid, nrow=10, pad_value=.2, padding=1), 7))

In [None]:
e = nn.Embedding(10, 5)
with torch.no_grad():
    e.weight[:] = diagonal_matrix(e.weight.shape)
e(torch.arange(0, 10))