In [28]:
# -------------------------------------------------------------------------------------------------
# Imports and settings
# -------------------------------------------------------------------------------------------------

import operator
from collections import defaultdict
from dataclasses import dataclass, fields, replace
from functools import reduce
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple

import polars as pl
import pyarrow as pa
import torch
from pyarrow import dataset as ds

from naics_gemini.utils.utilities import get_indices_codes

# -------------------------------------------------------------------------------------------------
# Configuration
# -------------------------------------------------------------------------------------------------

@dataclass
class CurriculumConfig:

    codes_parquet: str = './data/naics_descriptions.parquet'
    distance_parquet: str = './data/naics_distances.parquet'
    relation_parquet: str = './data/naics_relations.parquet'
    triplets_parquet: str = './data/naics_training_pairs'

    anchor_level: Optional[List[int]] = None
    positive_level: Optional[List[int]] = None
    negative_level: Optional[List[int]] = None

    anchor_distance: Optional[List[float]] = None
    positive_distance: Optional[List[float]] = None
    negative_distance: Optional[List[float]] = None
    
    n_positives: int = 2125
    n_negatives: int = 2125

    seed: int = 42

    def items(self):
        for f in fields(self):
            if f.name != 'input_parquet':
                v = getattr(self, f.name)
                if v is not None:
                    yield f.name, v


# -------------------------------------------------------------------------------------------------
# Utility functions
# -------------------------------------------------------------------------------------------------
    
def _get_file_list(
    codes_parquet: str,
    triplets_parquet: str,
    anchor_level: Optional[List[int]] = None
) -> List[str]:
        
    _, codes, codes_to_indices, _ = get_indices_codes(codes_parquet)

    level_dict = defaultdict(list)
    for code in codes:
        level = len(code)
        level_dict[level].append(code)

    if anchor_level is not None:
        dataset_files = []   
        for level in anchor_level:
            for code in level_dict[level]:
                idx = codes_to_indices[code]
                for pq_path in Path(f'{triplets_parquet}/anchor={idx}/').glob('*.parquet'):
                    dataset_files.append(pq_path.as_posix())
    else:
        dataset_files = []
        for pq_path in Path(f'{triplets_parquet}/').glob('**/*.parquet'):
            dataset_files.append(pq_path.as_posix())
    
    return dataset_files


def _create_dataset(
    codes_parquet: str,
    triplets_parquet: str,
    anchor_level: Optional[List[int]] = None
) -> ds.Dataset:
    
    dataset_files = _get_file_list(
        codes_parquet=codes_parquet,
        triplets_parquet=triplets_parquet,
        anchor_level=anchor_level
    )

    print(f'Number of batches (parquet files): {len(dataset_files):,}')
        
    return (
        ds
        .dataset(
            dataset_files, 
            format='parquet',
            partitioning=ds.partitioning(
                flavor='hive',
                schema=pa.schema([
                    pa.field('anchor', pa.uint32())
                ])
            )        
        )
    )


def _get_file_filters(
    curriculum: CurriculumConfig
) -> Optional[ds.Expression]:

    exprs = []
    for k, v in curriculum.items():

        if isinstance(v, list):
            exprs.append(
                ds.field(k).isin(v)
            )

    if not exprs:
        return None
    
    return reduce(operator.and_, exprs)


def _get_distance_dict(
    distance_parquet: str
) -> Dict[Tuple[int, int], float]:

    distance_iter = (
        pl
        .read_parquet(
            distance_parquet
        )
        .select('idx_i', 'idx_j', 'distance')
        .unique()
        .iter_rows(named=True)
    )

    distance_dict = {}
    for row in distance_iter:
        key = (row['idx_i'], row['idx_j'])
        value = row['distance']
        distance_dict[key] = (value)

    return distance_dict


def _get_relation_dict(
    relation_parquet: str
) -> Dict[Tuple[int, int], Tuple[int, str]]:

    relation_iter = (
        pl
        .read_parquet(
            relation_parquet
        )
        .select('idx_i', 'idx_j', 'relation_id', 'relation')
        .unique()
        .iter_rows(named=True)
    )

    relation_dict = {}
    for row in relation_iter:
        key = (row['idx_i'], row['idx_j'])
        value = (row['relation_id'], row['relation'])
        relation_dict[key] = value

    return relation_dict


def _fill_incomplete(
    incomplete_df: pl.DataFrame,
    triplets_parquet: str,
    curriculum: CurriculumConfig
):

    incomplete_anchors = (
        incomplete_df
        .get_column('anchor_idx')
        .sort()
        .to_list()
    )

    incomplete_files = []
    for idx in incomplete_anchors:
        for pq_path in Path(f'{triplets_parquet}/anchor={idx}/').glob('*.parquet'):
            incomplete_files.append(pq_path.as_posix())
        

    curriculum_keys = [k for k, v in curriculum.items()]
    if 'anchor_distance' in curriculum_keys:

        all_distances = [
            0.125, 0.25, 0.5, 1.0, 2.0, 2.5, 3.0, 
            3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5
        ]
        max_anchor_distance = max(curriculum.anchor_distance)

        incomplete_distance = [d for d in all_distances if d > max_anchor_distance]
        len_incomplete_distance = min(len(curriculum.anchor_distance), len(incomplete_distance))

        incomplete_distance = sorted(list(incomplete_distance))[:len_incomplete_distance] + [7.0]

    else:
        incomplete_distance = [7.0]

    incomplete_curriculum = replace(
        curriculum,
        anchor_distance=incomplete_distance
    )

    filter_expr = _get_file_filters(incomplete_curriculum)

    incomplete_dataset = (
        ds
        .dataset(
            incomplete_files, 
            format='parquet',
            partitioning=ds.partitioning(
                flavor='hive',
                schema=pa.schema([
                    pa.field('anchor', pa.uint32())
                ])
            )        
        )
        .filter(filter_expr)
    )

    incomplete_added = (
        pl
        .from_arrow(
            incomplete_dataset
            .to_table()
        )
        .sort('anchor_idx', 'positive_idx', 'negative_idx', 'anchor_distance')
        .group_by('anchor_idx', 'positive_idx', maintain_order=True)
        .agg(
            negative_added=pl.col('negative_idx')  
        )
    )

    completed = (
        incomplete_df
        .explode('positive_idx', 'negative_idx')
        .with_columns(
            to_add=pl.col('negative_idx')
                    .list.len()
                    .add(-curriculum.n_negatives)
                    .mul(-1)
        )
        .join(
            incomplete_added,
            how='left',
            on=['anchor_idx', 'positive_idx']
        )
        .with_columns(
            pl.col('negative_added')
            .fill_null([])
            .list.sample(
                pl.col('to_add'),
                with_replacement=False,
                shuffle=True, 
                seed=curriculum.seed
            )
        )
        .select(
            anchor_idx=pl.col('anchor_idx'),
            positive_idx=pl.col('positive_idx'),
            negative_idx=pl.col('negative_idx')
                            .list.set_union(pl.col('negative_added'))
                            .list.unique()
        )
        .group_by('anchor_idx')
        .agg(
            pl.col('positive_idx'),
            pl.col('negative_idx')
        )
    )

    print(f'    Incomplete: {len(incomplete_files):,}, Completed: {completed.height:,}')

    return completed
# -------------------------------------------------------------------------------------------------
# Generator function
# -------------------------------------------------------------------------------------------------

def iter_file_batches(
    dataset: Optional[ds.Dataset],
    filter_expr: Optional[ds.Expression],
    codes_parquet: Optional[str],
    triplets_parquet: Optional[str],
    curriculum: CurriculumConfig,
    anchor_level: Optional[List[int]] = None
) -> Iterator[Tuple[ds.FileFragment, pa.Table]]:
    
    if dataset is None:
        dataset = _create_dataset(
            codes_parquet=codes_parquet, # type: ignore
            triplets_parquet=triplets_parquet, # type: ignore
            anchor_level=anchor_level
        )

    
    if filter_expr is None:
        filter_expr = _get_file_filters(curriculum)
    
    for file_fragment in dataset.get_fragments(): # type: ignore

        table = file_fragment.to_table(
            filter=filter_expr,
            columns=['anchor_idx', 'positive_idx', 'negative_idx', 'anchor_distance']
        )
        
        yield file_fragment, table


# -------------------------------------------------------------------------------------------------
# Triplet batch generator
# -------------------------------------------------------------------------------------------------

def triplet_batches(
    iter_files: Iterator[Tuple[ds.FileFragment, pa.Table]],
    curriculum: CurriculumConfig,
    rng: Optional[torch.Generator] = None,
) -> Iterator[List[Tuple[int, int, int]]]:
    
    if rng is None:
        rng = torch.Generator()
        rng.manual_seed(curriculum.seed)

    n_neg = curriculum.n_negatives
    n_pos = curriculum.n_positives

    for file_num, (file, file_batch) in enumerate(iter_files, start=1):

        df_batch = (
            pl
            .from_arrow(file_batch)
        )

        if df_batch.height > n_pos:
            df_batch = df_batch.sample(
                n=curriculum.n_positives,
                with_replacement=False,
                shuffle=True,
                seed=curriculum.seed
            )

        df = (
            df_batch
            .group_by('anchor_idx', 'positive_idx', maintain_order=True)  # type: ignore
            .agg(
                pl.col('negative_idx')
            )
            .group_by('anchor_idx', maintain_order=True)
            .agg(
                pl.col('positive_idx'),
                pl.col('negative_idx')
            )
            .with_columns(
                fallback=pl.col('negative_idx')
                           .list.len()
                           .lt(n_neg)
            )
        )

        print(
            f'  Batch {file_num} '
            f'[{Path(file.path).parent.stem}]: '
            f'triplets = {df_batch.height:,}, '
            f'grouped triplets = {df.height:,}'
        )

        complete = df.filter(pl.col('fallback')).drop('fallback')
        incomplete = df.filter(~pl.col('fallback')).drop('fallback')

        print(f'    Complete {complete.height:,}, Incomplete = {incomplete.height:,}')

        if incomplete.height == 0:
            completed = complete

        elif complete.height == 0:
            completed = _fill_incomplete(incomplete, curriculum.triplets_parquet, curriculum)

        else:
            _completed = _fill_incomplete(incomplete, curriculum.triplets_parquet, curriculum)

            completed = (
                pl
                .concat([
                    complete, 
                    _completed
                ])
            )

        triplet_iter = (
            completed
            .explode('positive_idx', 'negative_idx')
            .iter_rows(named=True)
        )

        triplets = []
        for row in triplet_iter:
            anchor = row['anchor_idx']
            positive = row['positive_idx']
            negatives = row['negative_idx']

            triplets.append((anchor, positive))

            if negatives:
                for negative in negatives:
                    triplets.append((anchor, negative))

        yield triplets
# -------------------------------------------------------------------------------------------------
# Main logic
# -------------------------------------------------------------------------------------------------

curriculum = CurriculumConfig(
    anchor_level=[2, 3],
    anchor_distance=[0.25, 0.5],
    n_positives=100,
    n_negatives=30
)

file_iterator = iter_file_batches(
    dataset=None,
    filter_expr=None,
    codes_parquet=curriculum.codes_parquet,
    triplets_parquet=curriculum.triplets_parquet,
    curriculum=curriculum,
    anchor_level=curriculum.anchor_level
 )

triplets = triplet_batches(file_iterator, curriculum)

In [29]:
distance_dict = _get_distance_dict(curriculum.distance_parquet)
relation_dict = _get_relation_dict(curriculum.relation_parquet)

In [30]:
distance_dict[(0, 1)]

0.5

In [31]:
relation_dict[(0, 1)]

(0, 'child')

In [42]:
triplets = [(0, i) for i in range(1, 51)]
_, _, _, idx_to_code = get_indices_codes(curriculum.codes_parquet)

In [None]:
triplets_ext = []
for i, j in triplets:
    dist = distance_dict.get((i, j), None)
    rel_id, rel = relation_dict.get((i, j), (None, None))
    triplets_ext.append(((i, j), (
    (0, 1), (0, 2), (0, 3), (0, 4), (0, 5)) dist, rel_id, rel))

In [41]:
triplets_ext

[(0, 1, 0.5, 0, 'child'),
 (0, 2, 1.5, 2, 'grandchild'),
 (0, 3, 2.5, 4, 'great-grandchild'),
 (0, 4, 3.5, 7, 'great-great-grandchild'),
 (0, 5, 2.5, 4, 'great-grandchild'),
 (0, 6, 3.5, 7, 'great-great-grandchild'),
 (0, 7, 2.5, 4, 'great-grandchild'),
 (0, 8, 3.5, 7, 'great-great-grandchild'),
 (0, 9, 2.5, 4, 'great-grandchild'),
 (0, 10, 3.5, 7, 'great-great-grandchild'),
 (0, 11, 2.5, 4, 'great-grandchild'),
 (0, 12, 3.5, 7, 'great-great-grandchild'),
 (0, 13, 2.5, 4, 'great-grandchild'),
 (0, 14, 3.5, 7, 'great-great-grandchild'),
 (0, 15, 2.5, 4, 'great-grandchild'),
 (0, 16, 3.5, 7, 'great-great-grandchild'),
 (0, 17, 3.5, 7, 'great-great-grandchild'),
 (0, 18, 1.5, 2, 'grandchild'),
 (0, 19, 2.5, 4, 'great-grandchild'),
 (0, 20, 3.5, 7, 'great-great-grandchild'),
 (0, 21, 3.5, 7, 'great-great-grandchild'),
 (0, 22, 1.5, 2, 'grandchild'),
 (0, 23, 2.5, 4, 'great-grandchild'),
 (0, 24, 3.5, 7, 'great-great-grandchild'),
 (0, 25, 2.5, 4, 'great-grandchild'),
 (0, 26, 3.5, 7, 'grea