In [280]:
import torch
from fancy_einsum import einsum

In [281]:
TOPK = 7

In [282]:
# n_experts = 2
# topk (capacity eksperta) = 3
# gate_dim = 4
x = torch.randint(0, 4, (3, 4, 5)) # (batch, cutoff, dmodel)
gate = torch.randint(0, 4, (4, 2)) # (cutoff, n_experts)
print(f'x:\n{x}\n\ngate:\n{gate}')

x:
tensor([[[2, 0, 3, 1, 1],
         [0, 0, 1, 1, 2],
         [1, 1, 1, 2, 1],
         [3, 2, 0, 3, 0]],

        [[2, 3, 0, 1, 1],
         [3, 2, 2, 0, 3],
         [0, 0, 2, 3, 0],
         [0, 0, 0, 2, 0]],

        [[2, 3, 1, 0, 3],
         [0, 3, 3, 3, 0],
         [2, 1, 1, 1, 0],
         [2, 0, 1, 2, 1]]])

gate:
tensor([[3, 1],
        [3, 0],
        [0, 3],
        [0, 0]])


In [283]:
# 1. używamy gate, żeby stworzyć reprezentację ekspertów zamiast dmodel
gate_out = einsum('batch cutoff dmodel, cutoff n_experts -> batch cutoff n_experts', x, gate)
print(f'gate_out:\n{gate_out}\n\n shape: {gate_out.shape}')

gate_out:
tensor([[[21,  7],
         [12,  0],
         [ 0, 18],
         [ 0,  0]],

        [[21,  7],
         [30,  0],
         [ 0, 15],
         [ 0,  0]],

        [[27,  9],
         [27,  0],
         [ 0, 15],
         [ 0,  0]]])

 shape: torch.Size([3, 4, 2])


In [284]:
# wymiar eksperta chcę żeby był na początku
gate_out = gate_out.permute(2, 0, 1)
gate_out

tensor([[[21, 12,  0,  0],
         [21, 30,  0,  0],
         [27, 27,  0,  0]],

        [[ 7,  0, 18,  0],
         [ 7,  0, 15,  0],
         [ 9,  0, 15,  0]]])

In [285]:
gate_out.shape

torch.Size([2, 3, 4])

In [286]:
# spłaszczam wymiary 2 i 3, żeby każdy ekspert wybrał sobie swobodnie tokeny (ogólnie tokenów w batchu jest 12)
gate_out = gate_out.flatten(start_dim=1)
gate_out

tensor([[21, 12,  0,  0, 21, 30,  0,  0, 27, 27,  0,  0],
        [ 7,  0, 18,  0,  7,  0, 15,  0,  9,  0, 15,  0]])

In [287]:
gate_out.shape

torch.Size([2, 12])

In [288]:
# make gate_out float tensor
gate_out = gate_out.float()

In [289]:
# teraz przeprowadzamy softmax po 1 wymiarze
gate_out = gate_out.softmax(dim=1)

In [290]:
gate_out

tensor([[1.1221e-04, 1.3848e-08, 8.5083e-14, 8.5083e-14, 1.1221e-04, 9.0924e-01,
         8.5083e-14, 8.5083e-14, 4.5268e-02, 4.5268e-02, 8.5083e-14, 8.5083e-14],
        [1.5187e-05, 1.3849e-08, 9.0931e-01, 1.3849e-08, 1.5187e-05, 1.3849e-08,
         4.5272e-02, 1.3849e-08, 1.1222e-04, 1.3849e-08, 4.5272e-02, 1.3849e-08]])

In [292]:
softmax_gate_out = gate_out
# wybierz topk tokenów dla każdego eksperta
gate_out = torch.topk(gate_out, k=TOPK, dim=1).indices

gate_out

tensor([[ 5,  8,  9,  4,  0,  1, 10],
        [ 2, 10,  6,  8,  4,  0,  9]])

In [293]:
# zakoduj wybór eksperta jako one hot
# tutaj dodać i odjąć odpowiedni softmax (można jakoś pomnożyć przez one te softmaxy i wtedy dodać i odjąć takie zwierzę)
gate_out = torch.nn.functional.one_hot(gate_out, num_classes=12)
gate_out

tensor([[[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]],

        [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]]])

In [294]:
print(f'Gate out shape: {gate_out.shape}, softmax_gate_out shape: {softmax_gate_out.shape}')

Gate out shape: torch.Size([2, 7, 12]), softmax_gate_out shape: torch.Size([2, 12])


In [308]:
softmax_mask = einsum('n_experts topk n_examples, n_experts n_examples -> n_experts topk n_examples', gate_out, softmax_gate_out)
softmax_mask

tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          9.0924e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 4.5268e-02, 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.5268e-02,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.1221e-04,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.1221e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 1.3848e-08, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+0

In [310]:
gate_out = gate_out + softmax_mask - softmax_mask
print(f'Gate out: {gate_out} \n\n softmax_mask: {softmax_mask}')

Gate out: tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          1.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 1.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 1.0000, 0.0000]],

        [[0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.

In [248]:
gate_out_flattened = gate_out.flatten(start_dim=0, end_dim=1)
print(f'gate_out:\n{gate_out_flattened}\n\n shape: {gate_out_flattened.shape}')

gate_out:
tensor([[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])

 shape: torch.Size([14, 12])


In [249]:
id = torch.eye(12)
id

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

In [250]:
chosen_examples = einsum('n_elems n_elems, expert_layer_width n_elems -> expert_layer_width n_elems', id, gate_out_flattened)
print(f'chosen_examples:\n{chosen_examples}\n\n shape: {chosen_examples.shape}')

chosen_examples:
tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]])

 shape: torch.Size([14, 12])


In [251]:
not_chosen_examples = (chosen_examples.sum(dim=0) == 0).float()
not_chosen_examples

tensor([0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0.])

In [252]:
gate_out.shape

torch.Size([2, 7, 12])

In [253]:
# z powrotem odwróć spłaszczenie
# gate_out = gate_out.view(2, TOPK, 3, 4)
gate_out

tensor([[[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]]])

In [254]:
# batch_size=3, cutoff=4

In [255]:
x

tensor([[[0, 1, 0, 3, 3],
         [3, 2, 3, 0, 1],
         [3, 1, 1, 0, 0],
         [0, 2, 2, 1, 1]],

        [[1, 3, 1, 2, 1],
         [1, 1, 1, 3, 0],
         [1, 3, 3, 1, 3],
         [1, 1, 2, 1, 3]],

        [[3, 0, 3, 0, 1],
         [0, 1, 1, 0, 3],
         [1, 3, 1, 0, 3],
         [2, 1, 0, 2, 3]]])

In [256]:
x.shape

torch.Size([3, 4, 5])

In [257]:
x = x.flatten(start_dim=0, end_dim=1)
x
# tu zapisać x przed wejściem do layera

tensor([[0, 1, 0, 3, 3],
        [3, 2, 3, 0, 1],
        [3, 1, 1, 0, 0],
        [0, 2, 2, 1, 1],
        [1, 3, 1, 2, 1],
        [1, 1, 1, 3, 0],
        [1, 3, 3, 1, 3],
        [1, 1, 2, 1, 3],
        [3, 0, 3, 0, 1],
        [0, 1, 1, 0, 3],
        [1, 3, 1, 0, 3],
        [2, 1, 0, 2, 3]])

In [258]:
x_before_experts = x

In [259]:
# teraz permutujemy macierz x zgodne z tym one hot
x = einsum('n_elems dmodel, n_experts topk n_elems -> n_experts topk dmodel', x, gate_out)
print(f'x:\n{x}\n\n shape: {x.shape}')

x:
tensor([[[1, 3, 1, 2, 1],
         [0, 1, 0, 3, 3],
         [3, 0, 3, 0, 1],
         [1, 1, 2, 1, 3],
         [2, 1, 0, 2, 3],
         [0, 2, 2, 1, 1],
         [3, 2, 3, 0, 1]],

        [[3, 2, 3, 0, 1],
         [1, 1, 1, 3, 0],
         [0, 1, 1, 0, 3],
         [2, 1, 0, 2, 3],
         [1, 1, 2, 1, 3],
         [0, 2, 2, 1, 1],
         [1, 3, 1, 0, 3]]])

 shape: torch.Size([2, 7, 5])


In [260]:
# wypłasczamy wymiar 0 i 1, tak żeby otrzymać macierz wejścia do warstwy ekspertóœ
x = x.flatten(start_dim=0, end_dim=1)
print(f'x:\n{x}\n\n shape: {x.shape}')

x:
tensor([[1, 3, 1, 2, 1],
        [0, 1, 0, 3, 3],
        [3, 0, 3, 0, 1],
        [1, 1, 2, 1, 3],
        [2, 1, 0, 2, 3],
        [0, 2, 2, 1, 1],
        [3, 2, 3, 0, 1],
        [3, 2, 3, 0, 1],
        [1, 1, 1, 3, 0],
        [0, 1, 1, 0, 3],
        [2, 1, 0, 2, 3],
        [1, 1, 2, 1, 3],
        [0, 2, 2, 1, 1],
        [1, 3, 1, 0, 3]])

 shape: torch.Size([14, 5])


In [261]:
# dodajemy do siebie wyniki dla tych tokenów, które były wybrane przez wielu ekspertów
chosen_examples.shape

torch.Size([14, 12])

In [262]:
summed = einsum('expert_layer_width n_elems, expert_layer_width dmodel -> n_elems dmodel', chosen_examples.float(), x.float())
print(f'summed:\n{summed}\n\n shape: {summed.shape}')

summed:
tensor([[0., 1., 0., 3., 3.],
        [6., 4., 6., 0., 2.],
        [0., 0., 0., 0., 0.],
        [0., 4., 4., 2., 2.],
        [1., 3., 1., 2., 1.],
        [1., 1., 1., 3., 0.],
        [0., 0., 0., 0., 0.],
        [2., 2., 4., 2., 6.],
        [3., 0., 3., 0., 1.],
        [0., 1., 1., 0., 3.],
        [1., 3., 1., 0., 3.],
        [4., 2., 0., 4., 6.]])

 shape: torch.Size([12, 5])


In [264]:
# add examples that were not chosen by any expert
einsum('n_elems dmodel, n_elems -> n_elems dmodel', x_before_experts, not_chosen_examples)

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [3., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1., 3., 3., 1., 3.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

In [265]:
summed += einsum('n_elems dmodel, n_elems -> n_elems dmodel', x_before_experts, not_chosen_examples)
print(f'summed:\n{summed}\n\n shape: {summed.shape}')

summed:
tensor([[0., 1., 0., 3., 3.],
        [6., 4., 6., 0., 2.],
        [3., 1., 1., 0., 0.],
        [0., 4., 4., 2., 2.],
        [1., 3., 1., 2., 1.],
        [1., 1., 1., 3., 0.],
        [1., 3., 3., 1., 3.],
        [2., 2., 4., 2., 6.],
        [3., 0., 3., 0., 1.],
        [0., 1., 1., 0., 3.],
        [1., 3., 1., 0., 3.],
        [4., 2., 0., 4., 6.]])

 shape: torch.Size([12, 5])


In [266]:
# again reshape summed to original shape with batch and cutoff
summed = summed.view(3, 4, 5)
print(f'summed:\n{summed}\n\n shape: {summed.shape}')

summed:
tensor([[[0., 1., 0., 3., 3.],
         [6., 4., 6., 0., 2.],
         [3., 1., 1., 0., 0.],
         [0., 4., 4., 2., 2.]],

        [[1., 3., 1., 2., 1.],
         [1., 1., 1., 3., 0.],
         [1., 3., 3., 1., 3.],
         [2., 2., 4., 2., 6.]],

        [[3., 0., 3., 0., 1.],
         [0., 1., 1., 0., 3.],
         [1., 3., 1., 0., 3.],
         [4., 2., 0., 4., 6.]]])

 shape: torch.Size([3, 4, 5])
