In [97]:
import pandas as pd
import numpy as np
import os
import rasterio

In [98]:
data = pd.read_csv('processed_data/ethiopia_10_by_10_ready.csv')

In [99]:
data.shape

(523, 5)

In [100]:
data["nightlights"].value_counts()


nightlights
0.000000    197
0.020875      2
0.107431      2
0.008249      2
0.070480      2
           ... 
0.032912      1
0.002156      1
0.001243      1
0.031277      1
0.000624      1
Name: count, Length: 322, dtype: int64

In [101]:
def drop_0s(df):
    """
    Elimina el 90% de las filas donde la columna 'nightlights' es igual a 0 de forma aleatoria.

    Args:
    df (pandas.DataFrame): DataFrame de entrada que contiene una columna 'nightlights'.

    Returns:
    pandas.DataFrame: DataFrame con el 90% de las filas con 'nightlights' igual a 0 eliminadas.
    """
    # Filtrar filas donde 'nightlights' es igual a 0
    zero_nightlights = df[df['nightlights'] == 0]

    # Calcular el número de filas a mantener (10% de las filas con 'nightlights' igual a 0)
    n_keep = int(0.1 * len(zero_nightlights))

    # Seleccionar aleatoriamente el 10% de las filas para mantener
    rows_to_keep = zero_nightlights.sample(n=n_keep, random_state=42)

    # Filtrar filas donde 'nightlights' no es igual a 0
    non_zero_nightlights = df[df['nightlights'] != 0]

    # Concatenar las filas no cero con las 10% de filas cero seleccionadas para mantener
    new_df = pd.concat([non_zero_nightlights, rows_to_keep])

    # Opcional: Reordenar el DataFrame final
    new_df = new_df.sample(frac=1).reset_index(drop=True)

    return new_df


In [102]:
new_data = drop_0s(data)
new_data

Unnamed: 0,country,cluster_lat,cluster_lon,cons_pc,nightlights
0,eth,10.528154,39.813107,5.367982,0.000393
1,eth,9.368588,42.794475,23.115905,1.319907
2,eth,9.847087,36.342298,12.599424,0.003002
3,eth,8.957485,38.762128,10.622940,13.210879
4,eth,8.093828,36.457461,8.324623,0.010723
...,...,...,...,...,...
340,eth,10.016586,38.249227,12.509533,0.020326
341,eth,10.588352,39.907496,6.779063,0.019677
342,eth,8.533923,39.268271,8.125985,3.471767
343,eth,9.337776,42.080731,13.001333,0.691566


In [103]:
def extract_subimage(src, lat, lon):
    """
    Extrae una sub-imagen de tamaño especificado alrededor de un punto central dado.

    Args:
    src (rasterio.io.DatasetReader): El objeto fuente abierto de Rasterio.
    lat (float): Latitud del centro de la sub-imagen.
    lon (float): Longitud del centro de la sub-imagen.
    km_per_pixel (float): Cuántos kilómetros representa un píxel.
    size_km (int): Tamaño de un lado de la sub-imagen cuadrada en kilómetros.

    Returns:
    np.ndarray: La sub-imagen extraída como una matriz de NumPy.
    """
    # Convertir coordenadas geográficas a coordenadas de píxel
    px, py = ~src.transform * (lon, lat)
    px, py = int(px), int(py)

    # Calcular el rango en píxeles para la sub-imagen
    km_per_pixel = 0.418877
    pixel_range = int(10 / km_per_pixel / 2)

    # Extraer la sub-imagen
    window = rasterio.windows.Window(px - pixel_range, py - pixel_range, 2 * pixel_range, 2 * pixel_range)
    sub_image = src.read(1, window=window)

    return sub_image


In [104]:
new_data

Unnamed: 0,country,cluster_lat,cluster_lon,cons_pc,nightlights
0,eth,10.528154,39.813107,5.367982,0.000393
1,eth,9.368588,42.794475,23.115905,1.319907
2,eth,9.847087,36.342298,12.599424,0.003002
3,eth,8.957485,38.762128,10.622940,13.210879
4,eth,8.093828,36.457461,8.324623,0.010723
...,...,...,...,...,...
340,eth,10.016586,38.249227,12.509533,0.020326
341,eth,10.588352,39.907496,6.779063,0.019677
342,eth,8.533923,39.268271,8.125985,3.471767
343,eth,9.337776,42.080731,13.001333,0.691566


In [105]:

import matplotlib.pyplot as plt
import rasterio
source = rasterio.open("raw_data/picture.tif")
X = []
for idx, row in new_data.iterrows():
    sub_image = extract_subimage(source, row['cluster_lat'], row['cluster_lon'])
    X.append(sub_image)

In [110]:
np.savez('ethiopia_10_by_10.npz', *X)

In [107]:
        # to load and see arrays:
        #     loaded = np.load('arrays.npz')
        #     arrays = [loaded[f'arr_{i}'] for i in range(len(loaded.files))]
        # print(arrays)

In [108]:
# import matplotlib.pyplot as plt
# import seaborn as sns
# vmin, vmax = np.percentile(sub_image, [2,98])
# plt.figure(figsize=(10, 10))  # Tamaño de la figura, ajustable según necesidad
# sns.heatmap(sub_image, cmap='gray', vmin=vmin, vmax=vmax)