# Common

## Imports and Helper Function

In [3]:
# imports
import numpy as np
import math
np.random.seed(42)

# helper function
def print_matop(*matrices, separator="\t\t"):
    # Find the maximum number of rows across all matrices
    max_rows = max(len(matrix) for matrix in matrices)
    
    # Iterate over the rows by index up to the max number of rows
    for i in range(max_rows):
        formatted_rows = []
        for matrix in matrices:
            # Check if the matrix has a row at index i
            if i < len(matrix):
                row = matrix[i]
                # If row is iterable (like a list or numpy array), format each element
                if hasattr(row, '__iter__'):
                    formatted_rows.append("[" + " ".join(f"{val:2.2f}" for val in row) + "]")
                else:
                    # If it's a single scalar value, format it directly
                    formatted_rows.append(f"{row:2.2f}")
            else:
                # Add an empty value if the matrix does not have enough rows
                formatted_rows.append('\t')
        
        # Join the formatted rows with the specified separator and print them
        print(separator.join(formatted_rows))


## Sample Input

In [4]:
# SRAM memory size
M = 20
# simplified head size = hidden size
d = 2
# Sequence length
N = 6

# set block size for outer loop
Bc = math.ceil(M / (4 * d))
print(f"Bc: {Bc}")
# set block size for inner loop
Br = min(Bc, d)
print(f"Br: {Br}")

# Example dimensions for Q, K, V matrices
Q = np.random.randn(N, d)
K = np.random.randn(N, d)
V = np.random.randn(N, d)
with np.printoptions(precision=2, suppress=True):
    print(f"Q \t\t K \t\t V")
    for q, k, v in zip(Q, K, V):
        print(f"{q} \t {k} \t {v}")


Bc: 3
Br: 2
Q 		 K 		 V
[ 0.5  -0.14] 	 [ 0.24 -1.91] 	 [-0.54  0.11]
[0.65 1.52] 	 [-1.72 -0.56] 	 [-1.15  0.38]
[-0.23 -0.23] 	 [-1.01  0.31] 	 [-0.6  -0.29]
[1.58 0.77] 	 [-0.91 -1.41] 	 [-0.6   1.85]
[-0.47  0.54] 	 [ 1.47 -0.23] 	 [-0.01 -1.06]
[-0.46 -0.47] 	 [ 0.07 -1.42] 	 [ 0.82 -1.22]


## Simplified Forward Pass 1
- direct attention computation using numpy/scipy

In [5]:
import numpy as np
from scipy.special import softmax  # Optional, can use numpy's method or scipy's

def attention(Q, K, V):
    # Step 1: Calculate the dot product of Q and K.T (scores)
    S = np.dot(Q, K.T)  # (N, d) @ (d, N) -> (N, N)
    
    # Step 2: Apply the softmax to the scores for each row
    P = softmax(S, axis=1)  # Softmax along the rows
    
    # Step 3: Multiply the attention weights with the value matrix V
    O = np.dot(P, V)  # (N, N) @ (N, d) -> (N, d)
    
    return O

# Calculate the attention output
_O = attention(Q, K, V)

# Print the result
print("Attention Output O:")
with np.printoptions(precision=2, suppress=True):
    print(_O)


Attention Output O:
[[-0.17 -0.33]
 [-0.22 -0.7 ]
 [-0.41  0.14]
 [-0.03 -0.97]
 [-0.6   0.07]
 [-0.47  0.29]]


## Backward Pass

### Simplified Algorithm

In [20]:
with np.printoptions(precision=2, suppress=True):

    # Forward pass
    print("\nBackward Pass Outputs:")
    print("#" * 50)
    S = Q @ K.T
    print(f"S \t\t\t=\t\t\t Q \t\t@\t\t K.T")
    print_matop(S, Q, K.T)

    m = np.max(S, axis=1, keepdims=True)
    P = np.exp(S - m)
    l = np.sum(np.exp(S - m), axis=1, keepdims=True)
    A = P / l
    print(f"A \t\t\t=\t\t\t exp(S - m) \t\t\t/\t sum(exp(S - m))")
    print_matop(A, P, l, separator='\t\t\t')

    _O = A @ V
    print(f"O \t\t=\t\t\t A \t\t@\t\t V")
    print_matop(_O, A, V)
    
    dO = np.random.randn(N, d)  # Assume dO comes from the next layer
    print(f"dO obtained from the next layer in backprop")
    print(dO)

    # Backward pass
    print("\nBackward Pass Outputs:")
    print("#" * 50)
    dV = P.T @ dO
    print(f"dV \t\t=\t\t P.T \t\t\t@\t dO")
    print_matop(dV, P.T, dO)

    dP = dO @ V.T
    print(f"dP \t\t\t\t=\t\t dO \t\t@\t\t\t V.T")
    print_matop(dP, dO, V.T)

    dS = P * (dP - np.sum(P * dP, axis=1, keepdims=True))
    print(f"dS \t\t\t\t=\t\t P \t\t\t\t* (dP - np.sum(P * dP))")
    print_matop(dS, P, np.sum(P * dP, axis=1, keepdims=True))

    dQ = dS @ K
    print(f"dQ \t\t=\t\t\t dS \t\t\t@\t K")
    print_matop(dQ, dS, K)

    dK = dS.T @ Q
    print(f"dK \t\t=\t\t\t dS.T \t\t\t@\t Q")
    print_matop(dK, dS.T, Q)



Backward Pass Outputs:
##################################################
S 			=			 Q 		@		 K.T
[0.38 -0.78 -0.55 -0.26 0.76 0.23]		[0.50 -0.14]		[0.24 -1.72 -1.01 -0.91 1.47 0.07]
[-2.76 -1.97 -0.18 -2.74 0.61 -2.13]		[0.65 1.52]		[-1.91 -0.56 0.31 -1.41 -0.23 -1.42]
[0.39 0.54 0.16 0.54 -0.29 0.32]		[-0.23 -0.23]			
[-1.09 -3.16 -1.36 -2.52 2.14 -0.99]		[1.58 0.77]			
[-1.15 0.50 0.65 -0.34 -0.81 -0.80]		[-0.47 0.54]			
[0.78 1.06 0.32 1.08 -0.57 0.63]		[-0.46 -0.47]			
A 			=			 exp(S - m) 			/	 sum(exp(S - m))
[0.22 0.07 0.09 0.12 0.32 0.19]			[0.69 0.21 0.27 0.36 1.00 0.59]			[3.13]
[0.02 0.05 0.27 0.02 0.60 0.04]			[0.03 0.08 0.46 0.04 1.00 0.07]			[1.67]
[0.18 0.21 0.14 0.21 0.09 0.17]			[0.86 0.99 0.68 1.00 0.43 0.80]			[4.77]
[0.04 0.00 0.03 0.01 0.89 0.04]			[0.04 0.01 0.03 0.01 1.00 0.04]			[1.13]
[0.06 0.30 0.35 0.13 0.08 0.08]			[0.17 0.87 1.00 0.37 0.23 0.23]			[2.87]
[0.18 0.24 0.12 0.25 0.05 0.16]			[0.74 0.98 0.47 1.00 0.19 0.64]			[4.03]
O 		=			 A 		@		 V
[-0.17 -0.

In [None]:
def causal_mask(I):
    mask = np.tril(np.ones_like(I), k=0)  # Lower triangular matrix with ones
    O = np.where(mask == 0, -np.inf, I)   # Apply the mask
    return O, mask

def d_causal_mask(dO, mask):
    dI = dO * mask  # Propagate gradient only through unmasked elements (mask == 1)
    return dI

I = np.random.randn(4, 4)

# Forward pass
O, mask = causal_mask(I)
print("Forward Pass Output (O):")
print(O)

# Assume some gradient coming from the next layer
dO = np.random.randn(*O.shape)

# Backward pass
dI = d_causal_mask(dO, mask)
print("\nBackward Pass Output (dI):")
print(dI)
import numpy as np




Forward Pass Output (O):
[[-1.32023321        -inf        -inf        -inf]
 [-1.71313453  1.35387237        -inf        -inf]
 [-1.59442766 -0.59937502  0.0052437         -inf]
 [-0.45006547  0.62284993 -1.06762043 -0.14237949]]

Backward Pass Output (dI):
[[ 0.12029563  0.          0.         -0.        ]
 [-1.53411417  1.27767682  0.         -0.        ]
 [ 1.55115198  0.11567463  1.17929718  0.        ]
 [ 2.06074792  1.75534084 -0.24896415  0.97157095]]


In [None]:
def softmax(I: np.ndarray) -> np.ndarray:
    """
    I: (n, m)
    O: (n, m)
    """
    exp_I = np.exp(I - np.max(I, axis=1, keepdims=True))  # Subtract max for numerical stability
    O = exp_I / np.sum(exp_I, axis=1, keepdims=True)
    return O

def d_softmax(dO: np.ndarray, O: np.ndarray) -> np.ndarray:
    """
    dO: (n, m)
    O: (n, m)
    dI: (n, m)
    """
    n, m = O.shape
    dI = np.zeros_like(dO)

    for i in range(n):
        # Reshape O[i] to a column vector of shape (m, 1)
        O_i = O[i].reshape(-1, 1)
        # Compute the Jacobian matrix of the softmax function for this row, shape (m, m)
        jacobian_i = np.diagflat(O_i) - np.dot(O_i, O_i.T)
        # Compute the gradient for this row, shape (m,)
        dI[i] = np.dot(jacobian_i, dO[i])
    return dI

# Example input matrix I of shape (4, 4)
I = np.random.randn(4, 4)
print("I:")
print(I)
# Forward pass
O = softmax(I)
print("O:")
print(O)

# Assume some gradient coming from the next layer, of shape (4, 4)
dO = np.random.randn(*O.shape)
print("dO:")
print(dO)

# Backward pass
dI = d_softmax(dO, O)
print("dI:")
print(dI)

I:
[[ 1.68714164  0.88163976 -0.00797264  1.47994414]
 [ 0.07736831 -0.8612842   1.52312408  0.53891004]
 [-1.03724615 -0.19033868 -0.87561825 -1.38279973]
 [ 0.92617755  1.90941664 -1.39856757  0.56296924]]
O:
[[0.40928237 0.18289339 0.07513534 0.3326889 ]
 [0.13845177 0.05415604 0.58773596 0.21965624]
 [0.19172978 0.44719488 0.22536355 0.13571179]
 [0.22389644 0.59849772 0.02189895 0.15570689]]
dO:
[[-0.65064257 -0.48712538 -0.59239392 -0.86399077]
 [ 0.04852163 -0.83095012  0.27045683 -0.05023811]
 [-0.23894805 -0.90756366 -0.57677133  0.75539123]
 [ 0.50091719 -0.97755524  0.09933231  0.75138712]]
dI:
[[ 0.01501896  0.03661764  0.00713369 -0.05877028]
 [-0.00846181 -0.05093858  0.0945184  -0.03511801]
 [ 0.04605176 -0.1915893  -0.02200276  0.1675403 ]
 [ 0.19135461 -0.37335225  0.0099218   0.17207584]]
