In [28]:
import torch
import torchsummary
#https://storrs.io/attention/
device = 'cpu'
class transformer_cell(torch.nn.Module):
    """
 Q(batch_size, num_queries, embedding_dim)
 K(batch_size, num_keys, embedding_dim)
 V(batch_size, num_values, value_dim)

 num_queries, num_keys, and num_values are all equal to the sequence length.
 each ele

"""
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        d_emb = 512
        d_ff = 100
        self.atten = torch.nn.MultiheadAttention(d_emb,8)
        self.ff = torch.nn.Sequential(
            torch.nn.Linear(d_emb,d_ff),
            torch.nn.ReLU(),
            torch.nn.Linear(d_ff,d_emb)
        )
            
        self.ln = torch.nn.LayerNorm(d_emb)
        self.ln2 = torch.nn.LayerNorm(d_emb)

    def forward(self, input,mask=None) -> torch.Tensor:
        o = self.atten.forward(input,input,input,attn_mask=mask)[0]
        o = o + self.ln.forward(o)

        o = self.ff.forward(o)
        o = o+ self.ln2(o)
        return o
        pass

tc = transformer_cell().to(device)

x = torch.normal(0,1,(1,10,512))
o = tc.forward(x)
print(o.shape)



torch.Size([1, 10, 512])


Sure, let's consider a sequence of numbers as an example. Let's say we have the following sequence:

`[1, 2, 3, 4, 5]`

In a self-attention mechanism, this sequence would be transformed into a set of query, key, and value vectors using linear transformations (i.e., weight matrices and bias vectors). For simplicity, let's assume that these transformations result in the following query, key, and value vectors:

```
Q = [q1, q2, q3, q4, q5]
K = [k1, k2, k3, k4, k5]
V = [v1, v2, v3, v4, v5]
```

The self-attention mechanism computes the attention weights by taking the dot product of each query vector with each key vector, applying a softmax function to get probabilities, and then taking a weighted sum of the value vectors. For example, the first element in the output sequence would be computed as:

```
output1 = softmax(q1 * k1) * v1 + softmax(q1 * k2) * v2 + softmax(q1 * k3) * v3 + softmax(q1 * k4) * v4 + softmax(q1 * k5) * v5
```

This is repeated for each element in the sequence, resulting in a new sequence of the same length where each element is a weighted sum of all the elements in the original sequence.

In this way, each element in the output sequence "attends" to all the other elements in the input sequence, allowing the model to capture complex relationships between different parts of the sequence.

Let's assume that `Q`, `K`, and `V` are matrices, where each row corresponds to a vector in the query, key, and value sets, respectively. For example:

```
Q = [q1, q2, q3, q4, q5]
K = [k1, k2, k3, k4, k5]
V = [v1, v2, v3, v4, v5]
```

We can compute the attention weights using the matrix multiplication of `Q` and `K^T` (the transpose of `K`), followed by a softmax operation. This will give us a matrix `A` of attention weights, where each row corresponds to a weight vector for a single query vector. For example:

```
A = softmax(Q @ K.T)
```

Here, `@` denotes matrix multiplication.

We can then compute the output matrix `O` by taking the matrix product of `A` and `V`:

```
O = A @ V
```

The first row of `O` will correspond to the first element in the output sequence, which is the weighted sum of all the value vectors, where the weights are given by the attention weights in the first row of `A`. In other words:

```
output1 = O[0, :]
```

This gives us the first element in the output sequence, computed using matrix multiplication.


In the example I provided, the weights are implicitly included in the query, key, and value vectors. These vectors are typically obtained by applying linear transformations (i.e., weight matrices and bias vectors) to the input sequence.

Let's denote the input sequence as `X`, and the weight matrices for the query, key, and value transformations as `W_q`, `W_k`, and `W_v`, respectively. Then, the query, key, and value vectors are computed as:

```
Q = X * W_q + b_q
K = X * W_k + b_k
V = X * W_v + b_v
```

Here, `b_q`, `b_k`, and `b_v` are bias vectors. The weight matrices and bias vectors are learnable parameters that are optimized during training to minimize the loss function of the model.

In the example I provided, I simplified the explanation by assuming that the query, key, and value vectors were already given. In practice, these vectors are computed from the input sequence using the weight matrices and bias vectors.