<a href="https://colab.research.google.com/github/foxtrotmike/CS909/blob/master/xor_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Using Self-Attention to Solve the 2D XOR Problem in PyTorch

By Fayyaz Minhas.

In this tutorial, we will explore an innovative approach to solving the 2D XOR problem using a self-attention mechanism in PyTorch. The XOR problem is a classic example used to demonstrate the limitations of linear models and the power of more complex architectures.

## Overall Approach
The key idea is to model each bit in the input as a token and use an attention mechanism to develop a feature representation for these tokens. This allows the model to dynamically weigh the importance of each bit in the input sequence, leading to a more robust representation that can solve the XOR problem effectively.

## Self-Attention Layer
The self-attention mechanism enables the model to focus on different parts of the input sequence by computing a weighted sum of the value vectors, where the weights are determined by the similarity between the query and key vectors.

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class SelfAttention(nn.Module):
    def __init__(self, d_qk = 2, d_model = 2):
        super(SelfAttention, self).__init__()
        self.d_model = d_model

        self.W_q = nn.Linear(1, d_qk, bias=False)
        self.W_k = nn.Linear(1, d_qk, bias=False)
        self.W_v = nn.Linear(1, d_model, bias=False)

    def forward(self, x):
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        attention_weights = F.softmax(scores, dim=-1)

        # Weighted sum of values
        output = torch.matmul(attention_weights, V)

        return output, attention_weights


## Explanation:

Query, Key, and Value Vectors:
Each bit in the input is treated as a token and transformed into query (Q), key (K), and value (V) vectors using linear layers.

 Attention Scores:
The attention scores are computed using the dot product of Q and K, scaled by the square root of the model dimension to ensure stable gradients.

Attention Weights:
The softmax function is applied to the scores to obtain the attention weights, which determine the importance of each token.

Weighted Sum:
The final output is obtained by computing the weighted sum of the value vectors, where the weights are the attention scores.

## XOR Attention Model
The model combines the self-attention layer with a fully connected layer to predict the XOR output.

In [29]:
class XORAttentionModel(nn.Module):
    def __init__(self):
        super(XORAttentionModel, self).__init__()
        d_model = 2
        self.attention = SelfAttention(d_model=d_model)
        self.fc1 = nn.Linear(2 * d_model, 2)
        self.fc2 = nn.Linear(2, 1)
        # Freeze the attention layer parameters
        for param in self.attention.parameters():
            param.requires_grad = True #Can be set to False to check if it is indeed the attention that is making a difference

    def forward(self, x):
        attn_output, attention_weights = self.attention(x)
        attn_output = attn_output.view(attn_output.size(0), -1)
        attn_output = F.tanh(attn_output) #Non linearity
        fc1_output = self.fc1(attn_output) #just to convert to 2D
        output = self.fc2(fc1_output)
        return output, attention_weights, fc1_output


## Explanation:
Input Tokens:
Each bit of the XOR input is modeled as a token and passed through the self-attention layer.

Feature Representation:
The attention layer produces a feature representation by considering the importance of each bit in the input.

Fully Connected Layer:
The transformed output from the attention layer is passed through a fully connected layer to produce the final prediction.

Training the Model:
The model is trained on the XOR dataset using Mean Squared Error (MSE) loss and the Adam optimizer.

In [30]:
# XOR input and output
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
outputs = np.array([[-1], [1], [1], [-1]], dtype=np.float32)

# Convert to PyTorch tensors
inputs = torch.tensor(inputs).unsqueeze(-1)
outputs = torch.tensor(outputs)

# Model, loss function, and optimizer
model = XORAttentionModel()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

# Training loop
for epoch in range(2000):
    model.train()
    optimizer.zero_grad()
    preds, _, _ = model(inputs)
    loss = criterion(preds, outputs)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# Evaluate and print attention weights and transformed representations
model.eval()
with torch.no_grad():
    outputs, attention_weights, attn_outputs = model(inputs)
print("Output: \n",outputs)
print("Transformed representations at the output of the attention block:")
for i, input in enumerate(inputs):
    print(f"Input: {input.squeeze().tolist()}, Transformed Representation: {attn_outputs[i].squeeze().tolist()}")

print("Attention Weights for each input:")
for i, input in enumerate(inputs):
    print(f"Input: {input.squeeze().tolist()}, Attention Weights: {attention_weights[i].squeeze().tolist()}")


Epoch 100, Loss: 0.0001357378641841933
Epoch 200, Loss: 1.5329132452279737e-07
Epoch 300, Loss: 2.0622257579816505e-05
Epoch 400, Loss: 1.0135195793736784e-07
Epoch 500, Loss: 9.096464737012866e-08
Epoch 600, Loss: 0.0028985911048948765
Epoch 700, Loss: 1.0409092965346645e-07
Epoch 800, Loss: 5.261746593987482e-10
Epoch 900, Loss: 0.003003119956701994
Epoch 1000, Loss: 8.488447633681062e-08
Epoch 1100, Loss: 6.706748933993367e-09
Epoch 1200, Loss: 0.0020726402290165424
Epoch 1300, Loss: 2.0104394593545294e-07
Epoch 1400, Loss: 1.5863045632613648e-07
Epoch 1500, Loss: 3.0250492272898555e-05
Epoch 1600, Loss: 1.3077320772936218e-06
Epoch 1700, Loss: 5.18033857588307e-07
Epoch 1800, Loss: 0.01752374693751335
Epoch 1900, Loss: 2.225025639290834e-07
Epoch 2000, Loss: 1.4997780795056315e-11
Output: 
 tensor([[-1.0000],
        [ 1.0000],
        [ 1.0000],
        [-1.0000]])
Transformed representations at the output of the attention block:
Input: [0.0, 0.0], Transformed Representation: [-0.

## Conclusions

In this tutorial, we demonstrated how to implement a self-attention layer in PyTorch and use it to solve the 2D XOR problem. By modeling each bit in the input as a token and applying the attention mechanism, the model can dynamically weigh the importance of each bit, leading to a robust feature representation that effectively solves the XOR problem. This approach highlights the power of attention mechanisms and their potential applications in more complex tasks and models, such as transformers in natural language processing and beyond.






