<a href="https://colab.research.google.com/github/duahauby/character-classifier-cnn-chars74k/blob/master/TS_handler.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install nltk transformers torchserve

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 5.3 MB/s 
[?25hCollecting torchserve
  Downloading torchserve-0.7.0-py3-none-any.whl (19.6 MB)
[K     |████████████████████████████████| 19.6 MB 1.2 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 53.7 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 66.4 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers, torchserve
Successfully installed huggingface-hub-0.11.1 tokenizers-0.13.2 torchserve-0.7.0 transformers-4.25.1


In [None]:
from abc import ABC
import json
import logging
import os
import re
import json
import sys
import torch
import string
import nltk
import numpy as np
from unicodedata import normalize
from ts.torch_handler.base_handler import BaseHandler
from nltk.tokenize import word_tokenize
import zipfile
from pathlib import Path
from transformers import AutoTokenizer


logger = logging.getLogger(__name__)


class IntentV3Handler(BaseHandler, ABC):
    """
    Intent v3 handler class.
    """

    def __init__(self):
        super(IntentV3Handler, self).__init__()
        self.initialized = False

    def initialize(self, ctx):
        """In this initialize function, the intent model is loaded.
        Args:
            ctx (context): It is a JSON Object containing information
            pertaining to the model artefacts parameters.
        """
        self.manifest = ctx.manifest
        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        serialized_file = self.manifest["model"]["serializedFile"]
        model_pt_path = os.path.join(model_dir, serialized_file)

        self.device = torch.device(
            "cuda:" + str(properties.get("gpu_id"))
            if torch.cuda.is_available() and properties.get("gpu_id") is not None
            else "cpu"
        )

        sys.path.append(os.path.join(model_dir, self.manifest["model"]["modelFile"]))
        from model import HSDModel
        
        if model_pt_path.endswith('.zip'):
            with zipfile.ZipFile(model_pt_path, 'r') as zip_ref:
                zip_ref.extractall(os.path.join(model_dir, "pretrained"))
        else:
            logger.warning("Model should be compressed in zip file format.")
        pretrained_path = os.path.join(model_dir, "pretrained/weights")
        model_file = "model.pth"
        model_file_path = os.path.join(pretrained_path, model_file)
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
        state = torch.load(model_file_path, map_location=self.device)
        self.mapping = state["id2label"]
        self.model = HSDModel(pretrained_path, num_classes=len(self.mapping), device=self.device)
        self.model.to(self.device)
        self.model.load_state_dict(state["weights"])
        self.model.eval()

        self.sample_mapping2intent = {}
        if 'sample_mapping2intent' in state:
            self.sample_mapping2intent = state['sample_mapping2intent']

        logger.info(
            "Pretrained intent model from path %s loaded successfully", model_dir
        )

        self.initialized = True

    def preprocess(self, requests):
        """Basic text preprocessing, based on the user's chocie of application mode.
        Args:
            requests (str): The Input data in the form of text is passed on to the preprocess
            function.
        Returns:
            list : The preprocess function returns a list of Tensor for the size of the word tokens.
        """
        logger.info("Received text: '%s'", requests)
        input_ids_batch = []
        texts = []
        for idx, data in enumerate(requests):
            input_text = data.get("data")
            if input_text is None:
                input_text = data.get("body")
            if isinstance(input_text, (bytes, bytearray)):
                input_text = input_text.decode('utf-8')
            logger.info("Received text: '%s'", input_text)
            input_text = clean_text(input_text.lower())
            texts.append(input_text)
            input_ids_batch.append(self.custom_tokenize(input_text))

        max_length = max([len(idx) for idx in input_ids_batch])
        padded_input_ids = np.ones((len(input_ids_batch), max_length), dtype=np.long)

        for i, idx in enumerate(input_ids_batch):
            padded_input_ids[i, :len(idx)] = idx

        input_mask = np.ones(padded_input_ids.shape)
        input_mask[padded_input_ids == 1] = 0

        input_ids_batch = torch.tensor(padded_input_ids, device=self.device)
        attention_mask_batch = torch.tensor(input_mask, device=self.device)

        logger.info("Input shape is: '%s'", input_ids_batch.shape)
        return input_ids_batch, attention_mask_batch, texts

    def custom_tokenize(self, text):
        start_end_tokens = self.tokenizer.encode('')
        if '[UNK]' in self.tokenizer.vocab:
            unk_token = [self.tokenizer.vocab['[UNK]']]
        elif '<unk>' in self.tokenizer.vocab:
            unk_token = [self.tokenizer.vocab['<unk>']]
        else:
            unk_token = self.tokenizer.encode('<unk>')

        sub_word_ids = []
        for j, word in enumerate(text.split()):
            if j != 0: word = ' ' + word
            tokens = self.tokenizer.encode(word, add_special_tokens=False)
            if len(tokens) > 3:
                sub_word_ids.extend(unk_token)
            else:
                sub_word_ids.extend(tokens)
        sub_word_ids = sub_word_ids[:self.model.config.max_position_embeddings - 4]
        sub_word_ids = [start_end_tokens[0]] + sub_word_ids + [start_end_tokens[1]]

        return sub_word_ids

    def inference(self, input_batch):
        """Predict the class (or classes) of the received text using the
        serialized transformers checkpoint.
        Args:
            input_batch (list): List of Text Tensors from the pre-process function is passed here
        Returns:
            list : It returns a list of the predicted value for the input text
        """
        input_ids_batch, attention_mask_batch, texts = input_batch
        logits = self.model(input_ids_batch, attention_mask_batch)
        logits = torch.softmax(logits, dim=-1).detach().cpu().numpy()
        
        inferences = []
        for i, logit in enumerate(logits):
            sort_idx = logit.argsort()[::-1]
            labels = [self.mapping[idx] for idx in sort_idx]
            scores = [float(logit[idx]) for idx in sort_idx]
            if texts[i] in self.sample_mapping2intent:
                labels[0] = self.sample_mapping2intent[texts[i]]
                scores = [0] * len(scores)
                scores[0] = 1.0
            inferences.append(list(zip(labels, scores)))
        return inferences

    def postprocess(self, inference_output):
        """Post Process Function converts the predicted response into Torchserve readable format.
        Args:
            inference_output (list): It contains the predicted response of the input text.
        Returns:
            (list): Returns a list of the Predictions and Explanations.
        """
        return inference_output


def clean_text(text):
    text = normalize('NFC', text)
    text = text.lower()
    text = " ".join([w for w in word_tokenize(text) if w not in string.punctuation])

    text = re.sub('\s+', ' ', text)
    return text