In [1]:
import argparse
import jax
import jax.numpy as jnp
from jax.experimental import pallas

In [2]:
# Configure JAX to use CPU
jax.config.update('jax_platform_name', 'cpu')

In [5]:
# Create small tensors for testing
BATCH = 1
HEADS = 2
SEQ_LEN = 8
HEAD_DIM = 4

print("Creating random tensors")
K = jax.random.normal(jax.random.key(0), (BATCH, HEADS, SEQ_LEN, HEAD_DIM))
V = jax.random.normal(jax.random.key(1), (BATCH, HEADS, SEQ_LEN, HEAD_DIM))
Q = jax.random.normal(jax.random.key(2), (BATCH, HEADS, SEQ_LEN, HEAD_DIM))
print("Random tensors created")

Creating random tensors
Random tensors created


In [6]:
K[0][0][:]

Array([[ 1.6226422 ,  2.0252647 , -0.43359444, -0.07861735],
       [ 0.1760909 , -0.97208923, -0.49529874,  0.4943786 ],
       [ 0.6643493 , -0.9501635 ,  2.1795304 , -1.9551506 ],
       [ 0.35857072,  0.15779513,  1.2770847 ,  1.5104648 ],
       [ 0.970656  ,  0.59960806,  0.0247007 , -1.9164772 ],
       [-1.8593491 ,  1.728144  ,  0.04719035,  0.814128  ],
       [ 0.13132767,  0.28284705,  1.2435943 ,  0.6902801 ],
       [-0.80073744, -0.74099   , -1.5388287 ,  0.30269185]],      dtype=float32)

In [7]:
Q[0][0][:]

Array([[ 0.36057416,  1.2849895 , -0.73873436,  1.1830745 ],
       [-0.20641916, -0.8333566 ,  0.6233476 , -0.88721675],
       [ 0.6922108 ,  0.8711505 ,  1.4978964 , -1.32336   ],
       [ 1.6417218 , -0.46518597,  0.4113348 ,  0.8994859 ],
       [-0.39429244,  0.3383484 , -0.88536334, -1.8190236 ],
       [-0.38549703, -0.1720128 , -0.7236013 , -1.852459  ],
       [-0.14248204, -0.23113886, -2.6718314 , -0.07999262],
       [-0.9395458 , -0.9943608 ,  2.2809246 ,  0.91642076]],      dtype=float32)

In [8]:
V[0][0][:]

Array([[-0.15443718,  0.08470728, -0.13598049, -0.15503626],
       [ 1.2666674 ,  0.14829758,  2.1415603 ,  1.0026742 ],
       [-0.29033586,  0.3583448 , -0.70792735, -0.24555527],
       [ 0.8855825 ,  0.7861191 ,  0.88892716,  0.54932535],
       [ 0.9658084 , -1.797185  ,  0.86045414, -1.6176059 ],
       [-0.05994489, -1.5054733 , -1.0007281 ,  0.64356536],
       [-1.6915705 ,  0.5178627 ,  0.7885133 , -1.1122868 ],
       [-0.40614718, -1.1358441 ,  1.1137921 , -0.3322445 ]],      dtype=float32)

https://www.youtube.com/watch?v=IvgV6QcsC64

![./einsum_overview.png](einsum_overview.png)
![./einsum_element_wise_multiplication.png](einsum_element_wise_multiplication.png)
![./einsum_rule_1.png](einsum_rule_1.png)
![./einsum_rule_3-4.png](einsum_rule_3-4.png)
![./einsum_summation_along_axis.png](einsum_summation_along_axis.png)


In [9]:
# Matrix a
# create a 3x3 matrix
a = [[1,2,3],
     [4,5,6],
     [7,8,9]]

# Matrix b
b = [[1,2,3],
     [4,5,6],
     [7,8,9]]

print("Method 1: Using nested for loops")
# Initialize result matrix with zeros
result_loops = [[0 for x in range(len(b[0]))] for y in range(len(a))]

# Iterate through rows of a
for i in range(len(a)):
    # Iterate through columns of b
    for j in range(len(b[0])):
        # Iterate through rows of b
        for k in range(len(b)):
            result_loops[i][j] += a[i][k] * b[k][j]

print("Result using for loops:")
for row in result_loops:
    print(row)

print("\nMethod 2: Using numpy einsum")
import numpy as np
a_np = np.array(a)
b_np = np.array(b)
result_einsum = np.einsum('ik,kj->ij', a_np, b_np)

print("Result using einsum:")
print(result_einsum)


Method 1: Using nested for loops
Result using for loops:
[30, 36, 42]
[66, 81, 96]
[102, 126, 150]

Method 2: Using numpy einsum
Result using einsum:
[[ 30  36  42]
 [ 66  81  96]
 [102 126 150]]


In [32]:
def attention_ourselves(_Q, _K, _V):
    print("Computing attention weights")
    _weights_unnormalized = jax.numpy.einsum("bshd,bthd -> bhst", _Q, _K)
    _weights = jax.nn.softmax(_weights_unnormalized)
    output = jax.numpy.einsum("bhst,bshd -> bshd", _weights, _V)
    print("Weights computed successfully")
    breakpoint()
    return output


In [33]:
print("Calling attention function")
result = attention_ourselves(Q, K, V)
att_value = pallas.attention(Q, K, V, segment_ids=None)
print("Function completed")



Calling attention function
Computing attention weights
Weights computed successfully
Function completed


In [4]:
if __name__ == "__main__":
    print("Starting standalone_test.py")
    args = parse_args()
    print(f"Parsed arguments: {args}")

    if args.l5_p1:
        print("About to run lesson5_concept_1")
        lesson5_concept_1()
        print("Finished running lesson5_concept_1")
    else:
        print("No flag provided, running default behavior")

    print("Script completed")

Starting standalone_test.py


usage: ipykernel_launcher.py [-h] [--l5-p1]
ipykernel_launcher.py: error: unrecognized arguments: -f /home/ext_isaackkaranja_google_com/.local/share/jupyter/runtime/kernel-3b825520-81fa-4cda-85af-cea12c5b5658.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
