In [1]:
import os
import psycopg2
import xarray as xr
import numpy as np
from shapely.geometry import Polygon
from shapely import wkb
import psycopg2
from psycopg2 import extras
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from dotenv import load_dotenv
from datetime import datetime
import time

# import logging
# logging.basicConfig(level=logging.DEBUG, filename='woa23_debug.log', filemode='a', format='%(asctime)s - %(levelname)s : %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
# logger = logging.getLogger()

In [2]:
load_dotenv()
DBUSER = os.getenv('DBUSER')
DBPASS = os.getenv('DBPASS')
DBHOST = os.getenv('DBHOST')
DBPORT = os.getenv('DBPORT')
DBNAME = os.getenv('DBNAME')
GRIDSET = os.getenv('GRIDSET')
TIMESET = os.getenv('TIMESET')
PARAMSET = os.getenv('PARAMSET')
# print(PARAMSET)
pars = [c.strip() for c in PARAMSET.split(',')] # if c.strip() in available_params]))
time_periods = [p.strip() for p in str(TIMESET).split(',')]
grids = [g.strip() for g in str(GRIDSET).split(',')]
db_settings = {
    'dbname': DBNAME,
    'user': DBUSER,
    'password': DBPASS,
    'host': DBHOST,
    'port': DBPORT,
    'options': "-c statement_timeout=0",
    'keepalives': 1,
    'keepalives_idle': 30,  # Number of seconds of inactivity after which a keepalive message is sent
    'keepalives_interval': 10,  # Number of seconds between keepalive messages when no response is received
    'keepalives_count': 5  # Number of attempts before concluding the connection is dead

}

# Function to establish connection to PostgreSQL
def connect_db(settings):
    conn = psycopg2.connect(**settings)
    conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
    return conn

print(grids, pars, time_periods)

['04'] ['TS'] ['monthly']


In [3]:
# Function to create TS table in PostGIS
# When try grd025_annual_ts use temperature FLOAT8, salinity FLOAT8,
def create_ts_table(conn, table_name, parameter_set):
    with conn.cursor() as cur:
        columns = ", ".join([f"{param} FLOAT" for param in parameter_set])
        create_table_query = f"""
        CREATE TABLE IF NOT EXISTS {table_name} (
            lon FLOAT,
            lat FLOAT,
            depth INTEGER,
            time_period INTEGER,
            {columns},
            geom GEOMETRY
        );
        """
        cur.execute(create_table_query)
        
        index_geom = f"CREATE INDEX idx_{table_name}_geom ON {table_name} USING GIST (geom);"
        index_depth = f"CREATE INDEX idx_{table_name}_depth ON {table_name} (depth);"
        index_time_period = f"CREATE INDEX idx_{table_name}_time_period ON {table_name} (time_period);"

        cur.execute(index_geom)
        cur.execute(index_depth)
        cur.execute(index_time_period)

In [4]:
# Test read zarr dataset
zarr_store_path = '../data/025_degree/annual/TS'
res = '04'  # '04' for 0.25-degree, '01' for 1-degree
ds = xr.open_zarr(zarr_store_path, consolidated=False)
# print(ds)

In [4]:
# Function to generate geom for PostGIS using the grid structure
def generate_grid_polygon(lon, lat, res):
    step = 0.25 if res == '04' else 1.0
    lon_min = lon - step / 2
    lon_max = lon + step / 2
    lat_min = lat - step / 2
    lat_max = lat + step / 2
    return Polygon([(lon_min, lat_min), (lon_min, lat_max), (lon_max, lat_max), (lon_max, lat_min), (lon_min, lat_min)])

def sanitize_value(value):
    """Convert 'NaN' to None to handle as NULL in PostgreSQL."""
    if np.isnan(value):
        return None
    return value

In [8]:
def insert_into_postgis_sequentially(conn, table_name, data, parameter_set, res, batch_size=100000):
    cur = conn.cursor()
    count = 0
    #with conn.cursor() as cur:
    if True:
        for time_idx in range(data.dims['time_periods']):
            for depth_idx in range(data.dims['depth']):
                for lat_idx in range(data.dims['lat']):
                    for lon_idx in range(data.dims['lon']):
                        lon = float(data['lon'].data[lon_idx])
                        lat = float(data['lat'].data[lat_idx])
                        depth = int(data['depth'].data[depth_idx])
                        time_period = int(data['time_periods'].data[time_idx])

                        # Prepare values for all parameters in the parameter_set
                        values = []
                        for param in parameter_set:
                            values.append(float(data['mn'].sel(parameters=param).data[time_idx, depth_idx, lat_idx, lon_idx]))

                        # Check if all values are NaN; if so, skip the row
                        if all(np.isnan(v) for v in values):
                            continue

                        geom = generate_grid_polygon(lon, lat, res)
                        geom_wkb = wkb.dumps(geom, hex=True)

                        # Dynamically create the query string
                        query = f"""
                        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
                        VALUES (%s, %s, %s, %s, {', '.join(['%s'] * len(parameter_set))}, ST_SetSRID(%s::geometry, 4326));
                        """

                        cur.execute(query, (lon, lat, depth, time_period, *values, geom_wkb))
                        count += 1

                        if count % batch_size == 0:
                            conn.commit()  # Commit every batch_size records

In [5]:
grid_db = {'01': 'grd1', '04': 'grd025'}  # Two gridded resolutions data: 1-degree and 0.25-degree in WOA23
grid_dir = {'01': '1_degree', '04': '025_degree'}

In [6]:
def insert_into_postgis(conn, table_name, data, parameter_set, res, start_depth_idx=0, batch_size=5000):
    cur = conn.cursor()
    batch = []
    total_records = data.dims['time_periods'] * data.dims['depth'] * data.dims['lat'] * data.dims['lon']
    processed_records = 0
    start_time = datetime.now()

    for time_idx in range(data.dims['time_periods']):
        for depth_idx in range(start_depth_idx, data.dims['depth']):
            for lat_idx in range(data.dims['lat']):
                for lon_idx in range(data.dims['lon']):
                    lon = float(data['lon'].data[lon_idx])
                    lat = float(data['lat'].data[lat_idx])
                    depth = int(data['depth'].data[depth_idx])
                    time_period = int(data['time_periods'].data[time_idx])

                    # Prepare values for all parameters in the parameter_set
                    values = [
                        sanitize_value(float(data['mn'].sel(parameters=param).data[time_idx, depth_idx, lat_idx, lon_idx]))
                        for param in parameter_set
                    ]

                    # Check if all values are None (formerly NaN); if so, skip the row
                    if all(v is None for v in values):
                        continue

                    geom = generate_grid_polygon(lon, lat, res)
                    geom_wkb = wkb.dumps(geom, hex=True)

                    # Add the row to the batch
                    batch.append((lon, lat, depth, time_period, *values, geom_wkb))
                    processed_records += 1

                    if (processed_records == 1 or processed_records % 1000 == 0):
                        end_time = datetime.now()
                        print(f"In lon, lat, depth: {lon}, {lat}, {depth}, query records: {processed_records} with {(end_time - start_time).total_seconds()} seconds")
                        print(f"Values are: {values}", flush=True)
                        start_time = end_time

                    # Insert and commit in batches
                    if len(batch) >= batch_size:
                        psycopg2.extras.execute_values(cur, f"""
                        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
                        VALUES %s;
                        """, batch, template=None, page_size=batch_size)
                        conn.commit()
                        batch.clear()  # Clear the batch after committing

    # Commit any remaining rows
    if batch:
        psycopg2.extras.execute_values(cur, f"""
        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
        VALUES %s;
        """, batch, template=None, page_size=batch_size)
        conn.commit()

#    cur.close()

In [34]:
def insert_into_postgis_test(conn, table_name, data, parameter_set, res, start_depth_idx=0, batch_size=10000):
    cur = conn.cursor()
    batch = []
    total_records = data.dims['time_periods'] * data.dims['depth'] * data.dims['lat'] * data.dims['lon']
    processed_records = 0
    count = 0
    start_time = datetime.now()

    for time_idx in range(data.dims['time_periods']):
        for depth_idx in range(start_depth_idx, data.dims['depth']):
            for lat_idx in range(data.dims['lat']):
                for lon_idx in range(data.dims['lon']):
                    lon = float(data['lon'].data[lon_idx])
                    lat = float(data['lat'].data[lat_idx])
                    depth = int(data['depth'].data[depth_idx])
                    time_period = int(data['time_periods'].data[time_idx])

                    # Prepare values for all parameters in the parameter_set
                    values = []
                    for param in parameter_set:
                        values.append(float(data['mn'].sel(parameters=param).data[time_idx, depth_idx, lat_idx, lon_idx]))

                    end1_time = datetime.now()
                    check_all_na = all(np.isnan(v) for v in values)
                    if (count == 1 or count % 1000 == 0):
                        print(f"Step1 took {(end1_time - start_time).total_seconds()} seconds and values are all NaN: {check_all_na}")
                        print(f"Data status: {lon} {lat} {depth} {time_period} {values}")

                    # Skip rows with all NaN values
                    if check_all_na:
                        count += 1
                        continue

                    geom = generate_grid_polygon(lon, lat, res)
                    geom_wkb = wkb.dumps(geom, hex=True)
                    end2_time = datetime.now()

                    batch.append((lon, lat, depth, time_period, *values, f'ST_SetSRID(\'{geom_wkb}\'::geometry, 4326)'))
                    end3_time = datetime.now()
                    processed_records += 1
                    count += 1

                    if (processed_records == 1 or processed_records % 1000 == 0):
                        end_time = datetime.now()
                        print(f"Time taken to handle query records: {processed_records} with {(end_time - start_time).total_seconds()} seconds")
                        print(f"Each step took {(end1_time - start_time).total_seconds()}, {(end2_time - end1_time).total_seconds()}, {(end3_time - end2_time).total_seconds()} seconds")
                        start_time = end_time

                    # Insert and commit in batches
                    if len(batch) >= batch_size:
                        psycopg2.extras.execute_values(cur, f"""
                        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
                        VALUES %s;
                        """, batch, page_size=batch_size)
                        conn.commit()
                        batch.clear()  # Clear the batch after committing

                        # Print progress
                        print(f"Inserted {processed_records}/{total_records} records. Time: {time.ctime()}")

    # Insert any remaining data
    if batch:
        psycopg2.extras.execute_values(cur, f"""
        INSERT INTO {table_name} (lon, lat, depth, time_period, {', '.join(parameter_set)}, geom)
        VALUES %s;
        """, batch, page_size=batch_size)
        conn.commit()
        print(f"Final insertion completed. Total records inserted: {processed_records}. Time: {time.ctime()}")


In [7]:
if True: #__name__ == '__main__' #only Test
    res = grids[0]  # '04' for 0.25-degree, '01' for 1-degree
    print("Test res, pars", res, grids, pars)

    zarr_store_path = f"../data/{grid_dir[res]}/{time_periods[0]}/{pars[0]}"
    table_name = f"{grid_db[res]}_{time_periods[0]}_{pars[0]}"
    if res == '04' or pars == 'TS':
        parameter_set = ['temperature', 'salinity']  # For TS subgroup
    elif pars[0] == 'Oxy': 
        parameter_set = ['oxygen', 'o2sat', 'AOU']
    else:
        parameter_set = ['silicate', 'phosphate', 'nitrate']        

print("Test table name and parameter_set", table_name, parameter_set)

Test res, pars 04 ['04'] ['TS']
Test table name and parameter_set grd025_monthly_TS ['temperature', 'salinity']


In [8]:
data = xr.open_zarr(zarr_store_path, consolidated=False)
print(data.dims['depth'])

57


  print(data.dims['depth'])


In [9]:
print(data['depth'])

<xarray.DataArray 'depth' (depth: 57)> Size: 228B
array([   0.,    5.,   10.,   15.,   20.,   25.,   30.,   35.,   40.,   45.,
         50.,   55.,   60.,   65.,   70.,   75.,   80.,   85.,   90.,   95.,
        100.,  125.,  150.,  175.,  200.,  225.,  250.,  275.,  300.,  325.,
        350.,  375.,  400.,  425.,  450.,  475.,  500.,  550.,  600.,  650.,
        700.,  750.,  800.,  850.,  900.,  950., 1000., 1050., 1100., 1150.,
       1200., 1250., 1300., 1350., 1400., 1450., 1500.], dtype=float32)
Coordinates:
  * depth    (depth) float32 228B 0.0 5.0 10.0 15.0 ... 1.4e+03 1.45e+03 1.5e+03


In [10]:
print(data['time_periods'])

<xarray.DataArray 'time_periods' (time_periods: 12)> Size: 96B
array(['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12'],
      dtype='<U2')
Coordinates:
  * time_periods  (time_periods) <U2 96B '1' '2' '3' '4' ... '9' '10' '11' '12'


In [11]:
print(float(data['mn'].sel(parameters='temperature').data[0, 3, 60, 60]))
print(float(data['mn'].sel(parameters='temperature').data[0, 3, 200, 100]))

nan
18.777061462402344


In [13]:
conn = connect_db(db_settings)


In [None]:
# Main loop to iterate through grids, time_periods, and pars
# conn = connect_db(db_settings)
# grid_db = {'01': 'grd1', '04': 'grd025'}
# grid_dir = {'01': '1_degree', '04': '025_degree'}

# Loop to write data
for res in grids:
    for time_period in time_periods:
        for par in pars:
            # Determine the Zarr store path and table name
            zarr_store_path = f"../data/{grid_dir[res]}/{time_period}/{par}"
            table_name = f"{grid_db[res]}_{time_period}_{par}"

            # Determine the parameter set based on the subgroup
            if res == '04' or par == 'TS':
                parameter_set = ['temperature', 'salinity']
            elif par == 'Oxy':
                parameter_set = ['oxygen', 'o2sat', 'AOU']
            else:
                parameter_set = ['silicate', 'phosphate', 'nitrate']

            # Create the table and index it
            print("Current handling: ", res, time_period, parameter_set)
            print("Create table name: ", table_name)
            # create_ts_table(conn, table_name, parameter_set)
            print("Read dataset: ", zarr_store_path)
            # Load data from the Zarr store and insert it into the table
            data = xr.open_zarr(zarr_store_path, consolidated=False)
            start_depth_idx = 39  # Replace this with the desired start depth index if needed
            print("Start inserting data at date: ", datetime.now())
            insert_into_postgis(conn, table_name, data, parameter_set, res, start_depth_idx=start_depth_idx)
            print("End this insertion: ", datetime.now())

conn.close()


Current handling:  04 monthly ['temperature', 'salinity']
Create table name:  grd025_monthly_TS
Read dataset:  ../data/025_degree/monthly/TS
Start inserting data at date:  2024-11-20 17:27:25.670240


  total_records = data.dims['time_periods'] * data.dims['depth'] * data.dims['lat'] * data.dims['lon']
  for time_idx in range(data.dims['time_periods']):
  for depth_idx in range(start_depth_idx, data.dims['depth']):
  for lat_idx in range(data.dims['lat']):
  for lon_idx in range(data.dims['lon']):


In lon, lat, depth: -41.125, -77.625, 650, query records: 1 with 1767.255617 seconds
Values are: [-1.9065333604812622, 34.700538635253906]
In lon, lat, depth: -112.125, -67.875, 650, query records: 1000 with 1427.262295 seconds
Values are: [1.8961960077285767, 34.713050842285156]
In lon, lat, depth: 42.875, -65.875, 650, query records: 2000 with 319.697725 seconds
Values are: [1.067011833190918, 34.714866638183594]
In lon, lat, depth: 94.375, -64.375, 650, query records: 3000 with 237.597333 seconds
Values are: [0.7666078209877014, 34.709720611572266]
In lon, lat, depth: -139.625, -63.125, 650, query records: 4000 with 168.932369 seconds
Values are: [1.5220389366149902, 34.728126525878906]
In lon, lat, depth: -28.125, -62.125, 650, query records: 5000 with 167.699913 seconds
Values are: [0.3604220151901245, 34.686134338378906]
In lon, lat, depth: 45.625, -61.125, 650, query records: 6000 with 164.362923 seconds
Values are: [1.571279525756836, 34.712711334228516]
In lon, lat, depth: 12.

In [None]:
data = xr.open_zarr(zarr_store_path, consolidated=False)
print(data)


In [None]:
create_ts_table(conn, table_name)
# print(conn)
insert_into_postgis(conn, table_name, data, parameter_set, res)
conn.commit()
conn.close()

In [10]:
# Function to find the last record in the table
def find_last_record(conn, table_name):
    query = f"""
    SELECT lon, lat, depth, time_period 
    FROM {table_name} 
    ORDER BY depth DESC, time_period DESC, lat DESC, lon DESC 
    LIMIT 1;
    """
    with conn.cursor() as cur:
        cur.execute(query)
        result = cur.fetchone()
    if result:
        print("Find last record: ", result)
        return {
            'lon': result[0],
            'lat': result[1],
            'depth': result[2],
            'time_period': result[3]
        }
    return None


In [None]:
# Continue from the last record
RecoveryFromLastRecord = True
if RecoveryFromLastRecord:
    for depth_idx, depth in enumerate(data["depth"].data):
        for time_idx, time_period_value in enumerate(data["time_periods"].data):
            for lat_idx, lat in enumerate(data["lat"].data):
                for lon_idx, lon in enumerate(data["lon"].data):
                    # Check if this record should be skipped
                    if last_record:
                        if (
                            depth < last_record["depth"]
                            or (
                                depth == last_record["depth"]
                                and time_period_value < last_record["time_period"]
                            )
                            or (
                                depth == last_record["depth"]
                                and time_period_value == last_record["time_period"]
                                and lat < last_record["lat"]
                            )
                            or (
                                depth == last_record["depth"]
                                and time_period_value == last_record["time_period"]
                                and lat == last_record["lat"]
                                and lon <= last_record["lon"]
                            )
                        ):
                            continue
                        last_record = None  # Start processing from here once we pass the last record

                    # Insert the data into the database
                    insert_into_postgis(conn, table_name, data, parameter_set, res)

In [None]:
# Update your main loop to continue from the last record
# conn = connect_db(db_settings)
# grid_db = {'01': 'grd1', '04': 'grd025'}
# grid_dir = {'01': '1_degree', '04': '025_degree'}

for res in grids:
    for time_period in time_periods:
        for par in pars:
            table_name = f"{grid_db[res]}_{time_period}_{par}"
            last_record = find_last_record(conn, table_name)
            
            # Determine the parameter set
            if res == '04' or par == 'TS':
                parameter_set = ['temperature', 'salinity']
            elif par == 'Oxy':
                parameter_set = ['oxygen', 'o2sat', 'AOU']
            else:
                parameter_set = ['silicate', 'phosphate', 'nitrate']

            data = xr.open_zarr(zarr_store_path, consolidated=False)
            for depth_idx, depth in enumerate(data['depth'].data):
                for time_idx, time_period_value in enumerate(data['time_periods'].data):
                    for lat_idx, lat in enumerate(data['lat'].data):
                        for lon_idx, lon in enumerate(data['lon'].data):
                            # Check if this record should be skipped
                            if last_record:
                                if (depth < last_record['depth'] or
                                    (depth == last_record['depth'] and time_period_value < last_record['time_period']) or
                                    (depth == last_record['depth'] and time_period_value == last_record['time_period'] and lat < last_record['lat']) or
                                    (depth == last_record['depth'] and time_period_value == last_record['time_period'] and lat == last_record['lat'] and lon <= last_record['lon'])):
                                    continue
                                last_record = None  # Start processing from here once we pass the last record
                            
                            # Insert the data into the database
                            insert_into_postgis(conn, table_name, data, parameter_set, res)

conn.commit()
conn.close()


In [None]:
conn.close()