In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import math

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [15]:
BATCH_SIZE = 1024
train_dataset = datasets.MNIST(root='../../datasets/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE)
test_dataset = datasets.MNIST(root='../../datasets/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, shuffle=True, batch_size=BATCH_SIZE)
INPUT_CHANNELS = 1
NUM_CLASSES = 10

In [16]:
BATCH_SIZE = 500
train_val_dataset = datasets.CIFAR10(root='../../datasets/', train=True, transform=transforms.ToTensor(), download=True)
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [45000, 5000])
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE)
val_loader = DataLoader(dataset=val_dataset, shuffle=True, batch_size=BATCH_SIZE)
test_dataset = datasets.CIFAR10(root='../../datasets/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, shuffle=True, batch_size=BATCH_SIZE)
INPUT_CHANNELS = 3
NUM_CLASSES = 10

Files already downloaded and verified
Files already downloaded and verified


In [17]:
MU = 0.70
NU = 1.0
ETA = 0.1
STEPS = 5

In [18]:
class PCInputLayer(nn.Module):
    def __init__(self, channels, next_channels, width, height, kernel=(3,3), stride=1, padding='same', bias=False, pool=1, forw_actv=nn.ReLU()):
        super().__init__()
        self.channels = channels
        self.width = width
        self.height = height
        self.forw_actv = forw_actv
        
        self.conv = nn.Sequential(
            nn.Conv2d(channels, next_channels, kernel, stride=stride, padding=padding, bias=bias),
#             nn.Dropout(0.2),
#             nn.BatchNorm2d(next_channels),
        )
        
        self.pool = nn.MaxPool2d(kernel_size=pool)
        
    def init_vars(self, batch_size):
        e = torch.zeros((batch_size, self.channels, self.width, self.height)).to(device)
        return e
        
    def pool(self, x):
        return self.pool(x)
    
    def step(self, x, td_pred):
        e = self.forw_actv(x-td_pred)
        return self.conv(e), e

In [19]:
class PCHiddenLayer(nn.Module):
    def __init__(self, prev_channels, channels, next_channels, width, height,  prev_kernel=(3,3), kernel=(3,3), prev_stride=1, stride=1, prev_padding=1, padding='same', bias=False, pool=1, upsample=1, back_actv=nn.Tanh(), forw_actv=nn.ReLU()):
        super().__init__()
        self.channels = channels
        self.width = width
        self.height = height
        self.back_actv = back_actv
        self.forw_actv = forw_actv
        
        self.conv = nn.Sequential(
            nn.Conv2d(channels, next_channels, kernel, stride=stride, padding=padding, bias=bias),
#             nn.Dropout(0.2),
#             nn.BatchNorm2d(next_channels),
        )
        self.pool = nn.MaxPool2d(kernel_size=pool)
        self.upsample = nn.Upsample(scale_factor=upsample)
        self.convT = nn.Sequential(
            nn.ConvTranspose2d(channels, prev_channels, prev_kernel, stride=prev_stride, padding=prev_padding, bias=bias),
#             nn.Dropout(0.2),
#             nn.BatchNorm2d(prev_channels),
        )
        
    def init_vars(self, batch_size):
        e = torch.zeros((batch_size, self.channels, self.width, self.height)).to(device)
        r = torch.zeros((batch_size, self.channels, self.width, self.height)).to(device)
        return r, e
    
    def pred(self, r):
        td_pred = self.back_actv(self.convT(r))
        td_pred = self.upsample(td_pred)
        return td_pred
        
    def step(self, bu_err, r, e, td_pred):
        r = NU*r + MU*bu_err - ETA*e
        e = self.forw_actv(r-td_pred)
        return self.pool(self.conv(e)), r, e

In [20]:
class PCFinalLayer(nn.Module):
    def __init__(self, prev_channels, channels, width, height, prev_kernel=(3,3), prev_stride=1, prev_padding=1, bias=False, upsample=1, back_actv=nn.Tanh()):
        super().__init__()
        self.channels = channels
        self.width = width
        self.height = height
        self.back_actv = back_actv
        
        self.upsample = nn.Upsample(scale_factor=upsample)
        self.convT = nn.Sequential(
            nn.ConvTranspose2d(channels, prev_channels, prev_kernel, stride=prev_stride, padding=prev_padding, bias=bias),
#             nn.Dropout(0.2),
#             nn.BatchNorm2d(prev_channels),
        )
        
    def init_vars(self, batch_size):
        r = torch.zeros((batch_size, self.channels, self.width, self.height)).to(device)
        return r
    
    def pred(self, r):
        td_pred = self.back_actv(self.convT(r))
        td_pred = self.upsample(td_pred)
        return td_pred
        
    def step(self, bu_err, r, target=None):
        r = NU*r + MU*bu_err
        if target is not None:
            e = r - target
            r = r - ETA*e
        return r

In [23]:
class PCModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.in_layer = PCInputLayer(INPUT_CHANNELS, 64, 32, 32, 1, forw_actv=nn.ReLU())
        self.hid0_layer = PCHiddenLayer(INPUT_CHANNELS, 64, 64, 32, 32, prev_padding=1, pool=2, forw_actv=nn.ReLU())
        
        self.hid1_layer = PCHiddenLayer(64, 64, 128, 16, 16, prev_padding=1, upsample=2, forw_actv=nn.ReLU())
        self.hid2_layer = PCHiddenLayer(64, 128, 128, 16, 16, prev_padding=1, pool=2, forw_actv=nn.ReLU())
        
        self.hid3_layer = PCHiddenLayer(128, 128, 256, 8, 8, prev_padding=1, upsample=2, forw_actv=nn.ReLU())
        self.hid4_layer = PCHiddenLayer(128, 256, 256, 8, 8, prev_padding=1, forw_actv=nn.ReLU())  
        self.hid5_layer = PCHiddenLayer(128, 256, 256, 8, 8, prev_padding=1, forw_actv=nn.ReLU())  
        self.fin_layer = PCFinalLayer(256, 256, 8, 8, prev_padding=1)
        
        self.fc_layer = nn.Linear(256*8*8, 10)
        
    def forward(self, x, target=None):
        batch_size = x.shape[0]
        in_e = self.in_layer.init_vars(batch_size)
        hid0_r, hid0_e = self.hid0_layer.init_vars(batch_size)
        hid1_r, hid1_e = self.hid1_layer.init_vars(batch_size)
        hid2_r, hid2_e = self.hid2_layer.init_vars(batch_size)
        hid3_r, hid3_e = self.hid3_layer.init_vars(batch_size)
        hid4_r, hid4_e = self.hid4_layer.init_vars(batch_size)
        hid5_r, hid5_e = self.hid4_layer.init_vars(batch_size)
        fin_r = self.fin_layer.init_vars(batch_size)
        
        for _ in range(STEPS):
            in_bu_e, in_e = self.in_layer.step(x, self.hid0_layer.pred(hid0_r))
            hid0_bu_e, hid0_r, hid0_e = self.hid0_layer.step(in_bu_e, hid0_r, hid0_e, self.hid1_layer.pred(hid1_r))
            hid1_bu_e, hid1_r, hid1_e = self.hid1_layer.step(hid0_bu_e, hid1_r, hid1_e, self.hid2_layer.pred(hid2_r))
            hid2_bu_e, hid2_r, hid2_e = self.hid2_layer.step(hid1_bu_e, hid2_r, hid2_e, self.hid3_layer.pred(hid3_r))
            hid3_bu_e, hid3_r, hid3_e = self.hid3_layer.step(hid2_bu_e, hid3_r, hid3_e, self.hid4_layer.pred(hid4_r))
            hid4_bu_e, hid4_r, hid4_e = self.hid4_layer.step(hid3_bu_e, hid4_r, hid4_e, self.fin_layer.pred(fin_r))
            fin_r = self.fin_layer.step(hid4_bu_e, fin_r, target)

        out = self.fc_layer(torch.flatten(F.relu(fin_r), start_dim=1))
        
        in_err = in_e.square().sum()/in_e.numel()
        hid0_err = hid0_e.abs().sum()/hid0_e.numel()
        hid1_err = hid1_e.abs().sum()/hid1_e.numel()
        hid2_err = hid2_e.abs().sum()/hid2_e.numel()
        hid3_err = hid3_e.abs().sum()/hid3_e.numel()
        hid4_err = hid4_e.abs().sum()/hid4_e.numel()
        emag = in_err + hid0_err + hid1_err + hid2_err + hid3_err + hid4_err
        
        return [out, emag]
model = PCModel().to(device)

In [22]:
class CNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        # [Batch_size x Channels x 32 x 32]
        self.network = nn.Sequential(
            
            # [Batch_size x 64 x 32 x 32]
            nn.Conv2d(INPUT_CHANNELS, 64, (3,3), padding='same'),
            nn.Dropout(0.1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, (3,3), padding='same'),
            nn.Dropout(0.1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.MaxPool2d(kernel_size=2), # [Batch_size x 128 x 16 x 16]
            nn.Conv2d(64, 128, (3,3), padding='same'),
            nn.Dropout(0.1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, (3,3), padding='same'),
            nn.Dropout(0.1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.MaxPool2d(kernel_size=2), # [Batch_size x 256 x 8 x 8]
            nn.Conv2d(128, 256, (3,3), padding='same'),
            nn.Dropout(0.1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, (3,3), padding='same'),
            nn.Dropout(0.1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, (3,3), padding='same'),
            nn.Dropout(0.1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
#             nn.MaxPool2d(kernel_size=2), # [Batch_size x 512 x 4 x 4]
#             nn.Conv2d(256, 512, (3,3), padding='same'),
#             nn.BatchNorm2d(512),
#             nn.ReLU(),
#             nn.Conv2d(512, 512, (3,3), padding='same'),
#             nn.BatchNorm2d(512),
#             nn.ReLU(),
#             nn.Conv2d(512, 512, (3,3), padding='same'),
#             nn.BatchNorm2d(512),
#             nn.ReLU(),
            
#             nn.MaxPool2d(kernel_size=2), # [Batch_size x 512 x 2 x 2]
#             nn.Conv2d(512, 512, (3,3), padding='same'),
#             nn.BatchNorm2d(512),
#             nn.ReLU(),
#             nn.Conv2d(512, 512, (3,3), padding='same'),
#             nn.BatchNorm2d(512),
#             nn.ReLU(),
#             nn.Conv2d(512, 512, (3,3), padding='same'),
#             nn.BatchNorm2d(512),
#             nn.ReLU(),
            
#             nn.MaxPool2d(kernel_size=2), # [Batch_size x 512 x 1 x 1]
        )
        
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 8 * 8, 256 * 8 * 8),
            nn.ReLU(),
            nn.Linear(256 * 8 * 8, NUM_CLASSES),
        )

    def forward(self, x):
        out = self.network(x)
        out = torch.flatten(out, start_dim=1)
        out = self.fc(out)
        return [out]
    
model = CNNModel().to(device)

In [24]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
    
    N, inp_channels, H, W, NUM_CLASSES = 8, 3, 32, 32, 10
    z_dim = 100
    x = torch.randn((N, inp_channels, H, W)).to(device)
    cnn = CNNModel().to(device)
    initialize_weights(cnn)
    assert cnn(x)[0].shape == (N,10)
    
    pccnn = PCModel().to(device)
    initialize_weights(pccnn)
    assert pccnn(x)[0].shape == (N,10)
    
test()

In [25]:
def evaluate(model, with_targets=False):
    model.eval()
    n = 0
    correct = 0
    for batch_idx, (data, y) in enumerate(test_loader):
#         x = data[:,:,:27,:27].to(device)
        x = data.to(device)
        batch_size = data.shape[0]
        n = n + batch_size
        target = None
        if with_targets:
            target = F.one_hot(y).unsqueeze(2).unsqueeze(3).to(device)
        out = model(x, target)
        output = torch.argmax(out[0], dim=1)
        y = y.to(device)
        batch_correct = output == y
        correct = correct + batch_correct.sum()

    print(f'Test Accuracy: {correct/n}')

In [None]:
model.load_state_dict(torch.load('PC-RB_Conv2D-V2.pth'))
model.eval()

In [26]:
model_name = 'PCCNN_wdecay01'
writer = SummaryWriter(f"logs/{model_name}")

In [None]:
LEARNING_RATE = 3e-4
NUM_EPOCHS = 100
WEIGHT_DECAY = 0.01

# optimiser = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
optimiser = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss().to(device)
n = 0
correct = 0
val_loss = 0
model.train()
for epoch in range(NUM_EPOCHS):
    
    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
    
    train_loss = 0
    train_batches = 0
    train_emag = 0
    
    for batch_idx, (data, y) in loop:
#         x = data[:,:,:27,:27].to(device)
        x = data.to(device)
        target = F.one_hot(y).double().to(device)
        out = model(x)
        loss = criterion(out[0], target)
        if len(out) > 1:
            loss += out[1]
        
        loss.backward(retain_graph=True)
        optimiser.step()
        optimiser.zero_grad()
        
        with torch.no_grad():
            output = torch.argmax(out[0], dim=1)
            arg_correct = output == y.to(device)
            correct = correct + arg_correct.sum()
            n = n + x.shape[0]
            
            train_loss += loss.item()
            train_batches += 1
            if len(out) > 1:
                train_emag += out[1].item()
            
            # Update progress bar
            loop.set_description(f"Epoch [{epoch}/{NUM_EPOCHS}]")
            loop.set_postfix(loss = loss.item(), val_loss = val_loss, train_acc=(correct/n).item())
    
    # Calculate validation_acc
    with torch.no_grad():
        val_emag = 0
        val_n = 0
        val_correct = 0
        val_batches = 0
        for batch_idx, (data, y) in enumerate(val_loader):
            x = data.to(device)
            batch_size = data.shape[0]
            val_n += batch_size
            out = model(x)
            output = torch.argmax(out[0], dim=1)
            y = y.to(device)
            batch_correct = output == y
            val_correct += batch_correct.sum()
            val_emag += out[1].item()
            val_batches += 1
            
        train_loss /= train_batches
        val_acc = val_correct/val_n
        
        # Update logs
        writer.add_scalar('training loss', train_loss, epoch*train_batches*BATCH_SIZE)
        writer.add_scalar('validation accuracy', val_acc, epoch*train_batches*BATCH_SIZE)
        
        if len(out) > 1:
            train_emag /= train_batches
            val_emag /= val_batches
            writer.add_scalar('training emag', train_emag, epoch*train_batches*BATCH_SIZE)
            writer.add_scalar('validation emag', val_emag, epoch*train_batches*BATCH_SIZE)
        
    

Epoch [47/100]:  64%|███████████████▍        | 58/90 [00:27<00:15,  2.11it/s, loss=0.00708, train_acc=0.85, val_loss=0]

In [None]:
evaluate(model)

In [None]:
torch.save(model.state_dict(), 'PC-RB_Conv2D-V2.pth')

In [None]:
LEARNING_RATE = 3e-7
NUM_EPOCHS = 5
optimiser = optim.SGD(model.parameters(), lr=LEARNING_RATE)
mean_loss = 0

model.eval()
for epoch in range(NUM_EPOCHS):
    for batch_idx, (data, y) in enumerate(train_loader):
        x = data[:,:,:27,:27].to(device)
        out = model(x)
        break
    break

In [None]:
x = torch.tensor([1, 2, 3, 4, 1, 5,1, 9, 5, 8, 0, 2, 6, 5, 8, 2])
x = F.argmax(x)
x