In [0]:
use catalog identifier(:ctrl_catalog);
use schema identifier(:ctrl_schema);

In [0]:
create or replace procedure get_table_size_sqlserver(
    in src_catalog string,
    in src_schema string,
    in src_table string,
    out table_size_mb int
)
language sql
sql security invoker
comment 'Get the size of a table in MB using Lakehouse Federation'
as
begin
  use catalog identifier(src_catalog);
  
  set table_size_mb = (
    select
      round((sum(a.total_pages) * 8) / 1024) as table_size_mb
    from
      sys.tables t
      join sys.indexes i on t.object_id = i.object_id
      join sys.partitions p on i.object_id = p.object_id and i.index_id = p.index_id
      join sys.allocation_units a on p.partition_id = a.container_id
      left outer join sys.schemas s on t.schema_id = s.schema_id
    where
      s.name = src_schema
      and t.name = src_table
      and t.is_ms_shipped = 0
      and i.object_id > 255
    group by
      s.name, t.name
  );

end;

In [0]:
create or replace procedure get_table_size(
    in src_type string,
    in src_catalog string,
    in src_schema string,
    in src_table string,
    out table_size_mb int
)
language sql
sql security invoker
comment 'Get the size of a SQL Server table in MB using Lakehouse Federation'
as
begin
  declare arg_map map<string, string>;
  
  if src_type = 'sqlserver' then
    call get_table_size_sqlserver(src_catalog, src_schema, src_table, table_size_mb);
  else
    set arg_map = map('errorMessage', concat('Unsupported source type: ', src_type));
    signal user_raised_exception
      set message_arguments = arg_map;
  end if;
end;

In [0]:
create or replace procedure get_partition_col_bounds(
    in src_catalog string,
    in src_schema string,
    in src_table string,
    in partition_col string,
    out partition_col_type string,
    out partition_col_bounds struct<lower: string, upper: string>    
)
language sql
sql security invoker
comment 'Get partition boundaries (Min and max values) for partition column'
as
begin
  -- Get partition column type. Parentheses with precision & scale are removed.
  set partition_col_type = (
    select regexp_replace(typeof((select identifier(partition_col)
    from identifier(src_catalog || '.' || src_schema || '.' || src_table) limit 1)), '\\([^()]*\\)', '') as data_type
  );

  set partition_col_bounds = (
    select
      struct(
        min(identifier(partition_col))::bigint::string as lower,
        max(identifier(partition_col))::bigint::string as upper
      ) as partition_col_bounds
    from identifier(src_catalog || '.' || src_schema || '.' || src_table)
  );
end; 

In [0]:
create or replace function get_internal_bound_value(bound_val string, bound_val_type string)
returns bigint
return
  case
    when lower(bound_val_type) in ('int', 'bigint')
      then cast(bound_val as bigint)
    when lower(bound_val_type) = 'timestamp'
      then unix_timestamp(cast(bound_val as timestamp))
    when lower(bound_val_type) = 'date'
      then unix_timestamp(cast(bound_val as timestamp))
    else null -- or raise an error if your platform supports throw/fail
  end;

In [0]:
create or replace function get_partition_list(
    partition_col string,
    partition_col_type string,
    lower_bound bigint,
    upper_bound bigint,
    num_partitions int
)
returns string
language python
as $$
from datetime import date, datetime, timezone

def bound_value_to_str(bound_value:int, partition_col_type):
    """Convert bound value to string for SQL where clause
    
    Args:
        bound_value (int): bound value as int
        bound_value_orig: original bound value used to determine the type
    
    Returns:
        String representation of bound value
    """

    if partition_col_type in ('int', 'bigint', 'float', 'decimal'):
        bound_value_str = str(bound_value)
    elif partition_col_type == 'datetime':
        bound_value_str = f"'{datetime.fromtimestamp(bound_value, tz=timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}'"
    elif partition_col_type == 'date':
        bound_value_str = f"'{datetime.fromtimestamp(bound_value, tz=timezone.utc).strftime('%Y-%m-%d')}'"
    else:
        raise ValueError(f'Unsupported data type: {partition_col_type}. Only int, date, and datetime are supported')
    
    return bound_value_str

partition_list = []
stride = int(upper_bound / num_partitions - lower_bound / num_partitions)

i = 0
currentValue = lower_bound
while (i < num_partitions):
    lBoundValue = bound_value_to_str(currentValue, partition_col_type)
    lBound = f'{partition_col} >= {lBoundValue}' if i != 0 else None
    currentValue += stride
    uBoundValue = bound_value_to_str(currentValue, partition_col_type)
    uBound = f'{partition_col} < {uBoundValue}' if i != num_partitions - 1 else None
    if uBound == None:
        whereClause = lBound
    elif lBound == None:
        whereClause = f'{uBound} or {partition_col} is null'
    else:
        whereClause = f'{lBound} and {uBound}'
    partition_list.append({'id' : i, 'where_clause' : whereClause})
    i = i + 1
    
return partition_list
$$