In [5]:
!uv pip install jax

[2mUsing Python 3.12.11 environment at: /Users/andrew.lehe/Documents/roc/foundation/.venv[0m
[2K[2mResolved [1m6 packages[0m [2min 152ms[0m[0m                                         [0m
[2K[37m⠙[0m [2mPreparing packages...[0m (0/3)                                                   
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/3)--------------[0m[0m     0 B/661.06 KiB          [1A
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/3)--------------[0m[0m     0 B/661.06 KiB          [1A
[2mml-dtypes           [0m [32m[2m------------------------------[0m[0m     0 B/661.06 KiB
[2K[2A[37m⠙[0m [2mPreparing packages...[0m (0/3)--------------[0m[0m     0 B/53.15 MiB           [2A
[2mml-dtypes           [0m [32m[2m------------------------------[0m[0m     0 B/661.06 KiB
[2K[2A[37m⠙[0m [2mPreparing packages...[0m (0/3)--------------[0m[0m     0 B/53.15 MiB           [2A
[2mml-dtypes           [0m [32m[2m------------------------------

In [6]:
import jax.numpy as jnp 

A = jnp.array([[1., 2., 3.],
               [2., 4., 2.],
               [3., 2., 1.]])
b = jnp.array([14., 16., 10.])
x = jnp.linalg.solve(A, b)
x

Array([1., 2., 3.], dtype=float32)

In [13]:
b * A == A * b

Array([[ True,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True]], dtype=bool)

In [14]:
import jax
import jax.numpy as jnp
from jax import random

print("=" * 60)
print("LINEAR ALGEBRA BASES - JAX IMPLEMENTATION")
print("=" * 60)

# 1. Standard Basis for R^3
print("\n1. STANDARD BASIS FOR R³")
print("-" * 40)
e1 = jnp.array([1.0, 0.0, 0.0])
e2 = jnp.array([0.0, 1.0, 0.0])
e3 = jnp.array([0.0, 0.0, 1.0])

standard_basis = jnp.stack([e1, e2, e3])
print("Standard basis vectors:")
print(f"e₁ = {e1}")
print(f"e₂ = {e2}")
print(f"e₃ = {e3}")

# Express a vector in standard basis
v = jnp.array([3.0, -2.0, 5.0])
print(f"\nVector v = {v}")
print(f"v = {v[0]}·e₁ + {v[1]}·e₂ + {v[2]}·e₃")

# 2. Change of Basis
print("\n\n2. CHANGE OF BASIS")
print("-" * 40)

# Define an alternative basis for R³
b1 = jnp.array([1.0, 1.0, 0.0])
b2 = jnp.array([0.0, 1.0, 1.0])
b3 = jnp.array([1.0, 0.0, 1.0])

# Matrix where columns are basis vectors
B = jnp.column_stack([b1, b2, b3])
print("Alternative basis B:")
print(f"b₁ = {b1}")
print(f"b₂ = {b2}")
print(f"b₃ = {b3}")
print(f"\nBasis matrix B:\n{B}")

# Convert vector from standard basis to new basis
# [v]_B = B^(-1) @ [v]_standard
B_inv = jnp.linalg.inv(B)
v_in_new_basis = B_inv @ v

print(f"\nVector v in standard basis: {v}")
print(f"Vector v in basis B: {v_in_new_basis}")

# Verify: convert back to standard basis
v_reconstructed = B @ v_in_new_basis
print(f"Reconstructed v: {v_reconstructed}")
print(f"Error: {jnp.linalg.norm(v - v_reconstructed):.2e}")

# 3. Checking Linear Independence
print("\n\n3. CHECKING LINEAR INDEPENDENCE")
print("-" * 40)

def is_linearly_independent(vectors):
    """Check if vectors are linearly independent using rank."""
    M = jnp.column_stack(vectors)
    rank = jnp.linalg.matrix_rank(M)
    return rank == len(vectors)

# Test with independent vectors
indep_vecs = [b1, b2, b3]
print(f"Vectors: b₁, b₂, b₃")
print(f"Linearly independent: {is_linearly_independent(indep_vecs)}")

# Test with dependent vectors
dep_vecs = [
    jnp.array([1.0, 2.0, 3.0]),
    jnp.array([2.0, 4.0, 6.0]),  # This is 2 times the first
    jnp.array([0.0, 1.0, 0.0])
]
print(f"\nVectors: [1,2,3], [2,4,6], [0,1,0]")
print(f"Linearly independent: {is_linearly_independent(dep_vecs)}")

# 4. Gram-Schmidt Orthogonalization
print("\n\n4. GRAM-SCHMIDT ORTHOGONALIZATION")
print("-" * 40)
print("Creating orthonormal basis from arbitrary basis")

def gram_schmidt(vectors):
    """Apply Gram-Schmidt to get orthonormal basis."""
    n = len(vectors)
    ortho = []
    
    for i in range(n):
        # Start with current vector
        v = vectors[i]
        
        # Subtract projections onto previous orthonormal vectors
        for u in ortho:
            v = v - jnp.dot(v, u) * u
        
        # Normalize
        v = v / jnp.linalg.norm(v)
        ortho.append(v)
    
    return ortho

# Apply to our alternative basis
ortho_basis = gram_schmidt([b1, b2, b3])
print("\nOriginal basis (not orthonormal):")
for i, b in enumerate([b1, b2, b3], 1):
    print(f"b{i} = {b}, norm = {jnp.linalg.norm(b):.3f}")

print("\nOrthonormal basis after Gram-Schmidt:")
for i, u in enumerate(ortho_basis, 1):
    print(f"u{i} = {u}, norm = {jnp.linalg.norm(u):.3f}")

# Check orthogonality
print("\nDot products (should be ~0 for i≠j):")
for i in range(3):
    for j in range(i+1, 3):
        dot = jnp.dot(ortho_basis[i], ortho_basis[j])
        print(f"u{i+1} · u{j+1} = {dot:.2e}")

# 5. Coordinate Transformation Example
print("\n\n5. PRACTICAL EXAMPLE: ROTATION")
print("-" * 40)

# Rotation matrix (45 degrees around z-axis)
theta = jnp.pi / 4
rotation_basis = jnp.array([
    [jnp.cos(theta), -jnp.sin(theta), 0],
    [jnp.sin(theta), jnp.cos(theta), 0],
    [0, 0, 1]
])

print(f"45° rotation around z-axis basis:")
print(rotation_basis)

# A point in standard coordinates
point = jnp.array([1.0, 0.0, 0.0])
print(f"\nOriginal point: {point}")

# Transform to rotated basis
rotated_point = rotation_basis @ point
print(f"After rotation: {rotated_point}")
print(f"Expected: [{jnp.cos(theta):.3f}, {jnp.sin(theta):.3f}, 0]")

# 6. Finding Basis for Column Space
print("\n\n6. FINDING BASIS FOR COLUMN SPACE")
print("-" * 40)

# Create a matrix with dependent columns
A = jnp.array([
    [1.0, 2.0, 3.0, 4.0],
    [2.0, 4.0, 6.0, 8.0],
    [0.0, 1.0, 2.0, 3.0]
])

print("Matrix A:")
print(A)
print(f"\nRank of A: {jnp.linalg.matrix_rank(A)}")

# QR decomposition gives us an orthonormal basis
Q, R = jnp.linalg.qr(A)
rank = jnp.linalg.matrix_rank(A)

print(f"\nOrthonormal basis for column space (first {rank} columns of Q):")
print(Q[:, :rank])

print("\n" + "=" * 60)
print("DEMONSTRATION COMPLETE")
print("=" * 60)

LINEAR ALGEBRA BASES - JAX IMPLEMENTATION

1. STANDARD BASIS FOR R³
----------------------------------------
Standard basis vectors:
e₁ = [1. 0. 0.]
e₂ = [0. 1. 0.]
e₃ = [0. 0. 1.]

Vector v = [ 3. -2.  5.]
v = 3.0·e₁ + -2.0·e₂ + 5.0·e₃


2. CHANGE OF BASIS
----------------------------------------
Alternative basis B:
b₁ = [1. 1. 0.]
b₂ = [0. 1. 1.]
b₃ = [1. 0. 1.]

Basis matrix B:
[[1. 0. 1.]
 [1. 1. 0.]
 [0. 1. 1.]]

Vector v in standard basis: [ 3. -2.  5.]
Vector v in basis B: [-2.  0.  5.]
Reconstructed v: [ 3. -2.  5.]
Error: 0.00e+00


3. CHECKING LINEAR INDEPENDENCE
----------------------------------------
Vectors: b₁, b₂, b₃
Linearly independent: True

Vectors: [1,2,3], [2,4,6], [0,1,0]
Linearly independent: False


4. GRAM-SCHMIDT ORTHOGONALIZATION
----------------------------------------
Creating orthonormal basis from arbitrary basis

Original basis (not orthonormal):
b1 = [1. 1. 0.], norm = 1.414
b2 = [0. 1. 1.], norm = 1.414
b3 = [1. 0. 1.], norm = 1.414

Orthonormal basi

In [None]:
import jax
import jax.numpy as jnp
from jax import random

print("=" * 60)
print("LINEAR ALGEBRA BASES - JAX IMPLEMENTATION")
print("=" * 60)

# 1. Standard Basis for R^3
print("\n1. STANDARD BASIS FOR R³")
print("-" * 40)
e1 = jnp.array([1.0, 0.0, 0.0])
e2 = jnp.array([0.0, 1.0, 0.0])
e3 = jnp.array([0.0, 0.0, 1.0])

standard_basis = jnp.stack([e1, e2, e3])
print("Standard basis vectors:")
print(f"e₁ = {e1}")
print(f"e₂ = {e2}")
print(f"e₃ = {e3}")

# Express a vector in standard basis
v = jnp.array([3.0, -2.0, 5.0])
print(f"\nVector v = {v}")
print(f"v = {v[0]}·e₁ + {v[1]}·e₂ + {v[2]}·e₃")

# 2. Change of Basis
print("\n\n2. CHANGE OF BASIS")
print("-" * 40)

# Define an alternative basis for R³
b1 = jnp.array([1.0, 1.0, 0.0])
b2 = jnp.array([0.0, 1.0, 1.0])
b3 = jnp.array([1.0, 0.0, 1.0])

# Matrix where columns are basis vectors
B = jnp.column_stack([b1, b2, b3])
print("Alternative basis B:")
print(f"b₁ = {b1}")
print(f"b₂ = {b2}")
print(f"b₃ = {b3}")
print(f"\nBasis matrix B:\n{B}")

# Convert vector from standard basis to new basis
# [v]_B = B^(-1) @ [v]_standard
B_inv = jnp.linalg.inv(B)
v_in_new_basis = B_inv @ v

print(f"\nVector v in standard basis: {v}")
print(f"Vector v in basis B: {v_in_new_basis}")

# Verify: convert back to standard basis
v_reconstructed = B @ v_in_new_basis
print(f"Reconstructed v: {v_reconstructed}")
print(f"Error: {jnp.linalg.norm(v - v_reconstructed):.2e}")

# 3. Checking Linear Independence
print("\n\n3. CHECKING LINEAR INDEPENDENCE")
print("-" * 40)

def is_linearly_independent(vectors):
    """Check if vectors are linearly independent using rank."""
    M = jnp.column_stack(vectors)
    rank = jnp.linalg.matrix_rank(M)
    return rank == len(vectors)

# Test with independent vectors
indep_vecs = [b1, b2, b3]
print(f"Vectors: b₁, b₂, b₃")
print(f"Linearly independent: {is_linearly_independent(indep_vecs)}")

# Test with dependent vectors
dep_vecs = [
    jnp.array([1.0, 2.0, 3.0]),
    jnp.array([2.0, 4.0, 6.0]),  # This is 2 times the first
    jnp.array([0.0, 1.0, 0.0])
]
print(f"\nVectors: [1,2,3], [2,4,6], [0,1,0]")
print(f"Linearly independent: {is_linearly_independent(dep_vecs)}")

# 4. Gram-Schmidt Orthogonalization
print("\n\n4. GRAM-SCHMIDT ORTHOGONALIZATION")
print("-" * 40)
print("Creating orthonormal basis from arbitrary basis")

def gram_schmidt(vectors):
    """Apply Gram-Schmidt to get orthonormal basis."""
    n = len(vectors)
    ortho = []
    
    for i in range(n):
        # Start with current vector
        v = vectors[i]
        
        # Subtract projections onto previous orthonormal vectors
        for u in ortho:
            v = v - jnp.dot(v, u) * u
        
        # Normalize
        v = v / jnp.linalg.norm(v)
        ortho.append(v)
    
    return ortho

# Apply to our alternative basis
ortho_basis = gram_schmidt([b1, b2, b3])
print("\nOriginal basis (not orthonormal):")
for i, b in enumerate([b1, b2, b3], 1):
    print(f"b{i} = {b}, norm = {jnp.linalg.norm(b):.3f}")

print("\nOrthonormal basis after Gram-Schmidt:")
for i, u in enumerate(ortho_basis, 1):
    print(f"u{i} = {u}, norm = {jnp.linalg.norm(u):.3f}")

# Check orthogonality
print("\nDot products (should be ~0 for i≠j):")
for i in range(3):
    for j in range(i+1, 3):
        dot = jnp.dot(ortho_basis[i], ortho_basis[j])
        print(f"u{i+1} · u{j+1} = {dot:.2e}")

# 5. Coordinate Transformation Example
print("\n\n5. PRACTICAL EXAMPLE: ROTATION")
print("-" * 40)

# Rotation matrix (45 degrees around z-axis)
theta = jnp.pi / 4
rotation_basis = jnp.array([
    [jnp.cos(theta), -jnp.sin(theta), 0],
    [jnp.sin(theta), jnp.cos(theta), 0],
    [0, 0, 1]
])

print(f"45° rotation around z-axis basis:")
print(rotation_basis)

# A point in standard coordinates
point = jnp.array([1.0, 0.0, 0.0])
print(f"\nOriginal point: {point}")

# Transform to rotated basis
rotated_point = rotation_basis @ point
print(f"After rotation: {rotated_point}")
print(f"Expected: [{jnp.cos(theta):.3f}, {jnp.sin(theta):.3f}, 0]")

# 6. Finding Basis for Column Space
print("\n\n6. FINDING BASIS FOR COLUMN SPACE")
print("-" * 40)

# Create a matrix with dependent columns
A = jnp.array([
    [1.0, 2.0, 3.0, 4.0],
    [2.0, 4.0, 6.0, 8.0],
    [0.0, 1.0, 2.0, 3.0]
])

print("Matrix A:")
print(A)
print(f"\nRank of A: {jnp.linalg.matrix_rank(A)}")

# QR decomposition gives us an orthonormal basis
Q, R = jnp.linalg.qr(A)
rank = jnp.linalg.matrix_rank(A)

print(f"\nOrthonormal basis for column space (first {rank} columns of Q):")
print(Q[:, :rank])

print("\n" + "=" * 60)
print("DEMONSTRATION COMPLETE")
print("=" * 60)