In [1]:
import os
import json
import torch

import numpy as np
from scipy.special import softmax
from numpy import argmax

from transformers import BertTokenizer, BertModel, BertConfig
from datetime import datetime
from torch.utils.data import Dataset, DataLoader


class FactoryEncoder:

    def __init__(self, tokenizer, device, max_length):
        self.tokenizer = tokenizer
        self.device = device
        self.max_length = max_length
    # end

    # Q: WHY NOT USING PARAMETERS?
    # A: FOR FUN
    def get_instance(self, sentences):
        instance = SimpleEncoder(sentences)
        instance.set_tokenizer(self.tokenizer)
        instance.set_device(self.device)
        instance.set_max_length(self.max_length)
        return instance
    # end
# end

class SimpleEncoder(Dataset):
    def __init__(self, sentences):
        self.sentences = sentences
    # end

    def __getitem__(self, index):
        sentence = self.sentences[index]
        sentence = ' '.join(sentence.split())
        inputs = self.tokenizer.encode_plus(
            sentence, None,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_tensors='pt'
        )

        for key in inputs:
            inputs[key].to(self.device)
        # end

        return inputs
    # end

    def set_tokenizer(self, tokenizer):
        self.tokenizer = tokenizer
    # end

    def set_device(self, device):
        self.device = device
    # end

    def set_max_length(self, max_length):
        self.max_length = max_length
    # end


class SimpleBertClassifier(torch.nn.Module):

    DEFAULT_FILENAME_CLASSIFIER = '.model.json'
    DEFAULT_FILENAME_BERT = 'bert_config.json'
    DEFAULT_FILENAME_MODEL = 'model.pt'
    DEFAULT_KEYS_IGNORED_CLASSIFIER = ['metrics', 'allmetrics']

    def __init__(self, path_folder_model=None):
        super(SimpleBertClassifier, self).__init__()

        filename_config_classifier = self.__class__.DEFAULT_FILENAME_CLASSIFIER
        filename_config_bert = self.__class__.DEFAULT_FILENAME_BERT
        filename_model = self.__class__.DEFAULT_FILENAME_MODEL
        keys_ignored_classifier = self.__class__.DEFAULT_KEYS_IGNORED_CLASSIFIER


        self.path_folder_model = path_folder_model
        self.path_config_bert = os.path.join(path_folder_model, filename_config_bert)
        self.path_config_classifier = os.path.join(path_folder_model, filename_config_classifier)
        self.path_file_model = os.path.join(path_folder_model, filename_model)

        with open(self.path_config_classifier, 'r') as file:
            config_classifier = json.load(file)
        # end

        for key in keys_ignored_classifier:
            del(config_classifier[key])
        # end

        # classfier parameters
        self.classifier_input_size = config_classifier.get('bert').get('input_size')
        self.classifier_max_length = config_classifier.get('bert').get('max_length')
        self.classifier_output_size = config_classifier.get('bert').get('output_size')
        #

        self.labels_output_classifier = config_classifier.get('classes')
        self.dict_label_index = {label: index for index, label in enumerate(self.labels_output_classifier)}

        self.l1 = None
        self.linear = None

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.loaded = False
    # end

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
        output_bert = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_bert[0]
        pooler = hidden_state[:, 0]
        output = self.classifier(pooler)
        return output
    # end

    def load(self):
        if not self.loaded:
            self.l1 = BertModel(BertConfig.from_pretrained(self.path_config_bert))
            print('loading {} x {} linear classfier layer'.format(self.classifier_input_size, self.classifier_output_size))
            self.classifier = torch.nn.Linear(self.classifier_input_size, self.classifier_output_size)

            self.to(self.device)
            self.load_state_dict(torch.load(self.path_file_model, map_location=torch.device(self.device)))

            print('Please Ignore warning message sent by BertTokenizer below')
            self.tokenizer = BertTokenizer.from_pretrained(self.path_folder_model)
            self.factory_encoder = FactoryEncoder(self.tokenizer, self.device, self.classifier_max_length)
            self.loaded = True
        # end

        return self
    # end

    def predicts(self, samples_input):
        encoder = self.factory_encoder.get_instance(samples_input)
        outputs_raw = []
        for sample_encoded in encoder:
            with torch.no_grad():
                output_raw = self(**sample_encoded).cpu().numpy().flatten().tolist()
                # print('jinyuj: predicts:83 output_raw: {}'.format(output_raw))
                outputs_raw.append(output_raw)
            # end
        # end

        info_result = {
            'outputs': outputs_raw,
            'labels': self.labels_output_classifier
        }

        return info_result
    # end
# end

In [2]:
classfier = SimpleBertClassifier('vBERT-base-20')
classfier.load()
print('done')

loading 768 x 16 linear classfier layer
done


In [3]:
sentence = 'timestamp failed at play deploy vm efi nvme vmxnet number timestamp task upload local file to esxi data store task path home worker workspace dw rhel number arm ansible vsphere gos validation common esxi upload data store file yml number exception in vsphere copy python when main in request python when http error default fatal localhost failed http error number not found timestamp task testing exit due to failure task path home worker workspace dw rhel number arm ansible vsphere gos validation common test rescue yml number fatal localhost failed exit testing when exit testing when fail is set to true in test case deploy vm efi nvme vmxnet number'
classfier.predicts([sentence])

RuntimeError: Input, output and indices must be on the current device