In [None]:
# plotting imports
import contextily as cx
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.patches import Patch

# dataset import
from srai.datasets import ChicagoCrimeDataset

In [None]:
chicago_crime = ChicagoCrimeDataset()

Load default data

In [None]:
ds = chicago_crime.load()
ds.keys()

In [None]:
type(chicago_crime.train_gdf), type(chicago_crime.test_gdf)

In [None]:
print("Aggregation H3 resolution:", chicago_crime.resolution)

In [None]:
print("Prediction target:", chicago_crime.target)

In [None]:
gdf_train, gdf_test = ds["train"], ds["test"]

In [None]:
gdf_train.head()

Getting target values for h3

In [None]:
train_h3, _, test_h3 = chicago_crime.get_h3_with_labels()

In [None]:
train_h3.head()

In [None]:
test_h3.head()

In [None]:
fig, axes = plt.subplots(
    2, 1, sharex=False, sharey=False, figsize=(12, 20), height_ratios=[4, 1]
)

train_h3.plot(
    color="orange",
    markersize=0.1,
    ax=axes[0],
    label="train",
    alpha=np.minimum(np.power(train_h3[chicago_crime.target] + 0.4, 2), 1),
)
test_h3.plot(
    color="royalblue",
    markersize=0.1,
    ax=axes[0],
    label="test",
    alpha=np.minimum(np.power(test_h3[chicago_crime.target] + 0.4, 2), 1),
)

cx.add_basemap(axes[0], source=cx.providers.CartoDB.PositronNoLabels, crs=4326, zoom=12)
axes[0].set_title("Chicago crime data aggregated to H3 cells")
axes[0].legend(
    handles=[Patch(facecolor="orange"), Patch(facecolor="royalblue")],
    labels=["Train", "Test"],
)
axes[0].set_axis_off()

sns.kdeplot(
    x=train_h3[chicago_crime.target],
    label="train",
    color="orange",
    ax=axes[1],
    fill=False,
)
sns.kdeplot(
    x=test_h3[chicago_crime.target],
    label="test",
    color="royalblue",
    ax=axes[1],
    fill=False,
)
axes[1].set_title("Chicago crime data - target distribution")
axes[1].legend()

plt.show()

Load data from 2022

In [None]:
ds = chicago_crime.load(version="2022")
ds.keys()

In [None]:
type(chicago_crime.train_gdf), type(chicago_crime.test_gdf)

In [None]:
ds["train"].head()

Create your own train-test split -> Spatial splitting with bucket stratification

In [None]:
train, test = chicago_crime.train_test_split(
    test_size=0.2, random_state=42, n_bins=10, resolution=9
)

In [None]:
type(chicago_crime.train_gdf), type(chicago_crime.test_gdf)

In [None]:
chicago_crime.resolution

In [None]:
train.head()

In [None]:
test.head()