<a href="https://colab.research.google.com/github/casanovaalonso/TritonTutorials/blob/main/05_triton_Johnson_Lindenstrauss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install torch
!pip install triton

Collecting triton
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.1.0


In [3]:
import torch

import triton
import triton.language as tl
from triton.runtime import driver

In [7]:
DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}")

## Sparse Johnson-Lindenstrauss Transform (SJLDT)

The **Sparse Johnson-Lindenstrauss Transform (SJLDT)** is a technique used to reduce the dimensionality of high-dimensional data while approximately preserving pairwise distances. It works by multiplying the data matrix with a sparse random matrix, which makes it more memory-efficient and faster compared to the dense version of the **Johnson-Lindenstrauss Transform (JLT)**.

### Key Steps:

1. **Original Data**:
   Let $ X = \{x_1, x_2, \dots, x_n\} $ be a set of $ n $ data points, where each point $ x_i \in \mathbb{R}^d $.

2. **Target Dimension**:
   We want to reduce the data to a lower dimension $ k $, where $ k \ll d $.

3. **Sparse Random Matrix**:
   Construct a sparse random matrix $ \mathbf{A} \in \mathbb{R}^{k \times d} $ to project the data into a lower-dimensional space. The entries of $ \mathbf{A} $ are mostly zeros, and the non-zero entries are randomly chosen from a distribution (e.g., Gaussian or $ \pm 1 $):

   $$
   A_{ij} \sim \text{Bernoulli}(p) \cdot \mathcal{N}(0, \sigma^2)
   $$

4. **Projection**:
   The projection of each data point $ x_i $ is given by:

   $$
   \hat{x}_i = \mathbf{A} x_i
   $$

5. **Distance Preservation**:
   The SJLDT approximately preserves the pairwise distances between points:

   $$
   \| \mathbf{A} x_i - \mathbf{A} x_j \|^2 \approx \| x_i - x_j \|^2
   $$

   This approximation holds with high probability.

6. **Result**:
   After projection, the data $ X $ is now represented in a lower-dimensional space $ \hat{X} = \{ \hat{x}_1, \hat{x}_2, \dots, \hat{x}_n \} \in \mathbb{R}^k $.



In [None]:
@triton.jit
def sparse_jldt_kernel(
    input_data_ptr,
    sparse_matrix_indices_ptr,
    sparse_matrix_values_ptr,
    output_data_ptr,
    num_input_elements,
    output_dim,
    input_dim,
    block_size: tl.constexpr,
):
    program_id = tl.program_id(axis=0)
    block_start_index = program_id * block_size
    offsets = block_start_index + tl.arange(0, block_size)
    mask = offsets < num_input_elements
    input_data = tl.load(input_data_ptr + offsets, mask=mask)
    output_data = tl.zeros((output_dim,), dtype=tl.float32)
    for i in range(output_dim):
        for j in range(input_dim):
            idx = tl.load(sparse_matrix_indices_ptr + i * input_dim + j)
            value = tl.load(sparse_matrix_values_ptr + i * input_dim + j)
            if idx != 0:
                output_data[i] += value * input_data[j]
    tl.store(output_data_ptr + offsets, output_data, mask=mask)


def sparse_jldt(input_data, sparse_matrix_indices, sparse_matrix_values, output_dim):
    num_input_elements, input_dim = input_data.shape
    assert sparse_matrix_indices.shape == (output_dim, input_dim), \
        f"Shape mismatch: expected (output_dim, input_dim) for indices, got {sparse_matrix_indices.shape}"
    assert sparse_matrix_values.shape == (output_dim, input_dim), \
        f"Shape mismatch: expected (output_dim, input_dim) for values, got {sparse_matrix_values.shape}"
    output_data = torch.empty((num_input_elements, output_dim), device=input_data.device, dtype=torch.float32)
    grid = lambda meta: (triton.cdiv(num_input_elements, META['BLOCK_SIZE']),)
    sparse_jldt_kernel[grid](
        input_data, sparse_matrix_indices, sparse_matrix_values, output_data,
        num_input_elements, output_dim, input_dim, BLOCK_SIZE=128
    )

    return output_data

In [None]:
num_input_elements, input_dim, output_dim = 1024, 2048, 512  # Example dimensions
input_data = torch.randn(num_input_elements, input_dim, device="cuda")
sparse_matrix_values = torch.randn(output_dim, input_dim, device="cuda") * 0.001  # Small values for sparsity
sparse_matrix_indices = (torch.rand(output_dim, input_dim, device="cuda") > 0.8).int()  # 80% sparsity
projected_data = sparse_jldt(input_data, sparse_matrix_indices, sparse_matrix_values, output_dim)

print(f"Original Input Data shape: {input_data.shape}")
print(f"Projected Data shape: {projected_data.shape}")

