In [1]:
from typing import List, Union
import torch
import pandas as pd
import numpy as np
from simcse import SimCSE
from transformers import BertForSequenceClassification, AutoModelForSeq2SeqLM, BertTokenizer, AutoTokenizer
from utils import taxo_utils
from utils.taxo_utils import Taxonomy
from main.icon import ICON





In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ret_model = SimCSE('/data2T/jingchuan/tuned/ret/entity_type_tuned_sota/',device=device)
gen_model = AutoModelForSeq2SeqLM.from_pretrained('/data2T/jingchuan/tuned/gen/flan-t5-sota/').to(device)
gen_tokenizer = AutoTokenizer.from_pretrained('/data2T/jingchuan/tuned/gen/flan-t5-sota/')
sub_model = BertForSequenceClassification.from_pretrained('/data2T/jingchuan/tuned/sub/bertsubs-sota/').to(device)
sub_tokenizer = BertTokenizer.from_pretrained('/data2T/jingchuan/tuned/sub/bertsubs-sota/',model_max_length=128)

In [3]:
taxo = taxo_utils.from_json('./data/raw/ebay_us.json')
df = pd.DataFrame(taxo.nodes(data='label'),columns=['ID','Label']).drop(0).reset_index(drop=True)

In [4]:
id_dict = {}
idx_dict = {}
for i,row in df.iterrows():
    idx_dict[i] = row['ID']
    id_dict[row['ID']] = i
def index_to_ID(x):
    return idx_dict[x]
def ID_to_index(id):
    return id_dict[id]

In [12]:
ret_model.build_index(list(df['Label']))
class RET_model(taxo: Taxonomy, query: str, k=10):
    topk = ret_model.search(query, top_k=k)
    return [index_to_ID(i) for i,_,_ in topk]

In [6]:
def GEN_model(labels,prefix='summarize: '):
    corpus = prefix
    for l in labels:
        corpus += l + '; '
    corpus = corpus[:-2]
    inputs = gen_tokenizer(corpus,return_tensors='pt').to(device)['input_ids']
    outputs = gen_model.generate(inputs,max_length=64)[0]
    decoded = gen_tokenizer.decode(outputs.cpu().numpy(),skip_special_tokens=True)
    return decoded

In [7]:
def SUB_model(sub: Union[str, List[str]], sup: Union[str, List[str]], batch_size :int=256):
    if isinstance(sub, str):
        sub, sup = [sub], [sup]
    if len(sub) <= batch_size:
        inputs = sub_tokenizer(sub,sup,padding=True,return_tensors='pt').to(device)
        predictions = torch.softmax(sub_model(**inputs).logits.detach().cpu(),1)[:,1].numpy()
    else:
        head = (sub[:batch_size], sup[:batch_size])
        tail = (sub[batch_size:],sup[batch_size:])
        predictions = np.concatenate((SUB_model(head[0], head[1], batch_size=batch_size), SUB_model(tail[0], tail[1], batch_size=batch_size)))
    return predictions

In [8]:
kwargs = {'data': taxo,
        'ret_model': RET_model,
        'gen_model': GEN_model,
        'sub_model': SUB_model,
        'mode': 'manual',
        'auto_bases': True,
        'input_concepts': ['plastic round tubes', 'pipe wrenches', 'mixed lots', 'mountain lions', 'opticals', 'port expansion cards', 'eagles', 'drawer slides', 'steel drums', 'softballs'],
        'restrict_combinations': False,
        'retrieve_size': 5,
        'threshold': 0.95,
        'do_update': False}

newobj = ICON(**kwargs)

Loading lexical cache:   0%|          | 0/20334 [00:00<?, ?it/s]

In [9]:
# kwargs = {'data': taxo,
#         'ret_model': RET_model,
#         'gen_model': GEN_model,
#         'sub_model': SUB_model,
#         'mode': 'auto',
#         'semiauto_seeds': [175781],
#         'restrict_combinations': True,
#         'retrieve_size': 2,
#         'threshold': 0.9,
#         'log': 1}

# newobj = icon.ICON(**kwargs)

In [10]:
outputs = newobj.run()

Loaded Taxonomy with 20334 nodes and 20333 edges. Commencing enrichment


  0%|          | 0/1 [00:00<?, ?it/s]

Enrichment complete. Begin post-processing with transitive reduction
Return ICON predictions


In [11]:
outputs

{'plastic round tubes': {'equivalent': {258242: (0.9989275336265564,
    0.999014139175415)},
  'superclass': {160704: 0.9884874820709229,
   257824: 0.9656429886817932,
   11874: 0.9817126393318176,
   14308: 0.9655066728591919,
   26221: 0.992567777633667,
   20625: 0.9542515873908997,
   3187: 0.9875780940055847,
   160667: 0.9778825640678406},
  'subclass': {}},
 'pipe wrenches': {'equivalent': {20772: (1.0, 1.0)},
  'superclass': {183978: 0.9952627420425415,
   184042: 0.9966223239898682,
   46576: 0.9830930233001709,
   42622: 0.9626042246818542},
  'subclass': {}},
 'mixed lots': {'equivalent': {},
  'superclass': {1: 0.9882097244262695, 11700: 0.9958816766738892},
  'subclass': {32772: 0.9848905205726624,
   262022: 0.9989845156669617,
   527: 0.9930685758590698,
   529: 0.9889154434204102,
   260629: 0.9916552305221558,
   173690: 0.9922962784767151,
   3356: 0.9943913817405701,
   183455: 0.9864705204963684,
   260000: 0.9683923721313477,
   45089: 0.9923934936523438,
   1756