In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn

from src.models.weights import V1_weights, sensilla_weights, classical_weights

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'src'

### V1 init

In [2]:
def V1_init(layer, size, spatial_freq, center=None, scale=1., bias=False, seed=None):
    """
    Initialize weights of a Conv2d layer according to receptive fields of V1.
    Currently, only works when the number of input channels equals 1. The bias
    is turned off.
    
    Parameters
    ----------
    layer: torch.nn.Conv2d layer
        Layer that will be initialized
        
    size : float
        Determines the size of the random weights

    spatial_freq : float
        Determines the spatial frequency of the random weights 

    center: tuple of shape (2, 1), default = None
        Location of the center of the random weights
        With default value, the centers uniformly cover the RF space

    scale: float, default=1
        Normalization factor for Tr norm of cov matrix
        
    bias: Bool, default=False
        The bias of the convolutional layer

    seed : int, default=None
        Used to set the seed when generating random weights.
    """
    classname = layer.__class__.__name__
    assert classname.find('Conv2d') != -1,'This init only works for conv2d layers'
    assert layer.in_channels == 1, 'This init only works when image has 1 input channel'
    out_channels, in_channels, xdim, ydim = layer.weight.shape  
    v1_weight =  V1_weights(out_channels, (xdim, ydim), size, spatial_freq, center, scale, seed=seed)
    layer.weight.data = Tensor(v1_weight.reshape(out_channels, 1, xdim, ydim))
    if bias == False:
        layer.bias = None

In [3]:
class V1_mnist_RFNet(nn.Module):
    """
    Random Feature network to classify MNIST images. The first layer is initialized from GP
    with covariance inspired by V1.
    """
    def __init__(self, hidden_dim, size, spatial_freq, center=None, scale=1, bias=False, seed=None):
        super(V1_mnist_RFNet, self).__init__()
        self.v1_layer = nn.Conv2d(in_channels=1, out_channels=hidden_dim, kernel_size=28) 
        self.clf = nn.Conv2d(in_channels=hidden_dim, out_channels=10, kernel_size=1)
        self.relu = nn.ReLU()
        
        # initialize the first layer
        V1_init(self.v1_layer, size, spatial_freq, center, scale, bias, seed)
        self.v1_layer.weight.requires_grad = False
        
    def forward(self, x):
        h = self.relu(self.v1_layer(x))
        beta = self.clf(h)
        return beta.squeeze()

In [4]:
# check the init
train = torch.randn((1000, 1, 28, 28))
out_channels = 20
W = nn.Conv2d(1, out_channels, kernel_size=28)
W.weight.requires_grad = False
output = W(train)


# check the model
hidden_size, s, f, center, seed = 100, 5, 2, (16, 16), 20
model = V1_mnist_RFNet(hidden_size, s, f, center, seed=seed)
output = model(train)
print(output.shape)

torch.Size([1000, 10])


### sensilla init

In [5]:
def sensilla_init(layer, lowcut, highcut, decay_coef=np.inf, scale=1, bias=False, seed=None):
    """
    Initialize weights of a Linear layer according to STAs of insect sensilla.
    The bias is turned off by default.
    
    Parameters
    ----------

    layer: torch.nn.Linear layer
        Layer that will be initialized

    lowcut: int
        Low end of the frequency band. 

    highcut: int
        High end of the frequency band.
        
    decay_coef : float, default=np.inf
        controls the how fast the random features decay
        with default value, the weights do not decay
        
    scale: float, default=1
        Normalization factor for Tr norm of cov matrix
        
    bias: Bool, default=False
        The bias of the Linear layer
    
    seed : int, default=None
        Used to set the seed when generating random weights.
    
    """
    classname = layer.__class__.__name__
    assert classname.find('Linear') != -1,'This init only works for Linear layers'
    out_features, in_features = layer.weight.shape
    sensilla_weight = sensilla_weights(out_features, in_features, lowcut, highcut, decay_coef, scale, seed)
    layer.weight.data = Tensor(sensilla_weight)
    if bias == False:
        layer.bias = None
    

class sensilla_RFNet(nn.Module):
    """
    Random Feature network to classify time-series. The first layer is initialized from GP
    with covariance inspired by mechanosensory sensilla.
    """
    def __init__(self, input_dim, hidden_dim, 
                 lowcut, highcut, decay_coef=np.inf, scale=1, bias=False, seed=None):
        super(sensilla_RFNet, self).__init__()
        self.sensilla_layer = nn.Linear(in_features=input_dim, out_features=hidden_dim) 
        self.clf = nn.Linear(in_features=hidden_dim, out_features=2)
        self.relu = nn.ReLU()
        
        # initialize the first layer
        sensilla_init(self.sensilla_layer, lowcut, highcut, decay_coef, scale, bias, seed)
        self.sensilla_layer.weight.requires_grad = False
        
    def forward(self, x):
        h = self.relu(self.sensilla_layer(x))
        beta = self.clf(h)
        return beta.squeeze()

In [6]:
## check the init
input_dim = 1600
train = torch.randn(20, input_dim)
W = nn.Linear(input_dim, 100)
lowcut, highcut, decay_coef = 2, 8, 22
sensilla_init(W, lowcut, highcut, decay_coef)
output = W(train)

## check the model
hidden_size = 100
model = sensilla_RFNet(input_dim, hidden_size, lowcut, highcut, decay_coef)
output = model(train)
print(output.shape)

torch.Size([20, 2])


### classical init
Init with diagonal covariance GP

In [7]:
def classical_init(layer, scale=1, bias=False, seed=None):
    """
    Inialize weights of a Linear layer or convolutional layer according to
    GP with diagonal covariance. 
    
    The bias is turned off by default.
    
    Parameters
    ----------
    
    layer: torch.nn.Linear layer
        Layer that will be initialized
        
    scale: float, default=1
        Normalization factor for Tr norm of cov matrix
        
    bias: Bool, default=False
        The bias of the Linear layer
    
    seed : int, default=None
        Used to set the seed when generating random weights.

    """
    classname = layer.__class__.__name__
    assert classname.find('Linear') != -1 or classname.find('Conv2d') != -1, 'This init only works for Linear or Conv layers' 

    if classname.find('Linear') == 1: 
        in_features, out_features = layer.weight.shape
        classical_weight = classical_weights(out_features, in_features, scale, seed)
        layer.weight.data = Tensor(classical_weight)
        
    elif classname.find('Conv2d') == 1:
        assert layer.in_channels == 1, 'This init only works when image has 1 input channel'
        out_channels, in_channels, xdim, ydim = layer.weight.shape
        classical_weight = classical_weights(out_channels, (xdim, ydim), scale, seed=seed)
        layer.weight.data = Tensor(classical_weight.reshape(out_channels, 1, xdim, ydim))
        
    if bias == False:
        layer.bias = None
        
class classical_RFNet(nn.Module):
    """
    Random Feature network to classify time-series or MNIST digits. The first layer is initialized from GP
    with diagonal covariance.
    """
    def __init__(self, input_dim, hidden_dim, scale=1, bias=False, seed=None):
        super(classical_RFNet, self).__init__()
        if type(input_dim) is int: ## for time-series
            self.RF_layer = nn.Linear(in_features=input_dim, out_features=hidden_dim) 
            self.clf = nn.Linear(in_features=hidden_dim, out_features=2)
        elif type(input_dim) is tuple: ## for MNIST
            self.RF_layer = nn.Conv2d(in_channels=1, out_channels=hidden_dim, kernel_size=28)
            self.clf = nn.Conv2d(in_channels=hidden_dim, out_channels=10, kernel_size=1)
        self.relu = nn.ReLU()
        
        # initialize the first layer
        classical_init(self.RF_layer, scale, bias, seed)
        self.RF_layer.weight.requires_grad = False
        
    def forward(self, x):
        h = self.relu(self.RF_layer(x))
        beta = self.clf(h)
        return beta.squeeze()

In [16]:
## check the init
linear = nn.Linear(1, 20)
classical_init(linear)

conv = nn.Conv2d(1, 20, 28)
classical_init(conv)

In [19]:
## check the model for images
images = torch.randn(20, 1, 28, 28)
input_dim, hidden_dim = (1, 28, 28), 20
model = classical_RFNet(input_dim, hidden_dim=100, bias=False, seed=1)
output = model(images)
print(output.shape)

torch.Size([20, 10])


In [20]:
## check the model for time-series
tseries = torch.randn(20, 150)
input_dim, hidden_dim = 150, 20
model = classical_RFNet(input_dim, hidden_dim=100, bias=False, seed=1)
output = model(tseries)
print(output.shape)

torch.Size([20, 2])
