Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf segmentation filter #1067

Merged
merged 4 commits into from
Apr 10, 2024
Merged

Conversation

gogetron
Copy link
Contributor

Summary

This PR partially addresses #862

🎯 Purpose: Improve performance of find_label_issues in segmentation/filter.py file.

[ ✏️ Write your summary here. ]
After profiling, it seems that the loops was the slowest part. In addition, the inference in the get_unique_classes is very slow when multi_label is False because of the isinstance calls. The loops were converted to numpy operations. At the point we call that function we already know that multi_label is False, then we can pass this parameter to avoid inference.

For memory I used the memory-profiler library. The code I used for benchmarking is copied below. In addition I sorted the imports in the modified files.

Code Setup

import random

import numpy as np

from cleanlab.segmentation.filter import find_label_issues

SIZE = 1000
DATASET_EXP_SIZE = 2
np.random.seed(0)
%load_ext memory_profiler

# Copied from the test file with minor changes
def generate_three_image_dataset(bad_index):
    good_gt = np.zeros((SIZE, SIZE))
    good_gt[:SIZE // 2, :] = 1.0
    bad_gt = np.ones((SIZE, SIZE))
    bad_gt[:SIZE // 2, :] = 0.0
    good_pr = np.random.random((2, SIZE, SIZE))
    good_pr[0, :SIZE // 2, :] = good_pr[0, :SIZE // 2, :] / 10
    good_pr[1, SIZE // 2:, :] = good_pr[1, SIZE // 2:, :] / 10

    val = np.binary_repr([4, 2, 1][bad_index], width=3)
    error = [int(case) for case in val]

    labels = []
    pred = []
    for case in val:
        if case == "0":
            labels.append(good_gt)
            pred.append(good_pr)
        else:
            labels.append(bad_gt)
            pred.append(good_pr)

    labels = np.array(labels)
    pred_probs = np.array(pred)
    return labels, pred_probs, error

# Create input data
labels, pred_probs, error = generate_three_image_dataset(random.randint(0, 2))
for _ in range(DATASET_EXP_SIZE):
    labels = np.append(labels, labels, axis=0)
    pred_probs = np.append(pred_probs, pred_probs, axis=0)

labels, pred_probs = labels.astype(int), pred_probs.astype(float)

Current version

%%timeit
%memit find_label_issues(labels, pred_probs, n_jobs=1, verbose=False)
# peak memory: 1219.42 mib, increment: 659.78 mib
# peak memory: 1220.45 mib, increment: 677.88 mib
# peak memory: 1220.45 mib, increment: 677.88 mib
# peak memory: 1220.65 mib, increment: 678.08 mib
# peak memory: 1220.65 mib, increment: 677.88 mib
# peak memory: 1220.65 mib, increment: 677.88 mib
# peak memory: 1220.64 mib, increment: 677.88 mib
# peak memory: 1220.65 mib, increment: 677.88 mib
# 8.03 s ± 249 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
%memit find_label_issues(labels, pred_probs, downsample=4, n_jobs=1, verbose=False)
# peak memory: 607.32 MiB, increment: 64.55 MiB
# peak memory: 607.34 MiB, increment: 76.01 MiB
# peak memory: 607.34 MiB, increment: 76.01 MiB
# peak memory: 607.35 MiB, increment: 76.02 MiB
# peak memory: 607.34 MiB, increment: 76.01 MiB
# peak memory: 607.34 MiB, increment: 76.01 MiB
# peak memory: 607.36 MiB, increment: 76.03 MiB
# peak memory: 607.36 MiB, increment: 76.03 MiB
# 5.51 s ± 138 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This PR

%%timeit
%memit find_label_issues(labels, pred_probs, n_jobs=1, verbose=False)
# peak memory: 1240.52 MiB, increment: 674.35 MiB
# peak memory: 1223.41 MiB, increment: 681.34 MiB
# peak memory: 1236.62 MiB, increment: 694.56 MiB
# peak memory: 1246.82 MiB, increment: 704.75 MiB
# peak memory: 1227.48 MiB, increment: 685.22 MiB
# peak memory: 1230.82 MiB, increment: 688.56 MiB
# peak memory: 1227.61 MiB, increment: 685.36 MiB
# peak memory: 1221.61 MiB, increment: 679.35 MiB
# 3.4 s ± 27.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
%memit find_label_issues(labels, pred_probs, downsample=4, n_jobs=1, verbose=False)
# peak memory: 622.23 MiB, increment: 79.96 MiB
# peak memory: 622.27 MiB, increment: 94.12 MiB
# peak memory: 622.40 MiB, increment: 94.24 MiB
# peak memory: 622.19 MiB, increment: 94.03 MiB
# peak memory: 622.41 MiB, increment: 94.25 MiB
# peak memory: 622.40 MiB, increment: 94.24 MiB
# peak memory: 622.40 MiB, increment: 94.24 MiB
# peak memory: 622.42 MiB, increment: 94.26 MiB
# 704 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Testing

🔍 Testing Done: Existing tests.

References

Reviewer Notes

💡 Include any specific points for the reviewer to consider during their review.

Copy link

codecov bot commented Mar 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.14%. Comparing base (52a1f32) to head (78618d7).
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1067      +/-   ##
==========================================
- Coverage   96.20%   96.14%   -0.07%     
==========================================
  Files          74       74              
  Lines        5849     5861      +12     
  Branches     1044     1042       -2     
==========================================
+ Hits         5627     5635       +8     
- Misses        132      135       +3     
- Partials       90       91       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@elisno elisno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this PR, especially optimizing the runtimes of cleanlab.segmentation.filter.find_label_issues and cleanlab.count. num_label_issues! ⚡

I'm concerned about the increased memory usage reported in your tests. Since we only tested on a small number of images (6, if I understand your code correctly), could we check how the memory usage scales with more images and classes? We need to ensure our changes work well even for larger datasets.

If we can find a way to reduce the memory increase or understand its impact better, I'd be more inclined to merge this PR. Could you also profile the specific parts of the code affected by these changes to get a clearer picture of where the memory usage increases and if there are opportunities for optimization?

Your improvement with the get_unique_classes call is particularly valuable, and I'm keen to get that part integrated soon.

Looking forward to your thoughts and any further optimizations you might suggest.

@@ -1458,7 +1455,7 @@ def get_confident_thresholds(
# this approach is that there will be no standard value returned for missing classes.
labels = labels_to_array(labels)
all_classes = range(pred_probs.shape[1])
unique_classes = get_unique_classes(labels)
unique_classes = get_unique_classes(labels, multi_label=multi_label)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only code change in this file; everything else is just sorting the import statements.

I think this is a smart move 👍

Comment on lines +106 to +110
labels.reshape((num_image, h // factor, factor, w // factor, factor)).mean((4, 2))
)
small_pred_probs = pred_probs.reshape(
(num_image, num_classes, h // factor, factor, w // factor, factor)
).mean((5, 3))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

@gogetron
Copy link
Contributor Author

gogetron commented Mar 28, 2024

Hi, thank you for your review, you are absolutely right regarding the memory consumption. With larger datasets (more images and more classes) the memory usage was about 35% higher with this PR compared to the previous version. The issue was that we were allocating new very large arrays to do the mask operation at once. This gave me a great idea to lower memory further while maintaining the speed improvements.

I have just pushed some changes now to make the operations in batches using the batch_size parameter and the memory consumption was lower (because we were now allocating arrays in batches, instead of storing large arrays in memory) while still running way faster.

I have changed the setup code for the benchmark a little bit to make it more clear now and get a larger input dataset:
Code Setup

import numpy as np

from cleanlab.segmentation.filter import find_label_issues

SIZE = 250
NUM_IMAGES = 1000
NUM_CLASSES = 10
np.random.seed(0)
%load_ext memory_profiler

def generate_image_dataset():
    labels = np.random.randint(NUM_CLASSES, size=(NUM_IMAGES, SIZE, SIZE), dtype=int)
    pred_probs = np.random.random((NUM_IMAGES, NUM_CLASSES, SIZE, SIZE))
    return labels, pred_probs

# Create input data
labels, pred_probs = generate_image_dataset()

Current version

%%timeit
%memit find_label_issues(labels, pred_probs, n_jobs=1, verbose=False)
# peak memory: 10161.57 MiB, increment: 4629.44 MiB
# peak memory: 10162.77 MiB, increment: 4643.84 MiB
# peak memory: 10161.71 MiB, increment: 4642.77 MiB
# peak memory: 10161.88 MiB, increment: 4642.95 MiB
# peak memory: 10162.95 MiB, increment: 4643.82 MiB
# peak memory: 10161.89 MiB, increment: 4642.76 MiB
# peak memory: 10161.89 MiB, increment: 4642.75 MiB
# peak memory: 10162.96 MiB, increment: 4643.82 MiB
# 2min 6s ± 1.06 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
%memit find_label_issues(labels, pred_probs, downsample=5, n_jobs=1, verbose=False)
# peak memory: 6694.65 MiB, increment: 1163.38 MiB
# peak memory: 6710.26 MiB, increment: 1182.33 MiB
# peak memory: 6710.26 MiB, increment: 1182.33 MiB
# peak memory: 6710.26 MiB, increment: 1182.33 MiB
# peak memory: 6710.26 MiB, increment: 1182.33 MiB
# peak memory: 6710.26 MiB, increment: 1182.33 MiB
# peak memory: 6710.26 MiB, increment: 1182.33 MiB
# peak memory: 6710.26 MiB, increment: 1182.33 MiB
# 1min 11s ± 402 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This PR

%%timeit
%memit find_label_issues(labels, pred_probs, n_jobs=1, verbose=False)
# peak memory: 8846.55 MiB, increment: 3326.38 MiB
# peak memory: 8847.80 MiB, increment: 3347.29 MiB
# peak memory: 8848.02 MiB, increment: 3346.53 MiB
# peak memory: 8847.03 MiB, increment: 3345.55 MiB
# peak memory: 8848.19 MiB, increment: 3346.53 MiB
# peak memory: 8847.04 MiB, increment: 3345.38 MiB
# peak memory: 8848.20 MiB, increment: 3346.53 MiB
# peak memory: 8847.04 MiB, increment: 3345.37 MiB
# 32.9 s ± 355 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
%memit find_label_issues(labels, pred_probs, downsample=5, n_jobs=1, verbose=False)
# peak memory: 5899.00 MiB, increment: 397.33 MiB
# peak memory: 5898.74 MiB, increment: 364.73 MiB
# peak memory: 5908.86 MiB, increment: 412.89 MiB
# peak memory: 5898.74 MiB, increment: 402.76 MiB
# peak memory: 5898.76 MiB, increment: 402.78 MiB
# peak memory: 5898.63 MiB, increment: 402.64 MiB
# peak memory: 5898.76 MiB, increment: 402.76 MiB
# peak memory: 5898.63 MiB, increment: 402.63 MiB
# 11 s ± 185 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Copy link
Member

@elisno elisno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM @gogetron!

Had to make sure that the outputs of find_label_issues before and after the PR are identical, which seems to be the case!

This is a great speedup and memory improvement. Nice work!

@elisno
Copy link
Member

elisno commented Apr 10, 2024

Failing CI is unrelated. Should already be addressed on master.

@elisno elisno merged commit 7d1f335 into cleanlab:master Apr 10, 2024
13 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants