In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [14]:
class Layer(nn.Module):
    def __init__(self, size, prev_size=None, step_size=0.1, func=F.sigmoid, lateral=True):
        super(Layer, self).__init__()
        self.size = size
        self.prev_size = prev_size
        self.first_layer = prev_size is None
        self.step_size = step_size
        self.func = func

        if not self.first_layer:
            self.bottom_up = torch.sigmoid(torch.randn(self.size, self.prev_size))
            self.top_down = torch.sigmoid(torch.randn(self.prev_size, self.size))

        if lateral:
            self.lateral = torch.sigmoid(torch.randn(self.size, self.size))
        else:
            self.lateral = None

    def init_state(self, batch_size=1):
        if not self.first_layer:
            self.belief = torch.randn(batch_size, self.size)
        self.error = torch.randn(batch_size, self.size)

    def get_prediction(self):
        if self.first_layer:
            raise Exception("First layer has no prediction")
        return self.func(F.linear(self.belief, self.top_down))

    def forward(self, inp, pred=None): # inp is obs if first layer, else is error from layer below
        if self.first_layer:
            self.belief = inp
        else:
            b_update = F.linear(inp, self.bottom_up)
            self.belief -= self.step_size * (b_update + self.error)

        if pred is None:
            pred = torch.randn(self.size())
        self.error = self.belief - pred 
        if self.lateral is not None:
            self.error -= F.linear(self.error, self.lateral)

In [15]:
class Network(nn.Module):
    def __init__(self, sizes, steps=20):
        super(Network, self).__init__()
        self.sizes = sizes
        self.n_layers = len(sizes)
        if self.n_layers < 2:
            raise ValueError("At least two layers are needed")
        self.steps = steps

        layers = []
        layers.append(Layer(sizes[0]))
        for i in range(1, self.n_layers):
            layers.append(Layer(sizes[i], sizes[i-1]))
        self.layers = nn.ModuleList(layers)
        

    def forward(self, x, steps=None, lr=None):
        if steps is None:
            steps = self.steps

        for l in self.layers:
            l.init_state(x.shape[0])

        for _ in range(steps):
            for i, l in enumerate(self.layers):
                if i < self.n_layers - 1:
                    l(x, self.layers[i+1].get_prediction())
                else:
                    l(x)
            
            if lr is not None:
                for i, l in enumerate(self.layers):
                    if l.lateral is not None:
                        raise NotImplementedError("Lateral connections not implemented yet")
                    if i > 0:
                        l.bottom_up += lr * (l.belief.T @ self.layers[i-1].error)
                        l.top_down += lr * (self.layers[i-1].error.T @ l.belief)

                

In [None]:
class Network(nn.Module):
    def __init__(self, sizes, steps=20):
        super(Network, self).__init__()
        self.sizes = sizes
        self.n_layers = len(sizes)
        if self.n_layers < 2:
            raise ValueError("At least two layers are needed")
        self.steps = steps

        layers = []
        layers.append(Layer(sizes[0]))
        for i in range(1, self.n_layers):
            layers.append(Layer(sizes[i], sizes[i-1]))
        self.layers = nn.ModuleList(layers)
        

    def forward(self, x, steps=None, lr=None):
        if steps is None:
            steps = self.steps

        for l in self.layers:
            l.init_state(x.shape[0])

        for _ in range(steps):
            for i, l in enumerate(self.layers):
                if i < self.n_layers - 1:
                    l(x, self.layers[i+1].get_prediction())
                else:
                    l(x)
            
            if lr is not None:
                for i, l in enumerate(self.layers):
                    if l.lateral is not None:
                        raise NotImplementedError("Lateral connections not implemented yet")
                    if i > 0:
                        l.bottom_up += lr * (l.belief.T @ self.layers[i-1].error)
                        l.top_down += lr * (self.layers[i-1].error.T @ l.belief)

                