## Download and Install Requirements

In [None]:
!pip install -r requirements.txt

## Define the MLP and CNN Network Architectures

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP_APS_Net(nn.Module):  
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(400, 210)
        self.drop1 = nn.Dropout(0.3, inplace=True)
        self.fc2 = nn.Linear(210, 5)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.drop1(x)
        return F.softmax(self.fc2(x), dim=1)
    
class MLP_EMG_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(16, 230)
        self.fc2 = nn.Linear(230, 5)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)    
    
class MLP_Fused_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)
        
    def forward(self, x):
        x = self.fc1(x)
        return F.softmax(x, dim=1)

class Conv_APS_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)
        self.drop1 = nn.Dropout(0.25, inplace=True)
        self.fc1 = nn.Linear(1152, 512)
        self.drop2 = nn.Dropout(0.5, inplace=True)
        self.fc2 = nn.Linear(512, 5)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)) ,2)
        x = F.relu(self.conv3(x))
        x = self.drop1(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.drop2(x)
        return F.softmax(self.fc2(x), dim=1)    
    
class Conv_EMG_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(16, 128)
        self.drop1 = nn.Dropout(0.5, inplace=True)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 5)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.drop1(x)
        x = F.relu(self.fc2(x))
        return F.softmax(self.fc3(x))
    
class Conv_Fused_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)
        
    def forward(self, x):
        x = self.fc1(x)
        return F.softmax(x, dim=1)

## Convert Serialized *pt* Objects to Equivalent *protxt* and *caffemodel* Files

In [None]:
!conda install -c anaconda caffe -y

In [None]:
!git clone https://github.com/xxradon/PytorchToCaffe
%cd ~/PytorchToCaffe/
import sys
sys.path.append('..')
import pytorch_to_caffe
from torch.autograd import Variable
import copy


device = torch.device('cpu')

class Ensemble(nn.Module):
    def __init__(self, modelA, modelB, modelFused):
        super(Ensemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        self.classifier = modelFused
        
    def forward(self, x):
        x1 = self.modelA(x[0])
        x2 = self.modelB(x[1])
        x = torch.cat((x1, x2), dim=1)
        x = F.softmax(self.classifier(x), dim=1)
        return x
    
class MLP_APS_Net_Merged(nn.Module):
    def __init__(self, model_1, model_2, model_3, model_4):
        super().__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.model_3 = model_3
        self.model_4 = model_4
        
    def forward(self, x):
        x1 = self.model_1(x[0])
        x2 = self.model_2(x[1])
        x3 = self.model_3(x[2])
        x4 = self.model_4(x[3])
        x = torch.stack([x1, x2, x3, x4], dim=1).sum(1)
        return x
 
for fold in range(3):
    # MLP_APS_Net
    MLP_APS_Model_0 = torch.load('../fold_%d_MLP_APS_Net_Model_0.pt' % fold)
    MLP_APS_Model_1 = torch.load('../fold_%d_MLP_APS_Net_Model_1.pt' % fold)
    MLP_APS_Model_2 = torch.load('../fold_%d_MLP_APS_Net_Model_2.pt' % fold)
    MLP_APS_Model_3 = torch.load('../fold_%d_MLP_APS_Net_Model_3.pt' % fold)
    net = MLP_APS_Net_Merged(MLP_APS_Model_0, MLP_APS_Model_1, MLP_APS_Model_2, MLP_APS_Model_3).to(device)
    input_ = torch.ones((4, 1, 400))
    name = 'MLP_APS_Net'
    pytorch_to_caffe.trans_net(net, input_, name)
    pytorch_to_caffe.save_prototxt('../FPGA/fold_%d_%s.prototxt' % (fold, name))
    pytorch_to_caffe.save_caffemodel('../FPGA/fold_%d_%s.caffemodel' % (fold name))

    # MLP_Fused_Net
    net = Ensemble(net, torch.load('../fold_%d_MLP_EMG_Net.pt' % fold).to(device), torch.load('../fold_%d_MLP_Fused_Net.pt' % fold).to(device))
    input_ = [torch.ones(4, 1, 400), torch.ones(1, 16)]
    name = 'MLP_Fused_Net'
    pytorch_to_caffe.trans_net(net, input_, name)
    pytorch_to_caffe.save_prototxt('../FPGA/fold_%d_%s.prototxt' % (fold, name))
    pytorch_to_caffe.save_caffemodel('../FPGA/fold_%d_%s.caffemodel' % (fold, name))

    # Conv_Fused_Net
    net = Ensemble(torch.load('../fold_%d_Conv_APS_Net.pt' % fold).to(device), 
                   torch.load('../fold_%d_Conv_EMG_Net.pt' % fold).to(device), 
                   torch.load('../fold_%d_Conv_Fused_Net.pt' % fold).to(device))
    input_ = [torch.ones(1, 1, 40, 40), torch.ones(1, 16)]
    name = 'Conv_Fused_Net'
    pytorch_to_caffe.trans_net(net, input_, name)
    pytorch_to_caffe.save_prototxt('../FPGA/fold_%d_%s.prototxt' % (fold, name))
    pytorch_to_caffe.save_caffemodel('../FPGA/fold_%d_%s.caffemodel' % (fold, name))

    models = {}
    models['MLP_EMG_Net'] = ['MLP_EMG_Net.pt', MLP_EMG_Net, (1, 16)]
    models['Conv_APS_Net'] = ['Conv_APS_Net.pt', Conv_APS_Net, (1, 1, 40, 40)]
    models['Conv_EMG_Net'] = ['Conv_EMG_Net.pt', Conv_EMG_Net, (1, 16)]
    for model in models:
        net = torch.load('../fold_%d_%s' % models[model][0]).to(device)
        input_ = torch.ones(models[model][2])
        pytorch_to_caffe.trans_net(net, input_, model)
        pytorch_to_caffe.save_prototxt('../FPGA/fold_%d_%s.prototxt' % (fold, name))
        pytorch_to_caffe.save_caffemodel('../FPGA/fold_%d_%s.caffemodel' % (fold, name))