In [8]:
import torch
import torch.nn as nn
import numpy as np

import torchvision
import torchvision.transforms as transforms

In [9]:
# Define relevant variables for the ML task
batch_size = 64
num_classes = 10
learning_rate = 1e-7
num_epochs = 2

# Device will determine whether to run the training on GPU or CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [74]:
# Use transforms.compose method to reformat images for modeling,
# and save to variable all_transforms for later use
all_transforms = transforms.Compose([transforms.Resize((32,32)),
                                     transforms.ToTensor()])

# Create Training dataset
train_dataset = torchvision.datasets.MNIST(root = './data',
                                             train = True,
                                             transform = all_transforms,
                                             download = True)

# Create Testing dataset
test_dataset = torchvision.datasets.MNIST(root = './data',
                                            train = False,
                                            transform = all_transforms,
                                            download=True)

# Instantiate loader objects to facilitate processing
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)


test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)

In [107]:
class My2DConvLayer(nn.Module):
    """ Locally rotationally equivariant-baked CNNs (Harmonic Net implementation)."""
    def __init__(self, kernel_size, rotation_speed, in_channels, out_channels):
        # Construct a convolution of size kernel_size[0] (odd int) x kernel_size[1] (odd int)
        # specified by spherical harmonics with specified `rotation_speed`.
        # in_channels: int  - the last dimension of the input signal
        # out_channels: int - the specified last dimension of the output signal
        super().__init__()
        assert (type(kernel_size[0]==int) and type(kernel_size[1]) == int 
                and kernel_size[0]%2 == 1 and kernel_size[0]%2 == 1), 'Ensure kernel is odd-sized and of type tuple(int, int)'
        self.kernel_width, self.kernel_depth = kernel_size[1], kernel_size[0]
        self.in_channels, self.out_channels = in_channels, out_channels
        
        # Construct filters via spherical harmonics: R(r)*exp(i(m*phi + beta))
        # R(r): a function of the radial distance r
        # m: the rotation speed parameters
        # beta: a learnable rotation start parameter
        # R(r) is made a polynomial, here of order 2
        self.weights = nn.Parameter(torch.Tensor(2+2, 1, 1, self.out_channels)).to(device)
        torch.nn.init.uniform_(self.weights, -3.14/2, 3.14/2)
        self.weights = self.weights.to(device)
        
        # 2d kernel generation
        # construct harmonic kernels tensor
        r, theta = self._generate_meshgrid()

        self.kernels = ((self.weights[0]*r**0 + self.weights[1]*r + 
                         self.weights[2]*r**2)*
                         torch.exp((rotation_speed*theta + self.weights[3])*1j))

        # reshape for better use with forward pass
        self.kernels = torch.transpose(self.kernels, 2, 0)           # shape: out_channels, kernel_size (unpacked)
        self.kernels = self.kernels[None,:,None,:,:,None].to(device) # shape: 1, out_channels, 1, kernel_size (unpacked), 1


    def _generate_meshgrid(self):
        # generate coordinate space meshgrid with centre at (0,0)
        y, x = torch.meshgrid(torch.arange((-self.kernel_width+1)//2, (self.kernel_width+1)//2), 
                              torch.arange((-self.kernel_depth+1)//2, (self.kernel_depth+1)//2))
        x, y = x, -y
        # convert cartesian grid to that of polar, fixing non-surjective conversions manually
        r     = torch.sqrt(x**2 + y**2)
        theta = torch.atan(y/x)
        theta[:self.kernel_width//2,:self.kernel_depth//2] += torch.pi # fix 2nd quadrant
        theta[self.kernel_width//2:,:self.kernel_depth//2] -= torch.pi # fix 3rd quadrant
        theta = torch.nan_to_num(theta)
        
        # add a dimension for help with broadcasting later
        r, theta = r[:,:,None].to(device), theta[:,:,None].to(device)
        return r, theta

    
    def forward(self, x):
        # x is of shape (batch_size, 3, 32, 32)

        # perform a convolution, no padding
        # assuming the input is, without batching and empty dimensions, a rank 3 tensor
        conv_output = torch.zeros(x.shape[0], # <- batch size
                                  x.shape[2]-self.kernel_depth+1, 
                                  x.shape[3]-self.kernel_width+1, 
                                  self.out_channels, 
                                  dtype=torch.cdouble).to(device)
        
        # expanding dims for broadcasting reasons
        # size of batch_size, 1, 1, 3, 32, 32
        xcopy = torch.clone(x)
        xcopy = xcopy[:, None, None, :, :, :]
        xcopy = torch.transpose(xcopy, 5, 3)
        for w in range(x.shape[2]-self.kernel_depth):
            for h in range(x.shape[3]-self.kernel_width):
                conv_output[:,h,w] = torch.sum(self.kernels*
                                                xcopy[:,:,:,h:h+self.kernel_width,w:w+self.kernel_depth,:], 
                                                dim=[2, 3, 4, 5])
        
        return torch.transpose(conv_output, 3, 1)


class ComplexRelu(nn.Module):
    # Custom Complex Relu function defined by re^(i * theta) |-> |r|e^(i * theta).
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.nn.functional.relu(abs(x))*x


class ComplexMeanPooling(nn.Module):
    """ Complex valued mean pooling."""
    def __init__(self, kernel_size, stride):
        super().__init__()
        assert (type(kernel_size[0]==int) and type(kernel_size[1]) == int and type(stride) == int), 'Ensure kernel is of type tuple(int, int) and stride is of type int.'
        self.kernel_width, self.kernel_depth = kernel_size[1], kernel_size[0]
        self.stride = stride
        self.kernels = torch.ones((1, 1, self.kernel_width, self.kernel_depth)).to(device)
        self.kernels *= 1./(self.kernel_width*self.kernel_depth)
    
    def forward(self, x):
        # x is of shape [batch_size, prior out_channels, size, size]
        depth_shape = (x.shape[2] + 1 - self.kernel_depth)//self.stride
        width_shape = (x.shape[3] + 1 - self.kernel_width)//self.stride

        conv_output = torch.zeros((x.shape[0],x.shape[1],
                                   depth_shape,
                                   width_shape),
                                   dtype=torch.cdouble).to(device)
        
        for w in range(width_shape):
            for h in range(depth_shape):
                conv_output[:,:,h,w] = torch.sum(self.kernels*
                                                 x[:,:,self.stride*h:self.stride*h+self.kernel_depth,self.stride*w:self.stride*w+self.kernel_width], 
                                                 dim=[2,3])
        return conv_output


def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # Use `is_grad_enabled` to determine whether we are in training mode
    if not torch.is_grad_enabled():
        # In prediction mode, use mean and variance obtained by moving average
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # When using a fully connected layer, calculate the mean and
            # variance on the feature dimension
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of `X`, so that the broadcasting
            # operation can be carried out later
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # In training mode, the current mean and variance are used
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # Update the mean and variance using moving average
        moving_mean = (1.0 - momentum) * moving_mean + momentum * mean
        moving_var = (1.0 - momentum) * moving_var + momentum * var
    Y = gamma * X_hat + beta  # Scale and shift
    return Y, moving_mean.data, moving_var.data



class BatchNorm(nn.Module):
    # `num_features`: the number of outputs for a fully connected layer
    # or the number of output channels for a convolutional layer. `num_dims`:
    # 2 for a fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2: shape = (1, num_features)
        else: shape = (1, num_features, 1, 1)
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # The variables that are not model parameters are initialized to 0 and 1
        self.moving_mean = torch.zeros(shape, dtype=torch.cdouble)
        self.moving_var = torch.ones(shape, dtype=torch.cdouble)

    def forward(self, X):
        # If `X` is not on the main memory, copy `moving_mean` and
        # `moving_var` to the device where `X` is located
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # Save the updated `moving_mean` and `moving_var`

        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.1)
        return Y

In [127]:
# Make Harmonic Neural Net
class HarmonicNeuralNet(nn.Module):
	#  Determine what layers and their order in CNN object
    def __init__(self, num_classes):
        super(HarmonicNeuralNet, self).__init__()
        self.conv_layer1 = My2DConvLayer(kernel_size=(5, 5), rotation_speed=1, in_channels=3, out_channels=16)
        self.celu1 = ComplexRelu()
        self.bn1 = BatchNorm(16, 4)
        self.avgpool1 = ComplexMeanPooling(kernel_size=(5,5), stride=2)

        self.conv_layer2 = My2DConvLayer(kernel_size=(5, 5), rotation_speed=1, in_channels=16, out_channels=16)
        self.celu2 = ComplexRelu()
        self.bn2 = BatchNorm(16, 4)
        self.avgpool2 = ComplexMeanPooling(kernel_size=(3,3), stride=1)

        self.fc1 = nn.Linear(576, 10)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(10, num_classes)
        self.relu2 = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)
    
    # Progresses data across layers    
    def forward(self, x):
        out = self.conv_layer1(x)
        out = self.celu1(out)
        out = self.bn1(out)
        out = self.avgpool1(out)
        out = self.conv_layer2(out)
        out = self.celu2(out)
        out = self.bn2(out)
        out = self.avgpool2(out)
        out = out.float()
        out = out.reshape(out.size(0), -1).float()
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.softmax(out)
        # print(out[0])
        return out

In [128]:
modelH = HarmonicNeuralNet(num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(modelH.parameters())  
num_epochs=10

In [129]:
for epoch in range(num_epochs):
    #Load in the data in batches using the train_loader object
    for i, (images, labels) in enumerate(train_loader):
        # with torch.autograd.detect_anomaly():
        images = images.to(device)
        labels = labels.to(device)

        outputs = modelH(images.to(device))
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        print('.',end='')
        if i>100: break
    print('\nEpoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

......................................................................................................
Epoch [1/10], Loss: 1.4506
......................................................................................................
Epoch [2/10], Loss: 1.4072
......................................................................................................
Epoch [3/10], Loss: 1.1515
......................................................................................................
Epoch [4/10], Loss: 0.8748
......................................................................................................
Epoch [5/10], Loss: 1.1314
......................................................................................................
Epoch [6/10], Loss: 1.0970
......................................................................................................
Epoch [7/10], Loss: 0.6587
..........................................................................................

In [131]:
# Specify a path
PATH = "state_dict_model.pt"

# Save
torch.save(modelH.state_dict(), PATH)

# Load

model = HarmonicNeuralNet(10).to(device)
model.load_state_dict(torch.load(PATH))
model.eval()

HarmonicNeuralNet(
  (conv_layer1): My2DConvLayer()
  (celu1): ComplexRelu()
  (bn1): BatchNorm()
  (avgpool1): ComplexMeanPooling()
  (conv_layer2): My2DConvLayer()
  (celu2): ComplexRelu()
  (bn2): BatchNorm()
  (avgpool2): ComplexMeanPooling()
  (fc1): Linear(in_features=576, out_features=10, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=10, out_features=10, bias=True)
  (relu2): ReLU()
  (softmax): LogSoftmax(dim=1)
)