# ICON demonstration

This notebook is a guided example of using ICON to enrich the Google Product Type Taxonomy.
Before running this notebook, make sure that you have read README.md of the ICON repository.

## Preparation

**Replace SimCSE script**: For the purpose of this demonstration, please temporarily replace the `tool.py` in your SimCSE directory with `/utils/replace_simcse/tool.py`. The reasons are explained [here](/README.md#replace-simcse-script).

In [1]:
# ! pip show simcse | grep -P "Location: .*$" # Locate your SimCSE package. 
# Copy the directory given by the above command's outputs, which will look like:
    # Location: SIMCSE_DIR
# Now uncomment the following line and replace SIMCSE_DIR with what you have copied
# ! cp utils/replace_simcse/tool.py /home/jingcshi/.conda/envs/icon/lib/python3.8/site-packages/simcse/tool.py

## Importing relevant packages

A complete list of dependencies is available in the [README](/README.md#dependencies).

In [2]:
import os
from typing import List, Union, Hashable
import torch
import pandas as pd
import numpy as np
import faiss
from simcse import SimCSE
from ellement.transformers import AutoModel, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, AutoTokenizer
from utils import taxo_utils
from utils.taxo_utils import Taxonomy
from main.icon import ICON





## Reading data

The taxonomy dataset will be loaded as a `utils.taxo_utils.Taxonomy` object. For I/O format details, please refer to the corresponding section in [README](README.md#file-io-format).

In [3]:
taxo = taxo_utils.from_json('./data/raw/ebay_us.json')

## Loading the models

ICON requires three sub-models: `ret_model`, `gen_model` and `sub_model`.

**If you don't have these models**: The scripts in `/experiments/data_wrangling/` and notebooks in `/experiments/model_training/` will offer a pipeline for preparing the training data and fine-tuning pre-trained language models.

**Models for eBay**: Models fine-tuned on eBay data with the pipeline described below are available at RNO HDFS: `/user/jingcshi/ICON_models/`.

Our choices of ret_model, gen_model and sub_model each requires a tokenizer. The tokenizer for ret_model is automatically loaded during the SimCSE init command.

Notice that ICON uses its sub-models as callable functions and doesn't care how the models themselves are implemented. Therefore, we need to wrap these models in callable interfaces. This will be demonstrated in a [cell below](#wrapping-the-models-as-callables).

In [4]:
ret_model_path = '/data/ebay-slc-a100/data/jingcshi/ICON_models/ret/vector-prime'
gen_model_path = '/data/ebay-slc-a100/data/jingcshi/ICON_models/gen/flan-t5-xl-sota/'
sub_model_path = '/data/ebay-slc-a100/data/jingcshi/ICON_models/sub/ebert2-sota/'

## Wrapping the models as callables classes

Here we create a class for each sub-model with a `__call__` method so that ICON can directly call them.

Each model has its expected inputs and outputs:

- `RET_model`: Takes in a list of concepts, a query string (the concepts most similar to which we would like to find out), and an integer `k`, the amount of concepts to be retrieved. Returns a list of concept IDs in the taxonomy.

- `GEN_model`: Takes in a list of strings (concept labels which the model should summarise). Returns a single string (label for the union concept).

- `SUB_model`: Takes in two lists of strings (the labels for `sub` and `sup` respectively). Returns an 1D array of prediction scores of how likely each concept in `sup` subsumes the corresponding concept in `sub`.

In [5]:
class RET_model:

    def __init__(self, model_path, **kwargs) -> None:
        self.model = SimCSE(model_path, **kwargs)
        self.idx_dict = {}
    
    def build_index(self, concepts: List[Hashable], **kwargs) -> None:
        self.model.build_index(taxo.get_label(concepts), **kwargs)
        self.idx_dict = {i: c for i, c in enumerate(concepts)}

    def __call__(self, concepts: List[Hashable], query: str, k: int=10) -> List[Hashable]:
        if self.model.index is None:
            self.build_index(concepts)
        if set(concepts) != set(self.model.index['sentences']):
            self.build_index(concepts)
        ans = self.model.search(query, top_k=k)
        return [self.idx_dict[i] for i,_,_ in ans]
    
    def similarity(self, query, keys) -> np.ndarray:
        if self.model.index is None:
            key_embeds = self.model.encode(keys, return_numpy=True)
        else:
            key_embeds = []
            n = self.model.index['index'].ntotal
            d = self.model.index['index'].d
            for k in keys:
                try:
                    key_embeds.append(faiss.rev_swig_ptr(self.model.index['index'].get_xb(), n*d).reshape(n, d)[self.model.index['sentences'].index(k)])
                except ValueError:
                    key_embeds.append(self.model.encode(k, return_numpy=True))
            key_embeds = np.stack(key_embeds)
        return self.model.similarity(query, key_embeds)

class GEN_model:

    def __init__(self, model_path, **kwargs) -> None:
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path, **kwargs).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.max_length = self.model.config.max_length

    def __call__(self, labels: List[str], prefix='summarize: ') -> str:
        corpus = prefix
        for l in labels:
            corpus += l + '[SEP]'
        corpus = corpus[:-5]
        inputs = self.tokenizer(corpus,return_tensors='pt').to(device)['input_ids']
        outputs = self.model.generate(inputs,max_length=self.max_length)[0]
        decoded = self.tokenizer.decode(outputs.cpu().numpy(),skip_special_tokens=True)
        return decoded

class SUB_model:

    def __init__(self, model_path, **kwargs) -> None:
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path, **kwargs).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path,model_max_length=128)

    def __call__(self, sub: Union[str, List[str]], sup: Union[str, List[str]], batch_size :int=256) -> np.ndarray:
        if isinstance(sub, str):
            sub, sup = [sub], [sup]
        if len(sub) <= batch_size:
            inputs = self.tokenizer(sub,sup,padding=True,return_tensors='pt').to(device)
            predictions = torch.softmax(self.model(**inputs).logits.detach().cpu(),1)[:,1].numpy()
        else:
            head = (sub[:batch_size], sup[:batch_size])
            tail = (sub[batch_size:],sup[batch_size:])
            predictions = np.concatenate((SUB_model(head[0], head[1], batch_size=batch_size), SUB_model(tail[0], tail[1], batch_size=batch_size)))
        return predictions

device = 'cuda' if torch.cuda.is_available() else 'cpu'
ret_model = RET_model(ret_model_path, device=device, pooler="cls_before_pooler")
gen_model = GEN_model(gen_model_path, max_length=64)
sub_model = SUB_model(sub_model_path)
ret_model.build_index(list(taxo.nodes))

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Configuration

Almost there! Configure your run by specifying the data, models and settings. Check [here](/README.md#configurations) to see how to choose the right settings for your purpose. 

In the following example, we will run auto mode with 10 outer loops. We will also set `logging` to `True` to see a detailed logging of ICON's actions and results.

In [14]:
kwargs = {'data': taxo,
        'ret_model': RET_model,
        'gen_model': GEN_model,
        'sub_model': SUB_model,
        'restrict_combinations': False,
        'retrieve_size': 5,
        'logging': 1}

iconobj = ICON(**kwargs)

Loading lexical cache:   0%|          | 0/20334 [00:00<?, ?it/s]

## Running

We have prepared everything to run ICON. Simply initialise an ICON object with our configuration and call `run()`. 

If you change your mind on the settings before running, you don't have to initialise again: calling `update_config` would suffice.

The output of a run will be either a new taxonomy (as is the case here) or a list of ICON predictions. To save a taxonomy to a file, use the `to_json` method.

In [11]:
iconobj.update_config(threshold=0.8, logging=True, subgraph_strict=False) # Example of updating configurations
outputs = iconobj.run()

Loaded Taxonomy with 20334 nodes and 20333 edges. Commencing enrichment


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

	Outer loop [30m[1m1[0m: Seed 261700 ([34m[1mCollectible Coffee Roasters[0m) selected from 18022 possible candidates
		Retrieved [30m[1m5[0m classes
			[34m[1mCollectible Coffee Roasters[0m
			[34m[1mCommercial Coffee Roasters[0m
			[34m[1mCoffee Roasters[0m
			[34m[1mCollectible Coffee Makers[0m
			[34m[1mGround Coffee[0m
		Inner loop [30m[1m1.1[0m: Combination ([34m[1mCollectible Coffee Roasters[0m, [34m[1mCommercial Coffee Roasters[0m)
		Generated semantic union label: [36m[1mCoffee Roasters[0m
			Searching on a domain of 20334 classes
			Search complete. [33m[1mMapped[0m to a known class by lexical check
				[34m[1mCoffee Roasters[0m with score 1.0000
				[34m[1mCommercial Coffee Roasters[0m with score 0.9798
			For safety, only the highest ranked equivalence is preserved
		Inner loop [30m[1m1.2[0m: Combination ([34m[1mCollectible Coffee Roasters[0m, [34m[1mCoffee Roasters[0m)
		Generated semantic union label: [36m[1mCoffee & 

KeyboardInterrupt: 

In [9]:
iconobj._status.logs

{'Coffee Roasters': {'equivalent': {177753: 1.0},
  'superclass': {},
  'subclass': {261700: 0.9995437264442444, 57070: 0.9990851879119873}},
 'Coffee & Tea Collectibles': {'equivalent': {},
  'superclass': {13905: 0.9973963499069214, 20625: 0.9887692928314209},
  'subclass': {261700: 0.9896911382675171}},
 'Coffee, Tea & Soft Drinks': {'equivalent': {185035: 1.0},
  'superclass': {},
  'subclass': {185036: 0.9974029660224915, 11652: 0.9877796173095703}},
 'Commercial Coffee Roasters': {'equivalent': {57070: 1.0},
  'superclass': {177753: 0.9990851879119873},
  'subclass': {}},
 'Coffee, Tea & Espresso Makers': {'equivalent': {38250: 1.0},
  'superclass': {},
  'subclass': {177753: 0.9695248007774353, 11652: 0.9939215183258057}},
 'Industrial Equipment & Tools': {'equivalent': {61573: 0.9667483867535864},
  'superclass': {42892: 0.8795446753501892},
  'subclass': {181748: 0.851963222026825}},
 'Industrial Tools & Supplies': {'equivalent': {61573: 0.9383582318956734},
  'superclass': {}