In [1]:
import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision.models as models
import torch.autograd as autograd

### Loading Data

In [2]:
from torchvision import datasets, transforms
BATCH_SIZE = 1
FEATURE_SIZE = 784

train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
               transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))])),batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
               ])),batch_size=BATCH_SIZE, shuffle=True)

for batch_idx, (data, target) in enumerate(train_loader):
    input_stimulus = data.view(-1,FEATURE_SIZE)

### Custom Forward and Backward

In [3]:
class Forward(torch.autograd.Function):
    def forward(ctx, hidden_state, feed_forward, context=None):
        if context is not None:
            #print(feed_forward.shape, context.shape,hidden_state.shape)
            state = feed_forward * context
            result = torch.mm(state,hidden_state)
            
        else:
            result = torch.mm(feed_forward,hidden_state)
        ctx.hidden_state = hidden_state
        return result

    def backward(ctx, grad_forward):
        #print('Backward',ctx.hidden_state)
        return ctx.hidden_state,grad_forward

class Backward(torch.autograd.Function):
    def forward(ctx, hidden_state, feed_forward, context=None):
        if context is not None:
            state = feed_forward * context
            result = torch.mm(state,hidden_state)
            
        else:
            result = torch.mm(feed_forward,hidden_state)
        ctx.hidden_state = hidden_state
        return result

    def backward(ctx, grad_forward):
        #print('Backward',ctx.hidden_state)
        return ctx.hidden_state,grad_forward

### Contextual Feedback Network

In [6]:
class CF(nn.Module):
    def __init__(self, feature_size, hidden_size=100, dtype=torch.FloatTensor, num_layers=5,weight_transport=True):
        super(CF, self).__init__()
        self.weight_transport = weight_transport
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # forward weights: 1,2,3,4,5
        self.forward_params = nn.ParameterList([nn.Parameter(torch.randn(feature_size, hidden_size).type(dtype))] +\
                                               [nn.Parameter(torch.randn(hidden_size, hidden_size).type(dtype)) for _ in range(num_layers-1)])
        if weight_transport: # using same weights for forward and backward pass
            self.backward_params = self.forward_params
        else:
            # backward weights: 5,4,3,2,1
            self.backward_params = nn.ParameterList([nn.Parameter(torch.randn(hidden_size, hidden_size).type(dtype)) for _ in range(num_layers-1)]+\
                                [nn.Parameter(torch.randn(hidden_size, feature_size).type(dtype))])
        
        self.backward_context = []
        self.prior_activities = []
        self.current_activities = []
        #self.cell = nn.GRUCell(input_size=feature_size, hidden_size=hidden_size)
        #self.layer_1 = Forward.apply
        
    def reset_hidden(self):
        self.hidden_state = torch.zeros((self.batch_size, self.hidden_size))
        
    def forward(self, feed_forward):
        for seq_idx in range(feed_forward.shape[0]):
            forward = feed_forward[seq_idx].view(-1,784)
            
            # going up
            for layer_idx in range(self.num_layers):
                layer = self.forward_params[layer_idx]
                if seq_idx != 0:
                    forward = Forward()(layer, forward, self.backward_context[layer_idx])
                else:
                    forward = Forward()(layer, forward)

                self.current_activities.append(forward)

            backward = forward
            if self.weight_transport:
                start,stop,step = self.num_layers-1,0,-1
            else:
                start,stop,step = 0,self.num_layers,1
            
            # going down
            for layer_idx in range(start,stop,step):
                layer = self.backward_params[layer_idx]
                backward = Backward()(layer, backward)
                self.backward_context.insert(0,backward)
                
            self.prior_activities = self.current_activities
            
        #self.batch_size = feed_forward.size(1)
        #self.reset_hidden()
        return backward

model = CF(FEATURE_SIZE)
z = model(input_stimulus)

In [7]:
from graphviz import Digraph
# make_dot was moved to https://github.com/szagoruyko/pytorchviz
from torchviz import make_dot
d = make_dot(model(input_stimulus), params=dict(model.named_parameters()))
d.render(filename='one.png')

'one.png.pdf'