In [133]:
import os
import sys
import re
from typing import List, Union, Tuple, Dict, Callable, Any
sys.path.append(os.getcwd() + '/..')
sys.path.append(os.getcwd() + '/../..')
import torch
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from simcse import SimCSE
from transformers import BertForSequenceClassification, AutoModelForSeq2SeqLM, BertTokenizer, AutoTokenizer
from utils.taxo_utils import taxonomy
from utils import icon
from dataclasses import dataclass, fields, field

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ret_model = SimCSE('/data2T/jingchuan/tuned/ret/entity_type_tuned_sota/',device=device)
gen_model = AutoModelForSeq2SeqLM.from_pretrained('/data2T/jingchuan/tuned/gen/flan-t5-sota/').to(device)
gen_tokenizer = AutoTokenizer.from_pretrained('/data2T/jingchuan/tuned/gen/flan-t5-sota/')
sub_model = BertForSequenceClassification.from_pretrained('/data2T/jingchuan/tuned/sub/bertsubs-sota/').to(device)
sub_tokenizer = BertTokenizer.from_pretrained('/data2T/jingchuan/tuned/sub/bertsubs-sota/',model_max_length=128)

In [44]:
taxo = taxonomy.from_json('./../data/raw/ebay_us.json')
df = pd.DataFrame(taxo.nodes(data='label'),columns=['ID','Label']).drop(0).reset_index(drop=True)

In [46]:
id_dict = {}
idx_dict = {}
for i,row in df.iterrows():
    idx_dict[i] = row['ID']
    id_dict[row['ID']] = i
def index_to_ID(x):
    return idx_dict[x]
def ID_to_index(id):
    return id_dict[id]

In [50]:
ret_model.build_index(list(df['Label']))
def RET_model(taxo:taxonomy,seed:Union[int,str],k=10):
    if isinstance(seed,int):
        seed = taxo.get_label(seed)
    topk = ret_model.search(seed,top_k=k)
    return [index_to_ID(i) for i,_,_ in topk]

In [3]:
def GEN_model(labels,prefix='summarize: '):
    corpus = prefix
    for l in labels:
        corpus += l + '; '
    corpus = corpus[:-2]
    inputs = gen_tokenizer(corpus,return_tensors='pt').to(device)['input_ids']
    outputs = gen_model.generate(inputs,max_length=64)[0]
    decoded = gen_tokenizer.decode(outputs.cpu().numpy(),skip_special_tokens=True)
    return decoded

In [4]:
def SUB_model(classpairs:Union[Tuple[str, str], Tuple[List[str], List[str]]],batch_size:int=256):
    sub, sup = classpairs
    if isinstance(sub,str):
        sub, sup = [sub], [sup]
    if len(sub) <= batch_size:
        inputs = sub_tokenizer(sub,sup,padding=True,return_tensors='pt').to(device)
        predictions = torch.softmax(sub_model(**inputs).logits.detach().cpu(),1)[:,1].numpy()
    else:
        head = (sub[:batch_size],sup[:batch_size])
        tail = (sub[batch_size:],sup[batch_size:])
        predictions = np.concatenate((sub_model(head),sub_model(tail)))
    return predictions

In [9]:
# kwargs = {'knn_model': knn_model,
#           'gen_model': gen_model,
#           'sub_model': sub_model,
#           'mode': 'manual',
#           'manual_inputs': ['plastic round tubes', 'pipe wrenches', 'mixed lots', 'mountain lions', 'opticals', 'port expansion cards', 'eagles', 'drawer slides', 'steel drums', 'softballs'],
#           'restrict_combinations': False,
#           'retrieve_size': 5,
#           'subs_threshold': 0.95,
#           'log': True}

# enriched_taxo, results = enrich.main(ontology, **kwargs)

In [63]:
kwargs = {'knn_model': RET_model,
        'gen_model': GEN_model,
        'sub_model': SUB_model,
        'mode': 'semiauto',
        'semiauto_seeds': [175781],
        'restrict_combinations': False,
        'retrieve_size': 10,
        'subs_threshold': 0.9,
        'log': True}

enriched_taxo = icon.main(taxo, **kwargs)

Loading:   0%|          | 0/20334 [00:00<?, ?it/s]

Loaded taxonomy with 20334 nodes and 20339 edges. Commencing enrichment


Semiauto mode:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

	Cycle [30m[1m1[0m: Seed 175781 ([34m[1mMen's Vintage T-Shirts[0m)
		Retrieved [30m[1m10[0m classes
			[34m[1mMen's Vintage T-Shirts[0m
			[34m[1mMen's T-Shirts[0m
			[34m[1mMen's Equestrian Shirts[0m
			[34m[1mWomen's Western Show Shirts[0m
			[34m[1mKids' Dance Tops & Shirts[0m
			[34m[1mAdult Dance Tops & Shirts[0m
			[34m[1mVintage Sports Shirts[0m
			[34m[1mUnisex Kids' Tops & T-Shirts[0m
			[34m[1mMen's Western Show Shirts[0m
			[34m[1mBoys' Tops, Shirts & T-Shirts[0m
		Iteration [30m[1m1.1[0m: Combination ([34m[1mMen's Vintage T-Shirts[0m, [34m[1mMen's T-Shirts[0m)
		Generated common parent label: [36m[1mMen's Clothing[0m
			Searching on a domain of 5 classes
			Search complete. [33m[1mMapped[0m to a known class by NIL entity resolver
			Declared [35m[1mequivalence[0m between [33m[1mMen's Clothing[0m (1059) and [36m[1mMen's Clothing[0m
		Iteration [30m[1m1.2[0m: Combination ([34m[1mMen's Vintage T-Shirts[0m, [

In [72]:
class myclass:
    
    def __init__(self,
        ret_model=None,
        gen_model=None,
        sub_model=None,
        taxon_cache:Dict={},
        sub_score_cache:Dict={},
        mode:str='auto',
        max_outer_loop:int=None,
        semiauto_seeds:List[Union[int, str]]=[],
        manual_inputs:List[str]=[],
        inputs_bases:List[List[Union[int, str]]]=None,
        rand_seed=20230103,
        retrieve_size:int=10,
        restrict_combinations=True,
        ignore_label:List[str]=['','All categories','Root Concept','Thing','Allcats','Everything','root'],
        cached_subs_scores:Dict={},
        subgraph_crop:bool=True,
        subgraph_force:List[List[str]]=[['original']],
        subgraph_strict:bool=True,
        subs_threshold:float=0.5,
        search_tolerance:int=0,
        force_known_subsumptions:bool=False,
        force_prune_branches:bool=False,
        eqv_score_func:Callable[[Tuple[float, float]], float]=lambda x: x[0]*x[1],
        transitive_reduction:bool=True,
        log:Union[bool, int, List[str]]=False,
        ):
        saved_args = locals()
        print(saved_args)

In [137]:
@dataclass
class tree_config:
    
    def arglist(self):
        args = []
        for f in fields(self):
            if isinstance(getattr(self,f.name),tree_config):
                args += arglist(getattr(self,f.name))
            else:
                args.append(f.name)
        return args

@dataclass
class icon_models(tree_config):
    ret_model:Any
    gen_model:Any
    sub_model:Any

@dataclass
class icon_caches(tree_config):
    lexical_cache:Dict=field(default_factory=dict)
    sub_score_cache:Dict=field(default_factory=dict)

@dataclass
class icon_auto_config(tree_config):
    max_outer_loop:int=None

@dataclass
class icon_semiauto_config(tree_config):
    semiauto_seeds:List[Union[int, str]]=field(default_factory=list)

@dataclass
class icon_manual_config(tree_config):
    input_concepts:List[str]=field(default_factory=list)
    inputs_concept_bases:List[List[Union[int, str]]]=None

@dataclass
class icon_ret_config(tree_config):
    retrieve_size:int=10
    restrict_combinations:bool=True
    
@dataclass
class icon_gen_config(tree_config):
    ignore_label:List[str]=field(default_factory=list)
    filter_subset:bool=True

@dataclass
class icon_subgraph_config(tree_config):
    subgraph_crop:bool=True
    subgraph_force:List[List[str]]=field(default_factory=list)
    subgraph_strict:bool=True

@dataclass
class icon_search_config(tree_config):
    threshold:float=0.5
    tolerance:int=0
    force_base_subsumptions:bool=False
    force_prune:bool=False
    
@dataclass
class icon_sub_config(tree_config):
    subgraph:icon_subgraph_config=icon_subgraph_config()
    search:icon_search_config=icon_search_config()

@dataclass
class icon_update_config(tree_config):
    eqv_score_func:Callable[[Tuple[float, float]], float]=lambda x: x[0]*x[1]
    do_lexical_check:bool=True

    
@dataclass
class icon_config(tree_config):
    mode:str='auto'
    rand_seed:Any=114514
    auto_config:icon_auto_config=icon_auto_config()
    semiauto_config:icon_semiauto_config=icon_semiauto_config()
    manual_config:icon_manual_config=icon_manual_config()
    ret_config:icon_ret_config=icon_ret_config()
    gen_config:icon_gen_config=icon_gen_config()
    sub_config:icon_sub_config=icon_sub_config()
    update_config:icon_update_config=icon_update_config()
    transitive_reduction:bool=True
    log:Union[bool, int, List[str]]=False

TypeError: non-default argument 'auto_config' follows default argument