<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/iSnkhorn_Knopp_algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. PyTorch Execution (Dynamic Logic)

In PyTorch, the execution follows a Dynamic Computational Graph. This means that every time you run the forward pass, the logic is calculated step-by-step.

In [1]:
import torch

# 1. SETUP: Create a "chaotic" weight matrix (simulating a learning step)
raw_weights = torch.randn(8, 8)  # A small 8x8 mixing matrix

# 2. DEFINE THE EXECUTION: Log-space Sinkhorn
def execute_sinkhorn(A, iterations=20):
    # Move to log space for stability (preventing underflow)
    log_A = A
    for i in range(iterations):
        # Row Normalization (Subtracting the log-sum-exp)
        log_A = log_A - torch.logsumexp(log_A, dim=-1, keepdim=True)
        # Column Normalization
        log_A = log_A - torch.logsumexp(log_A, dim=-2, keepdim=True)
    return torch.exp(log_A)

# 3. RUN: This is where the 1967 math happens
stable_matrix = execute_sinkhorn(raw_weights)

# 4. VERIFY: The "law" is now enforced
print(f"Row sums: {stable_matrix.sum(dim=-1)}") # All will be 1.0
print(f"Col sums: {stable_matrix.sum(dim=-2)}") # All will be 1.0

Row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
Col sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


## 2. JAX Execution (Compiled & Parallel)

In [4]:
import jax
import jax.numpy as jnp

# 1. SETUP: Initialize with a PRNG Key (JAX requirement)
key = jax.random.PRNGKey(2026)
raw_weights = jax.random.normal(key, (8, 8))

# 2. DEFINE THE BODY: One single balanced pull
def sinkhorn_step(x, _):
    x = x - jax.scipy.special.logsumexp(x, axis=-1, keepdims=True)
    x = x - jax.scipy.special.logsumexp(x, axis=-2, keepdims=True)
    return x, None

# 3. COMPILE: Turn the logic into a fast GPU program
@jax.jit
def fast_sinkhorn_execution(A, n_iter=20):
    # Using lax.scan is faster than a Python loop in JAX
    final_log_A, _ = jax.lax.scan(sinkhorn_step, A, None, length=n_iter)
    return jnp.exp(final_log_A)

# 4. EXECUTE: The first call is "warmup" (compilation), the rest are instant
stable_matrix = fast_sinkhorn_execution(raw_weights)
print(f"Row sums: {stable_matrix.sum(axis=-1)}")

Row sums: [1.         0.99999994 1.         1.         1.         1.
 0.99999994 0.99999994]


##