In [5]:
import torch
import torch.nn as nn
import torch.optim as optim

In [6]:
class RawAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.Wq = nn.Linear(1, 1, bias=False)
        self.Wk = nn.Linear(1, 1, bias=False)
        self.Wv = nn.Linear(1, 1, bias=False)

    def forward(self, x):
        # x shape: (batch, 2)
        x1 = x[:, 0:1]
        x2 = x[:, 1:2]

        tokens = torch.stack([x1, x2], dim=1)  # (batch, 2, 1)

        Q = self.Wq(tokens)
        K = self.Wk(tokens)
        V = self.Wv(tokens)

        scores = torch.matmul(Q, K.transpose(1, 2))  # (batch, 2, 2)
        attended = torch.matmul(scores, V)           # (batch, 2, 1)

        attended = attended.squeeze(-1)               # (batch, 2)

        return attended

class RawAttentionWithTopNeuron(nn.Module):
    def __init__(self):
        super().__init__()
        self.raw_attention = RawAttention() # Instantiate the new RawAttention module
        self.attn_out = nn.Linear(2, 1)   # attention readout. Bias is requried for Nand.

    def forward(self, x):
        # x shape: (batch, 2)

        attended = self.raw_attention(x) # Use the raw_attention module to get attended output

        attn_output = self.attn_out(attended + x) # (batch, 1)

        return attn_output


In [7]:
torch.manual_seed(0)

# -----------------------------
# Dataset: XOR
# -----------------------------
X = torch.tensor([
    [0., 0.],
    [0., 1.],
    [1., 0.],
    [1., 1.]
])

y = torch.tensor([[0.], [1.], [1.], [0.]])

In [None]:
torch.manual_seed(0)

# -----------------------------
# Dataset: NAND
# -----------------------------
X = torch.tensor([
    [0., 0.],
    [0., 1.],
    [1., 0.],
    [1., 1.]
])

y = torch.tensor([[1.], [1.], [1.], [0.]])

In [8]:
model = RawAttentionWithTopNeuron()


#optimizer = optim.Adam(model.parameters(), lr=0.05)
optimizer = optim.AdamW(model.parameters(), lr=5e-2) #5e-2 is same as .05
#optimizer = torch.optim.SGD(model.parameters(), lr=0.05)


# Tried learning rate 1e-3, but for this problem it is not working.
# Using scheduler below helps, but since it is fluctuating with .00x, may not be worth it.
#scheduler = torch.optim.lr_scheduler.ExponentialLR(
#    optimizer,
#    gamma=0.9995
#)
#optimizer = optim.Adam(model.parameters(), lr=0.05)

loss_fn = nn.MSELoss()
loss_history = []
for epoch in range(8000):
    optimizer.zero_grad()
    preds = model(X)
    loss = loss_fn(preds, y)
    loss_history.append(loss.item()) # Store loss
    loss.backward()
    optimizer.step()
    #scheduler.step()

    if epoch % 50 == 0:
        print(f"Epoch {epoch}, loss = {loss.item():.4f}")


Epoch 0, loss = 0.8390
Epoch 50, loss = 0.1130
Epoch 100, loss = 0.0257
Epoch 150, loss = 0.0036
Epoch 200, loss = 0.0005
Epoch 250, loss = 0.0001
Epoch 300, loss = 0.0001
Epoch 350, loss = 0.0000
Epoch 400, loss = 0.0000
Epoch 450, loss = 0.0000
Epoch 500, loss = 0.0001
Epoch 550, loss = 0.0000
Epoch 600, loss = 0.0000
Epoch 650, loss = 0.0000
Epoch 700, loss = 0.0005
Epoch 750, loss = 0.0001
Epoch 800, loss = 0.0000
Epoch 850, loss = 0.0000
Epoch 900, loss = 0.0000
Epoch 950, loss = 0.0004
Epoch 1000, loss = 0.0000
Epoch 1050, loss = 0.0000
Epoch 1100, loss = 0.0000
Epoch 1150, loss = 0.0000
Epoch 1200, loss = 0.0000
Epoch 1250, loss = 0.0001
Epoch 1300, loss = 0.0000
Epoch 1350, loss = 0.0000
Epoch 1400, loss = 0.0000
Epoch 1450, loss = 0.0000
Epoch 1500, loss = 0.0000
Epoch 1550, loss = 0.0000
Epoch 1600, loss = 0.0000
Epoch 1650, loss = 0.0000
Epoch 1700, loss = 0.0000
Epoch 1750, loss = 0.0032
Epoch 1800, loss = 0.0000
Epoch 1850, loss = 0.0000
Epoch 1900, loss = 0.0000
Epoch 195

In [None]:
model = RawAttentionWithTopNeuron()
optimizer = torch.optim.LBFGS(
    model.parameters(),
    lr=0.05,
    max_iter=20,
    history_size=100
)

loss_fn = nn.MSELoss()
loss_history = []

for epoch in range(8000):

    def closure():
        optimizer.zero_grad()
        preds = model(X)
        loss = loss_fn(preds, y)
        loss.backward()
        return loss

    loss = optimizer.step(closure)
    loss_history.append(loss.item())

    if epoch % 50 == 0:
        print(f"Epoch {epoch}, loss = {loss.item():.4f}")


In [None]:
# Apply newton method on the loss function itself as we know Loss should be zero
# multiply and divide as we can't divide by gradient which is vector
model = RawAttentionWithTopNeuron()

loss_fn = nn.MSELoss()
loss_history = []

for epoch in range(800):
    # Zero gradients
    model.zero_grad()

    # Forward
    preds = model(X)
    loss = loss_fn(preds, y)
    loss_history.append(loss.item())

    # Backward
    loss.backward()

    # ---- YOUR UPDATE RULE ----
    with torch.no_grad():
        # Compute ||grad||^2 over ALL parameters
        grad_norm_sq = 0.0
        for p in model.parameters():
            if p.grad is not None:
                grad_norm_sq += torch.sum(p.grad ** 2)

        # Safety check
        if grad_norm_sq < 1e-12:
            print("Gradient vanished — stopping.")
            break

        step_scale = loss / grad_norm_sq

        for p in model.parameters():
            if p.grad is not None:
                p -= step_scale * p.grad
    # --------------------------

    if epoch % 1 == 0:
        print(f"Epoch {epoch}, loss = {loss.item():.6f}")


In [None]:
import numpy as np

losses = np.array(loss_history)
THRESHOLD = 0.0001

# Find first epoch where loss goes below threshold
below_idx = np.where(losses <= THRESHOLD)[0]

if len(below_idx) == 0:
    print("Loss never reached threshold.")
else:
    start = below_idx[0]
    increases = np.where(losses[start+1:] > losses[start:-1])[0]

    if len(increases) == 0:
        print("No loss increases after convergence.")
    else:
        print(f"Loss increased {len(increases)} times after reaching {THRESHOLD}.")
        for i in increases[:10]:  # show first few
            e = start + i + 1
            print(
                f"Epoch {e}: "
                f"{losses[e-1]:.6f} → {losses[e]:.6f}"
            )
    print(f"Loss at last {losses[-1]:.12f}")

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
#plt.plot(loss_history[:1500], label='Training Loss')
plt.plot(loss_history[400:], label='Training Loss')
#plt.plot(loss_history[:500], label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss History')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print('Learned Parameters:')
for name, param in model.named_parameters():
    print(f'Parameter: {name}, Value: {param.data}')

### Numerical Substitution for Gate

Using the learned parameters for the NAND gate:

*   $w_q = 0.6367$
*   $w_k = 1.1251$
*   $w_v = -1.3959$
*   $w_{attn\_out,1} = 0.4634$
*   $w_{attn\_out,2} = 0.5366$
*   $b_{attn\_out} = 1.0000$

First, calculate $C = w_q w_k w_v = (0.6367) \cdot (1.1251) \cdot (-1.3959) \approx -0.9995 \approx -1.0$.

Substituting these values into the simplified formula for binary inputs $(x_1^2 = x_1, x_2^2 = x_2)$:

$output = (w_{attn\_out,1} \cdot x_1 + w_{attn\_out,2} \cdot x_2) \cdot (C \cdot (x_1 + x_2) + 1) + b_{attn\_out}$

becomes:

$output = (0.4634 \cdot x_1 + 0.5366 \cdot x_2) \cdot (-1.0 \cdot (x_1 + x_2) + 1) + 1.0$

In [None]:
wq = model.raw_attention.Wq.weight.item()
wk = model.raw_attention.Wk.weight.item()
wv = model.raw_attention.Wv.weight.item()
w_attn_out1 = model.attn_out.weight[0,0].item()
w_attn_out2 = model.attn_out.weight[0,1].item()
b_attn_out = model.attn_out.bias[0].item()
#b_attn_out = 0
C = wq * wk * wv

# Calculate coefficients for the fully expanded form for explanation
coeff_x1_term = w_attn_out1 * (C + 1)
coeff_x2_term = w_attn_out2 * (C + 1)
coeff_x1x2_term = C * (w_attn_out1 + w_attn_out2)
print(coeff_x1_term, coeff_x2_term, coeff_x1x2_term, b_attn_out)
print(f"{coeff_x1_term:.4f}, {coeff_x2_term:.4f}, {coeff_x1x2_term:.4f}, {b_attn_out:.4f}")



### Numerical Substitution

Using the learned parameters (for XOR:

*   $w_q = -1.0343$
*   $w_k = -0.7918$
*   $w_v = -0.6119$
*   $w_{attn\_out,1} = 1.9998$
*   $w_{attn\_out,2} = 1.9997$
*   $b_{attn\_out} = -0.0005$

First, calculate $C = w_q w_k w_v = (-1.0343) \cdot (-0.7918) \cdot (-0.6119) \approx -0.5011$.

Substituting these values into the simplified formula:

$output = (1.9998 \cdot x_1 + 1.9997 \cdot x_2) \cdot (-0.5011 \cdot (x_1^2 + x_2^2) + 1) - 0.0005$
