<a href="https://colab.research.google.com/github/cuong3004/Multiple_Label_Bert/blob/main/Demo_Multiple_Label.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
!pip install -U gdown transformers

Collecting gdown
  Downloading gdown-4.3.1.tar.gz (13 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting transformers
  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 27.1 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 2.1 MB/s 
Collecting tokenizers!=0.11.3,>=0.10.1
  Downloading tokenizers-0.11.5-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.8 MB)
[K     |████████████████████████████████| 6.8 MB 33.5 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 41.9 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0

In [None]:
!gdown --id "1-GQ9GPpPL5tNEG5I7ooGbtn2CBqX50dS"
!gdown --id "1-4aCuw_XZ1iOCL824EKE1KsKsGL7EwXG"

Downloading...
From: https://drive.google.com/uc?id=1-GQ9GPpPL5tNEG5I7ooGbtn2CBqX50dS
To: /content/Question_Classification_Dataset.csv
100% 409k/409k [00:00<00:00, 15.2MB/s]
Downloading...
From: https://drive.google.com/uc?id=1-4aCuw_XZ1iOCL824EKE1KsKsGL7EwXG
To: /content/traced_bert.pt
100% 439M/439M [00:05<00:00, 76.9MB/s]


In [None]:
from transformers import AutoTokenizer
import torch 
import pandas as pd
from pprint import pprint

df = pd.read_csv("/content/Question_Classification_Dataset.csv", index_col=False)

label1_list = sorted(list(set(df.iloc[:,-3])))
label2_list = sorted(list(set(df.iloc[:,-1])))
label1_dir = {k:v for v, k in enumerate(label1_list)}
label2_dir = {k:v for v, k in enumerate(label2_list)}

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = torch.jit.load("traced_bert.pt")

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

In [None]:
def preprocessing_input(text):
    encoded_dict = tokenizer.encode_plus(
                text,                      
                add_special_tokens = True,
                max_length = 20,           
                pad_to_max_length = True,
                return_attention_mask = True,   
                return_tensors = 'pt',     
            )
    input_id = encoded_dict['input_ids']
    attention_mask = encoded_dict['attention_mask']

    return input_id, attention_mask

def predict(text):

    input_id, attention_mask = preprocessing_input(text)
    with torch.no_grad():
        out1, out2 = model(input_id, attention_mask)
    out1_confident = torch.softmax(out1[0], -1)
    out1_argmax = torch.argmax(out1_confident, -1).item()
    out1_confident = out1_confident[out1_argmax].item()
    out2_confident = torch.softmax(out2[0], -1)
    out2_argmax = torch.argmax(out2_confident, -1).item()
    out2_confident = out2_confident[out2_argmax].item()

    confident_label = out2_confident*out1_confident
    return {
        "text": text,
        "category": {
            "name": label1_list[out1_argmax],
            "confident": out1_confident,
        },
        "subcategory": {
            "name": label2_list[out2_argmax],
            "confident": out2_confident,
        },
        "confident_label": confident_label,
    }



## Demo

In [None]:
#@title Write text
my_text = "what is the most wonderful sightseeing in the world?" #@param {type:"string"}


pprint(predict(my_text))

{'category': {'confident': 0.8559366464614868, 'name': 'ENTITY'},
 'confident_label': 0.6638407024960529,
 'subcategory': {'confident': 0.7755722403526306, 'name': 'other'},
 'text': 'what is the most wonderful sightseeing in the world?'}


