In [77]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [38]:
class FFN(nn.Module):
    def __init__(self, dim_in, dim_hidden):
        super(FFN, self).__init__()
        self.network = nn.Sequential(nn.Linear(dim_in, dim_hidden), nn.ReLU(), nn.Linear(dim_hidden, dim_in))
    def forward(self, x):
        return self.network(x)

In [120]:
class Experts(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_experts):
        super(Experts, self).__init__()
        self.dim_in = dim_in
        self.num_experts = num_experts
        W1 = torch.empty(num_experts, dim_in, dim_hidden)
        b1 = torch.empty(num_experts, dim_hidden)
        W2 = torch.empty(num_experts, dim_hidden, dim_in)
        b2 = torch.empty(num_experts, dim_in)

        std = 1/math.sqrt(self.dim_in)
        W1.uniform_(-std, std)
        b1.uniform_(-std, std)
        W2.uniform_(-std, std)
        b2.uniform_(-std, std)
        
        self.W1 = nn.Parameter(W1)
        self.b1 = nn.Parameter(b1)
        self.W2 = nn.Parameter(W2)
        self.b2 = nn.Parameter(b2)

    def forward(self, x):
        #x, weights, experts_indices = input_and_weights
        #batch, context_length, _ = x.shape
        #experts_mask = torch.zeros( (batch, context_length, self.num_experts), device = x.device, dtype = int) # x.shape[:-1] = batch, context_length

        #experts_mask.scatter_(-1, experts_indices, torch.ones_like(experts_indices, device = x.device))
        a = torch.einsum('bcd,ndh->bcnh', x, self.W1) + self.b1
        z = F.relu(a)
        y = torch.einsum('bcnh,nhd->bcnd', z, self.W2) + self.b2
        return y

In [119]:
# W = torch.rand(3,4,5) # num experts, dim_in, dim_hidden
# s = torch.rand(10,7,3) # batch size, context length, num experts (gating vector)
# x = torch.rand(10,7,4) # batch size, context length, feature dim (data)
# s_v, s_i = torch.topk(s, 2, dim=-1)
# b = torch.zeros(10, 7, 3) # batch size, context length, num experts
# b.scatter_(-1, s_i, torch.ones_like(s_v))
# print(torch.einsum("nab,bio->naio",b,W).shape)

e = Experts(10,15,3)
data = torch.rand(100,20,10)
results = e(data)
y0 = F.relu(data @ e.W1[0,:,:] + e.b1[0,:]) @ e.W2[0,:,:] + e.b2[0,:]
y1 = F.relu(data @ e.W1[1,:,:] + e.b1[1,:]) @ e.W2[1,:,:] + e.b2[1,:]
y2 = F.relu(data @ e.W1[2,:,:] + e.b1[2,:]) @ e.W2[2,:,:] + e.b2[2,:]
print(torch.sum(results[:,:,0,:] - y0))
print(torch.sum(results[:,:,1,:] - y1))
print(torch.sum(results[:,:,2,:] - y2))

torch.Size([100, 20, 3, 15])
tensor(1.0356e-06, grad_fn=<SumBackward0>)
tensor(1.3784e-06, grad_fn=<SumBackward0>)
tensor(-1.1176e-08, grad_fn=<SumBackward0>)


In [99]:
e = Experts(6,8,4)
inp = (torch.rand(2,1,6), None, torch.tensor([
    [
        [1,2]
    ],
    [
        [2,3]
    ]
]))
e(inp)

tensor([[[0, 1, 1, 0]],

        [[0, 0, 1, 1]]])


In [143]:
class GatingNetwork(nn.Module):
    def __init__(self, dim_in, num_experts, topk, utilization_factor = 1e-2):
        super(GatingNetwork, self).__init__()
        self.dim_in = dim_in
        self.num_experts = num_experts
        self.topk = topk
        self.Wg = nn.Linear(dim_in, num_experts)
        self.Wnoise = nn.Linear(dim_in, num_experts)
        self.utilization_factor = utilization_factor
        
    def forward(self, x):
        noise = F.softplus(self.Wnoise(x))
        noise *= torch.randn_like(noise).to(noise.device)
        logits = self.Wg(x)
        logits += noise
        mask = torch.full_like(logits, -float('inf'))
        selected_logits, selected_indices = torch.topk(logits, self.topk, dim=-1)
        mask.scatter_(-1,selected_indices, selected_logits)
        weights = F.softmax(mask, dim=-1)
        return weights, self.utilization_loss(weights)

    def utilization_loss(self, weights):
        importance = weights.reshape(-1, self.num_experts).sum(dim=0)
        cv = importance.var(correction=0) / importance.mean().pow(2)
        return self.utilization_factor * cv
        
        

In [147]:
class MoE(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_experts, topk):
        super(MoE, self).__init__()
        # no need for dropout because it's already sparse?
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.num_experts = num_experts
        self.topk = topk
        self.gating = GatingNetwork(dim_in, num_experts, topk)
        self.experts = Experts(dim_in, dim_hidden, num_experts)
    def forward(self, x):
        weights, loss = self.gating(x)
        expert_results = self.experts(x)
        return torch.einsum('bcn,bcnd->bcd', weights, expert_results), loss
        # this implementation probably activates all the parameters, so no computational speed up. But that's not important for this RQ

In [146]:
data = torch.rand(1,1,6)
m = MoE(6,12,4,2)
m(data)[0].shape
m(data)[1]

tensor(0.0151, grad_fn=<MulBackward0>)

In [134]:
0.7804 * torch.tensor([ 0.3652,  0.3562,  0.1133, -0.4165,  0.2291, -0.7671]) + 0.2196 * torch.tensor([ 0.1765, -0.0946, -0.2208, -0.3564, -0.2796, -0.2473])

tensor([ 0.3238,  0.2572,  0.0399, -0.4033,  0.1174, -0.6530])