In [1]:
from transformers import QuestionAnsweringPipeline, AutoAdapterModel, AutoModelWithHeads, AutoTokenizer, AutoConfig
from transformers.onnx import OnnxConfig, validate_model_outputs, export
from transformers.models.bert import BertOnnxConfig
from transformers.models.bart import BartOnnxConfig

import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime import InferenceSession
import onnxruntime

from onnx_opcounter import calculate_params

import os
import time
import torch
import numpy as np

from datasets import load_metric, load_dataset

from typing import Mapping, OrderedDict
from pathlib import Path
import random
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
model = AutoModelWithHeads.from_pretrained("facebook/bart-base")
adapter_name = model.load_adapter("AdapterHub/narrativeqa", source="hf", set_active=True)
model.set_active_adapters(adapter_name)

Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 3779.79it/s]


In [4]:
question, text = "What does Sara hate?", "Sara hates taxes. She loves vanilla ice cream."
prompt = text + "</s>" + question + "</s>"

encoding = tokenizer(prompt, return_tensors='pt', padding=False)
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']

answer = model.generate(input_ids, attention_mask=attention_mask, num_beams=4, max_length=128, early_stopping=True)
answer = tokenizer.decode(answer[0], skip_special_tokens=True)
answer


' Taxes \n'

In [14]:
logits = model(input_ids, attention_mask=attention_mask).logits[0]
print(logits.shape)
print(logits.argmax())

tokenizer.decode(logits[-1:, :].argmax(dim=1))[0]

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

In [12]:
cur_len = 0
max_length = 10
eos_token_id = (
    tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id
)
features = encoding
input_ids = logits
generated_ids = []
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = ()

In [15]:
# We need manual generation for the ONNX model (greedy generation)
while cur_len < max_length:
    input_data = features.copy()
    input_data["input_ids"] = torch.cat(
        (
            features["input_ids"],
            torch.tensor(generated_ids, dtype=int).unsqueeze(dim=0),
        ),
        dim=1,
    )
    input_data["attention_mask"] = torch.ones(input_data["input_ids"].shape, dtype=torch.int64)

    predictions = model(input_data["input_ids"], attention_mask=input_data["attention_mask"])

    next_token_logits = predictions["logits"][:, -1, :]
    scores += (next_token_logits,)

    # argmax
    next_tokens = torch.argmax(next_token_logits, dim=-1)
    # update generated ids, model inputs, and length for next step
    generated_ids.append(next_tokens[:, None].item())
    cur_len = cur_len + 1

    if eos_token_id is not None:
        unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
        # stop when each sentence is finished, or if we exceed the maximum length
    if unfinished_sequences.max() == 0:
        break


In [16]:
generated_ids

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

In [4]:
config = AutoConfig.from_pretrained("facebook/bart-base")
onnx_config = BartOnnxConfig(config, task="seq2seq-lm")

onnx_path = Path("onnx/narrativeqabart/model.onnx")

onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, onnx_config.default_onnx_opset, onnx_path)

onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  if context.output_adapter_gating_scores:
  if tensor is not None and hidden_states.shape[0] != tensor.shape[0]:
  if input_shape[-1] > 1:
  mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
  if getattr(ctx, "output_" + attr, False):
  if input_ids is not None and x.shape[1] == input_ids.shape[1]:
  if len(torch.unique(eos_mask.sum(1))) > 1:


In [168]:
onnx_model = InferenceSession(
    str(onnx_path), providers=["CPUExecutionProvider"]
)

encoding = tokenizer(question, text, return_tensors='np')
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']

answer = model.generate(input_ids, attention_mask=attention_mask, num_beams=4, max_length=100, early_stopping=True)
answer = tokenizer.decode(answer[0], skip_special_tokens=True)

outputs = onnx_model.run(input_feed=dict(inputs), output_names=None)

answer_idx = np.argmax(outputs[0])
return choices[answer_idx]

Fish
