## Neural Network Training with the Extended Kalman Filter
This notebook demonstrates how to train a neural network with an EKF. I use a teacher network to produce data, and train a student network to model the input-output relation defined by the teacher network through examples. I use the same architecture for the two networks, but this is not a restriction in general.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

Below is the main class for the EKF based trainer. The trainer takes as input an NN object, along with some parameters. It is assumed that the student NN parameters perform a random walk. The training data is constructed by feeding random inputs to the teacher NN. At each step, the trainer is provided either a single input-output pair, or a batch of input-output pairs. The weights of the NN are updated so that the random walk of the student weights approach a set of values such that the student produces similar outputs as the teacher. 

In [2]:
class Trainer():
    # Kalman filter trainer for a differentiable neural network
    # full covariance matrix is kept throughout the iterations
    def __init__(self, student, params):
        self.C = params['Cov'] # initial covariance matrix
        self.num_param = torch.numel(torch.nn.utils.parameters_to_vector(student.parameters())) # number of elements
        self.RW = params['state_noise'] # noise for the definition of random walk on the state
        self.obs_noise = params['obs_noise'] # observation noise variance
        self.student = student # the student to be trained
        
    def vectorize_grad(self):
        # returns the gradients of all of the parameters in a vector
        # somehow I was not able to obtain this information using 
        # torch.nn.utils.parameters_to_vector
        # that function gives the values of the parameters, but not the gradients

        vec = []
        for param in self.student.parameters():
            vec.append(param.grad.view(-1))

        return torch.cat(vec)

    def Jacobian(self, x):
        # evaluates z = model(x,w) and the Jacobian of g(w) = model(x,w) at w
        # returns z along with the mentioned Jacobian
        self.student.zero_grad()
        z = self.student(x)
        H = torch.empty(torch.numel(z), self.num_param)
        for k, zk in enumerate(z.view(-1)):
            self.student.zero_grad()
            zk.backward(retain_graph = True)
            H[k,:] = self.vectorize_grad()
        return z, H
    
    def Update(self, x, y):
        ## prediction step
        # it is assumed that the weights are doing a random walk, so just update the covariance
        self.C += self.RW * torch.eye(self.num_param)
        
        ## update step
        z, H = self.Jacobian(x)
        
        error = y - z
        HP = torch.mm(H, self.C)
        HPH = torch.mm(HP, H.t()) + self.obs_noise * torch.eye(torch.numel(y))
        update = torch.solve(error.view(-1,1), HPH)[0].view(-1)
        mean = torch.nn.utils.parameters_to_vector(self.student.parameters()) + torch.mm(HP.t(),update.view(-1,1)).view(-1)
        update2 = torch.solve(HP, HPH)[0]
        self.C -= torch.mm(HP.t(), update2)
        self.C = (self.C + self.C.t())/2
        torch.nn.utils.vector_to_parameters(mean, self.student.parameters())# this converts the parameters 

The trainer is defined, we now declare the class for the student and teacher networks.

In [3]:
K = [5,10,10,5] # architecture 

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer1 = nn.Linear(K[0],K[1])        
        self.layer2 = nn.Linear(K[1],K[2])
        self.layer3 = nn.Linear(K[2],K[3])

    def forward(self, x):
        x = self.layer1(x)
        x = F.softsign(x)
        x = self.layer2(x)
        x = F.softsign(x)
        x = self.layer3(x)
        return x

Instantiate the teacher and student networks.

In [4]:
teacher = Net()
student = Net()

Instantiate the trainer. Note that we provide the student network to the trainer. This is similar to providing the student NN to a Pytorch optimizer.

In [5]:
total = torch.nn.utils.parameters_to_vector(student.parameters())
params = {}

# initial value of the covariance of the weights
params['Cov'] = 1e-5 * torch.eye(torch.numel(total)) 

# the variance defining the random walk of the weights
params['state_noise'] = 1e-8 

# noise variance for the observations. 
# The observation noise is actually zero since the student has the capability to perfectly mimic the teacher
params['obs_noise'] = 1e-5 

train_st = Trainer(student, params)

Train the student network for MAX_ITER steps. Here, we input single instances of input-output pairs.

In [6]:
MAX_ITER = 1000
test_size = 100
test_in = torch.FloatTensor(np.random.normal(0,1,(K[0],test_size)))
for iter in range(MAX_ITER):
    # generate a random input, and compute the output using the teacher
    x = torch.FloatTensor(np.random.normal(0,1,K[0]))
    x.requires_grad = False
    with torch.no_grad():
        y = teacher(x)

    train_st.Update(x, y)
    
    # check mse
    if np.mod(iter, 100) == 0:
        with torch.no_grad():
            y = teacher(test_in.t())
            z = student(test_in.t())
        mse = sum((y - z).view(-1)**2 / test_size)
        print("log mse: {}".format(torch.log10(mse)))

log mse: -1.1157795190811157
log mse: -2.36008358001709
log mse: -2.557671308517456
log mse: -2.6477041244506836
log mse: -2.6884725093841553
log mse: -2.737579822540283
log mse: -2.7889180183410645
log mse: -2.8262202739715576
log mse: -2.8347623348236084
log mse: -2.877030611038208


Train the student network for MAX_ITER steps. Here, we input multiple input-output pairs -- this is similar to using batches.

In [7]:
MAX_ITER = 1000
batch = 10
test_size = 100
test_in = torch.FloatTensor(np.random.normal(0,1,(K[0],test_size)))
for iter in range(MAX_ITER):
    # generate a random input
    x = torch.FloatTensor(np.random.normal(0,1,(K[0], batch)))
    x.requires_grad = False
    with torch.no_grad():
        y = teacher(x.t())

    train_st.Update(x.t(), y)
    # check mse
    if np.mod(iter, 100) == 0:
        with torch.no_grad():
            y = teacher(test_in.t())
            z = student(test_in.t())
        mse = sum((y - z).view(-1)**2 / test_size)
        print("log mse: {}".format(torch.log10(mse)))

log mse: -2.8665785789489746
log mse: -3.0722601413726807
log mse: -3.1949334144592285
log mse: -3.2839367389678955
log mse: -3.373305559158325
log mse: -3.475947380065918
log mse: -3.530769109725952
log mse: -3.57464599609375
log mse: -3.6186318397521973
log mse: -3.6387879848480225


#### Ilker Bayram, ibayram@ieee.org