In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import math
from torch.autograd import Variable
from torch.nn.modules.conv import _ConvNd
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from torch.nn.modules.utils import _single, _pair, _triple, _reverse_repeat_tuple
from typing import Optional, List, Tuple, Union
from torch import Tensor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
MU = 0.5
NU = 1.0
ETA = 0.05
STEPS = 10

In [4]:
class PCInputLayer(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.size = size
        
    def init_vars(self):
        e = torch.zeros((self.size, 1)).to(device)
        return e
        
    def step(self, x, td_pred):
        return x - td_pred

In [45]:
class Conv2d(_ConvNd):
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t,
        stride: _size_2_t = 1,
        padding: Union[str, _size_2_t] = 0,
        dilation: _size_2_t = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = 'zeros',  # TODO: refine this type
        device=None,
        dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        kernel_size_ = _pair(kernel_size)
        stride_ = _pair(stride)
        padding_ = padding if isinstance(padding, str) else _pair(padding)
        dilation_ = _pair(dilation)
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
            False, _pair(0), groups, bias, padding_mode, **factory_kwargs)

    def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                            weight, bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input: Tensor) -> Tensor:
        return self._conv_forward(input, self.weight, self.bias)

In [24]:
class PCConv2d(_ConvNd):
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t,
        stride: _size_2_t = 1,
        padding: Union[str, _size_2_t] = 0,
        dilation: _size_2_t = 1,
        groups: int = 1,
        bias: bool = False,
        padding_mode: str = 'zeros',  # TODO: refine this type
        device=None,
        dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        kernel_size_ = _pair(kernel_size)
        stride_ = _pair(stride)
        padding_ = padding if isinstance(padding, str) else _pair(padding)
        dilation_ = _pair(dilation)
        super(PCConv2d, self).__init__(
            in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
            False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
        
    def _output_padding(self, input: Tensor, output_size: Optional[List[int]],
                        stride: List[int], padding: List[int], kernel_size: List[int],
                        num_spatial_dims: int, dilation: Optional[List[int]] = None) -> List[int]:
        if output_size is None:
            ret = _single(self.output_padding)  # converting to list if was not already
        else:
            has_batch_dim = input.dim() == num_spatial_dims + 2
            num_non_spatial_dims = 2 if has_batch_dim else 1
            if len(output_size) == num_non_spatial_dims + num_spatial_dims:
                output_size = output_size[num_non_spatial_dims:]
            if len(output_size) != num_spatial_dims:
                raise ValueError(
                    "ConvTranspose{}D: for {}D input, output_size must have {} or {} elements (got {})"
                    .format(num_spatial_dims, input.dim(), num_spatial_dims,
                            num_non_spatial_dims + num_spatial_dims, len(output_size)))

            min_sizes = torch.jit.annotate(List[int], [])
            max_sizes = torch.jit.annotate(List[int], [])
            for d in range(num_spatial_dims):
                dim_size = ((input.size(d + num_non_spatial_dims) - 1) * stride[d] -
                            2 * padding[d] +
                            (dilation[d] if dilation is not None else 1) * (kernel_size[d] - 1) + 1)
                min_sizes.append(dim_size)
                max_sizes.append(min_sizes[d] + stride[d] - 1)

            for i in range(len(output_size)):
                size = output_size[i]
                min_size = min_sizes[i]
                max_size = max_sizes[i]
                if size < min_size or size > max_size:
                    raise ValueError((
                        "requested an output size of {}, but valid sizes range "
                        "from {} to {} (for an input of {})").format(
                            output_size, min_sizes, max_sizes, input.size()[2:]))

            res = torch.jit.annotate(List[int], [])
            for d in range(num_spatial_dims):
                res.append(output_size[d] - min_sizes[d])

            ret = res
        return ret
      
    def _conv_forward(self, input: Tensor, bias: Optional[Tensor]=None):
        if self.padding_mode != 'zeros':
                        raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
        return F.conv2d(input, self.weight, bias, self.stride,
                        self.padding, self.dilation, self.groups)
    
    def _conv_transpose(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
        num_spatial_dims = 2
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
            num_spatial_dims, self.dilation)  # type: ignore[arg-type]

        return F.conv_transpose2d(
            input, self.weight, self.bias, self.stride, self.padding,
            output_padding, self.groups, self.dilation)    
    
    def init_vars(self):
        r = torch.zeros((self.out_channels, 1)).to(device)
        e = torch.zeros((self.out_channels, 1)).to(device)
        return r, e

    def pred(self, r: Tensor) -> Tensor:
        return F.relu(self._conv_transpose(r.unsqueeze(0)))
    
    def step(self, e_inf, r, e, td_pred) -> Tensor:
        r = NU*r + MU*self._conv_forward(e_inf.unsqueeze(0)).reshape((-1, 1)) - ETA*e
        e = r - td_pred
        return r, e
    
x = torch.rand((1,28,28)).to(device)
layer1 = PCConv2d(1, 16, (4,4), (4,4), device=device)
r1 = layer1._conv_forward(x)
print("r1.shape:", r1.shape)
td_pred1 = layer1._conv_transpose(r1)
print("td_pred1.shape:", td_pred1.shape)
e1 = torch.rand(r1.shape).to(device)

layer2 = PCConv2d(16, 32, (2,2), (1,1), device=device)
r2 = layer2._conv_forward(e1)
print("r2.shape:", r2.shape)
td_pred2 = layer2._conv_transpose(r2)
print("td_pred2.shape:", td_pred2.shape)

r1.shape: torch.Size([16, 7, 7])
td_pred1.shape: torch.Size([1, 28, 28])
r2.shape: torch.Size([32, 6, 6])
td_pred2.shape: torch.Size([16, 7, 7])


In [81]:
class PCConv2d(_ConvNd):
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t,
        stride: _size_2_t = 1,
        padding: Union[str, _size_2_t] = 0,
        dilation: _size_2_t = 1,
        groups: int = 1,
        padding_mode: str = 'zeros',  # TODO: refine this type
        device=None,
        dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        kernel_size_ = _pair(kernel_size)
        stride_ = _pair(stride)
        padding_ = padding if isinstance(padding, str) else _pair(padding)
        dilation_ = _pair(dilation)
        super(PCConv2d, self).__init__(
            in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
            False, _pair(0), groups, padding_mode, **factory_kwargs)

    def _conv_forward(self, input: Tensor, weight: Tensor):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                            weight, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.stride,
                        self.padding, self.dilation, self.groups)
    

    
    def pred(self, r):
        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')

        assert isinstance(self.padding, tuple)
        # One cannot replace List by Tuple or Sequence in "_output_padding" because
        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
        num_spatial_dims = 2
        output_padding = 0 # <===========HACKY
#         output_padding = self._output_padding(
#             input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
#             num_spatial_dims, self.dilation)  # type: ignore[arg-type]

        return F.relu(F.conv_transpose2d(
            r.unsqueeze(0), self.weight, self.bias, self.stride, self.padding,
            output_padding, self.groups, self.dilation)).reshape((-1, 1))
    


In [82]:
class PCLayer(nn.Module):
    def __init__(self, size_prev, size):
        super().__init__()
        self.size, self.size_prev = size, size_prev
        
        U = torch.zeros((size_prev, size)).to(device)
        self.U = nn.Parameter(U)
        nn.init.kaiming_uniform_(self.U, a=25) # <=== To Revisit
        
#         V = torch.zeros((size, size_prev)).to(device)
#         self.V = nn.Parameter(V)
#         nn.init.kaiming_uniform_(self.V, a=25) # <=== To Revisit
        
    def init_vars(self):
        r = torch.zeros((self.size, 1)).to(device)
        e = torch.zeros((self.size, 1)).to(device)
        return r, e
        
    def pred(self, r):
        return F.relu(torch.mm(self.U, r))

    def step(self, e_inf, r, e, td_pred):
        r = NU*r + MU*torch.mm(self.U.t(),e_inf) - ETA*e
        e = r - td_pred      
        return r, e

In [83]:
class Model(nn.Module):
    def __init__(self, input_size, h1_size, h2_size, num_classes):
        super().__init__()
        self.input_size = input_size
        self.pc0 = PCInputLayer(input_size)
        self.pc1 = PCConv2d(input_size, h1_size, kernel_size=4, stride=4)
        self.pc2 = PCConv2d(h1_size, h2_size, kernel_size=4, stride=4)
        self.pc3 = PCConv2d(h2_size, num_classes, kernel_size=4, stride=4)
    
    def train(self, x, targets, debug=False):
        pc0_e = self.pc0.init_vars()
        pc1_r, pc1_e = self.pc1.init_vars()
        pc2_r, pc2_e = self.pc2.init_vars()
        pc3_r, pc3_e = self.pc3.init_vars()
        
        for _ in range(STEPS):
            pc0_e = self.pc0.step(x, self.pc1.pred(pc1_r))
            pc1_r, pc1_e = self.pc1.step(pc0_e, pc1_r, pc1_e, self.pc2.pred(pc2_r))
            pc2_r, pc2_e = self.pc2.step(pc1_e, pc2_r, pc2_e, self.pc3.pred(pc3_r))
            pc3_r, pc3_e = self.pc3.step(pc2_e, pc3_r, pc3_e, targets)
        
        if debug:
            print("printing pc3_r....")
            print(pc3_r)
            print("printing pc3_e...")
            print(pc3_e)
            
        pc0_err = pc0_e.square().sum()/self.pc0.size
        pc1_err = pc1_e.square().sum()/self.pc1.size
        pc2_err = pc2_e.square().sum()/self.pc2.size
        pc3_err = pc3_e.square().sum()/self.pc3.size
            
        total_sqr_err =  pc0_err + pc1_err + pc2_err + 10*pc3_err
        return total_sqr_err

In [84]:
INPUT_SIZE = 784
H1_SIZE = 784
H2_SIZE = 784
NUM_CLASSES = 10

In [85]:
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, shuffle=True)
test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, shuffle=True)

In [86]:
model = Model(INPUT_SIZE, H1_SIZE, H2_SIZE, NUM_CLASSES).to(device)

In [87]:
LEARNING_RATE = 0.0000001
NUM_EPOCHS = 1

optimiser = optim.SGD(model.parameters(), lr=LEARNING_RATE)

mean_loss = 0


for epoch in range(NUM_EPOCHS):
    for batch_idx, (data, y) in enumerate(train_loader):
        
        x = data.reshape((-1, 1)).to(device)
        targets = torch.zeros((NUM_CLASSES, 1)).to(device)
        targets[y[0]] = 1

        loss = model.train(x, targets, debug=False)
#         print("targets: ", targets)
#         print(loss)
        
        loss.backward(retain_graph=True)
        
        
        mean_loss += loss
        if batch_idx % 64 == 0:
            print("mean_loss:",mean_loss / 64)
            mean_loss = 0
            optimiser.step()
            optimiser.zero_grad()

RuntimeError: Given transposed=1, weight of size [784, 784, 4, 4], expected input[1, 1, 784, 1] to have 784 channels, but got 1 channels instead