# Self-Attention in Time Series Models

### 1. Initial Embedding

We take each time series, slice it into patches of a fixed length of time steps (32 in this example). 

These are converted into a vector.

View an illustrated walkthrough here: https://bitsofchris.com/p/how-to-implement-factorized-attention

In [2]:
import torch

# For 4 stocks with 10 time patches each into a 128 dimension embedding space
embeddings = torch.randn(4, 10, 128)  # Shape: [4, 10, 128]

### 2. Query, Key, Value Projections

Our Transformer model is initialized with a Query, Key, and Value matrix in each attention head. These matrices are updated throughout the model training process.

In [2]:
# Create matrices
W_query = torch.randn(128, 64)  # embedding_dim -> query_dim
W_key = torch.randn(128, 64)    # embedding_dim -> key_dim
W_value = torch.randn(128, 128)  # embedding_dim -> value_dim

# Project embeddings to Q, K, V for each vector
Q = torch.matmul(embeddings, W_query)  # Shape: [4, 10, 64]
K = torch.matmul(embeddings, W_key)    # Shape: [4, 10, 64]
V = torch.matmul(embeddings, W_value)  # Shape: [4, 10, 128]

### 3. Attention Scores and Weights

Now we compute how similar a query for a token (what am i looking for) is with the key of every other token (what info I offer).

This gives us a set of weights that shows how much a token cares about another token.

In [7]:
# Transpose K for matrix multiplication
K_transposed = K.transpose(-2, -1)  # Shape: [4, 64, 10]

# Compute attention scores
# Q @ K_T shape: [4, 10, 10]
attention_scores = torch.matmul(Q, K_transposed)

# Scale scores by square root of key dimension
attention_scores = attention_scores / (64 ** 0.5)

# Apply softmax to get weights that sum to 1
attention_weights = torch.softmax(attention_scores, dim=-1)  # Sh

### 4. Context Vector Creation

Finally, we use these weights to create a weighted sum of values:

In [9]:
# Weighted sum creates the context vectors
# Shape: [4, 10, 128]
context_vectors = torch.matmul(attention_weights, V)

In [10]:
context_vectors.shape

torch.Size([4, 10, 128])

In [11]:
# for one of the stocks
context_vectors[0]

tensor([[  1.2767, -11.7497,  21.9428,  ...,  -4.2383,  -7.0721,   5.4788],
        [  3.9176, -12.2673,   3.2509,  ..., -32.1171,   3.0273,   3.6007],
        [ -4.3467,   4.1053,  -6.2469,  ...,  -7.1524,  -3.2380,  -4.0690],
        ...,
        [ -6.1965,  -9.0071,  -5.2099,  ...,   3.4394,   8.5367,  -2.7584],
        [ -9.0476,   7.8684,  10.4539,  ...,   3.8208,  -3.9376,  -1.5432],
        [  1.2767, -11.7497,  21.9428,  ...,  -4.2383,  -7.0721,   5.4788]])

# Factorized Attention: Splitting Time and Space

The above is how attention is calculated in basic self-attention, common for a large language model. In time series, instead of doing one big attention calculation over all dimensions, we split it into time and space dimensions.

The time-wise attention is similar to what a LLM does, looking at sequences of patches for one series.

Space-wise attention looks at patches across different but related time series at the same time step.

In [None]:
embeddings = torch.randn(4, 10, 128)  # [batch_size=1, stocks, time_patches, embedding_dim]

time_embeddings = embeddings.reshape(4, 10, 128) # [batch_size=1 * stocks, time_patches, embedding_dim]

# Initialize time-wise projection matrices
W_time_query = torch.randn(128, 64)  # embedding_dim -> query_dim
W_time_key = torch.randn(128, 64)    # embedding_dim -> key_dim
W_time_value = torch.randn(128, 128)  # embedding_dim -> value_dim

# Project to Q, K, V
time_Q = torch.matmul(time_embeddings, W_time_query)
time_K = torch.matmul(time_embeddings, W_time_key)
time_V = torch.matmul(time_embeddings, W_time_value)

# Compute time-wise attention
time_scores = torch.matmul(time_Q, time_K.transpose(-2, -1)) / (64 ** 0.5)
time_weights = torch.softmax(time_scores, dim=-1)

# Get the context vectors
time_context = torch.matmul(time_weights, time_V)

In [4]:
time_context.shape

torch.Size([4, 10, 128])

In [5]:
# Reshape for space-wise attention
space_embeddings = time_context.transpose(0, 1)  # shape: [10, 4, 128]

W_space_query = torch.randn(128, 64)   
W_space_key   = torch.randn(128, 64)   
W_space_value = torch.randn(128, 128) 

# space_embeddings.shape = [10, 4, 128]
space_Q = torch.matmul(space_embeddings, W_space_query)
space_K = torch.matmul(space_embeddings, W_space_key)   
space_V = torch.matmul(space_embeddings, W_space_value)

# Compute the space-waise attention
space_scores = torch.matmul(space_Q, space_K.transpose(-2, -1)) / (64.0 ** 0.5)
space_weights = torch.softmax(space_scores, dim=-1)  # shape: [10, 4, 4]

# Weighted sum over values => [10, 4, 128]
space_context = torch.matmul(space_weights, space_V)

In [6]:
space_context.shape

torch.Size([10, 4, 128])

In [7]:
# shape: [10, 4, 128] -> [4, 10, 128]
space_context = space_context.transpose(0, 1)
print(space_context.shape)

torch.Size([4, 10, 128])
