In [1]:
from typing import List, Union, Tuple
import torch
import pandas as pd
import numpy as np
from simcse import SimCSE
from transformers import BertForSequenceClassification, AutoModelForSeq2SeqLM, BertTokenizer, AutoTokenizer
from utils import taxo_utils
from utils.taxo_utils import Taxonomy
from main.icon import ICON





In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ret_model = SimCSE('/data2T/jingchuan/tuned/ret/entity_type_tuned_sota/',device=device)
gen_model = AutoModelForSeq2SeqLM.from_pretrained('/data2T/jingchuan/tuned/gen/flan-t5-sota/').to(device)
gen_tokenizer = AutoTokenizer.from_pretrained('/data2T/jingchuan/tuned/gen/flan-t5-sota/')
sub_model = BertForSequenceClassification.from_pretrained('/data2T/jingchuan/tuned/sub/bertsubs-sota/').to(device)
sub_tokenizer = BertTokenizer.from_pretrained('/data2T/jingchuan/tuned/sub/bertsubs-sota/',model_max_length=128)

In [3]:
taxo = taxo_utils.from_json('./data/raw/ebay_us.json')
df = pd.DataFrame(taxo.nodes(data='label'),columns=['ID','Label']).drop(0).reset_index(drop=True)

In [4]:
id_dict = {}
idx_dict = {}
for i,row in df.iterrows():
    idx_dict[i] = row['ID']
    id_dict[row['ID']] = i
def index_to_ID(x):
    return idx_dict[x]
def ID_to_index(id):
    return id_dict[id]

In [5]:
ret_model.build_index(list(df['Label']))
def RET_model(taxo:Taxonomy,seed:Union[int,str],k=10):
    if isinstance(seed,int):
        seed = taxo.get_label(seed)
    topk = ret_model.search(seed,top_k=k)
    return [index_to_ID(i) for i,_,_ in topk]

In [6]:
def GEN_model(labels,prefix='summarize: '):
    corpus = prefix
    for l in labels:
        corpus += l + '; '
    corpus = corpus[:-2]
    inputs = gen_tokenizer(corpus,return_tensors='pt').to(device)['input_ids']
    outputs = gen_model.generate(inputs,max_length=64)[0]
    decoded = gen_tokenizer.decode(outputs.cpu().numpy(),skip_special_tokens=True)
    return decoded

In [7]:
def SUB_model(sub: Union[str, List[str]], sup: Union[str, List[str]], batch_size :int=256):
    if isinstance(sub, str):
        sub, sup = [sub], [sup]
    if len(sub) <= batch_size:
        inputs = sub_tokenizer(sub,sup,padding=True,return_tensors='pt').to(device)
        predictions = torch.softmax(sub_model(**inputs).logits.detach().cpu(),1)[:,1].numpy()
    else:
        head = (sub[:batch_size], sup[:batch_size])
        tail = (sub[batch_size:],sup[batch_size:])
        predictions = np.concatenate((SUB_model(head[0], head[1], batch_size=batch_size), SUB_model(tail[0], tail[1], batch_size=batch_size)))
    return predictions

In [8]:
kwargs = {'data': taxo,
        'ret_model': RET_model,
        'gen_model': GEN_model,
        'sub_model': SUB_model,
        'mode': 'manual',
        'input_concepts': ['plastic round tubes', 'pipe wrenches', 'mixed lots', 'mountain lions', 'opticals', 'port expansion cards', 'eagles', 'drawer slides', 'steel drums', 'softballs'],
        'restrict_combinations': False,
        'retrieve_size': 5,
        'threshold': 0.9,
        'log': True}

newobj = ICON(**kwargs)

Loading lexical cache:   0%|          | 0/20334 [00:00<?, ?it/s]

In [9]:
# kwargs = {'data': taxo,
#         'ret_model': RET_model,
#         'gen_model': GEN_model,
#         'sub_model': SUB_model,
#         'mode': 'auto',
#         'semiauto_seeds': [175781],
#         'restrict_combinations': True,
#         'retrieve_size': 2,
#         'threshold': 0.9,
#         'log': 1}

# newobj = icon.ICON(**kwargs)

In [10]:
newobj.run()

Loaded Taxonomy with 20334 nodes and 20333 edges. Commencing enrichment


  0%|          | 0/1 [00:00<?, ?it/s]

		Input: [36m[1mplastic round tubes[0m
			Searching on a domain of 20334 classes


TypeError: forward() got an unexpected keyword argument 'batch_size'