In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import math
from torch.autograd import Variable

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [101]:
MU = 0.5
NU = 1.0
ETA = 0.05
STEPS = 10

In [102]:
class PCInputLayer(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.size = size
        
    def init_vars(self):
        e = torch.zeros((self.size, 1)).to(device)
        return e
        
    def step(self, x, td_pred):
        return x - td_pred

In [103]:
class PCLayer(nn.Module):
    def __init__(self, size_prev, size):
        super().__init__()
        self.size, self.size_prev = size, size_prev
        
        U = torch.zeros((size_prev, size)).to(device)
        self.U = nn.Parameter(U)
        nn.init.kaiming_uniform_(self.U, a=25) # <=== To Revisit
        
#         V = torch.zeros((size, size_prev)).to(device)
#         self.V = nn.Parameter(V)
#         nn.init.kaiming_uniform_(self.V, a=25) # <=== To Revisit
        
    def init_vars(self):
        r = torch.zeros((self.size, 1)).to(device)
        e = torch.zeros((self.size, 1)).to(device)
        return r, e
        
    def pred(self, r):
        return F.relu(torch.mm(self.U, r))

    def step(self, e_inf, r, e, td_pred):
        r = NU*r + MU*torch.mm(self.U.t(),e_inf) - ETA*e
        e = r - td_pred      
        return r, e

In [157]:
class Model(nn.Module):
    def __init__(self, input_size, h1_size, h2_size, num_classes):
        super().__init__()
        self.input_size = input_size
        self.pc0 = PCInputLayer(input_size)
        self.pc1 = PCLayer(input_size, h1_size)
        self.pc2 = PCLayer(h1_size, h2_size)
        self.pc3 = PCLayer(h2_size, num_classes)
    
    def train(self, x, targets, debug=False):
        pc0_e = self.pc0.init_vars()
        pc1_r, pc1_e = self.pc1.init_vars()
        pc2_r, pc2_e = self.pc2.init_vars()
        pc3_r, pc3_e = self.pc3.init_vars()
        
        for _ in range(STEPS):
            pc0_e = self.pc0.step(x, self.pc1.pred(pc1_r))
            pc1_r, pc1_e = self.pc1.step(pc0_e, pc1_r, pc1_e, self.pc2.pred(pc2_r))
            pc2_r, pc2_e = self.pc2.step(pc1_e, pc2_r, pc2_e, self.pc3.pred(pc3_r))
            pc3_r, pc3_e = self.pc3.step(pc2_e, pc3_r, pc3_e, targets)
        
        if debug:
            print("printing pc3_r....")
            print(pc3_r)
            print("printing pc3_e...")
            print(pc3_e)
            
        pc0_err = pc0_e.square().sum()/self.pc0.size
        pc1_err = pc1_e.square().sum()/self.pc1.size
        pc2_err = pc2_e.square().sum()/self.pc2.size
        pc3_err = pc3_e.square().sum()/self.pc3.size
            
        total_sqr_err =  pc0_err + pc1_err + pc2_err + 10*pc3_err
        return total_sqr_err

In [158]:
INPUT_SIZE = 784
H1_SIZE = 784
H2_SIZE = 784
NUM_CLASSES = 10

In [159]:
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, shuffle=True)
test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, shuffle=True)

In [160]:
model = Model(INPUT_SIZE, H1_SIZE, H2_SIZE, NUM_CLASSES).to(device)

In [161]:
LEARNING_RATE = 0.0000001
NUM_EPOCHS = 1

optimiser = optim.SGD(model.parameters(), lr=LEARNING_RATE)

mean_loss = 0


for epoch in range(NUM_EPOCHS):
    for batch_idx, (data, y) in enumerate(train_loader):
        
        x = data.reshape((-1, 1)).to(device)
        targets = torch.zeros((NUM_CLASSES, 1)).to(device)
        targets[y[0]] = 1

        loss = model.train(x, targets, debug=False)
#         print("targets: ", targets)
#         print(loss)
        
        loss.backward(retain_graph=True)
        
        
        mean_loss += loss
        if batch_idx % 64 == 0:
            print("mean_loss:",mean_loss / 64)
            mean_loss = 0
            optimiser.step()
            optimiser.zero_grad()

mean_loss: tensor(0.0099, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.6074, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.5953, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.6075, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.6067, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.6046, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.6105, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.6115, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.6069, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.6089, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.6076, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.6022, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.5921, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.6040, device='cuda:0', grad_fn=<DivBackward0>)
mean_loss: tensor(0.5997, device='cuda:0', grad_fn=<DivBackwar

KeyboardInterrupt: 