In [2]:
import numpy as np
import numpy.random as rd

import torch
import torch.cuda
import torchvision.transforms as T
import torch.nn.functional as F
from scipy.stats import crystalball_gen
from torchvision.datasets import MNIST

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Train dataset
train_data_transform = T.ToTensor()
train_data_target_transform = T.Compose([
    lambda x: torch.LongTensor([x]),
    lambda x: F.one_hot(x, -1)
])
train_dataset = MNIST(
    root='data/MNIST',
    download=True,
    train=True,
    transform=train_data_transform,
    target_transform=train_data_target_transform,
)
train_kwargs = {
    'batch_size': 256,
}
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)

# Test dataset
test_data_transform = None
test_data_target_transform = None
test_dataset = MNIST(
    root='data/MNIST',
    download=True,
    train=False,
    transform=test_data_transform,
    target_transform=test_data_target_transform,
)
test_kwargs = {
    'batch_size': 256,
}
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)

In [None]:
for data, label in train_loader:
    print(f"Data: {data.shape}\nLabel: {label}")

In [None]:
# Hyperparams
n_epochs = 10
n_minibatch = 10
print_every = 100
n_1 = 300
n_2 = 100

p01 = .01
p02 = .03
p0out = .3
l1 = 1e-5
gdnoise = 1e-5
lr = 0.5
lr_epoch_decay = 0.8

# Useful constants
dtype = torch.float32
n_pixels = 28 * 28
n_out = 10
n_image_per_epoch = x_train.shape[0]
n_iter = n_epochs * n_image_per_epoch // n_minibatch
mnist_epoch_completed = 0

# Define the number of neurons per layer
layer_names = ["layer1", "layer2", "outlayer"]
sparsity_list = [p01, p02, p0out]
nb_non_zero_coeff_list = [n_pixels * n_1 * p01, n_1 * n_2 * p02, n_2 * n_out * p0out]
nb_non_zero_coeff_list = {layer_names[idx] : int(n) for idx,n in enumerate(nb_non_zero_coeff_list)}

x = torch.tensor() # TODO: autograd?
y = torch.tensor()

In [3]:
def weight_sampler_strict_number(n_in, n_out, nb_non_zero, dtype=torch.float32):
    w_0 = rd.randn(n_in,n_out) / np.sqrt(n_in) # initial weight values

    # Generate the random mask
    is_con_0 = np.zeros((n_in,n_out),dtype=bool)
    ind_in = rd.choice(np.arange(n_in),size=nb_non_zero)
    ind_out = rd.choice(np.arange(n_out),size=nb_non_zero)
    is_con_0[ind_in,ind_out] = True

    # Generate random signs
    sign_0 = np.sign(rd.randn(n_in,n_out))

    # Define the matrices
    theta = torch.tensor(np.abs(w_0) * is_con_0, dtype=dtype, requires_grad=True)
    w_sign = torch.tensor(sign_0, dtype=dtype, requires_grad=True)
    is_connected = torch.greater(theta, 0)
    w = torch.where(is_connected, input=w_sign * theta, out=torch.zeros((n_in, n_out), dtype=dtype))

    return w, w_sign, theta, is_connected

In [None]:
def assert_connection_number(theta, targeted_number):
    '''
    Function to check during the tensorflow simulation if the number of connection in well defined after each simulation.
    :param theta:
    :param targeted_number:
    :return:
    '''
    th = theta.item()
    is_con = torch.greater(th, 0)

    nb_is_con = is_con.type(torch.int32).sum()
    assert_is_con = torch.equal(nb_is_con, targeted_number)

    return assert_is_con

In [None]:
def rewiring(theta: torch.Tensor, target_nb_connection, epsilon=1e-12):
    th = theta.item()
    is_con = torch.greater(th, 0)

    n_connected = is_con.type(torch.int32).sum()
    nb_reconnect = target_nb_connection - n_connected
    nb_reconnect = torch.max(nb_reconnect, 0)

    reconnect_candidate_coord = torch.where(
        torch.logical_not(is_con),
        input=torch.full(is_con.shape, fill_value=1.0),
        other=torch.full(is_con.shape, fill_value=0.0)
    ).nonzero() # Apply the conditions on the matrix to get a binary mask, then the indices of non-zero values

    n_candidates = reconnect_candidate_coord.shape[0]
    reconnect_sample_id = torch.randperm(n_candidates)[:nb_reconnect]
    reconnect_sample_coord = torch.gather(reconnect_candidate_coord, 0, reconnect_sample_id)

    # Apply rewiring
    reconnect_vals = torch.full([nb_reconnect], epsilon)
    reconnect_op = theta.scatter_(0, reconnect_sample_coord, reconnect_vals)

    connection_check = assert_connection_number(theta, target_nb_connection) # TODO

    return reconnect_op

In this context theta is an absolute weight value which is positive if the node is connected. 'w' is the signed theta
theta = np.abs(w_0) * is_con_0
w = w_sign * theta if is_connected else 0

TODO: extract weights, thetas, grad from layers

In [1]:
from collections import OrderedDict

# Define layers
def layer1():
    W_1, _, th1, _ = weight_sampler_strict_number(n_pixels, n_1, nb_non_zero_coeff_list[0])
    a_1 = torch.matmul(x, W_1)
    z_1 = torch.nn.ReLU(a_1)
    return z_1

def layer2(z_1):
    W_2, _, th2, _ = weight_sampler_strict_number(n_1, n_2, nb_non_zero_coeff_list[1])
    a_2 = torch.matmul(z_1, W_2)
    z_2 = torch.nn.ReLU(a_2)
    return z_2

def out_layer(z_2):
    w_out, _, th_out, _ = weight_sampler_strict_number(n_2, n_out, nb_non_zero_coeff_list[2])
    logits_predict = torch.matmul(z_2, w_out)
    return logits_predict

class DeepR(torch.nn.Module):
    def __init__(self, layer_1_dim=(n_pixels, n_1), layer_2_dim=(n_1, n_2), layer_out_dim=(n_2, n_out), nb_non_zero_coeff_list=nb_non_zero_coeff_list):
        super().__init__()
        self.flatten = torch.nn.Flatten()
        self.nb_non_zero_coeff_list = nb_non_zero_coeff_list

        # Custom weight initialization
        layer1 = torch.nn.Linear(in_features=layer_1_dim[0], out_features=layer_1_dim[1])
        layer1.weight = torch.nn.Parameter(rd.randn(layer_1_dim[0],layer_1_dim[1]) / np.sqrt(layer_1_dim[0]))

        layer2 = torch.nn.Linear(in_features=layer_2_dim[0], out_features=layer_2_dim[1])
        layer2.weight = torch.nn.Parameter(rd.randn(layer_2_dim[0],layer_2_dim[1]) / np.sqrt(layer_2_dim[0]))

        outlayer = torch.nn.Linear(in_features=layer_out_dim[0], out_features=layer_out_dim[1])
        outlayer.weight = torch.nn.Parameter(rd.randn(layer_out_dim[0],layer_out_dim[1]) / np.sqrt(layer_out_dim[0]))

        self.layers = torch.nn.Sequential(OrderedDict[
            (layer_names[0], layer1),
            ("relu1", torch.nn.ReLU()),
            (layer_names[1], layer2),
            ("relu2", torch.nn.ReLU()),
            (layer_names[2], outlayer)
        ])

        # Meta inf
        self._linear_layers = [layer1, layer2, outlayer]
        self.theta_list = []
        self.weight_list = []

    def forward(self, x):
        x = self.flatten(x)
        logits = self.layers(x)
        return logits

    def meta_update(self):
        is_first_run = not self.theta_list and not self.weight_list
        for idx, layer in enumerate(self._linear_layers):
            w, w_sign, th, is_conn = self.get_underlying_matrices(layer=layer)

            if is_first_run:
                self.theta_list.append(th)
                self.weight_list.append(w)
            else:
                # Custom SGD impl with noise and L1 still in progress, should be relocated to a different function/optimizer (sparse_sgd)
                # Gradient computation should be done for only active connections, default autograd is overkill probably ref. deep rewiring paper
                [torch.add(th, lr * mask_connected(th) * noise_update(th)) for th in theta_list]# add_gradient_op_list

    def get_underlying_matrices(self, layer_name, layer=None):
        layer = layer if layer is not None else self.layers[layer_name]
        w_0 = layer.weight

        # Generate the random mask
        ind_in = rd.choice(np.arange(layer.in_features),size=self.nb_non_zero_coeff_list[layer_name])
        ind_out = rd.choice(np.arange(layer.out_features),size=self.nb_non_zero_coeff_list[layer_name])

        is_con_0 = np.zeros((layer.in_features, layer.out_features), dtype=bool)
        is_con_0[ind_in,ind_out] = True

        # Generate random signs
        sign_0 = np.sign(rd.randn(layer.in_features,layer.out_features))

        # Define the matrices
        theta = torch.tensor(np.abs(w_0) * is_con_0, dtype=dtype, requires_grad=True)
        w_sign = torch.tensor(sign_0, dtype=dtype, requires_grad=True)
        is_connected = torch.greater(theta, 0)
        w = torch.where(
                is_connected,
                input=w_sign * theta,
                out=torch.zeros((layer.in_features, layer.out_features),dtype=dtype)
        )

        return w, w_sign, theta, is_connected

NameError: name 'torch' is not defined

In [None]:
# Calc loss (softmax_cross_entropy_with_logits in tf)
def loss(logits_pred, y):
    return torch.nn.functional.cross_entropy(logits_pred, y).mean()

In [None]:
mask_connected = lambda th: torch.greater(th, 0).type(torch.float32)
noise_update = lambda th: torch.normal(std=gdnoise, size=th.size())

theta_list = []
add_gradient_op_list =

In [None]:
# Train
def train():
    optimizer = torch.optim.SGD(lr=lr) # TODO: model params
    optimizer.

In [None]:
# Stats and results