# Label data demo


In [None]:
!pip install rasterio
!pip install exactextract
!pip install mapclassify

In [1]:
import os
import subprocess
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio
import rasterio.features
import matplotlib.pyplot as plt


from rasterio.mask import mask
from rasterio.merge import merge
from rasterio.io import MemoryFile
from exactextract import exact_extract
from matplotlib.patches import Rectangle
from shapely.geometry import box, Polygon, MultiPolygon, GeometryCollection

## MOSAIKS grid

The standard resolution of MOSAIKS is a global grid at 0.01° resolution. Each grid cell is approximately 1 km² at the equator. The grid is often represented as a point grid, where each point is the center of a grid cell. This means that standard MOSAIKS features come with a latitude and longitude coordinate, which is the center of the grid cell.

Here we show several ways to create a MOSAIKS grid for a given location.

> It is important to note that the grid is standardised such that the grid cells are centered at intervals of 0.005 degrees (e.g., 10.005, 10.015, 10.025,...).

There is no firm reason why this resolution was chosen, but it is a good compromise between resolution and computational efficiency. If a user is computing customized features for their application, they can choose a different resolution. However, it is important to note that the resolution of the grid will affect the computation time and the amount of data that needs to be stored.

### Advantages of the MOSAIKS grid

The MOSAIKS grid has several advantages, but the promary advantage is that it helps avoid overlapping labels. Often data labels come with coordinates which are not standardly spaced. If you are forced to align your labels to a grid, you can avoid bleed over from one label to another. This is especially important when you are working with data that has a high degree of spatial autocorrelation.

### Create a grid for a given location

The following code creates a grid for a given location. The location can be given as a bounding box of the format `[minx, miny, maxx, maxy]`, a single polygon, or a dataframe with a geometry column. If a bounding box is given, the grid will be created for the bounding box. If a polygon is given, the grid will be created for the bounding box of the polygon, and then cropped to the polygon. If a dataframe is given, the grid will be created for the bounding box of each row, and then cropped to the geometry of that row, and it will be repeated for each row.


In [2]:
def create_grid(
    borders,
    resolution: float = 0.01,
    geometry_col: str = "geometry",
    id_col: str = "NAME",
    return_ids: bool = False,
) -> pd.DataFrame:
    """
    Create a grid of latitude and longitude coordinates for one or more geometries.
    It can accept a bounding box, a single polygon (or other Shapely geometry),
    or a GeoDataFrame with a geometry column.

    Parameters
    ----------
    borders : list or tuple or shapely.geometry.BaseGeometry or geopandas.GeoDataFrame
        - If list/tuple of length 4, interpreted as a bounding box: [minx, miny, maxx, maxy].
        - If a Shapely geometry (Polygon, MultiPolygon, etc.), creates a single-row GeoDataFrame.
        - If a GeoDataFrame, the function iterates over its rows.
    resolution : float, optional
        Grid resolution in degrees, default 0.01.
    geometry_col : str, optional
        Column name for the geometry in the resulting GeoDataFrame, by default "geometry".
    id_col : str, optional
        Column name in the GeoDataFrame to use as the ID column, or the name
        for the new column if bounding box / single polygon is provided. Default is "NAME".
    return_ids : bool, optional
        If True, generate and return the unique IDs for each grid cell. This will create a
        column 'unique_id' which follows the pattern 'lon_{lon}__lat_{lat}'. This option slows
        down the overall operation. Default is False.

    Returns
    -------
    pd.DataFrame
        A DataFrame with columns:
        - 'lat': latitude values (Y)
        - 'lon': longitude values (X)
        - `[id_col]`: the identifier for each geometry feature
        - 'unique_id': a string combining [id_col] + lon/lat for uniqueness (optional)
    """

    # 1. Convert input to a GeoDataFrame
    gdf = _to_geodataframe(borders, geometry_col, id_col)

    # 2. Ensure there's an ID column in the GeoDataFrame
    if id_col not in gdf.columns:
        # If user didn't provide an ID col for bounding box or single geometry,
        # assign a placeholder ID. For a multi-row GDF, user is expected to pass
        # an existing column name.
        gdf[id_col] = [f"feature_{i}" for i in range(len(gdf))]

    # 3. Rasterize each geometry and collect points
    result_list = []
    for _, row in gdf.iterrows():
        geom = row[geometry_col]
        this_id = row[id_col]

        if geom.is_empty:
            # Skip empty geometries
            continue

        minx, miny, maxx, maxy = geom.bounds

        # ---- Create arrays for lat and lon values (Note: lat reversed) ----
        # The 0.005 shift ensures that coordinates align on .005
        lats = np.arange(
            np.ceil(maxy / resolution) * resolution - 0.005, miny, -resolution
        )
        lons = np.arange(
            np.ceil(minx / resolution) * resolution + 0.005, maxx, resolution
        )

        if len(lats) == 0 or len(lons) == 0:
            # If bounding box is too small or resolution is large, might be empty
            continue

        # ---- Create a meshgrid ----
        lon_grid, lat_grid = np.meshgrid(lons, lats)

        # ---- Rasterize the geometry ----
        out_shape = (len(lats), len(lons))
        transform = rasterio.transform.from_bounds(
            minx, miny, maxx, maxy, out_shape[1], out_shape[0]
        )

        mask = rasterio.features.rasterize(
            [(geom, 1)],
            out_shape=out_shape,
            transform=transform,
            fill=0,
            dtype=np.uint8,
        )

        # ---- Extract the lat and lon values using the mask ----
        lat_values = lat_grid[mask == 1]
        lon_values = lon_grid[mask == 1]

        # ---- Create a DataFrame and append to the result list ----
        temp_df = pd.DataFrame({"lat": lat_values, "lon": lon_values})
        temp_df[id_col] = this_id
        result_list.append(temp_df)

    # 4. Concatenate the results
    if len(result_list) == 0:
        final_result = pd.DataFrame(columns=["lat", "lon", id_col, "unique_id"])
    else:
        final_result = pd.concat(result_list, ignore_index=True)
        if return_ids:
            # --- Create the unique_id column ---
            # e.g. 'lon_-10.005__lat_9.995'
            final_result["lon_rounded"] = final_result["lon"].round(3).astype(str)
            final_result["lat_rounded"] = final_result["lat"].round(3).astype(str)

            final_result["unique_id"] = (
                "lon_"
                + final_result["lon_rounded"]
                + "__lat_"
                + final_result["lat_rounded"]
            )

            final_result.drop(["lon_rounded", "lat_rounded"], axis=1, inplace=True)

    return final_result


def _to_geodataframe(borders, geometry_col: str, id_col: str) -> gpd.GeoDataFrame:
    """
    Internal helper that converts various input types into a standardized GeoDataFrame.

    Parameters
    ----------
    borders : list/tuple, shapely geometry, or GeoDataFrame
        Bounding box (list/tuple of length 4),
        single Shapely geometry (Polygon, MultiPolygon, etc.),
        or a GeoDataFrame.
    geometry_col : str
        The name of the geometry column to use or create.
    id_col : str
        The column in which to store or look for an ID (if relevant).

    Returns
    -------
    gpd.GeoDataFrame
        A GeoDataFrame with columns [id_col, geometry_col].
    """
    # Case 1: bounding box
    if isinstance(borders, (list, tuple)) and len(borders) == 4:
        minx, miny, maxx, maxy = borders
        geom = box(minx, miny, maxx, maxy)
        gdf = gpd.GeoDataFrame(
            {id_col: ["bbox_1"], geometry_col: [geom]}, crs="EPSG:4326"
        )

    # Case 2: single shapely geometry
    elif isinstance(borders, (Polygon, MultiPolygon, GeometryCollection)):
        gdf = gpd.GeoDataFrame(
            {id_col: ["geom_1"], geometry_col: [borders]}, crs="EPSG:4326"
        )

    # Case 3: GeoDataFrame
    elif isinstance(borders, gpd.GeoDataFrame):
        # If geometry_col does not exist, rename the current geometry column
        # so everything is consistent
        if geometry_col not in borders.columns:
            borders = borders.rename(columns={borders.geometry.name: geometry_col})

        gdf = borders.copy()
        gdf = gdf.set_geometry(geometry_col)

    else:
        raise ValueError(
            "Unsupported input for 'borders'. Must be one of:\n"
            "1) [minx, miny, maxx, maxy]\n"
            "2) A Shapely geometry (Polygon, MultiPolygon, etc.)\n"
            "3) A GeoDataFrame"
        )

    return gdf

### Create a grid for a bounding box

In the following example, we create a grid for an arbitrary bounding box.


In [None]:
# format [minx, miny, maxx, maxy]
bbox = [0.0, 0.0, 0.5, 0.5]
bbox_grid_df = create_grid(
    bbox,
    resolution=0.01,
    geometry_col="geometry",
    id_col="NAME",
    return_ids=True,
)
print(bbox_grid_df.head(), "\nShape: ", bbox_grid_df.shape)


#### Visualize the grid

Next we visualize the grid. To do this, we first create a GeoDataFrame with the grid points. We then buffer the points to create a polygon around each point. We use a `cap_style=3` to create square corners from the buffered points. Finally, we plot the grid points (green), the grid polygons (black), and the original bounding box (red).

**Note:** the following code produces a warning from the buffer operation. This is because the buffer operation is being conducted on geometry in a geographic CRS. This warning can be safely ignored in the context of MOSAIKS.


In [None]:
bbox_grid_gdf = gpd.GeoDataFrame(
    bbox_grid_df,
    geometry=gpd.points_from_xy(bbox_grid_df.lon, bbox_grid_df.lat),
    crs="EPSG:4326",
)
bbox_grid_gdf.geometry = bbox_grid_gdf.geometry.buffer(0.005, cap_style=3)

fig, ax = plt.subplots(figsize=(6, 6))

x_min, y_min, x_max, y_max = bbox
rect = Rectangle(
    (x_min, y_min),
    x_max - x_min,
    y_max - y_min,
    fill=False,
    color="red",
    linewidth=2,
)
ax.add_patch(rect)

bbox_grid_gdf.plot.scatter(x="lon", y="lat", s=0.25, c="green", ax=ax)
bbox_grid_gdf.boundary.plot(ax=ax, color="black", linewidth=0.5)

ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
plt.title("MOSAIKS grid cells with centroids")
plt.show()

### Create a grid for a polygon

We can also create a grid for a polygon. In the following example, we create a grid for a polygon which defines the outline of a dinosaur. We see that our function creates a grid for the bounding box of the polygon, and then crops the grid to the polygon. This is useful to minimize the number of grid points that are created. This is especially true if the purpose of the grid is to define coordinates to upload to the MOSAIKS API for feature requests in the file query.


In [None]:
dino_polygon = Polygon(
    [
        (0.000, 0.444),
        (0.136, 0.278),
        (0.227, 0.333),
        (0.318, 0.500),
        (0.364, 0.500),
        (0.409, 0.444),
        (0.455, 0.444),
        (0.500, 0.389),
        (0.500, 0.333),
        (0.364, 0.333),
        (0.500, 0.278),
        (0.455, 0.222),
        (0.364, 0.278),
        (0.364, 0.240),
        (0.390, 0.200),
        (0.390, 0.180),
        (0.364, 0.180),
        (0.318, 0.222),
        (0.273, 0.056),
        (0.318, 0.000),
        (0.227, 0.000),
        (0.182, 0.056),
        (0.136, 0.000),
        (0.091, 0.000),
        (0.045, 0.056),
        (0.091, 0.111),
        (0.136, 0.167),
        (0.045, 0.278),
    ]
)


dino_grid_df = create_grid(
    dino_polygon,
    resolution=0.01,
    return_ids=True,
)

print(dino_grid_df.head(), "\nShape: ", dino_grid_df.shape)

#### Visualize the grid

Here we visual our grid over the dinosaur polygon. Notice how the bounding box of the polygon is the same as the bounding box of the previous grid, though the grid is cropped to the polygon. Again, this is preferred to minimize the file sizes by not querying for unneccessary locations.


In [None]:
dino_grid_gdf = gpd.GeoDataFrame(
    dino_grid_df,
    geometry=gpd.points_from_xy(dino_grid_df.lon, dino_grid_df.lat),
    crs="EPSG:4326",
)
dino_grid_gdf.geometry = dino_grid_gdf.geometry.buffer(0.005, cap_style=3)

fig, ax = plt.subplots(figsize=(6, 6))

x, y = dino_polygon.exterior.xy
ax.plot(x, y, color="red")

dino_grid_gdf.plot.scatter(x="lon", y="lat", s=0.25, c="green", ax=ax)

dino_grid_gdf.boundary.plot(ax=ax, color="black", linewidth=0.5)

ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
plt.title("MOSAIKS grid cells with centroids")
plt.show()

### Create a grid for a GeoDataFrame

In the following example, we create a grid for a GeoDataFrame with a geometry column. The grid will be created for the bounding box of each row, and then cropped to the geometry of that row. First, we load an example of a GeoDataFrame with a geometry column, two countries in this case. We then create a grid for each country and visualize the grid for each country.

This is slightly more practical use case, as it allows for the creation of a grid for multiple locations at once, while still minimizing the number of grid points that are created.


In [None]:
base_url = "https://github.com/wmgeolab/geoBoundaries/raw/9469f09/releaseData/gbOpen"

togo_shape_fp = f"{base_url}/TGO/ADM0/geoBoundaries-TGO-ADM0.geojson"
benin_shape_fp = f"{base_url}/BEN/ADM0/geoBoundaries-BEN-ADM0.geojson"

togo_gdf = gpd.read_file(togo_shape_fp)
benin_gdf = gpd.read_file(benin_shape_fp)

tgo_ben_gdf = pd.concat([togo_gdf, benin_gdf], ignore_index=True)
tgo_ben_gdf

In [None]:
tgo_ben_grid = create_grid(
    tgo_ben_gdf,
    resolution=0.01,
    geometry_col="geometry",
    id_col="shapeISO",
    # return_ids=True,
)
tgo_ben_grid_gdf = gpd.GeoDataFrame(
    tgo_ben_grid,
    geometry=gpd.points_from_xy(tgo_ben_grid.lon, tgo_ben_grid.lat),
    crs="EPSG:4326",
)
tgo_ben_grid_gdf.geometry = tgo_ben_grid_gdf.geometry.buffer(0.005, cap_style=3)
tgo_ben_grid_gdf

#### Visualize the grids

Here we visualize our grids over the countries. Notice how the definition of the grid cells is lost as we zoom out. It is important to note that the maximum number of locations you can request via the MOSAIKS API is 100,000. In this case, it would be prudent to save a file for each country and upload them separately in 2 file queries.


In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

tgo_ben_grid_gdf.plot(column="shapeISO", ax=ax, legend=True)

ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
plt.title("MOSAIKS grid cells centroids")
plt.show()

## Sticking with the grid - preparing labels

For the purposes of this demonstration, we will take advantage of the standardised grid and match our labels to this resolution. This is recomended for the most efficient use of the MOSAIKS API file query. Alternatively users may take advantage of precomputed aggregations of MOSAIKS features. These precomputed files come in 0.1 degree, 1 degree grids as well as features summarised to several levels of administrative division. We will cover label aggregation to these administrative divisions later in this demonstration.

> There is no strict requirement to use this or any other grid system within the MOSAIKS framework. Using a different system may require the user to compute their own features.

### Example 1: Point labels (lat/lon in a CSV)

Scenario: You have a CSV containing locations (latitude, longitude) and some target variable (e.g., an economic indicator). In your notebook, you’ve loaded it into a DataFrame named XXXX.


In [None]:
data_dir = "geofabrik"

if not os.path.exists(data_dir):
    os.makedirs(data_dir, exist_ok=True)

    !wget -O {data_dir}/togo-latest-free.zip https://download.geofabrik.de/africa/togo-latest-free.shp.zip

    !unzip {data_dir}/togo-latest-free.zip -d {data_dir}

    !rm {data_dir}/togo-latest-free.zip

!ls -lhR {data_dir}

In [None]:
tgo_grid_gdf = tgo_ben_grid_gdf[tgo_ben_grid_gdf["shapeISO"] == "TGO"]
tgo_grid_gdf

In [None]:
togo_pois_gdf = gpd.read_file(os.path.join(data_dir, "gis_osm_pois_free_1.shp"))
togo_pois_gdf = togo_pois_gdf[togo_pois_gdf["fclass"].isin(["school"])]
togo_pois_gdf

In [None]:
# togo_pois_gdf.fclass.value_counts()

In [None]:
# togo_pois_gdf.plot(column="fclass", legend=False, figsize=(10, 10), markersize=0.5)

In [None]:
# First, perform spatial join between the grid and POIs
joined = gpd.sjoin(tgo_grid_gdf, togo_pois_gdf, how="left", predicate="contains")

# Group by grid cell (using lat/lon) and fclass, then count
summary = joined.groupby(
    [
        "lat",
        "lon",
    ],
    as_index=False,
).fclass.count()

summary = gpd.GeoDataFrame(
    summary,
    geometry=gpd.points_from_xy(summary.lon, summary.lat),
    crs="EPSG:4326",
)
summary.geometry = summary.geometry.buffer(0.005, cap_style=3)
summary

In [16]:
# summary.explore()

### Example 2: Polygon Labels

- Mines?


### Example 3: Line labels

- Road length


### Example 3: Raster Labels

- Forest cover


In [None]:
# Define target folder and ensure it exists
gw_folder = "global_forest_watch/70m_tree-cover"
os.makedirs(gw_folder, exist_ok=True)

# Base URL for downloading the data
base_url = "https://data-api.globalforestwatch.org/dataset/wri_tropical_tree_cover/v2020/download/geotiff"

# This is the default API key for the Global Forest Watch API (Public Key)
api_key = "2d60cd88-8348-4c0f-a6d5-bd9adb585a8c"

# List of tile IDs to download
tile_ids = ["20N_010W", "20N_000E", "10N_000E"]

# Loop through each tile ID and download
for tile_id in tile_ids:
    output_path = os.path.join(gw_folder, f"{tile_id}.tif")
    params = (
        f"?grid=10/40000&tile_id={tile_id}&pixel_meaning=percent&x-api-key={api_key}"
    )
    url = f"{base_url}{params}"

    print(f"Downloading {tile_id}...")
    result = subprocess.run(
        ["wget", "-O", output_path, url], capture_output=True, text=True
    )

    # Check for errors
    if result.returncode == 0:
        print(f"Downloaded {tile_id} successfully to {output_path}")
    else:
        print(f"Error downloading {tile_id}: {result.stderr}")

In [None]:
# Step 1: Load the Togo boundary GeoDataFrame and extract the geometry
togo_gdf = gpd.read_file(togo_shape_fp)
togo_boundary = [togo_gdf.geometry.values[0]]  # Ensure it's a list of geometries

# Folder and output paths
gw_files = ["10N_000E.tif", "20N_000E.tif", "20N_010W.tif"]
out_raster = os.path.join(gw_folder, "togo_gfw_tropical_tree_cover_2020.tif")

# Step 2: Read, crop, and store rasters in memory
memory_files = []
for raster_file in gw_files:
    raster_path = os.path.join(gw_folder, raster_file)
    with rasterio.open(raster_path) as src:
        # Clip raster with the Togo boundary
        out_image, out_transform = mask(src, togo_boundary, crop=True)
        out_meta = src.meta.copy()

        # Update metadata for cropped raster
        out_meta.update(
            {
                "driver": "GTiff",
                "height": out_image.shape[1],
                "width": out_image.shape[2],
                "transform": out_transform,
            }
        )

        # Write to a MemoryFile
        memfile = MemoryFile()
        with memfile.open(**out_meta) as dataset:
            dataset.write(out_image)
        memory_files.append(memfile)

# Step 3: Merge all cropped rasters into one
datasets = [memfile.open() for memfile in memory_files]
merged_data, merged_transform = merge(datasets)

# Modify values outside 0 and 100 to NaN
merged_data = merged_data.astype("float32")  # Ensure the data type supports NaN
merged_data[(merged_data < 0) | (merged_data > 100)] = np.nan

# Step 4: Write the final merged raster to file
# Step 4: Write the final merged raster to file with compression
with rasterio.open(
    out_raster,
    "w",
    driver="GTiff",
    height=merged_data.shape[1],
    width=merged_data.shape[2],
    count=merged_data.shape[0],
    dtype="float32",  # Ensure dtype matches merged_data
    crs=src.crs,
    transform=merged_transform,
    tiled=True,  # Enable tiling
    blockxsize=512,  # Set block size for optimal tiling
    blockysize=512,
    compress="lzw",  # Apply LZW compression
    nodata=np.nan,  # Set nodata value explicitly
) as dst:
    dst.write(merged_data)


print(f"Clipped, modified, and merged raster saved to: {out_raster}")


In [None]:
# Load the cropped file
forest_cover_raster = f"{gw_folder}/togo_gfw_tropical_tree_cover_2020.tif"

with rasterio.open(forest_cover_raster) as src:
    togo_gfw = src.read(1)  # read the first band

# Create the plot
fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(togo_gfw, cmap="YlGn")  # YlGn is a good colormap for forest cover
plt.colorbar(im, label="Tree Cover (%)")
ax.set_title("Togo Tree Cover 2020")
plt.show()

# Print some basic statistics
print(f"Min value: {togo_gfw.min()}")
print(f"Max value: {togo_gfw.max()}")
print(f"Mean value: {togo_gfw.mean():.2f}")

In [None]:
# Load the cropped file
with rasterio.open(forest_cover_raster) as src:
    # extract the raster values to the grid cells
    extracted_vals = exact_extract(
        src,
        tgo_grid_gdf,
        ["mean", "median", "min", "max"],
        include_cols=["lat", "lon"],
        output="pandas",
    )

# Make the extracted values into a GeoDataFrame
extracted_vals = gpd.GeoDataFrame(
    extracted_vals,
    geometry=gpd.points_from_xy(extracted_vals.lon, extracted_vals.lat),
    crs="EPSG:4326",
)

# Buffer the points to create a small rectangular polygon around each point
extracted_vals.geometry = extracted_vals.geometry.buffer(0.005, cap_style=3)

extracted_vals.head()

In [None]:
extracted_vals.plot(column="mean", legend=True, figsize=(10, 10), cmap="YlGn")

In [52]:
# extracted_vals.explore(column="mean")

## Aggregating labels to administrative divisions

- Reduce the number of labels (pros and cons)
- Reduce noise in the data
- Increase the interpretability of the data
- Not everything needs high resolution data
- Can be predicted at higher resolution (super resolution)
- Some data comes in aggregated already


## Label super resolution

- Predicting at a higher resolution than the labels
- Maybe not here? Might be better in the model notebook


In [None]:
# from sklearn.linear_model import RidgeCV, Ridge
# from sklearn.metrics import r2_score
# from sklearn.model_selection import train_test_split

# features = pd.read_parquet("features/features.parquet")
# features[["lat_temp", "lon_temp"]] = features["unique_id"].str.split("__", expand=True)

# features["lat"] = features["lat_temp"].str.replace("lat_", "")
# features["lon"] = features["lon_temp"].str.replace("lon_", "")

# features["lat"] = features["lat"].str.replace("--", ".").astype(float)
# features["lon"] = features["lon"].str.replace("--", ".").astype(float)

# features = features.drop(["lat_temp", "lon_temp"], axis=1)

# features = gpd.GeoDataFrame(
#     features,
#     geometry=gpd.points_from_xy(features.lon, features.lat),
#     crs="EPSG:4326",
# )

# joined = extracted_vals.sjoin(features, predicate="contains")

# joined = joined.dropna(subset=["mean"])

# # joined["mean"] = np.log1p(joined["mean"])
# joined["mean"].plot.hist(bins=50)

# feature_cols = [f"planet_{i}" for i in range(4000)]

# X = joined[feature_cols]
# y = joined["mean"]

# X_train, X_test, y_train, y_test = train_test_split(
#     X, y, test_size=0.2, random_state=42
# )

# # alphas = np.logspace(-2, 2, base=10, num=5)
# # ridge = RidgeCV(alphas=alphas, scoring="r2", cv=5)

# ridge = Ridge(alpha=0.1)


# ridge.fit(X_train, y_train)

# y_pred = np.maximum(ridge.predict(X_test), 0)

# r2 = r2_score(y_test, y_pred)

# # print(f"Best alpha: {ridge.alpha_}")
# # print(f"Validation R2 performance {ridge.best_score_:0.2f}")
# print(f"Test R2 performance {r2:.4f}")

# min_val = min(min(y_pred), min(y_test)) - 0.1
# max_val = max(max(y_pred), max(y_test)) + 0.1

# plt.figure(figsize=(6, 6))
# plt.plot([min_val, max_val], [min_val, max_val], 'k--', lw=1)
# plt.scatter(y_test, y_pred, alpha=0.5)

# plt.xlabel('Observed Test Values')
# plt.ylabel('Predicted Test Values')
# plt.title('Observed vs Predicted Test Values')
# plt.xlim(min_val, max_val)
# plt.ylim(min_val, max_val)

# # plt.text(
# #     0.05, 0.95,
# #     f'Validation R2: {ridge.best_score_:0.2f}',
# #     transform=plt.gca().transAxes, fontsize=12,
# #     verticalalignment='top'
# # )
# plt.text(
#     0.05, 0.90,
#     f'Test R2: {r2:.2f}',
#     transform=plt.gca().transAxes, fontsize=12,
#     verticalalignment='top'
# )

# plt.show()

In [17]:
# from sklearn.linear_model import RidgeCV, Ridge
# from sklearn.metrics import r2_score
# from sklearn.model_selection import train_test_split

# features = pd.read_parquet("features/features.parquet")
# features[["lat_temp", "lon_temp"]] = features["unique_id"].str.split("__", expand=True)

# features["lat"] = features["lat_temp"].str.replace("lat_", "")
# features["lon"] = features["lon_temp"].str.replace("lon_", "")

# features["lat"] = features["lat"].str.replace("--", ".").astype(float)
# features["lon"] = features["lon"].str.replace("--", ".").astype(float)

# features = features.drop(["lat_temp", "lon_temp"], axis=1)

# features = gpd.GeoDataFrame(
#     features,
#     geometry=gpd.points_from_xy(features.lon, features.lat),
#     crs="EPSG:4326",
# )

# joined = summary.sjoin(features, predicate="contains")

# joined["fclass"] = np.log1p(joined["fclass"])
# joined["fclass"].plot.hist(bins=20)

# feature_cols = [f"planet_{i}" for i in range(4000)]

# X = joined[feature_cols]
# y = joined["fclass"]

# X_train, X_test, y_train, y_test = train_test_split(
#     X, y, test_size=0.2, random_state=42
# )

# # alphas = np.logspace(-2, 2, base=10, num=5)
# # ridge = RidgeCV(alphas=alphas, scoring="r2", cv=5)

# ridge = Ridge(alpha=0.1)


# ridge.fit(X_train, y_train)

# y_pred = np.maximum(ridge.predict(X_test), 0)

# r2 = r2_score(y_test, y_pred)

# # print(f"Best alpha: {ridge.alpha_}")
# # print(f"Validation R2 performance {ridge.best_score_:0.2f}")
# print(f"Test R2 performance {r2:.4f}")

# min_val = min(min(y_pred), min(y_test)) - 0.1
# max_val = max(max(y_pred), max(y_test)) + 0.1

# plt.figure(figsize=(6, 6))
# plt.plot([min_val, max_val], [min_val, max_val], 'k--', lw=1)
# plt.scatter(y_test, y_pred, alpha=0.5)

# plt.xlabel('Observed Test Values')
# plt.ylabel('Predicted Test Values')
# plt.title('Observed vs Predicted Test Values')
# plt.xlim(min_val, max_val)
# plt.ylim(min_val, max_val)

# # plt.text(
# #     0.05, 0.95,
# #     f'Validation R2: {ridge.best_score_:0.2f}',
# #     transform=plt.gca().transAxes, fontsize=12,
# #     verticalalignment='top'
# # )
# plt.text(
#     0.05, 0.90,
#     f'Test R2: {r2:.2f}',
#     transform=plt.gca().transAxes, fontsize=12,
#     verticalalignment='top'
# )

# plt.show()