Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized memory usage and speed for covar type "full" #23

Open
wants to merge 31 commits into
base: master
Choose a base branch
from

Conversation

DeMoriarty
Copy link

Improved speed and memory usage with following optimizations:

  1. the (N, K, 1, D) * (1, K, D, D) matmul at line275 is replaced with an equivalent matmul (K, N, D) * (K, D, D). (N, K, 1, D) * (1, K, D, D) will be interpreted by cublas as batched matrix vector product, while (K, N, D) * (K, D, D) is batched matrix matrix product, which is more efficient on GPUs.

  2. in 2 consecutive iterations of fit, _estimate_log_prob was being called twice with the same input, in _e_step and __score. now weighted_log_probs is only computed once in __score of previous iteration, then cached to be reused at _e_step of next iteration.

  3. at line342 , mu was originally obtained by element wise multiplication & summation, which is now simplified to a matmul.

  4. at line346, the batched vector outer product followed by summation is rewritten as a single batched matmul, which is more efficient on GPUs.

  5. computations in _m_step and _estimate_log_prob is splitted into smaller "chunks" of computations in order to prevent OOM as much as possible.

  6. added option to choose the dtype of the covariance matrix. Use torch.linalg.eigvals to compute log_det if covariance_data_type = torch.float, otherwise use cholesky decomp.

  7. replaced some of the tensor-scalar or tensor-tensor additions/multiplications with their inplace counterparts to reduce unnecessary memory allocation.

benchmark results

remaining issues:

  1. when covariance_data_type = "float", and both n_components and n_features are large, covar contains NaN.

Copy link
Owner

@ldeecke ldeecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work, and thank you for putting so much thought into this! Impressive speed-ups! 👍

Left a couple of comments (apologies for the delay!). In particular, curious to hear your ideas on whether we should move optimizations for covariance_type=full, as this could benefit readability w.r.t. the underlying EM mechanism.

@@ -0,0 +1,39 @@
# Benchmark
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome results, thanks for sharing these!

Before merging with master, I would suggest removing benchmark.md.

@@ -1,9 +1,11 @@
import torch
import numpy as np
import math
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gmm.py:5 imports from math, so it'd make sense to either replace all occurrences of pi or import ceil alongside it.


from math import pi
from scipy.special import logsumexp
from utils import calculate_matmul, calculate_matmul_n_times
from utils import calculate_matmul, calculate_matmul_n_times, find_optimal_splits
from tqdm import tqdm
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend removing this to keep the repository light on dependencies — users that require this functionality can always add it.

return check_available_ram(device) >= size


def find_optimal_splits(n, get_required_memory, device="cpu", safe_mode=True):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe_mode doesn't seem to get passed on to will_it_fit.

@@ -188,7 +203,8 @@ def predict(self, x, probs=False):
"""
x = self.check_size(x)

weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)
weighted_log_prob = self._estimate_log_prob(x)
weighted_log_prob.add_(torch.log(self.pi))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While carrying this out in-place preserves memory, spreading this across two lines here and in 369 and 466 decreases readability somewhat. Alternatively, I reckon this could be moved into _estimate_log_prob.


log_det = self._calculate_log_det(precision) #[K, 1]

x_mu_T_precision_x_mu = torch.empty(N, K, 1, device=x.device, dtype=x.dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless there are reservations/concerns, I would consider moving this into its own utility function, in the interest of preserving readability of the code (happy to take care of this once it has been merged).

eps = (torch.eye(self.n_features) * self.eps).to(x.device)
var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0,
keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps
var = torch.empty(1, K, D, D, device=x.device, dtype=resp.dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! 👍

Same thought as before however, given the additional complexity that's introduced here, it might make sense to define these optimizations in some other place.

covariance_type: str
eps: float
init_params: str
covariance_data_type: str or torch.dtype
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since mu is getting matched against this type, might as well go ahead and introduce this as dtype altogether, right?

@@ -15,30 +17,31 @@ class GaussianMixture(torch.nn.Module):
probabilities are shaped (n, k, 1) if they relate to an individual sample,
or (1, k, 1) if they assign membership probabilities to one of the mixture components.
"""
def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None, var_init=None):
def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None, var_init=None, covariance_data_type="double"):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reservations against going with default "float" as default type (matches the torch.Tensor default)?

log_2pi = d * np.log(2. * pi)

log_det = self._calculate_log_det(precision)
x = x.to(var.dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since self.covariance_data_type has been allocated, maybe use that instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants