In [13]:
embedding_dim = 8
context_length = 4
number_of_experts = 6
top_k = 2
feedforward_factor = 2
dropout_rate = 0.4
# batch_size = 1

In [3]:
from torch import nn
import torch


class FeedForward(nn.Module):
    '''Feed-forward neural network'''
    def __init__(self, C, dropout = dropout_rate):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(C, C * feedforward_factor),
            nn.ReLU(),
            nn.Linear(C * feedforward_factor, C),
            nn.Dropout(dropout)
        )
        self.matrix = nn.Parameter(torch.randn(C, C) / C**0.5)

    def forward(self, x):
        return self.net(x)


class FeedForwardMixture(nn.Module):
    '''Feed-forward neural network'''
    def __init__(self, C, n_experts):
        super().__init__()
        self.nets = nn.ModuleList([FeedForward(C) for n in range(n_experts)])

    def forward(self, x, ixs):
        # NETS RUN SEQUENTIALLY! DON"T DO THIS
        return torch.mean(torch.stack([self.nets[ix](x) for ix in ixs],dim=0),dim=0) # BE CAREFUL WITH BATCH HERE

In [3]:
import torch
a = torch.randn([3,4])

a[1:3]

tensor([[ 0.2809, -0.5921, -1.7067, -0.8721],
        [-0.3847, -0.4496, -0.5358,  1.1395]])

In [8]:

embedded_tokens = torch.randn(context_length,embedding_dim)
embedded_tokens


routing_weights = torch.randn(embedding_dim, number_of_experts)
router_logits = (embedded_tokens @ routing_weights) / (embedding_dim ** 0.5)
router_probs = torch.softmax(router_logits, dim=-1)

top = torch.multinomial(router_probs, num_samples = top_k) # each row contains the indices of the top_k experts
# convert top to a list
top = top.tolist()
# print(top)

mixture = FeedForwardMixture(embedding_dim, number_of_experts)
mixture(embedded_tokens, top)



TypeError: 'list' object cannot be interpreted as an integer

In [12]:
# mixture.nets[0,1]
print(top)

[[2, 5], [5, 2], [5, 3], [3, 4]]


In [122]:
a = torch.randn(4, 3)
b = torch.rand(4, 3)
stacked = torch.stack([a,b],dim=-1)
torch.mean(stacked,dim=-1)

tensor([[ 0.8317, -0.2049,  0.4555],
        [-0.0144,  0.3624,  0.4267],
        [ 1.0783,  0.8905,  0.8838],
        [ 1.6270, -0.6571,  0.1192]])

In [124]:
a = torch.randn(4, 3)
b = torch.rand(4, 3)

#average = torch.mean([a,b], dim=0)
# average


a = torch.tensor([0,1, 2])
b = torch.tensor([2,3, 4])

c = torch.stack([a,b], dim=0)

print(c)
print(c.shape)

tensor([[0, 1, 2],
        [2, 3, 4]])
torch.Size([2, 3])
