In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import math

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
device

'cpu'

In [3]:
MU = 0.5
NU = 1.0
ETA = 0.05
STEPS = 10

In [4]:
INPUT_CHANNELS = 1

In [5]:
class PCInputLayer(nn.Module):
    def __init__(self, in_channels, in_width, in_height, out_channels, kernel_size, stride_out=1, padding=0, bias=False):
        super().__init__()
        self.in_channels = in_channels
        self.in_width = in_width
        self.in_height = in_height
#         self.out_channels = out_channels
#         self.kernel_size = kernel_size
#         self.stride = stride
#         self.padding = padding
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride_out, padding=padding, bias=bias)

        
    def init_vars(self):
        e = torch.zeros((self.in_channels, self.in_width, self.in_height)).to(device)
        return e
        
    def step(self, x, td_pred):
        e = F.relu(x-td_pred)
        return self.conv(e), e

In [6]:
class PCHiddenLayer(nn.Module):
    def __init__(self, in_channels, in_width, in_height, prev_channels, out_channels, kernel_size, stride_in=1, stride_out=1, padding=0, bias=False):
        super().__init__()
        self.in_channels = in_channels
        self.in_width = in_width
        self.in_height = in_height
#         self.out_channels = out_channels
#         self.kernel_size = kernel_size
#         self.stride = stride
#         self.padding = padding
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride_out, padding=padding, bias=bias)        
        self.convT = nn.ConvTranspose2d(in_channels=in_channels, out_channels=prev_channels, kernel_size=kernel_size, stride=stride_in, padding=padding, bias=bias)
        
        
    def init_vars(self):
        e = torch.zeros((self.in_channels, self.in_width, self.in_height)).to(device)
        r = torch.zeros((self.in_channels, self.in_width, self.in_height)).to(device)
        return r, e
    
    def pred(self, r):
        td_pred = F.tanh(self.convT(r))
        return td_pred
        
    def step(self, bu_err, r, e, td_pred):
        r = NU*r 
        r = r + MU*bu_err 
        r = r - ETA*e
        e = F.relu(r-td_pred)
        return self.conv(e), r, e

In [7]:
class PCFinalLayer(nn.Module):
    def __init__(self, in_channels, in_width, in_height, prev_channels, out_channels, kernel_size, stride_in=1, stride_out=1, padding=0, bias=False):
        super().__init__()
        self.in_channels = in_channels
        self.in_width = in_width
        self.in_height = in_height
#         self.out_channels = out_channels
#         self.kernel_size = kernel_size
#         self.stride = stride
#         self.padding = padding
        
        self.convT = nn.ConvTranspose2d(in_channels, prev_channels, kernel_size=kernel_size, stride=stride_in, padding=padding, bias=bias)
        
        
    def init_vars(self):
        r = torch.zeros((self.in_channels, self.in_width, self.in_height)).to(device)
        return r
    
    def pred(self, r):
        return F.tanh(self.convT(r))
        
    def step(self, bu_err, r):
        r = NU*r + MU*bu_err
        return r

In [8]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.in_layer = PCInputLayer(1, 27, 27, 32, (3,3), stride_out=3)
        self.hid0_layer = PCHiddenLayer(32, 9, 9, 1, 128, (3,3), stride_in=3, stride_out=1)
        self.hid1_layer = PCHiddenLayer(128, 7, 7, 32, 256, (3,3), stride_in=1, stride_out=2)
        self.fin_layer = PCFinalLayer(256, 3, 3, 128, 10, (3,3), stride_in=2)
        
    def forward(self, x):
        in_e = self.in_layer.init_vars()
        hid0_r, hid0_e = self.hid0_layer.init_vars()
        hid1_r, hid1_e = self.hid1_layer.init_vars()
        fin_r = self.fin_layer.init_vars()
        
        for _ in range(STEPS):
            in_bu_e, in_e = self.in_layer.step(x, self.hid0_layer.pred(hid0_r))
            hid0_bu_e, hid0_r, hid0_e = self.hid0_layer.step(in_bu_e, hid0_r, hid0_e, self.hid1_layer.pred(hid1_r))
            hid1_bu_e, hid1_r, hid1_e = self.hid1_layer.step(hid0_bu_e, hid1_r, hid1_e, self.fin_layer.pred(fin_r))
            fin_r = self.fin_layer.step(hid1_bu_e, fin_r)
    
        in_err = in_e.square().sum()/in_e.numel()
        hid0_err = hid0_e.square().sum()/hid0_e.numel()
        hid1_err = hid1_e.square().sum()/hid1_e.numel()
        total_mean_sqr_err = in_err + hid0_err + hid1_err
        return total_mean_sqr_err

In [9]:
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, shuffle=True)
test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, shuffle=True)

In [18]:
model = Model().to(device)

In [19]:
model.load_state_dict(torch.load('PC-RB_Conv2D.pth'))
model.eval()

Model(
  (in_layer): PCInputLayer(
    (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(3, 3), bias=False)
  )
  (hid0_layer): PCHiddenLayer(
    (conv): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (convT): ConvTranspose2d(32, 1, kernel_size=(3, 3), stride=(3, 3), bias=False)
  )
  (hid1_layer): PCHiddenLayer(
    (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (convT): ConvTranspose2d(128, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
  )
  (fin_layer): PCFinalLayer(
    (convT): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)
  )
)

In [34]:
LEARNING_RATE = 3e-7
NUM_EPOCHS = 5
optimiser = optim.SGD(model.parameters(), lr=LEARNING_RATE)
mean_loss = 0

model.train()
for epoch in range(NUM_EPOCHS):
    for batch_idx, (data, y) in enumerate(train_loader):
        x = data[:,:,:27,:27].reshape((1,27,27)).to(device)
        loss = model(x)
        
        loss.backward(retain_graph=True)
        mean_loss += loss
        if batch_idx % 512 == 0:
            if batch_idx % 8192 == 0:
                print(f'mean_loss: {mean_loss/512}')
            mean_loss = 0
            optimiser.step()
            optimiser.zero_grad()
     

mean_loss: 0.00012998381862416863
mean_loss: 0.12273669242858887
mean_loss: 0.12055642902851105
mean_loss: 0.12314210087060928
mean_loss: 0.12240884453058243
mean_loss: 0.12397337704896927
mean_loss: 0.11964358389377594
mean_loss: 0.12050309032201767
mean_loss: 0.021882830187678337
mean_loss: 0.12059059739112854
mean_loss: 0.1163795217871666
mean_loss: 0.11759209632873535
mean_loss: 0.11873355507850647
mean_loss: 0.11686024814844131
mean_loss: 0.11771158874034882
mean_loss: 0.11639812588691711
mean_loss: 0.020974311977624893
mean_loss: 0.11555317044258118
mean_loss: 0.11647531390190125
mean_loss: 0.11258436739444733
mean_loss: 0.11207454651594162
mean_loss: 0.11430878937244415
mean_loss: 0.1108192428946495
mean_loss: 0.11045262962579727
mean_loss: 0.02032107673585415
mean_loss: 0.11010843515396118
mean_loss: 0.10855016112327576
mean_loss: 0.1086776852607727
mean_loss: 0.10940124094486237
mean_loss: 0.10827821493148804
mean_loss: 0.10806307941675186


KeyboardInterrupt: 

In [35]:
torch.save(model.state_dict(), 'PC-RB_Conv2D.pth')

In [36]:
LEARNING_RATE = 3e-7
NUM_EPOCHS = 5
optimiser = optim.SGD(model.parameters(), lr=LEARNING_RATE)
mean_loss = 0

model.eval()
for epoch in range(NUM_EPOCHS):
    for batch_idx, (data, y) in enumerate(train_loader):
        x = data[:,:,:27,:27].reshape((1,27,27)).to(device)
        print(model.fin_r)
        loss = model(x)
        
        print(model.fin_r)
        break
    break

AttributeError: 'Model' object has no attribute 'fin_r'