In [1]:
import os
import psycopg2
import xarray as xr
import numpy as np
import pandas as pd
import polars as pl
from shapely.geometry import Polygon
from shapely import wkb
import psycopg2
from psycopg2.extras import execute_values
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from dotenv import load_dotenv
from datetime import datetime
import time

# import logging
# logging.basicConfig(level=logging.DEBUG, filename='woa23_debug.log', filemode='a', format='%(asctime)s - %(levelname)s : %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
# logger = logging.getLogger()

In [2]:
load_dotenv()
DBUSER = os.getenv('DBUSER')
DBPASS = os.getenv('DBPASS')
DBHOST = os.getenv('DBHOST')
DBPORT = os.getenv('DBPORT')
DBNAME = os.getenv('DBNAME')
GRIDSET = os.getenv('GRIDSET')
TIMESET = os.getenv('TIMESET')
PARAMSET = os.getenv('PARAMSET')
# print(PARAMSET)
pars = [c.strip() for c in PARAMSET.split(',')] # if c.strip() in available_params]))
time_periods = [p.strip() for p in str(TIMESET).split(',')]
grids = [g.strip() for g in str(GRIDSET).split(',')]
db_settings = {
    'dbname': DBNAME,
    'user': DBUSER,
    'password': DBPASS,
    'host': DBHOST,
    'port': DBPORT,
    'options': "-c statement_timeout=0",
    'keepalives': 1,
    'keepalives_idle': 30,  # Number of seconds of inactivity after which a keepalive message is sent
    'keepalives_interval': 10,  # Number of seconds between keepalive messages when no response is received
    'keepalives_count': 5  # Number of attempts before concluding the connection is dead

}

# Function to establish connection to PostgreSQL
def connect_db(settings):
    conn = psycopg2.connect(**settings)
    conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
    return conn

print(grids, pars, time_periods)

['04'] ['TS'] ['monthly']


In [3]:
# Function to create TS table in PostGIS
# When try grd025_annual_ts use temperature FLOAT8, salinity FLOAT8,
def create_ts_table(conn, table_name, parameter_set):
    with conn.cursor() as cur:
        columns = ", ".join([f"{param} FLOAT" for param in parameter_set])
        create_table_query = f"""
        CREATE TABLE IF NOT EXISTS {table_name} (
            lon FLOAT,
            lat FLOAT,
            depth INTEGER,
            time_period INTEGER,
            {columns},
            geom GEOMETRY
        );
        """
        cur.execute(create_table_query)
        
        index_geom = f"CREATE INDEX idx_{table_name}_geom ON {table_name} USING GIST (geom);"
        index_depth = f"CREATE INDEX idx_{table_name}_depth ON {table_name} (depth);"
        index_time_period = f"CREATE INDEX idx_{table_name}_time_period ON {table_name} (time_period);"

        cur.execute(index_geom)
        cur.execute(index_depth)
        cur.execute(index_time_period)

In [4]:
# Test read zarr dataset
zarr_store_path = '../data/025_degree/annual/TS'
res = '04'  # '04' for 0.25-degree, '01' for 1-degree
ds = xr.open_zarr(zarr_store_path, consolidated=False)
# print(ds)

In [4]:
# Function to generate geom for PostGIS using the grid structure
def generate_grid_polygon(lon, lat, res):
    step = 0.25 if res == '04' else 1.0
    lon_min = lon - step / 2
    lon_max = lon + step / 2
    lat_min = lat - step / 2
    lat_max = lat + step / 2
    return Polygon([(lon_min, lat_min), (lon_min, lat_max), (lon_max, lat_max), (lon_max, lat_min), (lon_min, lat_min)])

def sanitize_value(value):
    """Convert 'NaN' or 'None' to None to handle as NULL in PostgreSQL."""
    if value is None:  # Handle Python's None
        return None
    try:
        # Check if the value is NaN
        if np.isnan(value):
            return None
    except TypeError:
        # If the value is not numeric, it's fine to return it as-is
        pass
    return value

In [6]:
grid_db = {'01': 'grd1', '04': 'grd025'}  # Two gridded resolutions data: 1-degree and 0.25-degree in WOA23
grid_dir = {'01': '1_degree', '04': '025_degree'}

In [None]:
def insert_into_postgis_sequentially(conn, table_name, data, parameter_set, res, batch_size=100000):
    cur = conn.cursor()
    count = 0
    #with conn.cursor() as cur:
    if True:
        for time_idx in range(data.dims['time_periods']):
            for depth_idx in range(data.dims['depth']):
                for lat_idx in range(data.dims['lat']):
                    for lon_idx in range(data.dims['lon']):
                        lon = float(data['lon'].data[lon_idx])
                        lat = float(data['lat'].data[lat_idx])
                        depth = int(data['depth'].data[depth_idx])
                        period = int(data['time_periods'].data[time_idx])

                        # Prepare values for all parameters in the parameter_set
                        values = []
                        for param in parameter_set:
                            values.append(float(data['mn'].sel(parameters=param).data[time_idx, depth_idx, lat_idx, lon_idx]))

                        # Check if all values are NaN; if so, skip the row
                        if all(np.isnan(v) for v in values):
                            continue

                        geom = generate_grid_polygon(lon, lat, res)
                        geom_wkb = wkb.dumps(geom, hex=True)

                        # Dynamically create the query string
                        query = f"""
                        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
                        VALUES (%s, %s, %s, %s, {', '.join(['%s'] * len(parameter_set))}, ST_SetSRID(%s::geometry, 4326));
                        """

                        cur.execute(query, (lon, lat, depth, period, *values, geom_wkb))
                        count += 1

                        if count % batch_size == 0:
                            conn.commit()  # Commit every batch_size records

In [16]:
def insert_into_postgis(conn, table_name, data, parameter_set, res, start_depth_idx=0, batch_size=5000):
    cur = conn.cursor()
    batch = []
    total_records = data.dims['time_periods'] * data.dims['depth'] * data.dims['lat'] * data.dims['lon']
    processed_records = 0
    start_time = datetime.now()
    
    forceInitTime = 1
    forceMaxDepth = 12
    forceTerminate = False

    for time_idx in range(data.dims['time_periods']):
        if forceTerminate:
            break
        if time_idx < forceInitTime:
            continue
        print(f"Processing time_idx: {time_idx}, {int(data['time_periods'].data[time_idx])}")
        for depth_idx in range(start_depth_idx, data.dims['depth']):
            if depth_idx > forceMaxDepth:
                forceTerminate = True
                break
            for lat_idx in range(data.dims['lat']):
                for lon_idx in range(data.dims['lon']):
                    lon = float(data['lon'].data[lon_idx])
                    lat = float(data['lat'].data[lat_idx])
                    depth = int(data['depth'].data[depth_idx])
                    period = int(data['time_periods'].data[time_idx])

                    # Prepare values for all parameters in the parameter_set
                    values = [
                        sanitize_value(float(data['mn'].sel(parameters=param).data[time_idx, depth_idx, lat_idx, lon_idx]))
                        for param in parameter_set
                    ]

                    # Check if all values are None (formerly NaN); if so, skip the row
                    if all(v is None for v in values):
                        continue

                    geom = generate_grid_polygon(lon, lat, res)
                    geom_wkb = wkb.dumps(geom, hex=True)

                    # Add the row to the batch
                    batch.append((lon, lat, depth, period, *values, geom_wkb))
                    processed_records += 1

                    if (processed_records == 1 or processed_records % 1000 == 0):
                        end_time = datetime.now()
                        print(f"In lon, lat, depth, time_period: {lon}, {lat}, {depth}, {period}, query records: {processed_records} with {(end_time - start_time).total_seconds()} seconds")
                        print(f"Values are: {values}", flush=True)
                        start_time = end_time

                    # Insert and commit in batches
                    if len(batch) >= batch_size:
                        psycopg2.extras.execute_values(cur, f"""
                        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
                        VALUES %s;
                        """, batch, template=None, page_size=batch_size)
                        conn.commit()
                        batch.clear()  # Clear the batch after committing

    # Commit any remaining rows
    if batch:
        psycopg2.extras.execute_values(cur, f"""
        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
        VALUES %s;
        """, batch, template=None, page_size=batch_size)
        conn.commit()

#    cur.close()

In [24]:
def insert_into_postgis_test(conn, table_name, data, parameter_set, res, start_depth_idx=0, batch_size=5000):
    cur = conn.cursor()
    batch = []
    total_records = data.dims['time_periods'] * data.dims['depth'] * data.dims['lat'] * data.dims['lon']
    processed_records = 0
    count = 0
    start_time = datetime.now()

    forceInitTime = 1
    forceMaxDepth = 12
    forceTerminate = False

    for time_idx in range(data.dims['time_periods']):
        if forceTerminate:
            break
        if time_idx < forceInitTime:
            continue
        print(f"Processing time_idx: {time_idx}, {int(data['time_periods'].data[time_idx])}")
        for depth_idx in range(start_depth_idx, data.dims['depth']):
            if depth_idx > forceMaxDepth:
                forceTerminate = True
                break
            for lat_idx in range(data.dims['lat']):
                for lon_idx in range(data.dims['lon']):
                    lon = float(data['lon'].data[lon_idx])
                    lat = float(data['lat'].data[lat_idx])
                    depth = int(data['depth'].data[depth_idx])
                    period = int(data['time_periods'].data[time_idx])

                    # Prepare values for all parameters in the parameter_set
                    values = []
                    for param in parameter_set:
                        values.append(float(data['mn'].sel(parameters=param).data[time_idx, depth_idx, lat_idx, lon_idx]))

                    end1_time = datetime.now()
                    check_all_na = all(np.isnan(v) for v in values)
                    if (count == 1 or count % 1000 == 0):
                        print(f"Step1 took {(end1_time - start_time).total_seconds()} seconds and values are all NaN: {check_all_na}")
                        print(f"Data status: {lon} {lat} {depth} {period} {values}")

                    # Skip rows with all NaN values
                    if check_all_na:
                        count += 1
                        continue

                    geom = generate_grid_polygon(lon, lat, res)
                    geom_wkb = wkb.dumps(geom, hex=True)
                    end2_time = datetime.now()

                    batch.append((lon, lat, depth, period, *values, f'ST_SetSRID(\'{geom_wkb}\'::geometry, 4326)'))
                    end3_time = datetime.now()
                    processed_records += 1
                    count += 1

                    if (processed_records == 1 or processed_records % 1000 == 0):
                        end_time = datetime.now()
                        print(f"Time taken to handle query records: {processed_records} with {(end_time - start_time).total_seconds()} seconds")
                        print(f"Each step took {(end1_time - start_time).total_seconds()}, {(end2_time - end1_time).total_seconds()}, {(end3_time - end2_time).total_seconds()} seconds")
                        start_time = end_time

                    # Insert and commit in batches
                    if len(batch) >= batch_size:
                        psycopg2.extras.execute_values(cur, f"""
                        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
                        VALUES %s;
                        """, batch, page_size=batch_size)
                        conn.commit()
                        batch.clear()  # Clear the batch after committing

                        # Print progress
                        print(f"Inserted {processed_records}/{total_records} records. Time: {time.ctime()}")

    # Insert any remaining data
    if batch:
        psycopg2.extras.execute_values(cur, f"""
        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
        VALUES %s;
        """, batch, page_size=batch_size)
        conn.commit()
        print(f"Final insertion completed. Total records inserted: {processed_records}. Time: {time.ctime()}")


In [7]:
if True: #__name__ == '__main__' #only Test
    res = grids[0]  # '04' for 0.25-degree, '01' for 1-degree
    print("Test res, pars", res, grids, pars)

    zarr_store_path = f"../data/{grid_dir[res]}/{time_periods[0]}/{pars[0]}"
    table_name = f"{grid_db[res]}_{time_periods[0]}_{pars[0]}"
    if res == '04' or pars == 'TS':
        parameter_set = ['temperature', 'salinity']  # For TS subgroup
    elif pars[0] == 'Oxy': 
        parameter_set = ['oxygen', 'o2sat', 'AOU']
    else:
        parameter_set = ['silicate', 'phosphate', 'nitrate']        

print("Test table name and parameter_set", table_name, parameter_set)

Test res, pars 04 ['04'] ['TS']
Test table name and parameter_set grd025_monthly_TS ['temperature', 'salinity']


In [8]:
data = xr.open_zarr(zarr_store_path, consolidated=False)
print(data.dims['depth'])

57


  print(data.dims['depth'])


In [9]:
print(data)

<xarray.Dataset> Size: 57GB
Dimensions:       (time_periods: 12, parameters: 2, depth: 57, lat: 720,
                   lon: 1440)
Coordinates:
  * depth         (depth) float32 228B 0.0 5.0 10.0 ... 1.4e+03 1.45e+03 1.5e+03
  * lat           (lat) float32 3kB -89.88 -89.62 -89.38 ... 89.38 89.62 89.88
  * lon           (lon) float32 6kB -179.9 -179.6 -179.4 ... 179.4 179.6 179.9
  * parameters    (parameters) <U11 88B 'temperature' 'salinity'
  * time_periods  (time_periods) <U2 96B '1' '2' '3' '4' ... '9' '10' '11' '12'
Data variables:
    an            (time_periods, parameters, depth, lat, lon) float32 6GB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    dd            (time_periods, parameters, depth, lat, lon) float32 6GB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    gp            (time_periods, parameters, depth, lat, lon) float32 6GB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    ma            (time_periods, parameters, depth, lat, lon

In [10]:
print(data['depth'])

<xarray.DataArray 'depth' (depth: 57)> Size: 228B
array([   0.,    5.,   10.,   15.,   20.,   25.,   30.,   35.,   40.,   45.,
         50.,   55.,   60.,   65.,   70.,   75.,   80.,   85.,   90.,   95.,
        100.,  125.,  150.,  175.,  200.,  225.,  250.,  275.,  300.,  325.,
        350.,  375.,  400.,  425.,  450.,  475.,  500.,  550.,  600.,  650.,
        700.,  750.,  800.,  850.,  900.,  950., 1000., 1050., 1100., 1150.,
       1200., 1250., 1300., 1350., 1400., 1450., 1500.], dtype=float32)
Coordinates:
  * depth    (depth) float32 228B 0.0 5.0 10.0 15.0 ... 1.4e+03 1.45e+03 1.5e+03


In [11]:
print(data['time_periods'])

<xarray.DataArray 'time_periods' (time_periods: 12)> Size: 96B
array(['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12'],
      dtype='<U2')
Coordinates:
  * time_periods  (time_periods) <U2 96B '1' '2' '3' '4' ... '9' '10' '11' '12'


In [12]:
print(float(data['mn'].sel(parameters='temperature').data[0, 3, 60, 60]))
print(float(data['mn'].sel(parameters='temperature').data[0, 3, 200, 100]))

nan
18.777061462402344


In [None]:
#for time_idx in range(data.dims['time_periods']):
#    print(time_idx)
#    print(int(data['time_periods'].data[time_idx]))

In [12]:
print(data.coords['time_periods'].data)

['13' '14' '15' '16']


In [3]:
conn = connect_db(db_settings)


In [18]:
if True:
            time_idx = 1
            depth_idx = 3
            depth = int(data['depth'].data[depth_idx])
            period = int(data['time_periods'].data[time_idx])

            # Load a chunk (all lat/lon for the current depth and time period)
            chunk = data['mn'].isel(time_periods=time_idx, depth=depth_idx).load()
            lon_vals = data['lon'].data
            lat_vals = data['lat'].data
            print(chunk)
            print(lon_vals)
            print(lat_vals)
            
            # Prepare a DataFrame
            df = pd.DataFrame({
                'lon': np.tile(lon_vals, len(lat_vals)),
                'lat': np.repeat(lat_vals, len(lon_vals)),
                'depth': depth,
                'time_period': period,
            })

            # Add parameter values to the DataFrame with sanitization
            for param in parameter_set:
                param_values = chunk.sel(parameters=param).values.flatten()
                df[param] = [sanitize_value(v) for v in param_values]

            # Drop rows where all parameter values are None
            df = df.dropna(subset=parameter_set, how='all')

            # Add geometry as WKB
            df['geom'] = df.apply(
                lambda row: wkb.dumps(generate_grid_polygon(row['lon'], row['lat'], res), hex=True),
                axis=1
            )

            # Convert DataFrame to a list of tuples for batch insertion
            rows = df.to_records(index=False).tolist()




<xarray.DataArray 'mn' (parameters: 2, lat: 720, lon: 1440)> Size: 8MB
array([[[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        ...,
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan]],

       [[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        ...,
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan]]], dtype=float32)
Coordinates:
    depth         float32 4B 15.0
  * lat           (lat) float32 3kB -89.88 -89.62 -89.38 ... 89.38 89.62 89.88
  * lon           (lon) float32 6kB -179.9 -179.6 -179.4 ... 179.4 179.6 179.9
  * parameters    (parameters) <U11 88B 'temperature' 'salinity'
    time_periods  <U2 8B '2'
[-179.875 -179.625 -179.375 ...  179.375  179

In [19]:
print(df)

             lon     lat  depth  time_period  temperature   salinity  \
66304   -163.875 -78.375     15            2    -1.524229  33.875549   
66305   -163.625 -78.375     15            2    -0.711931  33.986385   
67730   -167.375 -78.125     15            2    -2.004463        NaN   
67731   -167.125 -78.125     15            2    -1.569681        NaN   
67735   -166.125 -78.125     15            2    -1.504463        NaN   
...          ...     ...    ...          ...          ...        ...   
1029592  178.125  88.625     15            2    -1.585266        NaN   
1030996  169.125  88.875     15            2    -1.566424        NaN   
1030997  169.375  88.875     15            2    -1.566300        NaN   
1031002  170.625  88.875     15            2    -1.567200  28.735386   
1031004  171.125  88.875     15            2    -1.567458        NaN   

                                                      geom  
66304    0103000000010000000500000000000000008064C00000...  
66305    0103

In [32]:
#Polars version
if True:
    # Create Polars DataFrame
    df = pl.DataFrame({
        'lon': np.tile(lon_vals, len(lat_vals)),
        'lat': np.repeat(lat_vals, len(lon_vals)),
        'depth': [depth] * len(lon_vals) * len(lat_vals),
        'time_period': [period] * len(lon_vals) * len(lat_vals),
    })

    # Add parameter values to the Polars DataFrame
    for param in parameter_set:
        param_values = chunk.sel(parameters=param).values.flatten()
        df = df.with_columns(pl.Series(param, [sanitize_value(v) for v in param_values]))

    # Drop rows where all parameter values are None
    df = df.filter(~pl.all_horizontal([df[param].is_null() for param in parameter_set]))

    # Add geometry as WKB
    geom = [
        wkb.dumps(generate_grid_polygon(row[0], row[1], res), hex=True)
        for row in zip(df['lon'].to_numpy(), df['lat'].to_numpy())
    ]
    df = df.with_columns(pl.Series('geom', geom))


    # Convert Polars DataFrame to a list of tuples for batch insertion
    rows = [
        (
            row['lon'], row['lat'], row['depth'], row['time_period'],
            *[row[param] for param in parameter_set], row['geom']
        )
        for row in df.iter_rows(named=True)
    ]


In [28]:
print(df)

shape: (262_122, 7)
┌──────────┬─────────┬───────┬─────────────┬─────────────┬───────────┬─────────────────────────────┐
│ lon      ┆ lat     ┆ depth ┆ time_period ┆ temperature ┆ salinity  ┆ geom                        │
│ ---      ┆ ---     ┆ ---   ┆ ---         ┆ ---         ┆ ---       ┆ ---                         │
│ f32      ┆ f32     ┆ i64   ┆ i64         ┆ f32         ┆ f32       ┆ str                         │
╞══════════╪═════════╪═══════╪═════════════╪═════════════╪═══════════╪═════════════════════════════╡
│ -163.875 ┆ -78.375 ┆ 15    ┆ 2           ┆ -1.524229   ┆ 33.875549 ┆ 010300000001000000050000000 │
│          ┆         ┆       ┆             ┆             ┆           ┆ 000…                        │
│ -163.625 ┆ -78.375 ┆ 15    ┆ 2           ┆ -0.711931   ┆ 33.986385 ┆ 010300000001000000050000000 │
│          ┆         ┆       ┆             ┆             ┆           ┆ 000…                        │
│ -167.375 ┆ -78.125 ┆ 15    ┆ 2           ┆ -2.004463   ┆ null      ┆ 

In [22]:
#Testing
import polars as pl
import numpy as np

# Mock data for testing
lon_vals = [-179.875, -179.625, -179.375]
lat_vals = [-89.875, -89.625]
depth = 5
time_period = 2
parameter_set = ['temperature', 'salinity']

# Create mock data with some NaN values
mock_data = {
    'lon': np.tile(lon_vals, len(lat_vals)),
    'lat': np.repeat(lat_vals, len(lon_vals)),
    'depth': [depth] * len(lon_vals) * len(lat_vals),
    'time_period': [time_period] * len(lon_vals) * len(lat_vals),
    'temperature': [None, 10.5, 20.1, np.nan, np.nan, 15.2],
    'salinity': [35.1, 34.5, None, np.nan, 36.0, np.nan]
}

# Create the Polars DataFrame
df = pl.DataFrame(mock_data)

# Print initial DataFrame
print("Original DataFrame:")
print(df)

# Ensure sanitization using a custom sanitize_value function
def sanitize_value(value):
    """Convert 'NaN' or 'None' to None to handle as NULL in PostgreSQL."""
    if value is None:  # Handle Python's None
        return None
    try:
        # Check if the value is NaN
        if np.isnan(value):
            return None
    except TypeError:
        # If the value is not numeric, it's fine to return it as-is
        pass
    return value

# Apply sanitization
df = df.with_columns([
    pl.Series('temperature', [sanitize_value(v) for v in df['temperature']]),
    pl.Series('salinity', [sanitize_value(v) for v in df['salinity']])
])

# Filter out rows where all parameter_set columns are null
df = df.filter(~pl.all_horizontal(pl.col(parameter_set).is_null()))

# Print sanitized DataFrame
print("\nSanitized and Filtered DataFrame:")
print(df)


Original DataFrame:
shape: (6, 6)
┌──────────┬─────────┬───────┬─────────────┬─────────────┬──────────┐
│ lon      ┆ lat     ┆ depth ┆ time_period ┆ temperature ┆ salinity │
│ ---      ┆ ---     ┆ ---   ┆ ---         ┆ ---         ┆ ---      │
│ f64      ┆ f64     ┆ i64   ┆ i64         ┆ f64         ┆ f64      │
╞══════════╪═════════╪═══════╪═════════════╪═════════════╪══════════╡
│ -179.875 ┆ -89.875 ┆ 5     ┆ 2           ┆ null        ┆ 35.1     │
│ -179.625 ┆ -89.875 ┆ 5     ┆ 2           ┆ 10.5        ┆ 34.5     │
│ -179.375 ┆ -89.875 ┆ 5     ┆ 2           ┆ 20.1        ┆ null     │
│ -179.875 ┆ -89.625 ┆ 5     ┆ 2           ┆ NaN         ┆ NaN      │
│ -179.625 ┆ -89.625 ┆ 5     ┆ 2           ┆ NaN         ┆ 36.0     │
│ -179.375 ┆ -89.625 ┆ 5     ┆ 2           ┆ 15.2        ┆ NaN      │
└──────────┴─────────┴───────┴─────────────┴─────────────┴──────────┘

Sanitized and Filtered DataFrame:
shape: (5, 6)
┌──────────┬─────────┬───────┬─────────────┬─────────────┬──────────┐
│ lon  

In [13]:
print(table_name)
print(parameter_set)

grd025_monthly_TS
['temperature', 'salinity']


In [34]:
cur = conn.cursor()
processed_records = 0
batch_size=100000

if True:
            try:
                execute_values(
                    cur,
                    f"""
                    INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
                    VALUES %s;
                    """,
                    rows,
                    page_size=batch_size
                )
                conn.commit()
                processed_records += len(rows)
                print(f"Inserted {len(rows)} records for depth={depth}, time_period={period}. Total inserted: {processed_records}")

            except Exception as e:
                conn.rollback()
                print(f"Error during batch insert for depth={depth}, time_period={period}: {e}")

conn.close()

Inserted 262122 records for depth=15, time_period=2. Total inserted: 262122


In [14]:
def process_chunk_with_polars(chunk, lon_vals, lat_vals, depth, period, parameter_set, res):
    """Process a chunk of data with Polars."""
    # Create Polars DataFrame
    df = pl.DataFrame({
        'lon': np.tile(lon_vals, len(lat_vals)),
        'lat': np.repeat(lat_vals, len(lon_vals)),
        'depth': [depth] * len(lon_vals) * len(lat_vals),
        'time_period': [period] * len(lon_vals) * len(lat_vals),
    })

    # Add parameter values to the Polars DataFrame
    for param in parameter_set:
        if param not in chunk.coords['parameters'].values:
            raise ValueError(f"Parameter '{param}' not found in the dataset's parameters dimension.")
        param_values = chunk['mn'].sel(parameters=param).values.flatten()
        df = df.with_columns(pl.Series(param, [sanitize_value(v) for v in param_values]))

    # Drop rows where all parameter values are None
    df = df.filter(~pl.all_horizontal([df[param].is_null() for param in parameter_set]))

    # Add geometry as WKB
    geom = [
        wkb.dumps(generate_grid_polygon(row[0], row[1], res), hex=True)
        for row in zip(df['lon'].to_numpy(), df['lat'].to_numpy())
    ]
    df = df.with_columns(pl.Series('geom', geom))

    # Convert Polars DataFrame to a list of tuples for batch insertion
    rows = [
        (
            row['lon'], row['lat'], row['depth'], row['time_period'],
            *[row[param] for param in parameter_set], row['geom']
        )
        for row in df.iter_rows(named=True)
    ]
    return rows

In [None]:
def process_and_insert_zarr(conn, table_name, data, parameter_set, res, start_depth_idx=0, batch_size=100000):
    """Process Zarr data and insert into PostGIS using Polars."""
    cur = conn.cursor()
    lon_vals = data['lon'].values
    lat_vals = data['lat'].values

    processed_records = 0
    start_time = datetime.now()

    #forceInitTime = 3
    #forceMaxTime = 3
    #forceMaxDepth = 0
    #forceTerminate = False

    for time_idx in range(data.dims['time_periods']):
        #if forceTerminate or time_idx > forceMaxTime:
        #    break
        #if time_idx < forceInitTime:
        #    continue

        period = str(data['time_periods'].data[time_idx])  # Convert to string
        print(f"Processing time_idx={time_idx}, time_period={period}")
        
        for depth_idx in range(start_depth_idx, data.dims['depth']):
            #if depth_idx > forceMaxDepth:
            #    forceTerminate = True
            #    break
                        
            depth = float(data['depth'].data[depth_idx])
            print(f"Processing depth_idx={depth_idx}, depth={depth}")

            # Load the data chunk
            chunk = data.sel(time_periods=period, depth=depth)  # Use string `period`
            # Debugging chunk
            print(f"Chunk dimensions: {chunk.dims}")
            print(f"Chunk coordinates: {chunk.coords}")
            
            # Process the chunk with Polars
            rows = process_chunk_with_polars(chunk, lon_vals, lat_vals, int(depth), int(period), parameter_set, res)
                
            # Insert rows in batches
            for i in range(0, len(rows), batch_size):
                batch = rows[i:i + batch_size]
                try:
                    execute_values(
                        cur,
                        f"""
                        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
                        VALUES %s;
                        """,
                        batch,
                        page_size=batch_size
                    )
                    conn.commit()
                    processed_records += len(batch)
                    print(f"Inserted {processed_records} records. Last batch size: {len(batch)}.")
                except Exception as e:
                    print(f"Error during batch insert for depth={depth}, time_period={period}: {e}")
                    conn.rollback()
                    raise

            #if (processed_records == 1 or processed_records % 1000 == 0):
            end_time = datetime.now()
            print(f"Finished depth_idx={depth_idx}, depth={depth}records and handle {processed_records} with {(end_time - start_time).total_seconds()} second")
            start_time = end_time    

    print(f"Processing complete. Total records inserted: {processed_records}.")
    cur.close()

In [17]:
if True: #__name__ == '__main__' #only Test
    print("Test res, pars", res, grids, pars)

print("Test table name and parameter_set", table_name, parameter_set)
print(data)

Test res, pars 04 ['04'] ['TS']
Test table name and parameter_set grd025_monthly_TS ['temperature', 'salinity']
<xarray.Dataset> Size: 57GB
Dimensions:       (time_periods: 12, parameters: 2, depth: 57, lat: 720,
                   lon: 1440)
Coordinates:
  * depth         (depth) float32 228B 0.0 5.0 10.0 ... 1.4e+03 1.45e+03 1.5e+03
  * lat           (lat) float32 3kB -89.88 -89.62 -89.38 ... 89.38 89.62 89.88
  * lon           (lon) float32 6kB -179.9 -179.6 -179.4 ... 179.4 179.6 179.9
  * parameters    (parameters) <U11 88B 'temperature' 'salinity'
  * time_periods  (time_periods) <U2 96B '1' '2' '3' '4' ... '9' '10' '11' '12'
Data variables:
    an            (time_periods, parameters, depth, lat, lon) float32 6GB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    dd            (time_periods, parameters, depth, lat, lon) float32 6GB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    gp            (time_periods, parameters, depth, lat, lon) float32 6GB dask.a

In [17]:
chunk = data.sel(time_periods='13', depth=10)
print(chunk)

<xarray.Dataset> Size: 83MB
Dimensions:       (parameters: 2, lat: 720, lon: 1440)
Coordinates:
    depth         float32 4B 10.0
  * lat           (lat) float32 3kB -89.88 -89.62 -89.38 ... 89.38 89.62 89.88
  * lon           (lon) float32 6kB -179.9 -179.6 -179.4 ... 179.4 179.6 179.9
  * parameters    (parameters) <U11 88B 'temperature' 'salinity'
    time_periods  <U2 8B '13'
Data variables:
    an            (parameters, lat, lon) float32 8MB dask.array<chunksize=(1, 90, 360), meta=np.ndarray>
    dd            (parameters, lat, lon) float32 8MB dask.array<chunksize=(1, 90, 360), meta=np.ndarray>
    gp            (parameters, lat, lon) float32 8MB dask.array<chunksize=(1, 90, 360), meta=np.ndarray>
    ma            (parameters, lat, lon) float32 8MB dask.array<chunksize=(1, 90, 360), meta=np.ndarray>
    mn            (parameters, lat, lon) float32 8MB dask.array<chunksize=(1, 90, 360), meta=np.ndarray>
    oa            (parameters, lat, lon) float32 8MB dask.array<chunksize=(1

In [66]:
chunk.sel(parameters='temperature').values

<bound method Mapping.values of <xarray.Dataset> Size: 41MB
Dimensions:       (lat: 720, lon: 1440)
Coordinates:
    depth         float32 4B 10.0
  * lat           (lat) float32 3kB -89.88 -89.62 -89.38 ... 89.38 89.62 89.88
  * lon           (lon) float32 6kB -179.9 -179.6 -179.4 ... 179.4 179.6 179.9
    parameters    <U11 44B 'temperature'
    time_periods  <U2 8B '2'
Data variables:
    an            (lat, lon) float32 4MB dask.array<chunksize=(90, 360), meta=np.ndarray>
    dd            (lat, lon) float32 4MB dask.array<chunksize=(90, 360), meta=np.ndarray>
    gp            (lat, lon) float32 4MB dask.array<chunksize=(90, 360), meta=np.ndarray>
    ma            (lat, lon) float32 4MB dask.array<chunksize=(90, 360), meta=np.ndarray>
    mn            (lat, lon) float32 4MB dask.array<chunksize=(90, 360), meta=np.ndarray>
    oa            (lat, lon) float32 4MB dask.array<chunksize=(90, 360), meta=np.ndarray>
    sd            (lat, lon) float32 4MB dask.array<chunksize=(90, 36

In [18]:
conn = connect_db(db_settings)


In [19]:
start_depth_idx=0
process_and_insert_zarr(conn, table_name, data, parameter_set, res, start_depth_idx)
conn.close()

  for time_idx in range(data.dims['time_periods']):
  for depth_idx in range(start_depth_idx, data.dims['depth']):


Processing time_idx=3, time_period=4
Processing depth_idx=0, depth=0.0
Chunk coordinates: Coordinates:
    depth         float32 4B 0.0
  * lat           (lat) float32 3kB -89.88 -89.62 -89.38 ... 89.38 89.62 89.88
  * lon           (lon) float32 6kB -179.9 -179.6 -179.4 ... 179.4 179.6 179.9
  * parameters    (parameters) <U11 88B 'temperature' 'salinity'
    time_periods  <U2 8B '4'
Inserted 100000 records. Last batch size: 100000.
Inserted 200000 records. Last batch size: 100000.
Inserted 249608 records. Last batch size: 49608.
Finished depth_idx=0, depth=0.0records and handle 249608 with 25.027086 second
Processing complete. Total records inserted: 249608.


In [28]:
# Main loop to iterate through grids, time_periods, and pars
# conn = connect_db(db_settings)
# grid_db = {'01': 'grd1', '04': 'grd025'}
# grid_dir = {'01': '1_degree', '04': '025_degree'}

# Loop to write data
for res in grids:
    for time_period in time_periods:
        for par in pars:
            # Determine the Zarr store path and table name
            zarr_store_path = f"../data/{grid_dir[res]}/{time_period}/{par}"
            table_name = f"{grid_db[res]}_{time_period}_{par}"

            # Determine the parameter set based on the subgroup
            if res == '04' or par == 'TS':
                parameter_set = ['temperature', 'salinity']
            elif par == 'Oxy':
                parameter_set = ['oxygen', 'o2sat', 'AOU']
            else:
                parameter_set = ['silicate', 'phosphate', 'nitrate']

            # Create the table and index it
            print("Current handling: ", res, time_period, parameter_set)
            print("Create table name: ", table_name)
            # create_ts_table(conn, table_name, parameter_set)
            print("Read dataset: ", zarr_store_path)
            # Load data from the Zarr store and insert it into the table
            data = xr.open_zarr(zarr_store_path, consolidated=False)
            start_depth_idx = 0  # Replace this with the desired start depth index if needed
            print("Start inserting data at date: ", datetime.now())
            insert_into_postgis(conn, table_name, data, parameter_set, res, start_depth_idx=start_depth_idx)
            print("End this insertion: ", datetime.now())

conn.close()


Current handling:  04 monthly ['temperature', 'salinity']
Create table name:  grd025_monthly_TS
Read dataset:  ../data/025_degree/monthly/TS
Start inserting data at date:  2024-12-02 15:42:14.019984
Processing time_idx: 1, 2


  total_records = data.dims['time_periods'] * data.dims['depth'] * data.dims['lat'] * data.dims['lon']
  for time_idx in range(data.dims['time_periods']):
  for depth_idx in range(start_depth_idx, data.dims['depth']):
  for lat_idx in range(data.dims['lat']):
  for lon_idx in range(data.dims['lon']):


In lon, lat, depth, time_period: -163.875, -78.375, 0, 2, query records: 1 with 1694.878965 seconds
Values are: [-1.5646120309829712, 33.86015319824219]
In lon, lat, depth, time_period: 176.875, -74.875, 0, 2, query records: 1000 with 573.29832 seconds
Values are: [-0.41999998688697815, 34.43000030517578]
In lon, lat, depth, time_period: -107.375, -72.125, 0, 2, query records: 2000 with 392.877128 seconds
Values are: [-1.0499999523162842, None]
In lon, lat, depth, time_period: 27.125, -69.125, 0, 2, query records: 3000 with 481.921367 seconds
Values are: [-1.0289210081100464, 33.130001068115234]
In lon, lat, depth, time_period: -18.625, -67.875, 0, 2, query records: 4000 with 193.17548 seconds
Values are: [0.07239300012588501, None]
In lon, lat, depth, time_period: 73.625, -67.125, 0, 2, query records: 5000 with 124.598179 seconds
Values are: [0.5349999666213989, None]
In lon, lat, depth, time_period: 110.125, -66.375, 0, 2, query records: 6000 with 122.044423 seconds
Values are: [-1.5

KeyboardInterrupt: 

In [None]:
data = xr.open_zarr(zarr_store_path, consolidated=False)
print(data)


In [None]:
create_ts_table(conn, table_name)
# print(conn)
insert_into_postgis(conn, table_name, data, parameter_set, res)
conn.commit()
conn.close()

In [10]:
# Function to find the last record in the table
def find_last_record(conn, table_name):
    query = f"""
    SELECT lon, lat, depth, time_period 
    FROM {table_name} 
    ORDER BY depth DESC, time_period DESC, lat DESC, lon DESC 
    LIMIT 1;
    """
    with conn.cursor() as cur:
        cur.execute(query)
        result = cur.fetchone()
    if result:
        print("Find last record: ", result)
        return {
            'lon': result[0],
            'lat': result[1],
            'depth': result[2],
            'time_period': result[3]
        }
    return None


In [None]:
# Continue from the last record
RecoveryFromLastRecord = True
if RecoveryFromLastRecord:
    for depth_idx, depth in enumerate(data["depth"].data):
        for time_idx, time_period_value in enumerate(data["time_periods"].data):
            for lat_idx, lat in enumerate(data["lat"].data):
                for lon_idx, lon in enumerate(data["lon"].data):
                    # Check if this record should be skipped
                    if last_record:
                        if (
                            depth < last_record["depth"]
                            or (
                                depth == last_record["depth"]
                                and time_period_value < last_record["time_period"]
                            )
                            or (
                                depth == last_record["depth"]
                                and time_period_value == last_record["time_period"]
                                and lat < last_record["lat"]
                            )
                            or (
                                depth == last_record["depth"]
                                and time_period_value == last_record["time_period"]
                                and lat == last_record["lat"]
                                and lon <= last_record["lon"]
                            )
                        ):
                            continue
                        last_record = None  # Start processing from here once we pass the last record

                    # Insert the data into the database
                    insert_into_postgis(conn, table_name, data, parameter_set, res)

In [None]:
# Update your main loop to continue from the last record
# conn = connect_db(db_settings)
# grid_db = {'01': 'grd1', '04': 'grd025'}
# grid_dir = {'01': '1_degree', '04': '025_degree'}

for res in grids:
    for time_period in time_periods:
        for par in pars:
            table_name = f"{grid_db[res]}_{time_period}_{par}"
            last_record = find_last_record(conn, table_name)
            
            # Determine the parameter set
            if res == '04' or par == 'TS':
                parameter_set = ['temperature', 'salinity']
            elif par == 'Oxy':
                parameter_set = ['oxygen', 'o2sat', 'AOU']
            else:
                parameter_set = ['silicate', 'phosphate', 'nitrate']

            data = xr.open_zarr(zarr_store_path, consolidated=False)
            for depth_idx, depth in enumerate(data['depth'].data):
                for time_idx, time_period_value in enumerate(data['time_periods'].data):
                    for lat_idx, lat in enumerate(data['lat'].data):
                        for lon_idx, lon in enumerate(data['lon'].data):
                            # Check if this record should be skipped
                            if last_record:
                                if (depth < last_record['depth'] or
                                    (depth == last_record['depth'] and time_period_value < last_record['time_period']) or
                                    (depth == last_record['depth'] and time_period_value == last_record['time_period'] and lat < last_record['lat']) or
                                    (depth == last_record['depth'] and time_period_value == last_record['time_period'] and lat == last_record['lat'] and lon <= last_record['lon'])):
                                    continue
                                last_record = None  # Start processing from here once we pass the last record
                            
                            # Insert the data into the database
                            insert_into_postgis(conn, table_name, data, parameter_set, res)

conn.commit()
conn.close()


In [11]:
# select count(*) from grd025_annual_ts;
dt1 = [35440693]
# select count(*) from grd025_monthly_ts;
dt1.append(141700649)
# select count(*) from grd025_seasonal_ts;
dt1.append(94845824)
# select count(*) from grd1_annual_ts;
dt1.append(2821725)
# select count(*) from grd1_monthly_ts;
dt1.append(22226594)
# select count(*) from grd1_seasonal_ts;
dt1.append(9842721)
# select count(*) from grd1_annual_oxy;
dt1.append(2048066)
# select count(*) from grd1_monthly_oxy;
dt1.append(6405112)
# select count(*) from grd1_seasonal_oxy;
dt1.append(3844217)
# select count(*) from grd1_annual_nutrients;
dt1.append(1333093)
# select count(*) from grd1_monthly_nutrients;
dt1.append(2202256)
# select count(*) from grd1_seasonal_nutrients;
dt1.append(1552181)
print(dt1)
sum(dt1)

[35440693, 141700649, 94845824, 2821725, 22226594, 9842721, 2048066, 6405112, 3844217, 1333093, 2202256, 1552181]


324263131

In [None]:
conn.close()