In [1]:
from model import CNN_Resnet, ResidualBlock

In [50]:
from tqdm import tqdm

In [2]:
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch
import torch.nn.functional as F

In [44]:
class Expert(nn.Module):
    def __init__(self, size, in_channels, num_classes, hidden_units=100, hidden_channels=32,out_channels = 16,  dropout=0.5):
        super(Expert, self).__init__()
        #in_channels: so kenh dau vao cua expert = so kenh dau ra cua gating
        #size: kich thuoc dau ra cua gating
        self.batch_norm = nn.BatchNorm2d(in_channels)
        self.dropout = nn.Dropout(dropout)
        self.residual_layer1 = ResidualBlock(in_channels, hidden_channels)
        self.residual_layer2 = ResidualBlock(hidden_channels, out_channels)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(int(size*size*out_channels), hidden_units)
        self.fc2 = nn.Linear(hidden_units, num_classes)
    
    def forward(self, x):
        x = self.batch_norm(x)
        x = self.residual_layer1(x)
        x = self.dropout(x)
        x = self.residual_layer2(x)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

In [45]:
class Resnet_gating(nn.Module):
    def __init__(self, in_channels, out_channels,downsample, hidden_channels=64, dropout=0.5):
        #in_channels: so kenh anh dau vao
        #out_channels: so kenh dau ra cua gating = in_channels cua expert
        super().__init__()
        self.batch_norm = nn.BatchNorm2d(in_channels)
        self.residual_layer1 = ResidualBlock(in_channels, hidden_channels, downsample=downsample)
        self.residual_layer2 = ResidualBlock(hidden_channels, hidden_channels)
        self.residual_layer3 = ResidualBlock(hidden_channels, out_channels)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.batch_norm(x)
        x = self.residual_layer1(x)
        x = self.residual_layer2(x)
        x = self.dropout(x)
        x = self.residual_layer3(x)

        return x
        

In [46]:
class Gating(nn.Module):
    def __init__(self, in_channels, num_experts, size,hidden_units=64, dropout = 0.5):
        super(Gating, self).__init__()
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(int(size*size*in_channels), hidden_units)
        self.fc2 = nn.Linear(hidden_units, 128)
        self.fc3 = nn.Linear(128, num_experts)
    
    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

In [47]:
class MoE(nn.Module):
    def __init__(self, num_experts, in_channels, out_channels_res, size, num_classes, downsample=True, dropout=0.5,topk=2):
        # in_channels: so kenh anh dau vao
        # size: kich thuoc cua anh
        # out_channels_res: so kenh dau ra của resnet_gating
        super(MoE, self).__init__()
        if downsample: self.size = size/2
        else: self.size = size
        self.topk = topk
        self.CNN_gating = Resnet_gating(in_channels=in_channels, out_channels=out_channels_res, downsample=downsample)
        self.gating = Gating(in_channels=out_channels_res,num_experts= num_experts,size= self.size, dropout=dropout)
        self.experts = nn.ModuleList([Expert(size=self.size, in_channels= out_channels_res,num_classes= num_classes) for _ in range(num_experts)])
    
    def forward(self, x):
        CNNgate_output = self.CNN_gating(x)
        gating_output = self.gating(CNNgate_output)
        top_values, top_indices = torch.topk(gating_output, self.topk, dim=1)
        zeros = torch.full_like(gating_output, float('-inf'))
        sparse_weigths = F.softmax(zeros.scatter_(1, top_indices, top_values), dim=-1)
        experts_outputs = torch.stack([torch.tensor(expert(CNNgate_output)) for expert in self.experts], dim=1)
        sparse_weigths = sparse_weigths.unsqueeze(-1)
        final_outputs = sparse_weigths*experts_outputs
        return final_outputs.sum(dim=1)

In [48]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)

In [49]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1024)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024)

In [54]:
images, labels = next(iter(testloader))

In [None]:
epochs = 100
model = MoE(num_experts=6, in_channels=1, out_channels_res=16, size=28, num_classes=10)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
train_loss_hist = []
valid_loss_hist = []
num_valid_batch = len(testloader)
num_train_batch = len(trainloader)
for epoch in range(1,epochs+1):
    model.train()
    epoch_train_loss = 0
    for batch in tqdm(trainloader, desc="Epoch {}".format(epoch)):
        optimizer.zero_grad()
        outputs = model(batch[0])
        loss = criterion(outputs, batch[1])
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()
    print(epoch_train_loss/num_train_batch)
    train_loss_hist.append(epoch_train_loss/num_train_batch)
    
    model.eval()
    with torch.no_grad():
        epoch_valid_loss = 0
        for batch in tqdm(testloader, desc="Validate epoch {}".format(epoch)):
            valid_outputs = model(batch[0])
            valid_loss = criterion(valid_outputs, batch[1])
            epoch_valid_loss += valid_loss
        print(epoch_valid_loss/num_valid_batch)
        valid_loss_hist.append(epoch_valid_loss/num_valid_batch)
        

  experts_outputs = torch.stack([torch.tensor(expert(CNNgate_output)) for expert in self.experts], dim=1)
Epoch 1: 100%|██████████| 59/59 [01:35<00:00,  1.62s/it]


2.335179914862423


Validate epoch 1: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]


tensor(2.3118)


Epoch 2: 100%|██████████| 59/59 [01:34<00:00,  1.61s/it]


2.3363744000257074


Validate epoch 2: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]


tensor(2.3089)


Epoch 3:  54%|█████▍    | 32/59 [00:52<00:44,  1.64s/it]


KeyboardInterrupt: 

In [53]:
train_loss_hist

[2.246499417191845,
 2.1912409612687966,
 2.184216685214285,
 2.1655157784284174,
 2.15790807190588,
 2.1548096729537187,
 2.1520512467723782,
 2.1434801230996343,
 2.1445296053159035,
 2.142780093823449,
 2.1474807828159657,
 2.1430247072446145,
 2.13886850163088,
 2.1400596610570357,
 2.134870517051826,
 2.1399735515400513,
 2.1353436041686495,
 2.135501085701635,
 2.1384689565432273,
 2.1381928718696206,
 2.1360329450187034,
 2.140641176094443,
 2.1387435541314592,
 2.1391610493094233,
 2.1394749940451927,
 2.135067814487522,
 2.1371908955654857,
 2.1317489268416066,
 2.132109625864837,
 2.130346403283588,
 2.1361818273188704,
 2.133071220527261,
 2.137053788718531,
 2.1325807126901917,
 2.1323697930675443,
 2.132864244913651,
 2.1362427574093057,
 2.134420600988097,
 2.1363175723512295,
 2.1375965950852733,
 2.1373437744075967,
 2.135022268456928,
 2.1340454917843057,
 2.1336403539625266,
 2.128615743022854,
 2.130455635361752,
 2.130951586416212,
 2.1335526927042814,
 2.1305424480

In [63]:
model.eval()
out1 = model.CNN_gating(images)
outputs = model.gating(out1)

In [64]:
F.softmax(outputs, dim=-1)

tensor([[4.1625e-25, 1.0000e+00, 6.0943e-23, 3.9817e-16, 7.2082e-21, 2.6845e-13],
        [8.3777e-04, 1.1612e-01, 1.2217e-04, 8.8081e-01, 1.5746e-03, 5.3502e-04],
        [5.8534e-01, 2.4043e-02, 1.0507e-02, 2.9346e-02, 1.7873e-02, 3.3289e-01],
        ...,
        [4.2683e-20, 1.0000e+00, 1.2568e-18, 6.6745e-13, 1.1356e-16, 1.6347e-10],
        [1.1070e-05, 1.0685e-05, 6.1248e-07, 1.4987e-06, 1.0375e-06, 9.9998e-01],
        [1.0000e+00, 1.7268e-10, 1.4752e-09, 1.3625e-09, 1.4916e-12, 5.0004e-11]],
       grad_fn=<SoftmaxBackward0>)