<a href="https://colab.research.google.com/github/noobylub/Computational-Linguistic/blob/master/AttentionAsKernelRegression_(with_TODOs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
from sklearn import datasets, linear_model
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import mean_squared_error, r2_score
import torch
import torch.optim as optim
import torch.nn as nn

# Load the diabetes dataset

# Ten baseline variables, age, sex, body mass index, average blood pressure,
# and six blood serum measurements were obtained for each of n = 442 diabetes
# patients, as well as the response of interest, a quantitative measure of
# disease progression one year after baseline.
# Note: Each of these 10 feature variables have been mean centered and scaled by
# the standard deviation times the square root of n_samples (i.e. the sum of
# squares of each column totals 1).

# DN note: scaling by 1/sqrt(n)*std instead of 1/std works better with
# regularised regression.

X, y = datasets.load_diabetes(return_X_y=True)
X.shape

(442, 10)

### A good baseline: vanilla OLS linear regression

In [2]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

ols_y_pred = None

# TODO: train a linear-regression model on X_train + y_train
# and predict on X_test

lr_instance = linear_model.LinearRegression()
lr_instance.fit(X_train, y_train)

ols_y_pred = lr_instance.predict(X_test)

print("Mean squared error: %.2f" % mean_squared_error(y_test, ols_y_pred))
# The coefficient of determination: 1 is perfect prediction
print("Coefficient of determination: %.2f" % r2_score(y_test, ols_y_pred))

Mean squared error: 2900.19
Coefficient of determination: 0.45


### A dumb baseline: 1-nearest-neighbour regressor

In [3]:
def one_nearest_neighbor_regressor(X_train, y_train, X_test):
    y_pred = []

    # Calculate Euclidean distances from all elements of X_test
    # to all training samples.
    # Find the index of the closest training sample for each.
    # Predict the value from y_train corresponding to the closest sample
    # in X_train.

    for row_i in range(X_test.shape[0]):
        x_test_sample = X_test[row_i]
        vector_differences = X_train - x_test_sample
        euclidean_distances = np.square(vector_differences).sum(axis=1)
        y_pred.append(
            y_train[np.argmin(euclidean_distances)]
        )

    return y_pred

ols_1nn_y_pred = one_nearest_neighbor_regressor(X_train, y_train, X_test)

print("Mean squared error (1-NN): %.2f" % mean_squared_error(y_test, ols_1nn_y_pred))
print("Coefficient of determination (1-NN): %.2f" % r2_score(y_test, ols_1nn_y_pred))

Mean squared error (1-NN): 5191.24
Coefficient of determination (1-NN): 0.02


### Reducing variance: interpolate between k nearest neigbhours

In [4]:
def k_nearest_neighbor_regressor(X_train, y_train, X_test, k=3):
    y_pred = []

    # Calculate Euclidean distance to all training samples
    # Get the indices of the k closest training samples
    # Predict the average of y_train values corresponding to the k closest
    # samples in X_train

    for row_i in range(X_test.shape[0]):
        x_test_sample = X_test[row_i]
        vector_differences = X_train - x_test_sample
        euclidean_distances = np.sum(np.abs(vector_differences), axis=1)
        # euclidean_distances = np.square(vector_differences).sum(axis=1)
        # euclidean_distances = np.linalg.norm(X_train - x_test_sample, axis=1)
        closest_indices = np.argsort(euclidean_distances)[:k]
        y_pred.append(np.mean(y_train[closest_indices]))

    return y_pred

# To select k properly we need a dev set/cross-validation, but this will do
# for comparison purposes
best_knn_mse = float('inf')
best_knn_r2 = 0.0
best_k = None
for k in [3, 5, 10, 15]:
    knn_y_pred = k_nearest_neighbor_regressor(X_train, y_train, X_test, k=k)
    k_mse = mean_squared_error(y_test, knn_y_pred)
    print(f"Mean squared error ({k}-NN): {k_mse:2f}")
    k_r2 = r2_score(y_test, knn_y_pred)
    print(f"Coefficient of determination ({k}-NN): {k_r2:2f}")
    if k_mse < best_knn_mse:
        best_knn_mse = k_mse
        best_knn_r2 = k_r2
        best_k = k
    print()

Mean squared error (3-NN): 3344.272160
Coefficient of determination (3-NN): 0.368785

Mean squared error (5-NN): 2925.796404
Coefficient of determination (5-NN): 0.447770

Mean squared error (10-NN): 3050.318202
Coefficient of determination (10-NN): 0.424267

Mean squared error (15-NN): 3130.723146
Coefficient of determination (15-NN): 0.409091



### Weighted interpolation: weight all values in the training set by their distance to the test sample in the feature space, using dot products as distances

In [5]:
import numpy as np
from scipy.special import softmax

def dot_product_softmax_regressor(X_train, y_train, X_test):
    y_pred = []

    # Similar to k_nearest_neighbor_regressor, but with dot product
    # instead of Euclidean distances, passed through softmax to get weights,
    # and with final result being
    # a weighted average of all y_train elements.

    for row_i in range(X_test.shape[0]):
        x_test_sample = X_test[row_i]
        vector_differences = X_train - x_test_sample
        euclidean_distances = np.square(vector_differences).sum(axis=1)
        weights = softmax(-1 * euclidean_distances)
        y_pred.append(np.dot(weights, y_train))
    return y_pred

dot_softmax_y_pred = dot_product_softmax_regressor(X_train, y_train, X_test)

print("Mean squared error (Dot Product Softmax Regressor): %.2f" % mean_squared_error(y_test, dot_softmax_y_pred))
print("Coefficient of determination (Dot Product Softmax Regressor): %.2f" % r2_score(y_test, dot_softmax_y_pred))

Mean squared error (Dot Product Softmax Regressor): 5285.03
Coefficient of determination (Dot Product Softmax Regressor): 0.00


### Some preliminary work for training on the GPU

In [6]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Convert NumPy arrays to PyTorch tensors
# Ensure X and X_test are float32, and y, y_train, y_test are float32 and
# reshaped to (-1, 1)
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1)

# Split the original training data into new training and validation sets
X_train_new, X_val, y_train_new, y_val = train_test_split(
    X_train_tensor, y_train_tensor, test_size=0.2, random_state=42
)

# Move tensors to GPU
X_train_new = X_train_new.to(device)
y_train_new = y_train_new.to(device)
X_val = X_val.to(device)
y_val = y_val.to(device)
X_test_tensor = X_test_tensor.to(device)
y_test_tensor = y_test_tensor.to(device)
X_train_tensor = X_train_tensor.to(device)
y_train_tensor = y_train_tensor.to(device)


train_indices = torch.arange(X_train_new.shape[0]).to(device)
train_dataset = TensorDataset(X_train_new, y_train_new, train_indices)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

# Create DataLoader objects to conveniently iterate over batches
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"New training set size: {len(train_dataset)} samples")
print(f"Validation set size: {len(val_dataset)} samples")
print(f"Test set size: {len(test_dataset)} samples")

Using device: cuda
New training set size: 282 samples
Validation set size: 71 samples
Test set size: 89 samples


### Let's make the distance calculation more general by using a bilinear form

In [7]:
class ProjectionSoftmaxRegressor(nn.Module):
    def __init__(self, num_features):
        super(ProjectionSoftmaxRegressor, self).__init__()
        # Initialize a learnable parameter W of shape (num_features, num_features)
        # num_features by num_features parameter, a lot more parameters than linear regression
        # This can cause overfitting
        self.A = nn.Parameter(torch.randn(num_features, num_features))
    # y_train_full already give the answers
    def forward(self, X_query, X_train_full, y_train_full, query_indices=None):
        # X_query: (batch_size, num_features)
        # X_train_full.T: (num_features, num_training_samples)

        # W: (num_features, num_features)
        W = self.A.T @ self.A
        # This ensures that W is symmetric and positive semi-definite,
        # so we are learning a proper metric.

        # Calculate scores: X_query @ W @ X_train_full.T
        # scores: (batch_size, num_training_samples)
        # VERY IMPORTANT TO ATTENTION ----------
        # Similarity score, weighted dot product
        # Dot product is unnormalised correlation
        # This is because similar vectors have high number, this is because of the sign
        # You are adding a product of all the numbers within that vector, and if it has a different sign
        # W esseentially normalises the dot product to prepare a dot product with X_train_full
        scores = X_query @ W @ X_train_full.T

        # TODO: Prevent DATA LEAKAGE
        # Data leakage leads to overfitting
        if query_indices is not None:
            pass

        # Apply softmax along the dimension corresponding to training samples
        weights = torch.softmax(scores, dim=-1)

        # Compute predictions as a weighted sum of y_train_full
        # weights: (batch_size, num_training_samples)
        # y_train_full: (num_training_samples, 1)
        # predictions: (batch_size, 1)
        predictions = weights @ y_train_full
        return predictions

In [8]:
num_features = X_train_new.shape[1]
model = ProjectionSoftmaxRegressor(num_features)
model = model.to(device)

In [9]:
# Mean squared error, similar to euclidian distance, but take the mean, instead of sum
criterion = nn.MSELoss()
# Performs backward pass
optimizer = optim.AdamW(model.parameters(), lr=0.001)

# Early stopping parameters
patience = 10
min_val_loss = float('inf')
trigger_times = 0

# Set a reasonably high number of epochs, early stopping will prevent overfitting
epochs = 20000

# Store training and validation losses for plotting/analysis
train_losses = []
val_losses = []

print("Starting model training...")

for epoch in range(epochs):
    # Set model to training mode: this changes some behaviours
    model.train()
    current_train_loss = 0.0
    for batch_X, batch_y, _ in train_loader:
        # 4. Forward pass, loss calculation, backward pass, and optimizer step
        # Reset the entire gradient set
        optimizer.zero_grad()
        # X_train_new and y_train_new are full training data
        # Is there a potential problem here?
        # How can we avoid it?
        predictions = model(batch_X, X_train_new, y_train_new)
        # Retrieve the gradient
        loss = criterion(predictions, batch_y)
        loss.backward()
        # Implement
        optimizer.step()
        current_train_loss += loss.item()

    avg_train_loss = current_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # 5. Evaluate the model on the val_loader
    model.eval() # Set model to evaluation mode
    current_val_loss = 0.0
    # Torch.no_grad means you are not calculating the gradient,
    with torch.no_grad(): # Disable gradient calculation for validation
        for batch_X_val, batch_y_val in val_loader:
            predictions_val = model(batch_X_val, X_train_new, y_train_new)
            val_loss = criterion(predictions_val, batch_y_val)
            current_val_loss += val_loss.item()

    avg_val_loss = current_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    if (epoch + 1) % 250 == 0:
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

    # 6. Implement early stopping
    if avg_val_loss < min_val_loss:
        min_val_loss = avg_val_loss
        trigger_times = 0
        # Save the best model weights
        torch.save(model.state_dict(), 'best_pretrained_softmax_regressor_model.pth')
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print(f'Early stopping triggered after {epoch+1} epochs due to no improvement in validation loss.')
            break

print("Model training complete. Best model saved.")

Starting model training...
Early stopping triggered after 175 epochs due to no improvement in validation loss.
Model training complete. Best model saved.


### The story so far

In [10]:
model.load_state_dict(torch.load('best_pretrained_softmax_regressor_model.pth'))
model.eval()

all_predictions = []
all_targets = []

print("Evaluating model on the test set...")

with torch.no_grad():
    for batch_X_test, batch_y_test in test_loader:
        predictions_test = model(batch_X_test, X_train_tensor, y_train_tensor)
        all_predictions.append(predictions_test.cpu().numpy())
        all_targets.append(batch_y_test.cpu().numpy())

test_predictions = np.vstack(all_predictions)
test_targets = np.vstack(all_targets)

mse = mean_squared_error(test_targets, test_predictions)
r2 = r2_score(test_targets, test_predictions)

print("\n--- Comparison with previous models ---")
print(f"OLS Regressor - MSE: {mean_squared_error(y_test, ols_y_pred):.2f}, R2: {r2_score(y_test, ols_y_pred):.2f}")
print(f"1-NN Regressor - MSE: {mean_squared_error(y_test, ols_1nn_y_pred):.2f}, R2: {r2_score(y_test, ols_1nn_y_pred):.2f}")
print(f"{best_k}-NN Regressor - MSE: {best_knn_mse:.2f}, R2: {best_knn_r2:.2f}")
print(f"Vanilla Dot Product Softmax Regressor - MSE: {mean_squared_error(y_test, dot_softmax_y_pred):.2f}, R2: {r2_score(y_test, dot_softmax_y_pred):.2f}")
print(f"PyTorch ProjectionSoftmaxRegressor (Trained) - MSE: {mse:.2f}, R2: {r2:.2f}")

Evaluating model on the test set...

--- Comparison with previous models ---
OLS Regressor - MSE: 2900.19, R2: 0.45
1-NN Regressor - MSE: 5191.24, R2: 0.02
5-NN Regressor - MSE: 2925.80, R2: 0.45
Vanilla Dot Product Softmax Regressor - MSE: 5285.03, R2: 0.00
PyTorch ProjectionSoftmaxRegressor (Trained) - MSE: 2840.28, R2: 0.46


### (Almost) standard attention

In [11]:
class AttentionRegressor(nn.Module):
    def __init__(self, num_features, projection_dim=64):
        super(AttentionRegressor, self).__init__()
        self.projection_dim = projection_dim

        # Initialize two learnable parameter matrices for query and train projections
        self.W_query_proj = nn.Parameter(torch.randn(num_features, projection_dim))
        self.W_train_proj = nn.Parameter(torch.randn(num_features, projection_dim))

    def forward(self, X_query, X_train_full, y_train_full, query_indices=None):
        # Project X_query and X_train_full using their respective projection matrices
        # X_query: (batch_size, num_features) -> X_query_projected: (batch_size, projection_dim)
        # X_train_full: (num_training_samples, num_features) -> X_train_projected: (num_training_samples, projection_dim)
        X_query_projected = X_query @ self.W_query_proj
        X_train_projected = X_train_full @ self.W_train_proj

        # Let's use more standard names and wrap things up.
        Q = X_query @ self.W_query_proj
        # Measure of relevance
        K = X_train_full @ self.W_train_proj

        V = y_train_full

        # Calculate scores: Q @ K.T
        # When we multiply the matrix, we essentially multiply below
        # But for the second part, because we are transposing
        # X_query @ self.W_query_proj @ X_train_full @ self.W_train_proj
        scores = Q @ K.T
        weights = torch.softmax(scores, dim=-1)
        predictions = None

        # TODO: Prevent DATA LEAKAGE
        # Data leakage leads to overfitting
        if query_indices is not None:
            pass
        # Compute predictions as a weighted sum of y_train_full
        # weights: (batch_size, num_training_samples)
        # y_train_full: (num_training_samples, 1)
        # predictions: (batch_size, 1)
        predictions = weights @ y_train_full
        return predictions

        # return predictions

In [14]:
# 1. Model instantiation
num_features = X_train_new.shape[1]
# We project inputs to a higher-dimensional space because relations between
# data-points may be more transparent there
projection_dim = 32
model_projected = AttentionRegressor(
    num_features,
    projection_dim=projection_dim)
model_projected.to(device)

# 2. Definition of a loss function and an optimizer
criterion_projected = nn.MSELoss()
optimizer_projected = optim.AdamW(model_projected.parameters(), lr=0.001)

# Early stopping parameters
patience = 10  # Same patience as before
min_val_loss_projected = float('inf')
trigger_times_projected = 0

epochs = 20000

# Store training and validation losses for plotting/analysis
train_losses_projected = []
val_losses_projected = []

print("Starting AttentionRegressor model training...")

for epoch in range(epochs):
    model_projected.train() # Set model to training mode
    current_train_loss_projected = 0.0
    for batch_X,batch_y,_ in train_loader:
        # 4. Forward pass, loss calculation, backward pass, and optimizer step
        optimizer_projected.zero_grad()
        # X_train_tensor and y_train_tensor are the full original training data, required by the model
        predictions = model_projected(batch_X, X_train_tensor, y_train_tensor)
        loss = criterion_projected(predictions, batch_y)
        loss.backward()
        optimizer_projected.step()
        current_train_loss_projected += loss.item()

    avg_train_loss_projected = current_train_loss_projected / len(train_loader)
    train_losses_projected.append(avg_train_loss_projected)

    # 5. Evaluate the model on the val_loader
    model_projected.eval() # Set model to evaluation mode
    current_val_loss_projected = 0.0
    with torch.no_grad(): # Disable gradient calculation for validation
        for batch_X_val, batch_y_val in val_loader:
            predictions_val = model_projected(batch_X_val, X_train_tensor, y_train_tensor)
            val_loss = criterion_projected(predictions_val, batch_y_val)
            current_val_loss_projected += val_loss.item()

    avg_val_loss_projected = current_val_loss_projected / len(val_loader)
    val_losses_projected.append(avg_val_loss_projected)

    if (epoch + 1) % 250 == 0:
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss_projected:.4f}, Val Loss: {avg_val_loss_projected:.4f}')

    # 6. Implement early stopping
    if avg_val_loss_projected < min_val_loss_projected:
        min_val_loss_projected = avg_val_loss_projected
        trigger_times_projected = 0
        # Save the best model weights
        torch.save(model_projected.state_dict(), 'best_projected_softmax_regressor_model.pth')
    else:
        trigger_times_projected += 1
        if trigger_times_projected >= patience:
            print(f'Early stopping triggered after {epoch+1} epochs due to no improvement in validation loss.')
            break

print("AttentionRegressor model training complete. Best model saved.")

Starting AttentionRegressor model training...
Early stopping triggered after 101 epochs due to no improvement in validation loss.
AttentionRegressor model training complete. Best model saved.


In [16]:
from sklearn.metrics import mean_squared_error, r2_score

model_projected.load_state_dict(torch.load('best_projected_softmax_regressor_model.pth'))
model_projected.eval()

all_predictions_projected = []
all_targets_projected = []

print("Evaluating AttentionRegressor on the test set...")

with torch.no_grad():
    for batch_X_test, batch_y_test in test_loader:
        predictions_test_projected = model_projected(batch_X_test, X_train_tensor, y_train_tensor)
        all_predictions_projected.append(predictions_test_projected.cpu().numpy())
        all_targets_projected.append(batch_y_test.cpu().numpy())

test_predictions_projected = np.vstack(all_predictions_projected)
test_targets_projected = np.vstack(all_targets_projected)
mse_projected = mean_squared_error(test_targets_projected, test_predictions_projected)
r2_projected = r2_score(test_targets_projected, test_predictions_projected)

print(f"Mean Squared Error (PyTorch AttentionRegressor on Test Set): {mse_projected:.2f}")
print(f"Coefficient of Determination (PyTorch AttentionRegressor on Test Set): {r2_projected:.4f}")

Evaluating AttentionRegressor on the test set...
Mean Squared Error (PyTorch AttentionRegressor on Test Set): 2742.54
Coefficient of Determination (PyTorch AttentionRegressor on Test Set): 0.4824


## Why ‘almost’?

In this exercise, we interpolated between different values of **y** based on pairwise distances between corresponding rows of X.

In real attention, we interpolate the rows of X themselves. They correspond to input tokens and are ordered. Our goal is to find the most useful/informative linear combination of their values in order to predict something -- usually the next token in the sequence.