-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Made some performance improvements and added type-hints with function…
… definitions (#5) Co-authored-by: Najib Ishaq <najib_ishaq@uri.edu>
- Loading branch information
Showing
9 changed files
with
399 additions
and
226 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,4 +9,5 @@ dist | |
.DS_Store | ||
.pytest_cache | ||
.ipynb* | ||
.coverage | ||
.coverage | ||
plots |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import time | ||
|
||
import numpy | ||
|
||
import distogram | ||
import utils | ||
|
||
|
||
def bench_count_at(): | ||
num_samples = 100 | ||
num_points = 100_000 | ||
values = numpy.random.normal(size=num_points) | ||
|
||
times_dict: utils.TimesDict = {num_points: dict()} | ||
|
||
for n in range(6): | ||
bin_count = 32 * (2 ** n) | ||
h = utils.create_distogram(bin_count, values) | ||
|
||
start = time.time() | ||
[distogram.count_at(h, 0) for _ in range(num_samples)] | ||
time_taken = (time.time() - start) / num_samples | ||
|
||
times_dict[num_points][bin_count] = time_taken | ||
|
||
utils.plot_times_dict( | ||
times_dict, | ||
title='count-at', | ||
) | ||
return | ||
|
||
|
||
if __name__ == '__main__': | ||
bench_count_at() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import time | ||
|
||
import numpy | ||
|
||
import distogram | ||
import utils | ||
|
||
|
||
def bench_histogram(): | ||
num_samples = 100 | ||
num_points = 100_000 | ||
values = numpy.random.normal(size=num_points) | ||
|
||
times_dict: utils.TimesDict = {num_points: dict()} | ||
|
||
for n in range(6): | ||
bin_count = 32 * (2 ** n) | ||
h = utils.create_distogram(bin_count, values) | ||
|
||
start = time.time() | ||
[distogram.histogram(h, ucount=bin_count) for _ in range(num_samples)] | ||
time_taken = (time.time() - start) / num_samples | ||
|
||
times_dict[num_points][bin_count] = time_taken | ||
|
||
utils.plot_times_dict( | ||
times_dict, | ||
title='histogram', | ||
) | ||
return | ||
|
||
|
||
if __name__ == '__main__': | ||
bench_histogram() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import time | ||
from functools import reduce | ||
|
||
import distogram | ||
import utils | ||
|
||
|
||
def bench_merge(): | ||
num_samples = 10 | ||
num_points = 100_000 | ||
values_list = [ | ||
utils.create_values(mean, 0.3, num_points) | ||
for mean in range(num_samples) | ||
] | ||
|
||
times_dict: utils.TimesDict = {num_points: dict()} | ||
|
||
for n in range(6): | ||
bin_count = 32 * (2 ** n) | ||
|
||
histograms = [ | ||
utils.create_distogram(bin_count, values) | ||
for values in values_list | ||
] | ||
start = time.time() | ||
_ = reduce( | ||
lambda res, val: distogram.merge(res, val), | ||
histograms, | ||
distogram.Distogram(bin_count=bin_count) | ||
) | ||
time_taken = (time.time() - start) / num_samples | ||
|
||
times_dict[num_points][bin_count] = time_taken | ||
|
||
utils.plot_times_dict( | ||
times_dict, | ||
title='merge', | ||
) | ||
return | ||
|
||
|
||
if __name__ == '__main__': | ||
bench_merge() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import time | ||
|
||
import numpy | ||
|
||
import utils | ||
|
||
|
||
# This one takes the longest to run | ||
def bench_update(): | ||
num_samples = 1 | ||
num_points_list = [100_000, 250_000, 500_000, 1_000_000, 2_000_000] | ||
|
||
times_dict: utils.TimesDict = { | ||
num_points: dict() | ||
for num_points in num_points_list | ||
} | ||
|
||
for num_points in num_points_list: | ||
values = numpy.random.normal(size=num_points) | ||
|
||
for n in range(6): | ||
bin_count = 32 * (2 ** n) | ||
|
||
start = time.time() | ||
[utils.create_distogram(bin_count, values) for _ in range(num_samples)] | ||
time_taken = (time.time() - start) / num_samples | ||
|
||
times_dict[num_points][bin_count] = time_taken | ||
|
||
utils.plot_times_dict( | ||
times_dict, | ||
title='update', | ||
) | ||
return | ||
|
||
|
||
if __name__ == '__main__': | ||
bench_update() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
from functools import reduce | ||
from typing import Dict | ||
|
||
import numpy | ||
from matplotlib import pyplot | ||
|
||
import distogram | ||
|
||
# A Dictionary to store runtimes. The structure is intended to be: | ||
# { num_points: { bin_count: time_takes } } | ||
TimesDict = Dict[int, Dict[int, float]] | ||
|
||
COL_NAMES = ['num_points', 'bin_count', 'old_time', 'new_time'] | ||
|
||
|
||
def create_values(mean, stddev, num_points) -> numpy.ndarray: | ||
return numpy.random.normal(loc=mean, scale=stddev, size=num_points) | ||
|
||
|
||
def create_distogram(bin_count: int, values: numpy.ndarray): | ||
return reduce( | ||
lambda res, val: distogram.update(res, float(val)), | ||
values.flat, | ||
distogram.Distogram(bin_count) | ||
) | ||
|
||
|
||
def plot_times_dict(times_dict: TimesDict, title: str): | ||
pyplot.close('all') | ||
|
||
pyplot.figure() | ||
|
||
if len(times_dict) == 1: | ||
num_points = next(iter(times_dict.keys())) | ||
x = list(sorted(times_dict[num_points].keys())) | ||
ys = [times_dict[num_points][k] for k in x] | ||
|
||
pyplot.plot(x, ys, label='time taken') | ||
pyplot.title(f'time vs bin-count for {title}') | ||
pyplot.xlabel('bin-count') | ||
|
||
else: | ||
x = list(sorted(times_dict.keys())) | ||
|
||
for bin_count in sorted(times_dict[x[0]].keys()): | ||
y = [times_dict[num_points][bin_count] for num_points in x] | ||
pyplot.plot(x, y, label=f'{bin_count} bins') | ||
|
||
pyplot.title(f'time vs num-points for {title}') | ||
pyplot.xlabel('num-points') | ||
|
||
pyplot.ylabel('time (s)') | ||
pyplot.legend() | ||
|
||
os.makedirs('plots', exist_ok=True) | ||
pyplot.savefig(f'plots/{title}.png', dpi=300) | ||
|
||
return |
Oops, something went wrong.