In [1]:
import os
import psycopg2
import xarray as xr
import numpy as np
from shapely.geometry import Polygon
from shapely import wkb
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from dotenv import load_dotenv
from datetime import datetime


In [2]:
load_dotenv()
DBUSER = os.getenv('DBUSER')
DBPASS = os.getenv('DBPASS')
DBHOST = os.getenv('DBHOST')
DBPORT = os.getenv('DBPORT')
DBNAME = os.getenv('DBNAME')
GRIDSET = os.getenv('GRIDSET')
TIMESET = os.getenv('TIMESET')
PARAMSET = os.getenv('PARAMSET')
# print(PARAMSET)
pars = [c.strip() for c in PARAMSET.split(',')] # if c.strip() in available_params]))
time_periods = [p.strip() for p in str(TIMESET).split(',')]
grids = [g.strip() for g in str(GRIDSET).split(',')]
db_settings = {
    'dbname': DBNAME,
    'user': DBUSER,
    'password': DBPASS,
    'host': DBHOST,
    'port': DBPORT
}

# Function to establish connection to PostgreSQL
def connect_db(settings):
    conn = psycopg2.connect(**settings)
    conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
    return conn

print(grids, pars, time_periods)

['01'] ['TS'] ['seasonal', 'monthly']


In [3]:
# Function to create TS table in PostGIS
# When try grd025_annual_ts use temperature FLOAT8, salinity FLOAT8,
def create_ts_table(conn, table_name, parameter_set):
    with conn.cursor() as cur:
        columns = ", ".join([f"{param} FLOAT" for param in parameter_set])
        create_table_query = f"""
        CREATE TABLE IF NOT EXISTS {table_name} (
            lon FLOAT,
            lat FLOAT,
            depth INTEGER,
            time_period INTEGER,
            {columns},
            geom GEOMETRY
        );
        """
        cur.execute(create_table_query)
        
        index_geom = f"CREATE INDEX idx_{table_name}_geom ON {table_name} USING GIST (geom);"
        index_depth = f"CREATE INDEX idx_{table_name}_depth ON {table_name} (depth);"
        index_time_period = f"CREATE INDEX idx_{table_name}_time_period ON {table_name} (time_period);"

        cur.execute(index_geom)
        cur.execute(index_depth)
        cur.execute(index_time_period)

In [4]:
# Test read zarr dataset
zarr_store_path = '../data/025_degree/annual/TS'
res = '04'  # '04' for 0.25-degree, '01' for 1-degree
ds = xr.open_zarr(zarr_store_path, consolidated=False)
# print(ds)

In [4]:
# Function to generate geom for PostGIS using the grid structure
def generate_grid_polygon(lon, lat, res):
    step = 0.25 if res == '04' else 1.0
    lon_min = lon - step / 2
    lon_max = lon + step / 2
    lat_min = lat - step / 2
    lat_max = lat + step / 2
    return Polygon([(lon_min, lat_min), (lon_min, lat_max), (lon_max, lat_max), (lon_max, lat_min), (lon_min, lat_min)])


In [5]:
def insert_into_postgis(conn, table_name, data, parameter_set, res, batch_size=100000):
    cur = conn.cursor()
    count = 0
    #with conn.cursor() as cur:
    if True:
        for time_idx in range(data.dims['time_periods']):
            for depth_idx in range(data.dims['depth']):
                for lat_idx in range(data.dims['lat']):
                    for lon_idx in range(data.dims['lon']):
                        lon = float(data['lon'].data[lon_idx])
                        lat = float(data['lat'].data[lat_idx])
                        depth = int(data['depth'].data[depth_idx])
                        time_period = int(data['time_periods'].data[time_idx])

                        # Prepare values for all parameters in the parameter_set
                        values = []
                        for param in parameter_set:
                            values.append(float(data['mn'].sel(parameters=param).data[time_idx, depth_idx, lat_idx, lon_idx]))

                        # Check if all values are NaN; if so, skip the row
                        if all(np.isnan(v) for v in values):
                            continue

                        geom = generate_grid_polygon(lon, lat, res)
                        geom_wkb = wkb.dumps(geom, hex=True)

                        # Dynamically create the query string
                        query = f"""
                        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
                        VALUES (%s, %s, %s, %s, {', '.join(['%s'] * len(parameter_set))}, ST_SetSRID(%s::geometry, 4326));
                        """

                        cur.execute(query, (lon, lat, depth, time_period, *values, geom_wkb))
                        count += 1

                        if count % batch_size == 0:
                            conn.commit()  # Commit every batch_size records

In [6]:
grid_db = {'01': 'grd1', '04': 'grd025'}  # Two gridded resolutions data: 1-degree and 0.25-degree in WOA23
grid_dir = {'01': '1_degree', '04': '025_degree'}

In [10]:
if True: #__name__ == '__main__' #only Test
    res = grids[0]  # '04' for 0.25-degree, '01' for 1-degree
    print("Test res, pars", res, grids, pars)

    zarr_store_path = f"../data/{grid_dir[res]}/{time_periods[0]}/{pars[0]}"
    table_name = f"{grid_db[res]}_{time_periods[0]}_{pars[0]}"
    if res == '04' or pars == 'TS':
        parameter_set = ['temperature', 'salinity']  # For TS subgroup
    elif pars[0] == 'Oxy': 
        parameter_set = ['oxygen', 'o2sat', 'AOU']
    else:
        parameter_set = ['silicate', 'phosphate', 'nitrate']        

print("Test table name and parameter_set", table_name, parameter_set)

Test res, pars 01 ['01'] ['Oxy']
Test table name and parameter_set grd1_annual_Oxy ['oxygen', 'o2sat', 'AOU']


In [7]:
conn = connect_db(db_settings)


In [8]:
# Main loop to iterate through grids, time_periods, and pars
# conn = connect_db(db_settings)
# grid_db = {'01': 'grd1', '04': 'grd025'}
# grid_dir = {'01': '1_degree', '04': '025_degree'}

for res in grids:
    for time_period in time_periods:
        for par in pars:
            # Determine the Zarr store path and table name
            zarr_store_path = f"../data/{grid_dir[res]}/{time_period}/{par}"
            table_name = f"{grid_db[res]}_{time_period}_{par}"

            # Determine the parameter set based on the subgroup
            if res == '04' or par == 'TS':
                parameter_set = ['temperature', 'salinity']
            elif par == 'Oxy':
                parameter_set = ['oxygen', 'o2sat', 'AOU']
            else:
                parameter_set = ['silicate', 'phosphate', 'nitrate']

            # Create the table and index it
            print("Currnet handling: ", res, time_period, par)
            print("Create table name: ", table_name)
            create_ts_table(conn, table_name, parameter_set)
            print("Read dataset: ", zarr_store_path)
            # Load data from the Zarr store and insert it into the table
            data = xr.open_zarr(zarr_store_path, consolidated=False)
            print("Start inserting data at date: ", datetime.now())
            insert_into_postgis(conn, table_name, data, parameter_set, res)
            print("End this insertion: ", datetime.now())

conn.close()


Currnet handling:  01 seasonal TS
Create table name:  grd1_seasonal_TS
Read dataset:  ../data/1_degree/seasonal/TS
Start inserting data at date:  2024-09-18 08:06:52.607566


  for time_idx in range(data.dims['time_periods']):
  for depth_idx in range(data.dims['depth']):
  for lat_idx in range(data.dims['lat']):
  for lon_idx in range(data.dims['lon']):


End this insertion:  2024-09-19 05:23:15.034021
Currnet handling:  01 monthly TS
Create table name:  grd1_monthly_TS
Read dataset:  ../data/1_degree/monthly/TS
Start inserting data at date:  2024-09-19 05:23:15.449830


  for time_idx in range(data.dims['time_periods']):
  for depth_idx in range(data.dims['depth']):
  for lat_idx in range(data.dims['lat']):
  for lon_idx in range(data.dims['lon']):


End this insertion:  2024-09-20 22:28:12.190812


In [None]:
data = xr.open_zarr(zarr_store_path, consolidated=False)
print(data)


In [None]:
create_ts_table(conn, table_name)
# print(conn)
insert_into_postgis(conn, table_name, data, parameter_set, res)
conn.commit()
conn.close()

In [10]:
# Function to find the last record in the table
def find_last_record(conn, table_name):
    query = f"""
    SELECT lon, lat, depth, time_period 
    FROM {table_name} 
    ORDER BY depth DESC, time_period DESC, lat DESC, lon DESC 
    LIMIT 1;
    """
    with conn.cursor() as cur:
        cur.execute(query)
        result = cur.fetchone()
    if result:
        print("Find last record: ", result)
        return {
            'lon': result[0],
            'lat': result[1],
            'depth': result[2],
            'time_period': result[3]
        }
    return None


In [None]:
# Continue from the last record
RecoveryFromLastRecord = True
if RecoveryFromLastRecord:
    for depth_idx, depth in enumerate(data["depth"].data):
        for time_idx, time_period_value in enumerate(data["time_periods"].data):
            for lat_idx, lat in enumerate(data["lat"].data):
                for lon_idx, lon in enumerate(data["lon"].data):
                    # Check if this record should be skipped
                    if last_record:
                        if (
                            depth < last_record["depth"]
                            or (
                                depth == last_record["depth"]
                                and time_period_value < last_record["time_period"]
                            )
                            or (
                                depth == last_record["depth"]
                                and time_period_value == last_record["time_period"]
                                and lat < last_record["lat"]
                            )
                            or (
                                depth == last_record["depth"]
                                and time_period_value == last_record["time_period"]
                                and lat == last_record["lat"]
                                and lon <= last_record["lon"]
                            )
                        ):
                            continue
                        last_record = None  # Start processing from here once we pass the last record

                    # Insert the data into the database
                    insert_into_postgis(conn, table_name, data, parameter_set, res)

In [None]:
# Update your main loop to continue from the last record
# conn = connect_db(db_settings)
# grid_db = {'01': 'grd1', '04': 'grd025'}
# grid_dir = {'01': '1_degree', '04': '025_degree'}

for res in grids:
    for time_period in time_periods:
        for par in pars:
            table_name = f"{grid_db[res]}_{time_period}_{par}"
            last_record = find_last_record(conn, table_name)
            
            # Determine the parameter set
            if res == '04' or par == 'TS':
                parameter_set = ['temperature', 'salinity']
            elif par == 'Oxy':
                parameter_set = ['oxygen', 'o2sat', 'AOU']
            else:
                parameter_set = ['silicate', 'phosphate', 'nitrate']

            data = xr.open_zarr(zarr_store_path, consolidated=False)
            for depth_idx, depth in enumerate(data['depth'].data):
                for time_idx, time_period_value in enumerate(data['time_periods'].data):
                    for lat_idx, lat in enumerate(data['lat'].data):
                        for lon_idx, lon in enumerate(data['lon'].data):
                            # Check if this record should be skipped
                            if last_record:
                                if (depth < last_record['depth'] or
                                    (depth == last_record['depth'] and time_period_value < last_record['time_period']) or
                                    (depth == last_record['depth'] and time_period_value == last_record['time_period'] and lat < last_record['lat']) or
                                    (depth == last_record['depth'] and time_period_value == last_record['time_period'] and lat == last_record['lat'] and lon <= last_record['lon'])):
                                    continue
                                last_record = None  # Start processing from here once we pass the last record
                            
                            # Insert the data into the database
                            insert_into_postgis(conn, table_name, data, parameter_set, res)

conn.commit()
conn.close()


In [None]:
conn.close()