Skip to content

Commit

Permalink
feat: use dask to compute metrics' dataframes in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
martibosch committed Apr 14, 2024
1 parent 741520f commit 2dafaaf
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
48 changes: 35 additions & 13 deletions pylandstats/multilandscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import abc
import functools

import dask
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from dask import diagnostics

from . import settings
from .landscape import Landscape
Expand Down Expand Up @@ -179,17 +181,27 @@ def compute_class_metrics_df( # noqa: D102
class_metrics_df.index.names = "class_val", self.attribute_name
class_metrics_df.columns.name = "metric"

for attribute_value, landscape in zip(attribute_values, self.landscapes):
# get the class metrics DataFrame for the landscape that corresponds to this
# attribute value
df = landscape.compute_class_metrics_df(
tasks = [
dask.delayed(landscape.compute_class_metrics_df)(
metrics=metrics, metrics_kws=metrics_kws
)
for landscape in self.landscapes
]
with diagnostics.ProgressBar():
dfs = dask.compute(*tasks)

for attribute_value, df in zip(attribute_values, dfs):
# get the class metrics DataFrame for the landscape that corresponds to this
# attribute value
# df = landscape.compute_class_metrics_df(
# metrics=metrics, metrics_kws=metrics_kws
# )
# filter so we only check the classes considered in this `MultiLandscape`
# instance
df = df.loc[df.index.intersection(classes)]
# df = df.loc[df.index.intersection(classes)]
# put every row of the filtered DataFrame of this particular attribute value
for class_val, row in df.iterrows():
# for class_val, row in df.iterrows():
for class_val, row in df.loc[df.index.intersection(classes)].iterrows():
class_metrics_df.loc[(class_val, attribute_value), columns] = row

class_metrics_df = class_metrics_df.apply(pd.to_numeric)
Expand Down Expand Up @@ -227,14 +239,24 @@ def compute_landscape_metrics_df( # noqa: D102
landscape_metrics_df.index.name = self.attribute_name
landscape_metrics_df.columns.name = "metric"

for attribute_value, landscape in zip(attribute_values, self.landscapes):
landscape_metrics_df.loc[
attribute_value, columns
] = landscape.compute_landscape_metrics_df(
tasks = [
dask.delayed(landscape.compute_landscape_metrics_df)(
metrics=metrics, metrics_kws=metrics_kws
).iloc[
0
]
)
for landscape in self.landscapes
]
with diagnostics.ProgressBar():
dfs = dask.compute(*tasks)

for attribute_value, df in zip(attribute_values, dfs):
# landscape_metrics_df.loc[
# attribute_value, columns
# ] = landscape.compute_landscape_metrics_df(
# metrics=metrics, metrics_kws=metrics_kws
# ).iloc[
# 0
# ]
landscape_metrics_df.loc[attribute_value, columns] = df.iloc[0]

return landscape_metrics_df.apply(pd.to_numeric)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ classifiers = [
requires-python = ">=3.8"
dependencies = [
"black",
"dask",
"geopandas",
"matplotlib >= 2.2",
"numba ; platform_system == 'Windows'",
Expand Down

0 comments on commit 2dafaaf

Please sign in to comment.