# Common

## Imports and Helper Function

In [42]:
def print_matop(*matrices, separator="\t\t", width=5):
    """
    Prints multiple matrices side by side, ensuring proper alignment.
    
    Args:
    - matrices: The matrices to print side by side.
    - separator: The string used to separate the matrices in the output.
    - width: The fixed width for each element in the matrix to ensure alignment.
    """
    # Find the maximum number of rows across all matrices
    max_rows = max(len(matrix) for matrix in matrices)
    
    # Iterate over the rows by index up to the max number of rows
    for i in range(max_rows):
        formatted_rows = []
        for matrix in matrices:
            # Check if the matrix has a row at index i
            if i < len(matrix):
                row = matrix[i]
                # If row is iterable (like a list or numpy array), format each element
                if hasattr(row, '__iter__'):
                    formatted_rows.append("[" + " ".join(f"{val:{width}.2f}" for val in row) + "]")
                else:
                    # If it's a single scalar value, format it directly
                    formatted_rows.append(f"  {row:{width}.2f}")
            else:
                # Add an empty value if the matrix does not have enough rows
                formatted_rows.append('\t' * (width // 4))
        
        # Join the formatted rows with the specified separator and print them
        print(separator.join(formatted_rows))



## Sample Input

In [43]:
# SRAM memory size
M = 20
# simplified head size = hidden size
d = 2
# Sequence length
N = 6

# set block size for outer loop
Bc = math.ceil(M / (4 * d))
print(f"Bc: {Bc}")
# set block size for inner loop
Br = min(Bc, d)
print(f"Br: {Br}")

# Example dimensions for Q, K, V matrices
Q = np.random.randn(N, d)
K = np.random.randn(N, d)
V = np.random.randn(N, d)
with np.printoptions(precision=2, suppress=True):
    print(f"Q \t\t K \t\t V")
    for q, k, v in zip(Q, K, V):
        print(f"{q} \t {k} \t {v}")


Bc: 3
Br: 2
Q 		 K 		 V
[-0.47  1.09] 	 [-2.03  0.19] 	 [1.77 0.4 ]
[ 0.06 -1.08] 	 [-0.66  0.85] 	 [-1.26  0.92]
[-0.72  0.68] 	 [-0.79 -0.11] 	 [2.12 1.03]
[-0.73  0.22] 	 [0.5  0.87] 	 [-1.52 -0.48]
[ 0.05 -0.65] 	 [-1.2  -0.33] 	 [ 1.27 -0.71]
[2.14 0.63] 	 [-0.47 -0.65] 	 [0.44 0.77]


## Simplified Forward Pass
- direct attention computation using numpy/scipy

In [44]:
import numpy as np
from scipy.special import softmax  # Optional, can use numpy's method or scipy's

def attention(Q, K, V):
    # Step 1: Calculate the dot product of Q and K.T (scores)
    S = np.dot(Q, K.T)  # (N, d) @ (d, N) -> (N, N)
    
    # Step 2: Apply the softmax to the scores for each row
    P = softmax(S, axis=1)  # Softmax along the rows
    
    # Step 3: Multiply the attention weights with the value matrix V
    O = np.dot(P, V)  # (N, N) @ (N, d) -> (N, d)
    
    return O

# Calculate the attention output
_O = attention(Q, K, V)

# Print the result
print("Attention Output O:")
with np.printoptions(precision=2, suppress=True):
    print(_O)


Attention Output O:
[[ 0.23  0.37]
 [ 0.85  0.36]
 [ 0.69  0.38]
 [ 0.91  0.35]
 [ 0.74  0.35]
 [-1.28 -0.29]]


## Forward + Backward Pass

### Simplified Algorithm

In [49]:
with np.printoptions(precision=2, suppress=True):

    # Forward pass
    print("\nBackward Pass Outputs:")
    print("#" * 50)
    S = Q @ K.T
    print(f"  S \t\t\t=\t\t\t Q \t\t@\t\t K.T")
    print_matop(S, Q, K.T)

    m = np.max(S, axis=1, keepdims=True)
    P = np.exp(S - m)
    l = np.sum(np.exp(S - m), axis=1, keepdims=True)
    A = P / l
    print(f"  A \t\t\t=\t\t\t\t exp(S - m) \t\t\t/\t\t sum(exp(S - m))")
    print_matop(A, P, l, separator='\t\t\t')

    O = A @ V
    print(f"  O \t\t=\t\t\t A \t\t\t@\t\t V")
    print_matop(O, A, V)
    
    dO = np.random.randn(N, d)  # Assume dO comes from the next layer
    print(f"  dO obtained from the next layer in backprop")
    print(dO)

    # Backward pass
    print("\nBackward Pass Outputs:")
    print("#" * 50)
    dV = P.T @ dO
    print(f"  dV \t\t=\t\t P.T \t\t\t@\t\t dO")
    print_matop(dV, P.T, dO)

    dP = dO @ V.T
    print(f"  dP \t\t\t\t=\t\t dO \t\t@\t\t\t V.T")
    print_matop(dP, dO, V.T)

    dS = P * (dP - np.sum(P * dP, axis=1, keepdims=True))
    print(f"  dS \t\t\t\t=\t\t P \t\t\t\t*\t (dP - np.sum(P * dP))")
    print_matop(dS, P, np.sum(P * dP, axis=1, keepdims=True))

    dQ = dS @ K
    print(f"  dQ \t\t=\t\t\t dS \t\t\t@\t K")
    print_matop(dQ, dS, K)

    dK = dS.T @ Q
    print(f"  dK \t\t=\t\t\t dS.T \t\t\t@\t Q")
    print_matop(dK, dS.T, Q)



Backward Pass Outputs:
##################################################
  S 			=			 Q 		@		 K.T
[ 1.16  1.24  0.25  0.70  0.20 -0.49]		[-0.47  1.09]		[-2.03 -0.66 -0.79  0.50 -1.20 -0.47]
[-0.33 -0.96  0.07 -0.90  0.28  0.67]		[ 0.06 -1.08]		[ 0.19  0.85 -0.11  0.87 -0.33 -0.65]
[ 1.58  1.05  0.49  0.23  0.63 -0.10]		[-0.72  0.68]			
[ 1.52  0.67  0.55 -0.18  0.80  0.21]		[-0.73  0.22]			
[-0.21 -0.59  0.04 -0.54  0.16  0.40]		[ 0.05 -0.65]			
[-4.22 -0.88 -1.77  1.63 -2.79 -1.43]		[ 2.14  0.63]			
  A 			=				 exp(S - m) 			/		 sum(exp(S - m))
[ 0.27  0.29  0.11  0.17  0.10  0.05]			[ 0.92  1.00  0.37  0.59  0.35  0.18]			[ 3.41]
[ 0.12  0.07  0.18  0.07  0.23  0.33]			[ 0.37  0.19  0.55  0.21  0.68  1.00]			[ 2.99]
[ 0.36  0.21  0.12  0.09  0.14  0.07]			[ 1.00  0.59  0.34  0.26  0.39  0.19]			[ 2.77]
[ 0.36  0.16  0.14  0.07  0.18  0.10]			[ 1.00  0.43  0.38  0.18  0.49  0.27]			[ 2.75]
[ 0.14  0.10  0.18  0.10  0.21  0.26]			[ 0.54  0.37  0.69  0.39  0.79  1.00]			[ 3.78]
[ 0.00

### Detailed Algorithm

In [63]:
tau = 1 / np.sqrt(d)
p_drop = 0.1


with np.printoptions(precision=2, suppress=True):

    # Forward pass
    print("\nForward Pass Outputs:")
    print("#" * 50)
    S = tau * Q @ K.T
    print(f"\tS \t\t\t\t=\t\t Q \t@\t\t K.T \t\t\t\t*\t tau")
    print_matop(S, Q, K.T, np.array([tau]))

    # Apply causal mask directly to S
    S_cm = np.where(np.tril(np.ones_like(S), k=0) == 1, S, -np.inf)
    print(f"\tmasked S")
    print(S_cm)

    m = np.max(S_cm, axis=1, keepdims=True)
    S_ = S_cm - m
    P = np.exp(S_)
    print(f"softmax numerator P:")
    print(f"\tP \t\t\t\t<=exp\t\t\t (S - m):")
    print_matop(P, S_, separator='\t\t')

    l = np.sum(np.exp(P), axis=1, keepdims=True)

    # Apply dropout to P
    drop_mask = ((np.random.rand(*P.shape) > p_drop) / (1 - p_drop))
    P_dm = P * drop_mask 
    print(f"P after drop_out:")
    print(f"\tP_dm \t\t\t<=drop_out\t\t\t P \t\t\tby\t\t\t drop_mask:")
    print_matop(P_dm, P, drop_mask, separator='\t\t')

    A = P_dm / l
    print(f"Attention Score A:")
    print(f"\t\tA \t\t\t=\t\t\t\t P_dm \t\t\t/\t sum(exp(S_cm - m))")
    print_matop(A, P_dm, l, separator='\t\t\t')

    O = A @ V
    print(f"output matrix O:")
    print(f"\tO \t=\t\t\t A \t\t\t@\t\t V")
    print_matop(O, A, V)

    # Backward pass
    print("\nBackward Pass Outputs:")    
    print("#" * 50)
    dO = np.random.randn(N, d)  # Assume dO comes from the next layer
    print(f"\tdO obtained from the next layer in backprop")
    print(dO)

    dV = P_dm.T @ dO
    print(f"\tdV \t\t=\t\t P.T \t\t\t@\t\t dO")
    print_matop(dV, P.T, dO)

    dP_dm = dO @ V.T
    print(f"\tdP_dm \t\t\t\t=\t\t dO \t@\t\t\t V.T")
    print_matop(dP_dm, dO, V.T)

    dP = dP_dm * drop_mask

    dS = P * (dP - np.sum(P * dP, axis=1, keepdims=True))
    print(f"\tdS \t\t\t\t=\t\t P \t\t\t*\t (dP - np.sum(P * dP))")
    print_matop(dS, P, np.sum(P * dP, axis=1, keepdims=True))

    dQ = dS @ K * tau
    print(f"\tdQ \t=\t\t\t dS \t\t\t@\t\t K\t*\t tau")
    print_matop(dQ, dS, K, np.array([tau]))

    dK = dS.T @ Q * tau
    print(f"\tdK \t=\t\t\t dS.T \t\t\t@\t\t Q\t*\t tau")
    print_matop(dK, dS.T, Q, np.array([tau]))




Forward Pass Outputs:
##################################################
	S 				=		 Q 	@		 K.T 				*		 tau
[ 0.82  0.88  0.18  0.50  0.14 -0.34]		[-0.47  1.09]		[-2.03 -0.66 -0.79  0.50 -1.20 -0.47]		 0.71
[-0.23 -0.68  0.05 -0.64  0.20  0.48]		[ 0.06 -1.08]		[ 0.19  0.85 -0.11  0.87 -0.33 -0.65]			
[ 1.11  0.74  0.35  0.16  0.45 -0.07]		[-0.72  0.68]						
[ 1.07  0.47  0.39 -0.13  0.57  0.15]		[-0.73  0.22]						
[-0.15 -0.41  0.03 -0.38  0.12  0.29]		[ 0.05 -0.65]						
[-2.99 -0.62 -1.25  1.15 -1.97 -1.01]		[ 2.14  0.63]						
	masked S
[[ 0.82  -inf  -inf  -inf  -inf  -inf]
 [-0.23 -0.68  -inf  -inf  -inf  -inf]
 [ 1.11  0.74  0.35  -inf  -inf  -inf]
 [ 1.07  0.47  0.39 -0.13  -inf  -inf]
 [-0.15 -0.41  0.03 -0.38  0.12  -inf]
 [-2.99 -0.62 -1.25  1.15 -1.97 -1.01]]
softmax numerator P:
	P 				<=exp			 (S - m):
[ 1.00  0.00  0.00  0.00  0.00  0.00]		[ 0.00  -inf  -inf  -inf  -inf  -inf]
[ 1.00  0.64  0.00  0.00  0.00  0.00]		[ 0.00 -0.45  -inf  -inf  -inf  -inf]
[ 1.00  0.69  0.4

In [51]:
def causal_mask(I):
    mask = np.tril(np.ones_like(I), k=0)  # Lower triangular matrix with ones
    O = np.where(mask == 0, -np.inf, I)   # Apply the mask
    return O, mask

def d_causal_mask(dO, mask):
    dI = dO * mask  # Propagate gradient only through unmasked elements (mask == 1)
    return dI

I = np.random.randn(4, 4)

# Forward pass
O, mask = causal_mask(I)
print("Forward Pass Output (O):")
print(O)

# Assume some gradient coming from the next layer
dO = np.random.randn(*O.shape)

# Backward pass
dI = d_causal_mask(dO, mask)
print("\nBackward Pass Output (dI):")
print(dI)
import numpy as np




Forward Pass Output (O):
[[ 0.51934651        -inf        -inf        -inf]
 [ 0.69014399 -0.40122047        -inf        -inf]
 [ 0.0976761  -0.77300978  0.02451017        -inf]
 [ 1.45114361  0.95927083  2.15318246 -0.76734756]]

Backward Pass Output (dI):
[[ 0.87232064  0.          0.         -0.        ]
 [-0.83972184 -0.59939265 -0.         -0.        ]
 [-0.75913266  0.15039379  0.34175598  0.        ]
 [ 0.95042384 -0.57690366 -0.89841467  0.49191917]]


In [47]:
def softmax(I: np.ndarray) -> np.ndarray:
    """
    I: (n, m)
    O: (n, m)
    """
    exp_I = np.exp(I - np.max(I, axis=1, keepdims=True))  # Subtract max for numerical stability
    O = exp_I / np.sum(exp_I, axis=1, keepdims=True)
    return O

def d_softmax(dO: np.ndarray, O: np.ndarray) -> np.ndarray:
    """
    dO: (n, m)
    O: (n, m)
    dI: (n, m)
    """
    n, m = O.shape
    dI = np.zeros_like(dO)

    for i in range(n):
        # Reshape O[i] to a column vector of shape (m, 1)
        O_i = O[i].reshape(-1, 1)
        # Compute the Jacobian matrix of the softmax function for this row, shape (m, m)
        jacobian_i = np.diagflat(O_i) - np.dot(O_i, O_i.T)
        # Compute the gradient for this row, shape (m,)
        dI[i] = np.dot(jacobian_i, dO[i])
    return dI

# Example input matrix I of shape (4, 4)
I = np.random.randn(4, 4)
print("I:")
print(I)
# Forward pass
O = softmax(I)
print("O:")
print(O)

# Assume some gradient coming from the next layer, of shape (4, 4)
dO = np.random.randn(*O.shape)
print("dO:")
print(dO)

# Backward pass
dI = d_softmax(dO, O)
print("dI:")
print(dI)

I:
[[-0.02090159  0.11732738  1.2776649  -0.59157139]
 [ 0.54709738 -0.20219265 -0.2176812   1.09877685]
 [ 0.82541635  0.81350964  1.30547881  0.02100384]
 [ 0.68195297 -0.31026676  0.32416635 -0.13014305]]
O:
[[0.15680308 0.18004733 0.57453284 0.08861676]
 [0.27216028 0.12865072 0.12667346 0.47251554]
 [0.2468106  0.24388932 0.3988892  0.11041088]
 [0.39778803 0.14748103 0.27814225 0.17658869]]
dO:
[[ 0.09699596  0.59515703 -0.81822068  2.09238728]
 [-1.00601738 -1.21418861  1.15811087  0.79166269]
 [ 0.62411982  0.62834551 -0.01224677 -0.89725437]
 [ 0.07580456 -0.67716171  0.97511973 -0.14705738]]
dI:
[[ 0.04065972  0.13637962 -0.37684316  0.19980382]
 [-0.29850213 -0.16788394  0.13520368  0.3311824 ]
 [ 0.1038543   0.10365567 -0.085993   -0.12151698]
 [-0.03967315 -0.12575717  0.22239719 -0.05696687]]
