In [38]:
from lit_nlp import notebook
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.lib import utils

import pandas as pd

import torch
import transformers

In [24]:
class SR2PRDataset(lit_dataset.Dataset):
    LABELS = ["0", "1"]
    def __init__(self, split="test", max_seq_len=128):
        """Dataset constructor, loads the data into memory."""
        #raw_examples = load_tfds("imdb_reviews", split=split)
        df = pd.read_csv(f"./{split}-even.tsv", sep="\t", header=None)
        #df = df.sample(frac=1).reset_index(drop=True)
        df = df.head(50)
        texts = df[2].tolist()
        labels = df[1].tolist()
        self._examples = []  # populate this with data records
        for i in range(len(texts)):
            # format and truncate from the end to max_seq_len tokens.
            self._examples.append({
                "text": texts[i],
                "label": self.LABELS[labels[i]],
            })

    def spec(self) -> lit_types.Spec:
        """Dataset spec, which should match the model"s input_spec()."""
        return {
            "text": lit_types.TextSegment(),
            "label": lit_types.CategoryLabel(vocab=self.LABELS),
        }

In [25]:
datasets = {
    "SR-2-PR": SR2PRDataset()
}

[{'text': '[SR][MCS][Walmart] all the connectors from different datacenters are randomly going onto the deadlist, preventing all users from logging in', 'label': '1'}, {'text': '[Likely 3rd party HPE plugin][SR][CT Probate Administration] Unable to mount some of the VMFS datastore that were formatted with ATS after the host reboot', 'label': '1'}, {'text': '[SR][Marsh & McLennan Companies Inc] Need to Migrate Multi-Machines within a Deployment to a different Business Group and maintain the deployment structure', 'label': '1'}, {'text': '[18733616203][Nvidia][ESXi 6.5][vSGA]: vSGA waiver logs for P40 + R390 on ESXi 6.5', 'label': '1'}, {'text': '[DCPNc][PureStorage]Unable to validate MSCS SCSI-3 Reservations when configuring on ESXi/vCenter 6.7 with VVols', 'label': '1'}, {'text': '[18925035209][Server][VMware ESXi 6.5.0 build-7967591][CIS][UCSC-C460-M4][40540][WB: 3.5.7] Server Log Submission', 'label': '1'}, {'text': '[HP] AlterName and the standardInquiry properies of HostScsiDisk is

In [39]:
def _from_pretrained(cls, *args, **kw):
    """Load a transformers model in PyTorch, with fallback to TF2/Keras weights."""
    try:
        return cls.from_pretrained(*args, **kw)
    except OSError as e:
        logging.warning("Caught OSError loading model: %s", e)
        logging.warning("Re-trying to convert from TensorFlow checkpoint (from_tf=True)")
        return cls.from_pretrained(*args, from_tf=True, **kw)

class BERTClassification(lit_model.Model):
    LABELS = ["0", "1"]

    def __init__(self, model_name_or_path):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
        model_config = transformers.AutoConfig.from_pretrained(
            model_name_or_path,
            num_labels=2,
            output_hidden_states=True,
            output_attentions=True,
        )
        # This is a just a regular PyTorch model.
        self.model = _from_pretrained(
            transformers.AutoModelForSequenceClassification,
            model_name_or_path,
            config=model_config
        )
        self.model.eval()
    
    # LIT API implementation
    def max_minibatch_size(self):
        # This tells lit_model.Model.predict() how to batch inputs to predict_minibatch().
        # Alternately, you can just override predict() and handle batching yourself.
        return 32
    
    def predict_minibatch(self, inputs):
        # Preprocess to ids and masks, and make the input batch.
        encoded_input = self.tokenizer.batch_encode_plus(
            [ex["text"] for ex in inputs],
            return_tensors="pt",
            add_special_tokens=True,
            max_length=128,
            padding="longest",
            truncation="longest_first"
        )
        

        # Check and send to cuda (GPU) if available
        #if torch.cuda.is_available():
        #    self.model.cuda()
        #    for tensor in encoded_input:
        #        encoded_input[tensor] = encoded_input[tensor].cuda()
        
        # Run a forward pass.
        with torch.no_grad():  # remove this if you need gradients.
            out: transformers.modeling_outputs.SequenceClassifierOutput = self.model(**encoded_input)
        
        # Post-process outputs.
        batched_outputs = {
            "probas": torch.nn.functional.softmax(out.logits, dim=-1),
            "input_ids": encoded_input["input_ids"],
            "ntok": torch.sum(encoded_input["attention_mask"], dim=1),
            "cls_emb": out.hidden_states[-1][:, 0],  # last layer, first token
        }
        # Return as NumPy for further processing.
        detached_outputs = {k: v.cpu().numpy() for k, v in batched_outputs.items()}
        # Unbatch outputs so we get one record per input example.
        for output in utils.unbatch_preds(detached_outputs):
            ntok = output.pop("ntok")
            output["tokens"] = self.tokenizer.convert_ids_to_tokens(output.pop("input_ids")[1:ntok - 1])
            yield output
            
    def input_spec(self) -> lit_types.Spec:
        return {
            "text": lit_types.TextSegment(),
            "label": lit_types.CategoryLabel(vocab=self.LABELS, required=False)
        }

    def output_spec(self) -> lit_types.Spec:
        return {
            "tokens": lit_types.Tokens(),
            "probas": lit_types.MulticlassPreds(parent="label", vocab=self.LABELS),
            "cls_emb": lit_types.Embeddings()
        }

In [40]:
models = {"SR-2-PR": BERTClassification(".")}

INFO:absl:Received 50 predictions from model
INFO:absl:Requested types: ['MulticlassPreds']
INFO:absl:Will return keys: {'probas'}
127.0.0.1 - - [18/May/2021 08:51:35] "POST /get_preds?model=SR-2-PR&dataset_name=SR-2-PR&requested_types=MulticlassPreds HTTP/1.1" 200 2714
INFO:absl:50 of 50 inputs sent as IDs; reconstituting from dataset 'SR-2-PR'
INFO:absl:CachingModelWrapper 'SR-2-PR': misses (dataset=SR-2-PR): []
INFO:absl:CachingModelWrapper 'SR-2-PR': 0 misses out of 50 inputs
INFO:absl:Prepared 0 inputs for model
INFO:absl:Received 0 predictions from model
INFO:absl:Requested types: ['RegressionScore']
INFO:absl:Will return keys: set()
127.0.0.1 - - [18/May/2021 08:51:35] "POST /get_preds?model=SR-2-PR&dataset_name=SR-2-PR&requested_types=RegressionScore HTTP/1.1" 200 200
INFO:absl:50 of 50 inputs sent as IDs; reconstituting from dataset 'SR-2-PR'
INFO:absl:CachingModelWrapper 'SR-2-PR': misses (dataset=SR-2-PR): []
INFO:absl:CachingModelWrapper 'SR-2-PR': 0 misses out of 50 inputs

In [43]:
widget = notebook.LitWidget(models, datasets, height=1024)

INFO:absl:
 (    (           
 )\ ) )\ )  *   ) 
(()/((()/(` )  /( 
 /(_))/(_))( )(_))
(_)) (_)) (_(_()) 
| |  |_ _||_   _| 
| |__ | |   | |   
|____|___|  |_|   


INFO:absl:Starting LIT server...
INFO:absl:CachingModelWrapper 'SR-2-PR': no cache path specified, not loading.


In [44]:
# Render the widget
widget.render()

127.0.0.1 - - [18/May/2021 08:56:10] "GET / HTTP/1.1" 200 1408
127.0.0.1 - - [18/May/2021 08:56:10] "GET /main.js HTTP/1.1" 200 1661454
127.0.0.1 - - [18/May/2021 08:56:10] "POST /get_info? HTTP/1.1" 200 8539
127.0.0.1 - - [18/May/2021 08:56:10] "GET /static/favicon.png HTTP/1.1" 200 13257
INFO:absl:0 of 0 inputs sent as IDs; reconstituting from dataset 'SR-2-PR'
127.0.0.1 - - [18/May/2021 08:56:10] "POST /get_dataset?dataset_name=SR-2-PR HTTP/1.1" 200 10095
INFO:absl:50 of 50 inputs sent as IDs; reconstituting from dataset 'SR-2-PR'
INFO:absl:CachingModelWrapper 'SR-2-PR': misses (dataset=SR-2-PR): ['337523fbe33d5eecc81eca992f63b8af', '25d4b4b3713dd7e18e6a3391f2da2728', 'c478e6429b7fe0e55259bd3c19cfbcd6', '26c1987e952aa32e253373f9e37d4db5', 'b771e5c74788d99dfb892dfd4d6fbd9c', '20c3c0a18d48fe6d93b5c8d9f97d9eca', 'e8c29bb59ed0b43334b9096f37c8f6ab', '23ee7563d8af51aa32d935d8b416fbbc', '207332383a325a838c551098d87a5b0b', 'fe0b52c24dcc76cfd9e51c462bb94850', 'a6c52c755880a6d52b36e8d831e9fbb