Update: Showed some good initial result but training is infeasible due to limited expressivity from 2-layer neural network even with ultra wide width. This can be expected

Demonstrated some promising result. Concept is similar to gradient boosting, where model is trained with few neurons, then new neurons are added and original neurons are freezed and continue training. Freezing prevents original neurons from working with new neurons to memorize training data, rather than generalizing through robust patterns. Also utilizes "looks linear" init which helps with speeding up training.

In [None]:
class TwoLayerGradientBoosting(nn.Module):
    def splits_inputs_init(self, layer):
        with torch.no_grad():
            weight = torch.zeros(layer.out_features, layer.in_features)

            for i in range(layer.in_features):
                weight[2 * i, i] = 1
                weight[2 * i + 1, i] = -1

            layer.weight.copy_(weight)
            
        """ Example matrix: [
            [1, 0, 0],
            [-1, 0, 0],
            [0, 1, 0],
            [0, -1, 0],
            [0, 0, 1],
            [0, 0, -1]
        ] """
        
    def __init__(self, input_size, output_size, num_neurons):
        super(TwoLayerGradientBoosting, self).__init__()
        self.activation = nn.ReLU()
        
        self.input_size = input_size
        self.output_size = output_size
        
        self.fc1_trainable = nn.Linear(input_size, input_size * 2 + num_neurons)
        self.fc2_trainable = nn.Linear(input_size * 2 + num_neurons, output_size)
        
        self.splits_inputs_init(self.fc1_trainable)
        nn.init.zeros_(self.fc1_trainable.bias)

        nn.init.zeros_(self.fc2_trainable.weight)
        nn.init.zeros_(self.fc2_trainable.bias)

        self.fc1_frozen = None
        self.fc2_frozen = None

    def grow_network(self, num_new_neurons):
        if self.fc1_frozen is None:
            self.fc1_frozen = self.fc1_trainable
            self.fc2_frozen = self.fc2_trainable
        else:
            new_fc1_frozen = nn.Linear(self.input_size, self.fc1_frozen.out_features + self.fc1_trainable.out_features)
            new_fc2_frozen = nn.Linear(self.fc1_frozen.out_features + self.fc1_trainable.out_features, self.output_size)

            with torch.no_grad():
                new_fc1_frozen.weight[:self.fc1_frozen.out_features] = self.fc1_frozen.weight
                new_fc1_frozen.bias[:self.fc1_frozen.out_features] = self.fc1_frozen.bias

                new_fc1_frozen.weight[self.fc1_frozen.out_features:] = self.fc1_trainable.weight
                new_fc1_frozen.bias[self.fc1_frozen.out_features:] = self.fc1_trainable.bias

                new_fc2_frozen.weight[:, :self.fc1_frozen.out_features] = self.fc2_frozen.weight
                new_fc2_frozen.weight[:, self.fc1_frozen.out_features:] = self.fc2_trainable.weight

                new_fc2_frozen.bias = torch.nn.Parameter(self.fc2_frozen.bias + self.fc2_trainable.bias)

            self.fc1_frozen = new_fc1_frozen
            self.fc2_frozen = new_fc2_frozen

        self.fc1_trainable = nn.Linear(self.input_size, self.input_size * 2 + num_new_neurons)
        self.fc2_trainable = nn.Linear(self.input_size * 2 + num_new_neurons, self.output_size)
        
        self.splits_inputs_init(self.fc1_trainable)
        nn.init.zeros_(self.fc2_trainable.bias)

        nn.init.zeros_(self.fc2_trainable.weight)
        nn.init.zeros_(self.fc2_trainable.bias)

        self.fc1_frozen.to(device)
        self.fc2_frozen.to(device)

        self.fc1_trainable.to(device)
        self.fc2_trainable.to(device)

        for param in self.fc1_frozen.parameters():
            param.requires_grad = False

        for param in self.fc2_frozen.parameters():
            param.requires_grad = False

    def forward(self, x):
        if self.fc1_frozen is None:  
            x = self.activation(self.fc1_trainable(x))
            return self.fc2_trainable(x)
        
        x_frozen = self.activation(self.fc1_frozen(x))
        x_frozen = self.fc2_frozen(x_frozen)
        
        x = self.activation(self.fc1_trainable(x))
        x = self.fc2_trainable(x)
        return x_frozen + x
