Skip to content

Commit

Permalink
Add method chaining support, fixes #382
Browse files Browse the repository at this point in the history
  • Loading branch information
anitagraser committed May 10, 2024
1 parent 6dfe3a1 commit e157868
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 20 deletions.
7 changes: 7 additions & 0 deletions movingpandas/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ def add_traj_id(self, overwrite=False):
f"Use overwrite=True to overwrite exiting values."
)
self.df[TRAJ_ID_COL_NAME] = self.id
return self

def add_direction(self, overwrite=False, name=DIRECTION_COL_NAME):
"""
Expand Down Expand Up @@ -1017,6 +1018,7 @@ def add_direction(self, overwrite=False, name=DIRECTION_COL_NAME):
# set the direction in the first row to the direction of the second row
self.df.at[self.get_start_time(), name] = self.df.iloc[1][name]
self.df.drop(columns=["prev_pt"], inplace=True)
return self

def add_angular_difference(
self,
Expand Down Expand Up @@ -1054,6 +1056,7 @@ def add_angular_difference(
self.df.at[self.get_start_time(), name] = 0.0
if not direction_exists:
self.df.drop(columns=[DIRECTION_COL_NAME], inplace=True)
return self

def add_distance(self, overwrite=False, name=DISTANCE_COL_NAME, units=None):
"""
Expand Down Expand Up @@ -1116,6 +1119,7 @@ def add_distance(self, overwrite=False, name=DISTANCE_COL_NAME, units=None):
)
conversion = get_conversion(units, self.crs_units)
self.df = self._get_df_with_distance(conversion, name)
return self

def add_speed(self, overwrite=False, name=SPEED_COL_NAME, units=UNITS()):
"""
Expand Down Expand Up @@ -1186,6 +1190,7 @@ def add_speed(self, overwrite=False, name=SPEED_COL_NAME, units=UNITS()):
)
conversion = get_conversion(units, self.crs_units)
self.df = self._get_df_with_speed(conversion, name)
return self

def add_acceleration(
self, overwrite=False, name=ACCELERATION_COL_NAME, units=UNITS()
Expand Down Expand Up @@ -1264,6 +1269,7 @@ def add_acceleration(
)
conversion = get_conversion(units, self.crs_units)
self.df = self._get_df_with_acceleration(conversion, name)
return self

def add_timedelta(self, overwrite=False, name=TIMEDELTA_COL_NAME):
"""
Expand All @@ -1287,6 +1293,7 @@ def add_timedelta(self, overwrite=False, name=TIMEDELTA_COL_NAME):
f"name arg."
)
self.df = self._get_df_with_timedelta(name)
return self

def _get_df_with_timedelta(self, name=TIMEDELTA_COL_NAME):
temp_df = self.df.copy()
Expand Down
8 changes: 8 additions & 0 deletions movingpandas/trajectory_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def add_speed(
self._add_speed(self.trajectories, name, units, overwrite)
else:
self._multithread(self._add_speed, n_threads, name, units, overwrite)
return self

def _add_speed(self, trajs, name, units, overwrite):
for traj in trajs:
Expand All @@ -575,6 +576,7 @@ def _multithread(self, fun, n_threads, name, units, overwrite):
for added in p.starmap(fun, args_iter):
results.extend(added)
self.trajectories = results
return results

def add_direction(self, overwrite=False, name=DIRECTION_COL_NAME, n_threads=1):
"""
Expand Down Expand Up @@ -602,6 +604,7 @@ def add_direction(self, overwrite=False, name=DIRECTION_COL_NAME, n_threads=1):
UNITS(),
overwrite,
)
return self

def _add_direction(self, trajs, name, units, overwrite):
for traj in trajs:
Expand Down Expand Up @@ -633,6 +636,7 @@ def add_angular_difference(
self._multithread(
self._add_angular_difference, n_threads, name, UNITS(), overwrite
)
return self

def _add_angular_difference(self, trajs, name, units, overwrite):
for traj in trajs:
Expand Down Expand Up @@ -675,6 +679,7 @@ def add_acceleration(
self._add_acceleration(self.trajectories, name, units, overwrite)
else:
self._multithread(self._add_acceleration, n_threads, name, units, overwrite)
return self

def _add_acceleration(self, trajs, name, units, overwrite):
for traj in trajs:
Expand Down Expand Up @@ -704,6 +709,7 @@ def add_distance(
self._add_distance(self.trajectories, name, units, overwrite)
else:
self._multithread(self._add_distance, n_threads, name, units, overwrite)
return self

def _add_distance(self, trajs, name, units, overwrite):
for traj in trajs:
Expand All @@ -730,6 +736,7 @@ def add_timedelta(self, overwrite=False, name=TIMEDELTA_COL_NAME, n_threads=1):
self._add_distance(self.trajectories, name, UNITS(), overwrite)
else:
self._multithread(self._add_distance, n_threads, name, UNITS(), overwrite)
return self

def _add_timedelta(self, trajs, name, units, overwrite):
for traj in trajs:
Expand All @@ -747,6 +754,7 @@ def add_traj_id(self, overwrite=False):
"""
for traj in self:
traj.add_traj_id(overwrite)
return self

def get_min(self, column):
"""
Expand Down
10 changes: 5 additions & 5 deletions movingpandas/trajectory_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def hvplot(self): # noqa F811
def _hvplot_end_points(self, tc):
from holoviews import dim, Overlay

try:
if "TrajectoryCollection" in str(type(tc)):
end_pts = tc.get_end_locations(with_direction=True)
except AttributeError: # if tc is actually a Trajectory
else: # Trajectory
tc.add_direction(name=self.direction_col_name, overwrite=True)
end_pts = tc.df.tail(1).copy()

Expand Down Expand Up @@ -268,21 +268,21 @@ def hvplot_pts(self):
self.MPD_PALETTE = list(Category10_10) + cc.palette["glasbey"]
self.color = self.kwargs.pop("color", None)

try:
if "TrajectoryCollection" in str(type(self.data)):
tc = self.data.copy()
if self.direction_col_name not in tc.trajectories[0].df.columns:
tc.add_direction(name=self.direction_col_name)
if self.column:
if self.column == self.speed_col_name and self.speed_col_missing:
tc.add_speed()
pts_gdf = tc.to_point_gdf()
except AttributeError:
else: # Trajectory
traj = self.data.copy()
if self.direction_col_name not in traj.df.columns:
traj.add_direction(name=self.direction_col_name)
if self.column:
if self.column == self.speed_col_name and self.speed_col_missing:
tc.add_speed()
traj.add_speed()
pts_gdf = traj.df

ids = None
Expand Down
10 changes: 9 additions & 1 deletion tutorials/0-debug.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@
"toy_traj.df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"toy_traj.add_speed(overwrite=True).df"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -91,7 +100,6 @@
"metadata": {},
"outputs": [],
"source": [
"toy_traj.add_speed()#overwrite=True)\n",
"intersections = toy_traj.clip(polygon)\n",
"intersections"
]
Expand Down
13 changes: 5 additions & 8 deletions tutorials/1-getting-started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@
},
"outputs": [],
"source": [
"toy_traj.add_distance(overwrite=True)\n",
"toy_traj.df"
"toy_traj.add_distance(overwrite=True).df"
]
},
{
Expand All @@ -170,8 +169,7 @@
},
"outputs": [],
"source": [
"toy_traj.add_speed(overwrite=True)\n",
"toy_traj.df"
"toy_traj.add_speed(overwrite=True).df"
]
},
{
Expand All @@ -180,8 +178,7 @@
"metadata": {},
"outputs": [],
"source": [
"toy_traj.add_acceleration(overwrite=True)\n",
"toy_traj.df"
"toy_traj.add_acceleration(overwrite=True).df"
]
},
{
Expand Down Expand Up @@ -975,7 +972,7 @@
"source": [
"cleaned = mpd.OutlierCleaner(split).clean(alpha=2) # .clean(v_max=100, units=(\"km\", \"h\"))\n",
"cleaned.add_speed(units=(\"km\", \"h\"), overwrite=True)\n",
"print(cleaned)"
"cleaned"
]
},
{
Expand Down Expand Up @@ -1093,7 +1090,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.10.11"
}
},
"nbformat": 4,
Expand Down
6 changes: 2 additions & 4 deletions tutorials/2-reading-data-from-files.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,7 @@
" 'link1_href', 'link1_text', 'link1_type', 'link2_href', 'link2_text', 'link2_type', 'sym', \n",
" 'type', 'fix', 'sat', 'hdop', 'vdop', 'pdop', 'ageofdgpsdata', 'dgpsid'], inplace=True) \n",
"traj = mpd.Trajectory(gdf, \"2019-02-18 0745\", obj_id=\"304\")\n",
"traj.add_distance()\n",
"traj.add_speed(name=\"speed (kph)\", units=(\"km\", \"h\"))\n",
"traj"
"traj.add_distance().add_speed(name=\"speed (kph)\", units=(\"km\", \"h\"))"
]
},
{
Expand Down Expand Up @@ -451,7 +449,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.11"
},
"vscode": {
"interpreter": {
Expand Down
3 changes: 1 addition & 2 deletions tutorials/7-multithreading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@
"outputs": [],
"source": [
"%%time\n",
"tc.add_speed(n_threads=5)\n",
"tc.plot(column=\"speed\", vmin=0, vmax=20)"
"tc.add_speed(n_threads=5).plot(column=\"speed\", vmin=0, vmax=20)"
]
},
{
Expand Down

0 comments on commit e157868

Please sign in to comment.