In [12]:
from scratch_transformer import MultiHeadAttentionBlock
from data import create_weights, get_reg_data, get_nonlinear_data
import numpy as np

feature_size = 10
output_size = 1
M = 10
N = 1000
lr = 1e-4

# linear attention params override
la_params = create_weights(feature_size, output_size, N, lr)

# get the data
eval_data = get_reg_data(no_tasks=M, feature_size=feature_size, no_examples=N)


# Create a MultiHeadAttentionBlock
mha = MultiHeadAttentionBlock(
    d_model=feature_size + 1, heads=1, dropout=0.0, softmax_att=False
)  # (batch_size, seq_len, d_model)

In [13]:
import torch


# Now we will override the weights of the model to implement those that perform GD in the forward pass
def override_weights(model, new_params, w_name):
    w_name = "Transformer_gd/multi_head_attention/" + w_name
    w_numpy = new_params[w_name]["w"]
    w_tensor = torch.tensor(w_numpy, dtype=model.weight.dtype)
    model.weight.data = w_tensor


# Override the weights of the model
override_weights(mha.w_q, la_params, "query")
override_weights(mha.w_k, la_params, "key")
override_weights(mha.w_v, la_params, "value")
override_weights(mha.w_o, la_params, "linear")

In [14]:
def compute_loss(preds, targets):
    """Compute the MSE loss."""
    return 0.5 * np.sum((targets - preds) ** 2) / targets.shape[0]

In [15]:
e_eval = torch.tensor(eval_data[0]).float()

# Forward pass
out = mha(e_eval, e_eval, e_eval)

# Compare the output to the targets
eval_targets = eval_data[1][:, -1]
eval_preds = out[:, -1, -1] * (-1.0)

In [16]:
loss = compute_loss(eval_preds.detach().numpy(), eval_targets)
print(f"Loss for M: {M}, N: {N} is {loss:.3f}.")


Loss for M: 10, N: 1000 is 0.472.


In [17]:
def train(
    model, optimizer, criterion, eval_data=None, training_steps=1000, linear_data=False
):
    eval_losses = []
    lowest_loss = 1e9

    # Get the evaluation data if it is not provided
    if eval_data is None:
        if linear_data:
            eval_data = get_reg_data(
                no_tasks=M, feature_size=feature_size, no_examples=N
            )
        else:
            eval_data = get_nonlinear_data(
                no_tasks=M, feature_size=feature_size, no_examples=N
            )
    assert eval_data is not None, "No evaluation data provided."
    e_eval = torch.tensor(eval_data[0]).float()
    eval_targets = eval_data[1][:, -1]
    for step in range(training_steps):
        # Generate train data
        if linear_data:
            train_data = get_reg_data(
                no_tasks=M, feature_size=feature_size, no_examples=N
            )
        else:
            train_data = get_nonlinear_data(
                no_tasks=M, feature_size=feature_size, no_examples=N
            )
        e_train = torch.tensor(train_data[0]).float()
        targets = train_data[1][:, -1]

        # Forward pass
        optimizer.zero_grad()
        out = model(e_train, e_train, e_train)
        preds = out[:, -1, -1] * (-1.0)
        loss = criterion(preds, torch.tensor(targets).float())
        loss.backward()
        optimizer.step()

        # Evaluate
        if step % 100 == 0:
            ev_preds = model(e_eval, e_eval, e_eval)
            ev_preds = ev_preds[:, -1, -1] * (-1.0)
            eval_loss = compute_loss(ev_preds.detach().numpy(), eval_targets)
            eval_losses.append(eval_loss)
            if eval_loss < lowest_loss:
                lowest_loss = eval_loss
                if linear_data:
                    data_type = "lin_data"
                else:
                    data_type = "nonlin_data"
                if model.softmax_att:
                    att = "softmax_attn"
                else:
                    att = "linear_attn"
                path = f"models/{att}-{data_type}.pth"
                torch.save(model.state_dict(), path)
            print(f"Step {step}, Train Loss: {loss.item():.3f}")
            print(f"Step {step}, Eval Loss: {eval_loss:.3f}")


In [18]:
# Now let's explore training the model
import torch.optim as optim

# Train
optimizer = optim.Adam(mha.parameters(), lr=lr)
criterion = torch.nn.MSELoss()

training_steps = 1000

train(
    mha,
    optimizer,
    criterion,
    eval_data=eval_data,
    training_steps=training_steps,
    linear_data=True,
)


Step 0, Train Loss: 0.340
Step 0, Eval Loss: 0.467
Step 100, Train Loss: 0.344
Step 100, Eval Loss: 0.288
Step 200, Train Loss: 0.221
Step 200, Eval Loss: 0.174
Step 300, Train Loss: 0.262
Step 300, Eval Loss: 0.100
Step 400, Train Loss: 0.238
Step 400, Eval Loss: 0.061
Step 500, Train Loss: 0.220
Step 500, Eval Loss: 0.040
Step 600, Train Loss: 0.174
Step 600, Eval Loss: 0.031
Step 700, Train Loss: 0.121
Step 700, Eval Loss: 0.025
Step 800, Train Loss: 0.064
Step 800, Eval Loss: 0.020
Step 900, Train Loss: 0.140
Step 900, Eval Loss: 0.018


In [20]:
# Let's do the same but with non linear data
eval_nl_data = get_nonlinear_data(no_tasks=M, feature_size=feature_size, no_examples=N)
e_eval_nl = torch.tensor(eval_nl_data[0]).float()

# Create a MultiHeadAttentionBlock
mha_nl = MultiHeadAttentionBlock(
    d_model=feature_size + 1, heads=1, dropout=0.0, softmax_att=False
)  # (batch_size, seq_len, d_model)

# Forward pass pre override
out_nl = mha_nl(e_eval_nl, e_eval_nl, e_eval_nl)

# Compare the output to the targets
eval_nl_targets = eval_nl_data[1][:, -1]
eval_nl_preds = out_nl[:, -1, -1] * (-1.0)

loss_nl = compute_loss(eval_nl_preds.detach().numpy(), eval_nl_targets)
print(f"Loss pre override for M: {M}, N: {N} is {loss_nl:.3f}.")

# Override the weights of the model
override_weights(mha_nl.w_q, la_params, "query")
override_weights(mha_nl.w_k, la_params, "key")
override_weights(mha_nl.w_v, la_params, "value")
override_weights(mha_nl.w_o, la_params, "linear")

# Forward pass
out_nl = mha_nl(e_eval_nl, e_eval_nl, e_eval_nl)

# Compare the output to the targets
eval_nl_targets = eval_nl_data[1][:, -1]
eval_nl_preds = out_nl[:, -1, -1] * (-1.0)

loss_nl = compute_loss(eval_nl_preds.detach().numpy(), eval_nl_targets)
print(f"Loss with GD weights for M: {M}, N: {N} is {loss_nl:.3f}.")


Loss pre override for M: 10, N: 1000 is 1919.709.
Loss with GD weights for M: 10, N: 1000 is 0.252.


In [21]:
optimizer = optim.Adam(mha_nl.parameters(), lr=lr)
criterion = torch.nn.MSELoss()

training_steps = 1000

# Now let's explore training the model
train(
    mha_nl,
    optimizer,
    criterion,
    eval_data=eval_nl_data,
    training_steps=training_steps,
    linear_data=False,
)


Step 0, Train Loss: 1.010
Step 0, Eval Loss: 0.538
Step 100, Train Loss: 1.512
Step 100, Eval Loss: 0.203
Step 200, Train Loss: 1.370
Step 200, Eval Loss: 0.156
Step 300, Train Loss: 1.028
Step 300, Eval Loss: 0.468
Step 400, Train Loss: 1.237
Step 400, Eval Loss: 0.202
Step 500, Train Loss: 0.446
Step 500, Eval Loss: 0.317
Step 600, Train Loss: 1.837
Step 600, Eval Loss: 0.233
Step 700, Train Loss: 0.529
Step 700, Eval Loss: 0.172
Step 800, Train Loss: 0.763
Step 800, Eval Loss: 0.279
Step 900, Train Loss: 0.646
Step 900, Eval Loss: 0.177


In [27]:
# Finally let's use softmax attention
# Create a MultiHeadAttentionBlock
mha_nl_sa = MultiHeadAttentionBlock(
    d_model=feature_size + 1, heads=1, dropout=0.0, softmax_att=True
)  # (batch_size, seq_len, d_model)

optimizer = optim.Adam(mha_nl_sa.parameters(), lr=lr)
criterion = torch.nn.MSELoss()

training_steps = 1000

# Training the model
train(
    mha_nl_sa,
    optimizer,
    criterion,
    eval_data=None,
    training_steps=training_steps,
    linear_data=False,
)


Step 0, Train Loss: 1.029
Step 0, Eval Loss: 0.473
Step 100, Train Loss: 1.641
Step 100, Eval Loss: 0.459
Step 200, Train Loss: 0.993
Step 200, Eval Loss: 0.449
Step 300, Train Loss: 1.458
Step 300, Eval Loss: 0.442
Step 400, Train Loss: 0.446
Step 400, Eval Loss: 0.438
Step 500, Train Loss: 1.105
Step 500, Eval Loss: 0.433
Step 600, Train Loss: 1.011
Step 600, Eval Loss: 0.430
Step 700, Train Loss: 1.001
Step 700, Eval Loss: 0.426
Step 800, Train Loss: 0.972
Step 800, Eval Loss: 0.422
Step 900, Train Loss: 0.477
Step 900, Eval Loss: 0.421
