In [11]:
from typing import Dict, Optional

import polars as pl
import pytorch_lightning as pyl

In [13]:
distance_matrix_path = './data/naics_distance_matrix.parquet'

In [28]:
class NAICSContrastiveModel(pyl.LightningModule):
    
    def __init__(
        self,
        distance_matrix_path: Optional[str] = None
    ):
        super().__init__()
        
        self.ground_truth_distances = None
        self.code_to_idx = None
        if distance_matrix_path:
            self._load_ground_truth_distances(distance_matrix_path)
        
        self.validation_embeddings = {}
        self.validation_codes = []

        self.code_to_pseudo_label: Dict[str, int] = {}
        
        
    def _load_ground_truth_distances(self, distance_matrix_path: str):

        '''Load ground truth NAICS tree distances for evaluation.'''

        try:
            print(
                f'Loading ground truth distances\n'
                f'  • from: {distance_matrix_path}')
            
            df = pl.read_parquet(distance_matrix_path)
            n_codes = df.height
            
            ground_truth_distances = df.to_torch()
            print(f'  • distance matrix: [{n_codes}, {n_codes}]\n')
            
            code_to_idx = {}
            for col in df.columns:
                idx_col, code_col = col.split('-')
                idx = int(idx_col.replace('idx_', ''))
                code = code_col.replace('code_', '')
                code_to_idx[code] = idx
                
            self.ground_truth_distances = ground_truth_distances
            self.code_to_idx = code_to_idx
            
        except Exception as e:
            print(f'Could not load ground truth distances: {e}')
            ground_truth_distances = None
            code_to_idx = None

In [29]:
model = NAICSContrastiveModel()

In [30]:
model._load_ground_truth_distances(distance_matrix_path)

Loading ground truth distances
  • from: ./data/naics_distance_matrix.parquet
  • distance matrix: [2125, 2125]



In [32]:
model.code_to_idx


{'11': 0,
 '111': 1,
 '1111': 2,
 '11111': 3,
 '111110': 4,
 '11112': 5,
 '111120': 6,
 '11113': 7,
 '111130': 8,
 '11114': 9,
 '111140': 10,
 '11115': 11,
 '111150': 12,
 '11116': 13,
 '111160': 14,
 '11119': 15,
 '111191': 16,
 '111199': 17,
 '1112': 18,
 '11121': 19,
 '111211': 20,
 '111219': 21,
 '1113': 22,
 '11131': 23,
 '111310': 24,
 '11132': 25,
 '111320': 26,
 '11133': 27,
 '111331': 28,
 '111332': 29,
 '111333': 30,
 '111334': 31,
 '111335': 32,
 '111336': 33,
 '111339': 34,
 '1114': 35,
 '11141': 36,
 '111411': 37,
 '111419': 38,
 '11142': 39,
 '111421': 40,
 '111422': 41,
 '1119': 42,
 '11191': 43,
 '111910': 44,
 '11192': 45,
 '111920': 46,
 '11193': 47,
 '111930': 48,
 '11194': 49,
 '111940': 50,
 '11199': 51,
 '111991': 52,
 '111992': 53,
 '111998': 54,
 '112': 55,
 '1121': 56,
 '11211': 57,
 '112111': 58,
 '112112': 59,
 '11212': 60,
 '112120': 61,
 '11213': 62,
 '112130': 63,
 '1122': 64,
 '11221': 65,
 '112210': 66,
 '1123': 67,
 '11231': 68,
 '112310': 69,
 '11232':

2124