# Training a simple single layer Neural network with DFA



Work from Anas. Served as a basis for my own code

### useful imports

In [129]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

### Define NN structure

In [130]:
# Define the network architecture with a hidden layer and Tanh activation
class SimpleDFA(nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size):
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        super(SimpleDFA, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)  # Input weight matrix (input to hidden)
        self.fc2 = nn.Linear(hidden_size, output_size)  # Output weight matrix (hidden to output)
        self.activation = nn.Tanh()  # Use Tanh activation function

    def forward(self, x):
        
        x = self.activation(self.fc1(x))  # Hidden layer with Tanh activation
        return self.activation(self.fc2(x))  # Output layer with Tanh activation
    
    
    def Init_weights(self):
        # Xavier init on weights
        nn.init.xavier_uniform_(self.fc1.weight) 
        nn.init.xavier_uniform_(self.fc2.weight) 

        print(self.fc1.weight)
        print(self.fc2.weight)
        
    
    def Init_feedback(self):
        
        self.B_1 = torch.randn(self.output_size, self.hidden_size) * 0.01  # Feedback for 1st layer
        self.B_2 = torch.randn(self.hidden_size, self.input_size) * 0.01 # Feedback for 2nd layer
        
        #normalize the feedback matrices to avoid exploding feedback
        self.B_1 = self.B_1 / self.B_1.norm()
        self.B_2 = self.B_2 / self.B_2.norm()
        return self.B_1, self.B_2
    
    def forward_pass(self,Input,Target,loss_func):
        
        # Forward pass
        Output = self.forward(Input)
        loss = loss_func(Output, Target)
        return Output, loss
        
    def activation_derivative(self, x):
        
        return 1 - self.activation(x)**2  # Derivative of Tanh activation function
    
    
def dfa_update_step(model, feedback_matrix_output, feedback_matrix_input, output, target, hidden_activations, inputs, learning_rate=0.01):
    #function to update the weights and biases using DFA

    error = output - target  # Compute error
    
    # update the output weights and biases at the output
    feedback_signal_output = torch.matmul(error * model.activation_derivative(output), feedback_matrix_output)
    
    with torch.no_grad():
        #weight update for the output layer
        delta_w_out = torch.matmul(feedback_signal_output.T, hidden_activations)
        model.fc2.weight -= learning_rate * delta_w_out
        #bias update for the output layer
        delta_b_out = feedback_signal_output.sum(0)
        model.fc2.bias -= learning_rate * delta_b_out
        
    print("output no problem")

    # update the input to hidden weights and biases
    feedback_signal_input = torch.matmul(error * model.activation_derivative(hidden_activations), feedback_matrix_input)
    
    with torch.no_grad():
        #weight update for the input layer
        delta_w_in = torch.matmul(feedback_signal_input.T, inputs)
        model.fc1.weight -= learning_rate * delta_w_in
        #bias update for the input layer
        delta_b_in = feedback_signal_input.sum(0)
        model.fc1.bias -= learning_rate * delta_b_in
        


        
def fa_update_step(model, feedback_matrix_output, feedback_matrix_input, output, target, hidden_activations, inputs, learning_rate=0.01):
    #function to update the weights and biases using simple FA
    error = output - target  # Compute error
    
    # update the output weights and biases at the output
    feedback_signal_output = torch.matmul(error * model.activation_derivative(output), feedback_matrix_output)
    
    with torch.no_grad():
        #weight update for the output layer
        delta_w_out = torch.matmul(feedback_signal_output.T, hidden_activations)
        model.fc2.weight -= learning_rate * delta_w_out
        #bias update for the output layer
        delta_b_out = feedback_signal_output.sum(0)
        model.fc2.bias -= learning_rate * delta_b_out
        
    # update the input to hidden weights and biases
    feedback_signal_input = torch.matmul(error * model.activation_derivative(hidden_activations), feedback_matrix_input)
    
    with torch.no_grad():
        #weight update for the input layer
        delta_w_in = torch.matmul(feedback_signal_input.T, inputs)
        model.fc1.weight -= learning_rate * delta_w_in
        #bias update for the input layer
        delta_b_in = feedback_signal_input.sum(0)
        model.fc1.bias -= learning_rate * delta_b_in
        
    
    

In [131]:
# import torch

# Credits : https://github.com/iacolippo/Direct-Feedback-Alignment/blob/master/dfa-mnist.ipynb

def dfa_found_update_step(model, feedback_matrix_output, feedback_matrix_input, output, target, hidden_activations, inputs, learning_rate=0.01):




    print("feedback_matrix_output", feedback_matrix_output.size())
    print("feedback_matrix_input", feedback_matrix_input.size())
    print("output", output.size())
    print("inputs", inputs.size())
    print("target", target.size())
    print("hidden_activations", hidden_activations.size())



    # Compute error at the output layer
    e = output - target  # Error signal
    
    # Generate feedback-aligned signal for the output layer using feedback_matrix_output
    da2 = torch.matmul(feedback_matrix_output.T, e.T)  # (output_size, batch_size)
    
    # Compute weight gradient for the output layer using DFA
    dW2 = -torch.matmul(da2.T, hidden_activations)  # (output_size, batch_size) * (batch_size, hidden_size) -> (output_size, hidden_size)
    db2 = -torch.sum(da2, dim=1, keepdim=True)  # Sum along batch dimension -> (output_size, 1)
    
    # Generate feedback-aligned signal for the hidden layer using feedback_matrix_input
    da1 = torch.matmul(feedback_matrix_input.T, e.T) * (1 - torch.tanh(hidden_activations) ** 2)  # (hidden_size, batch_size)
    
    # Compute weight gradients for the input-to-hidden layer using inputs
    dW1 = -torch.matmul(da1.T, inputs)  # (batch_size, hidden_size) * (batch_size, input_size) -> (hidden_size, input_size)
    db1 = -torch.sum(da1, dim=1, keepdim=True)  # Sum along batch dimension -> (hidden_size, 1)

    # Update the model's weights and biases
    with torch.no_grad():
        model.fc2.weight -= learning_rate * dW2  # Update hidden-to-output weights
        model.fc2.bias -= learning_rate * db2.squeeze()  # Update output layer bias
        
        model.fc1.weight -= learning_rate * dW1  # Update input-to-hidden weights
        model.fc1.bias -= learning_rate * db1.squeeze()  # Update hidden layer bias




### Define training loop funciton

In [132]:
def train(model, data, targets, loss_fn, learning_rate=0.01, epochs=1000):
    # Initialize feedback matrices
    B_1, B_2 = model.Init_feedback()

    losses = []  # To store the loss at each epoch
    for epoch in range(epochs):
        # Forward pass
        hidden_activations = model.activation(model.fc1(data))  # Get hidden layer activations
        output, loss = model.forward_pass(data, targets, loss_fn)

        # Apply DFA update step
        # dfa_update_step(model, B_1, B_2, output, targets, hidden_activations, data, learning_rate)
        dfa_found_update_step(model, B_1, B_2, output, targets, hidden_activations, data, learning_rate)

        # Store the loss
        losses.append(loss.item())

        # Print loss and accuracy every 100 epochs
        if epoch % 100 == 0:
            print(f"Epoch [{epoch}/{epochs}], Loss: {loss.item()}")

            # Compute accuracy
            predictions = (output >= 0.5).float()  # Threshold the output to get binary predictions
            correct = (predictions == targets).float().sum()
            accuracy = correct / targets.size(0)
            print(f"Accuracy: {accuracy * 100:.2f}%")
    
    # Plot the convergence of loss
    plt.plot(losses)
    plt.title("Convergence of Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.show()


#### Trying the code on a simple XOR task

In [133]:
# Example usage with XOR data
input_size = 3  # XOR has two input features (0 or 1)
hidden_size = 4  # You can adjust this value for the hidden layer
output_size = 2  # XOR has one output feature (0 or 1)


data = torch.tensor([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]], dtype=torch.float32)  # XOR inputs
targets = torch.tensor([[0, 0], [1, 0], [1, 0], [0, 0]], dtype=torch.float32)  # XOR targets

# Initialize the model and loss function
model = SimpleDFA(input_size, hidden_size, output_size)
model.Init_weights()  # Initialize the weights

loss_fn = nn.MSELoss()  

# Train the network using DFA and print XOR accuracy
train(model, data, targets, loss_fn, learning_rate=0.1, epochs=10000)


Parameter containing:
tensor([[-0.0299, -0.4139, -0.3307],
        [-0.5169,  0.2924,  0.8625],
        [-0.6726,  0.5981, -0.0745],
        [-0.4026, -0.3973, -0.4581]], requires_grad=True)
Parameter containing:
tensor([[-0.9295,  0.0952, -0.5856,  0.2938],
        [-0.4500, -0.2215,  0.0379, -0.3219]], requires_grad=True)
feedback_matrix_output torch.Size([2, 4])
feedback_matrix_input torch.Size([4, 3])
output torch.Size([4, 2])
inputs torch.Size([4, 3])
target torch.Size([4, 2])
hidden_activations torch.Size([4, 4])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x4 and 2x4)