## Fast GMM GPU Evaluation

In [47]:
import os
import time
import math

import torch
from torch import jit

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

To run this notebook make sure to run a 2.x Pytorch version.

In [48]:
print(torch.__version__)

2.6.0+cu126


In [49]:
@jit.script
def jit_log_prob(
    x: torch.Tensor,
    means: torch.Tensor,
    precisions_cholesky: torch.Tensor,
) -> torch.Tensor:

    log_prob = x.new_empty((x.size(0), means.size(0)))
    for k, (mu, prec_chol) in enumerate(zip(means, precisions_cholesky)):
        inner = x.matmul(prec_chol) - mu.matmul(prec_chol)
        log_prob[:, k] = inner.square().sum(1)

    num_features = x.size(1)
    logdet = precisions_cholesky.diagonal(dim1=-2, dim2=-1).log().sum(-1)
    constant = math.log(2 * math.pi) * num_features
    return logdet - 0.5 * (constant + log_prob)

In [50]:
# create random features
feature = torch.randn(640000, 32)

In [51]:
# load real gmm
gmm = torch.load('train_gmm_scale_3.pt', weights_only=False)
loc = gmm.loc
covariance_matrix = gmm.covariance_matrix

In [52]:
# prepare means and precisions_cholesky
target = torch.eye(32, dtype=covariance_matrix.dtype, device=covariance_matrix.device)
target = target.unsqueeze(0).expand(17, -1, -1)
cholesky_decomp = torch.linalg.cholesky(covariance_matrix)

# two ways to compute precisions_cholesky, torch.linalg.solve_triangular is not supported in older pytorch versions
precisions_cholesky = torch.linalg.solve_triangular(cholesky_decomp, target, upper=False).transpose(-2, -1)
precisions_cholesky2 = torch.inverse(cholesky_decomp).matmul(target.transpose(-2, -1)).transpose(-2, -1)

In [53]:
torch.abs(precisions_cholesky2 - precisions_cholesky).max().item()

2.9206275939941406e-06

### GPU GMM

In [54]:
feature = feature.cuda()
loc = loc.cuda()
precisions_cholesky = precisions_cholesky.cuda()

# Warm-up iterations
for _ in range(10):
    _ = jit_log_prob(feature, loc, precisions_cholesky)

# Benchmarking
times = []
for _ in range(100):
    start_time = time.time()
    _ = jit_log_prob(feature, loc, precisions_cholesky)
    times.append(time.time() - start_time)

# Calculate and print average time and standard deviation
avg_time = (sum(times) / len(times)) * 1000  # Convert to milliseconds
std_time = ((sum((t - avg_time / 1000) ** 2 for t in times) / len(times)) ** 0.5) * 1000  # Convert to milliseconds
print(f'Average Time: {avg_time:.6f} ms')
print(f'Standard Deviation: {std_time:.6f} ms')

Average Time: 9.487302 ms
Standard Deviation: 1.224776 ms


### CPU GMM

In [55]:
# Warm-up iterations
feature = feature.cpu()
for _ in range(10):
    _ = gmm.log_prob(feature[:, None, :])

# Benchmarking
times = []
for _ in range(100):
    start_time = time.time()
    _ = gmm.log_prob(feature[:, None, :])
    times.append(time.time() - start_time)

# Calculate and print average time and standard deviation
avg_time = (sum(times) / len(times)) * 1000  # Convert to milliseconds
std_time = ((sum((t - avg_time / 1000) ** 2 for t in times) / len(times)) ** 0.5) * 1000  # Convert to milliseconds
print(f'Average Time: {avg_time:.6f} ms')
print(f'Standard Deviation: {std_time:.6f} ms')

Average Time: 938.410990 ms
Standard Deviation: 175.627048 ms


### Numerical Difference

In [56]:
# Compute log probabilities using gmm.log_prob
gmm_log_prob = gmm.log_prob(feature[:, None, :].cpu())

# Compute log probabilities using jit_log_prob
jit_log_prob_output = jit_log_prob(feature.cuda(), loc, precisions_cholesky).cpu()

# Calculate the maximum absolute difference
max_diff = torch.abs(gmm_log_prob - jit_log_prob_output).max().item()
print(f'Maximum Absolute Difference: {max_diff}')

Maximum Absolute Difference: 0.0025634765625
