In [2]:
import os
import json
import re
import shutil

import pyarrow.dataset as ds
import pyarrow as pa
import pyarrow.csv as csv
import pyarrow.compute as pc
import pyarrow

In [3]:
repo_data_dir = '/home/matthew/Documents/TSE/AppliedEconometrics/repo/data/'
laptop_data_dir = '/home/matthew/data/'

schema_path = os.path.join(repo_data_dir, 'schemas.json')


In [4]:
with open(schema_path, 'r') as f:
    schemas = json.load(f)

In [4]:
# AEMO's schemas have Oracle SQL types
# map those to types arrow can use
# e.g. DATE -> pl.datatypes.Date
# NUMBER(2,0) -> pl.Int16
# NUMBER(15,5) -> pl.Float64
# VARCHAR2(10) -> pl.String
# if date_as_str, return string instead of datetime
# (because pyarrow can't read datetimes when parsing from CSV)
def aemo_type_to_arrow_type(t: str, date_as_str=False) -> pa.DataType:
    t = t.upper()
    if re.match(r"VARCHAR(2)?\(\d+\)", t):
        return pa.string()
    if re.match(r"CHAR\((\d+)\)", t):
        # single character
        # arrow has no dedicated type for that
        # so use string
        # (could use categorical?)
        return pa.string()
    elif t.startswith("NUMBER"):
        match = re.match(r"NUMBER ?\((\d+), ?(\d+)\)", t)
        if match:
            whole_digits = int(match.group(1))
            decimal_digits = int(match.group(2))
        else:
            # e.g. NUMBER(2)
            match = re.match(r"NUMBER ?\((\d+)", t)
            assert match, f"Unsure how to cast {t} to arrow type"
            whole_digits = int(match.group(1))
            decimal_digits = 0
            
        if decimal_digits == 0:
            # integer
            # we assume signed (can't tell unsigned from the schema)
            # but how many bits?
            max_val = 10**whole_digits

            if 2**(8-1) > max_val:
                return pa.int8()
            elif 2**(16-1) > max_val:
                return pa.int16()
            elif 2**(32-1) > max_val:
                return pa.int32()
            else:
                return pa.int64()
        else:
            # we could use pa.decimal128(whole_digits, decimal_digits)
            # but we don't need that much accuracy
            return pa.float64()
    elif (t == 'DATE') or re.match(r"TIMESTAMP\((\d)\)", t):
        # watch out, when AEMO say "date" they mean "datetime"
        # for both dates and datetimes they say "date",
        # but both have a time component. (For actual dates, it's always midnight.)
        # and some dates go out as far as 9999-12-31 23:59:59.999
        # (and some dates are 9999-12-31 23:59:59.997)
        if date_as_str:
            return pa.string()
        else:
            return pa.timestamp('s')
    else:
        raise ValueError(f"Unsure how to convert AEMO type {t} to arrow type")


In [5]:
table = "DISPATCHPRICE"
source_dir = f"/home/matthew/data/01-D-split-mapped-csv-done/{table}/"

csv_schema = {c: aemo_type_to_arrow_type(t['AEMO_type'], date_as_str=False) for (c,t) in schemas[table]['columns'].items()}
part_schema = {
    "SCHEMA_VERSION": pa.int8(), 
    "TOP_TIMESTAMP": pa.string(),
}
schema = dict(csv_schema, **part_schema)
part = ds.partitioning(
    pa.schema(part_schema),
    flavor="hive"
)
dataset = ds.dataset(
    source=source_dir, 
    format=ds.CsvFileFormat(
        convert_options=csv.ConvertOptions(
            timestamp_parsers=["%Y/%m/%d %H:%M:%S"]
        )
    ),
    partitioning=part,
    schema=pyarrow.schema(schema)
)


table = dataset.to_table()
for (c, t) in schema.items():
    if isinstance(t, pa.TimestampType):
        col_i = table.schema.get_all_field_indices(c)[0]
        table = table.set_column(col_i, c, pc.assume_timezone(table.column(col_i), 'Australia/Brisbane'))


KeyboardInterrupt: 

In [None]:
#dest_partition_cols = ['REGIONID', 'INTERVENTION']
dest_partition_cols = []
dest_dir = f"/home/matthew/data/debug/test-pyarrow-partitioned/{table}-3"
shutil.rmtree(dest_dir, ignore_errors=True)
ds.write_dataset(
    data=dataset, 
    base_dir=dest_dir, 
    format="parquet", 
    min_rows_per_group=1024*(2**3), 
    existing_data_behavior="delete_matching",
    partitioning = ds.partitioning(
        pa.schema({c:t for (c,t) in schema.items() if c in dest_partition_cols}),
        flavor="hive"
    )
)

In [None]:
import polars as pl

pl.scan_parquet(f"/home/matthew/data/debug/test-pyarrow-partitioned/{table}-2/").fetch()

In [32]:
schema = {
    'a': pa.int64(),
    't': pa.timestamp('us'),
}
dataset = ds.dataset(
    source='/home/matthew/data/debug/testcsv/data.csv', 
    format=ds.CsvFileFormat(
        convert_options=csv.ConvertOptions(
            timestamp_parsers=[
                "%Y/%m/%d %H:%M:%S",
                "%Y/%m/%d %H:%M:%S.%f",
            ]
        )
    ),
    schema=pyarrow.schema(schema)
)

dataset.to_table().to_pandas()

Unnamed: 0,a,t
0,1,2016-04-20 10:12:10


In [19]:
from datetime import datetime
datetime.strptime("2016/04/20 10:12:10.123456", "%Y/%m/%d %H:%M:%S.%f")

datetime.datetime(2016, 4, 20, 10, 12, 10, 123456)

In [None]:
dir(dataset)

In [None]:
table.schema.get_all_field_indices('t')[0]

In [None]:
isinstance(pa.timestamp('s'), pa.TimestampType)

In [None]:
pa.TimestampType