In [None]:
# -------------------------------------------------------------------------------------------------
# Imports and settings
# -------------------------------------------------------------------------------------------------

import logging
import random
from dataclasses import dataclass, field
from random import sample
from typing import Dict, Generator, List, Optional, Tuple

import polars as pl
import torch

from naics_gemini.utils.utilities import get_indices_codes

logger = logging.getLogger(__name__)


# -------------------------------------------------------------------------------------------------
# Configuration
# -------------------------------------------------------------------------------------------------

@dataclass
class CurriculumConfig:

    positive_levels: List[int] = field(
        default_factory=lambda: []
    )
    positive_distances: Tuple[float, float] = field(
        default_factory=lambda: (0.5, 10.0)
    )
    hardness_buckets: List[int] = field(
        default_factory=lambda: [1, 2, 3, 4, 5, 6, 7, 8]
    )
    n_positives: int = 2025
    k_negatives: int = 16

    def __post_init__(self):
        n = len(self.hardness_buckets)
        self.hardness_weights = {k: 1.0 / n for k in self.hardness_buckets}

In [77]:
# -------------------------------------------------------------------------------------------------
# Index building functions
# -------------------------------------------------------------------------------------------------

#def build_triplet_indices(
#    fields_path: str,
#    triplets_path: str,
#    curriculum: CurriculumConfig,
#    rng: random.Random
#) -> Tuple[
#    Dict[str, List[str]], 
#    Dict[str, Dict[int, List[str]]]
#]:

"""Build anchor-to-positives and positive-to-negatives mappings."""
fields_path = '../data/naics_descriptions.parquet'
triplets_path = '../data/naics_training_pairs'
curriculum = CurriculumConfig(
    positive_levels=[2],
    positive_distances=(0.5, 2.0)
)

_, codes, _, _ = get_indices_codes(fields_path)

In [79]:

logger.info('Building triplet indices...')


buckets = curriculum.hardness_buckets
levels = curriculum.positive_levels
codes = sample(codes, curriculum.n_positives)
dist_min, dist_max = curriculum.positive_distances

df = (
    pl
    .scan_parquet(
        triplets_path
    )
    .filter(
        pl.col('hardness')
          .is_in(buckets),
        #pl.col('positive_code')
        #  .is_in(codes),
        pl.col('positive_code')
          .str.len_chars()
          .is_in(levels),
        pl.col('positive_distance')
          .is_between(
              lower_bound=dist_min, 
              upper_bound=dist_max
          )
    )
    .sort('anchor_idx', 'positive_idx', 'distance_diff')
    #.group_by('anchor_idx', 'positive_idx', 'distance_diff', maintain_order=True)
    #.agg(
    #    pl.col('negative_idx')
    #)
    .collect(engine='streaming')
)
df

anchor_idx,positive_idx,negative_idx,anchor_code,positive_code,negative_code,excluded,unrelated,positive_distance,negative_distance,distance_diff,hardness
u32,u32,u32,str,str,str,bool,bool,f32,f32,f32,i64


In [None]:

# Build anchor to positives mapping
anchor_to_positives_iter = (
    df
    .select('anchor_code', 'positive_code')
    .unique()
    .iter_rows(named=True)
)

anchor_to_negatives_iter = (
    df
    .iter_rows(named=True)
)

anchor_to_positives = {}
for row in anchor_to_positives_iter:
    anchor = row['anchor_code']
    positive = row['positive_code']
    
    if anchor not in anchor_to_positives:
        anchor_to_positives[anchor] = []
    
    anchor_to_positives[anchor].append(positive)

# Limit max positives per anchor if specified
if curriculum.max_positives:
    for anchor in anchor_to_positives:
        positives = anchor_to_positives[anchor]
        if len(positives) > curriculum.max_positives:
            anchor_to_positives[anchor] = rng.sample(
                positives, curriculum.max_positives
            )

positive_to_negatives = {}
for row in anchor_to_negatives_iter:
    pos = row['positive_code']
    neg = row['negative_code']
    hardness = row['hardness']
    
    if pos not in positive_to_negatives:
        positive_to_negatives[pos] = {}
    
    if hardness not in positive_to_negatives[pos]:
        positive_to_negatives[pos][hardness] = []
    
    positive_to_negatives[pos][hardness].append(neg)

total_pairs = sum(len(v) for v in anchor_to_positives.values())
logger.info(f'Filtered to {len(anchor_to_positives)} anchors, {total_pairs} positive pairs')

#return anchor_to_positives, positive_to_negatives

In [None]:


# -------------------------------------------------------------------------------------------------
# Negative sampling function
# -------------------------------------------------------------------------------------------------

def sample_negatives(
    positive_code: str,
    positive_to_negatives_by_hardness: Dict[str, Dict[int, List[str]]],
    curriculum: CurriculumConfig,
    rng: random.Random
) -> List[str]:
    """Sample negatives according to curriculum configuration."""
    
    negatives_by_hardness = positive_to_negatives_by_hardness.get(positive_code, {})
    
    k = curriculum.k_negatives
    target_counts = {
        bucket: int(k * pct)
        for bucket, pct in curriculum.bucket_percentages.items()
    }
    
    remaining = k - sum(target_counts.values())
    if remaining > 0:
        for bucket in sorted(target_counts.keys(), reverse=True):
            if remaining == 0:
                break
            target_counts[bucket] += 1
            remaining -= 1
    
    sampled = []
    for bucket in sorted(curriculum.difficulty_buckets, reverse=True):
        target = target_counts.get(bucket, 0)
        if target == 0:
            continue
        
        available = negatives_by_hardness.get(bucket, [])
        
        if len(available) >= target:
            sampled.extend(rng.sample(available, target))
        
        elif len(available) > 0:
            sampled.extend(available)
            shortage = target - len(available)
            
            for fallback_bucket in range(bucket - 1, 0, -1):
                if shortage == 0:
                    break
                fallback_available = negatives_by_hardness.get(fallback_bucket, [])
                fallback_available = [n for n in fallback_available if n not in sampled]
                
                if len(fallback_available) >= shortage:
                    sampled.extend(rng.sample(fallback_available, shortage))
                    shortage = 0
                
                elif len(fallback_available) > 0:
                    sampled.extend(fallback_available)
                    shortage -= len(fallback_available)
    
    if len(sampled) < k:
        all_negatives = []
        for negs in negatives_by_hardness.values():
            all_negatives.extend(negs)
        all_negatives = list(set(all_negatives) - set(sampled))
        
        if all_negatives:
            needed = min(k - len(sampled), len(all_negatives))
            sampled.extend(rng.sample(all_negatives, needed))
    
    return sampled[:k]


# -------------------------------------------------------------------------------------------------
# Streaming dataset generator
# -------------------------------------------------------------------------------------------------

def create_streaming_dataset(
    descriptions_path: str,
    triplets_path: str,
    token_cache: Dict[str, Dict[str, torch.Tensor]],
    curriculum: CurriculumConfig,
    seed: int = 42
) -> Generator[Dict, None, None]:
    """Create a generator that yields training examples."""
    
    rng = random.Random(seed)
    
    # Build indices
    anchor_to_positives, positive_to_negatives_by_hardness = build_triplet_indices(
        descriptions_path,
        triplets_path,
        curriculum,
        rng
    )
    
    # Create all pairs
    all_pairs = []
    for anchor, positives in anchor_to_positives.items():
        for positive in positives:
            all_pairs.append((anchor, positive))
    
    rng.shuffle(all_pairs)
    
    # Yield examples
    for anchor_code, positive_code in all_pairs:
        
        negative_codes = sample_negatives(
            positive_code,
            positive_to_negatives_by_hardness,
            curriculum,
            rng
        )
        
        if len(negative_codes) == 0:
            continue
        
        anchor_tokens = token_cache[anchor_code]
        positive_tokens = token_cache[positive_code]
        negative_tokens_list = [token_cache[neg] for neg in negative_codes]
        
        yield {
            'anchor': anchor_tokens,
            'positive': positive_tokens,
            'negatives': negative_tokens_list,
            'anchor_code': anchor_code,
            'positive_code': positive_code,
            'negative_codes': negative_codes
        }