In [1]:
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
import warnings
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import (
    BaseModelOutput,
    Seq2SeqLMOutput,
)
from transformers.utils import (
    logging,
    replace_return_docstrings,
    ModelOutput
)
from transformers.models.t5.modeling_t5 import __HEAD_MASK_WARNING_MSG
import inspect
logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "T5Config"
from transformers import T5ForConditionalGeneration

class FiD(T5ForConditionalGeneration):
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, T5ForConditionalGeneration

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
        >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")

        >>> # training
        >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
        >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
        >>> outputs = model(input_ids=input_ids, labels=labels)
        >>> loss = outputs.loss
        >>> logits = outputs.logits

        >>> # inference
        >>> input_ids = tokenizer(
        ...     "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
        ... ).input_ids  # Batch size 1
        >>> outputs = model.generate(input_ids)
        >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
        >>> # studies have shown that owning a dog is good for you.
        ```"""
        # print("attention_mask 1: ", attention_mask.size())
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # NOTE: FiD
        # Reshape from [batch, n_passages, length] to [batch * n_passages, length]
        if input_ids is not None:
            if input_ids.dim() == 3 and attention_mask.dim() == 3:
                #logger.info("Start Reshape from [batch, n_passage, length] to [batch * n_passage, length]")
                self.n_passages = input_ids.size(1)
                self.batch = input_ids.size(0)
                self.seq_length = input_ids.size(2)
                input_ids = input_ids.view(self.batch*self.n_passages, self.seq_length)
                attention_mask = attention_mask.view(self.batch*self.n_passages, self.seq_length)
            else:
                raise ValueError(f"NOT FiD TRAINING, got input_ids {input_ids.size()} and attention_mask {attention_mask.size()}")

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            # print("attention_mask 2: ", attention_mask.size())
            # print(f"batchsize: {self.batch}, n_passages: {self.n_passages}, seq_length: {self.seq_length}")
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        # [batch*n_passages, seq_length, model_dim]
        hidden_states = encoder_outputs[0]

        # NOTE: FiD
        # from [batch*n_passages, seq_length, model_dim] to [batch, n_passages*seq_length, model_dim]
        hidden_states = hidden_states.view(self.batch, self.n_passages*self.seq_length, -1)

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

        # NOTE: change(FiD): reshape attention mask
        # print("attention_mask 3: ", attention_mask.size())
        # print(f"batchsize: {self.batch}, n_passages: {self.n_passages}, seq_length: {self.seq_length}")
        attention_mask = attention_mask.view(-1, self.n_passages*self.seq_length)
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            # move labels to correct device to enable PP
            labels = labels.to(lm_logits.device)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

    def _prepare_encoder_decoder_kwargs_for_generation(
        self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
    ) -> Dict[str, Any]:
        # print("attention_mask 4: ", model_kwargs["attention_mask"].size())
        if inputs_tensor.dim() == 3:
            #logger.info("Start Reshape from [batch, n_passage, length] to [batch * n_passage, length]")
            self.n_passages = inputs_tensor.size(1)
            self.batch = inputs_tensor.size(0)
            self.seq_length = inputs_tensor.size(2)
            inputs_tensor = inputs_tensor.view(self.batch*self.n_passages, self.seq_length)
            model_kwargs["attention_mask"] = model_kwargs["attention_mask"].view(self.batch*self.n_passages, self.seq_length)
            # print("attention_mask 5: ", model_kwargs["attention_mask"].size())
            # print(f"batchsize: {self.batch}, n_passages: {self.n_passages}, seq_length: {self.seq_length}")

        # 1. get encoder
        encoder = self.get_encoder()
        # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
        # as the inputs.
        if hasattr(self, "hf_device_map"):
            if hasattr(encoder, "_hf_hook"):
                encoder._hf_hook.io_same_device = True
            else:
                add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))

        # 2. Prepare encoder args and encoder kwargs from model kwargs.
        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not any(argument.startswith(p) for p in irrelevant_prefix)
        }
        encoder_signature = set(inspect.signature(encoder.forward).parameters)
        encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
        if not encoder_accepts_wildcard:
            encoder_kwargs = {
                argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
            }

        # 3. make sure that encoder returns `ModelOutput`
        model_input_name = model_input_name if model_input_name is not None else self.main_input_name

        encoder_kwargs["return_dict"] = True
        encoder_kwargs[model_input_name] = inputs_tensor
        encoder_outputs = encoder(**encoder_kwargs)
        # print("encoder_outputs: ", encoder_outputs)
        encoder_outputs["last_hidden_state"] = encoder_outputs["last_hidden_state"].view(self.batch, self.n_passages*self.seq_length, -1)
        model_kwargs["encoder_outputs"]: ModelOutput = encoder_outputs
        # model_kwargs["attention_mask"] = model_kwargs["attention_mask"].view(self.batch, self.n_passages*self.seq_length)

        return model_kwargs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from haystack import BaseComponent
from transformers import AutoTokenizer
from typing import List, Optional

class FiDReader(BaseComponent):
    outgoing_edges = 1

    def __init__(self, 
                 model_name_or_path: str = "gradients-ai/fid_large_en_v1.0",
                 device: str = "cpu"):
        super().__init__()
        self.device = device
        print("Initializing model...")
        self.model = FiD.from_pretrained(model_name_or_path)
        self.model.to(self.device)
        print("Done!")
        print("Initializing tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        print("Done!")

    def append_question(
            self,
            question: str,
            documents: List[str],
            question_prefix: str = "Question: ",
            document_prefix: str = "Document: "
    ) -> List[str]:
        """Pair question to each document

        Args:
            question:
                a string - question
            documents:
                a list of string
        Returns:
            A question is paired with each document in `documents`
            become a list of string
        """

        if documents is None:
            return [question_prefix + question]
        return [question_prefix + question + " " + document_prefix + d for d in documents] 

    def run(self, query, documents):
        # pprint(contexts)
        # print("Contexts len:", len(contexts))
        # print(top2_docs)
        inputs = self.append_question(
            query,
            list(doc.content for doc in documents)
        )
        tokenized_input = self.tokenizer(inputs, return_tensors="pt", padding=True)
        input_tensor = tokenized_input.input_ids[None, :, :].to(self.device)
        attention_mask = tokenized_input.attention_mask[None, :, :].to(self.device)
        print("Generating answers...")
        model_outputs = self.model.generate(
            input_ids=input_tensor,
            attention_mask=attention_mask,
            max_length=256,
            min_length=64,
            do_sample=True,
            num_beams=1,
            top_k=50,
            top_p=0.9,
            temperature=0.7,
            num_return_sequences=1,
            no_repeat_ngram_size=3,
            repetition_penalty=1.1
        )
        output = {"answer": []}
        print("Model output len:", len(model_outputs))
        for out in model_outputs:
            output["answer"].append(
                self.tokenizer.decode(out, skip_special_tokens=True)
            )
        return output, "output_1"
    
    def run_batch(self, queries: List[str], my_optional_param: Optional[int]):
        # process the inputs
        output = {"my_output": ...}
        return output, "output_1"


In [3]:
import pandas as pd

In [4]:
df = pd.read_csv("faq.csv", index_col=0)
df.columns = ['title', 'content']
df.head()

Unnamed: 0,title,content
0,What is a novel coronavirus?,A novel coronavirus is a new coronavirus that ...
1,Why is the disease being called coronavirus di...,"On February 11, 2020 the World Health Organiza..."
2,Why might someone blame or avoid individuals a...,People in the U.S. may be worried or anxious a...
3,How can people help stop stigma related to COV...,"People can fight stigma and help, not hurt, ot..."
4,What is the source of the virus?,Coronaviruses are a large family of viruses. S...


In [5]:
df.describe()

Unnamed: 0,title,content
count,213,213
unique,210,210
top,Why might someone blame or avoid individuals a...,People in the U.S. may be worried or anxious a...
freq,2,2


In [6]:
from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.nodes import EmbeddingRetriever
from haystack.nodes import PreProcessor
from haystack.schema import Document
from haystack.pipelines import Pipeline
import os

In [7]:
if os.path.isfile("faiss_document_store.db"):
    os.remove("faiss_document_store.db")

In [8]:
retriever = EmbeddingRetriever("BAAI/bge-large-en-v1.5", use_gpu=True)
document_store = FAISSDocumentStore(
    faiss_index_factory_str="Flat",
    embedding_dim=1024,
    return_embedding=True
)
reader = FiDReader()

  return self.fget.__get__(instance, owner)()


Initializing model...
Done!
Initializing tokenizer...
Done!


In [9]:
retriever.embedding_encoder.embed("Hello")

Batches: 100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


array([ 0.0510522 ,  0.00719567,  0.00387395, ..., -0.0314708 ,
       -0.0355579 , -0.01076885], dtype=float32)

In [10]:
documents = []

for idx, row in df.iterrows():
    title = row['title']
    content = row['content']
    embedding = retriever.embedding_encoder.embed(title)
    doc = Document(
        content=content,
        embedding=embedding
    )
    documents.append(doc)

Batches: 100%|██████████| 1/1 [00:00<00:00, 43.75it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 38.81it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 43.04it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 28.28it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 27.59it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 41.05it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 41.75it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 48.53it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 42.03it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 47.37it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 52.24it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 48.06it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 54.92it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 55.07it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 50.75it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 48.10it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 52.34it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 51.79it/s]
Batches: 1

In [11]:
document_store.write_documents(documents=documents)
retriever.document_store = document_store

Writing Documents: 10000it [00:00, 32557.62it/s]          


In [12]:
pipeline = Pipeline()
pipeline.add_node(component=retriever, name="retriever", inputs=["Query"])
pipeline.add_node(component=reader, name="reader", inputs=["retriever"])

In [13]:
query = "What is Covid?"

outputs = retriever.retrieve(query, top_k=10)

Batches: 100%|██████████| 1/1 [00:00<00:00, 49.61it/s]


In [14]:
for output in outputs:
    print(output.content)
    print(output.score)

COVID-19 is the infectious disease caused by the most recently discovered coronavirus. This new virus and disease were unknown before the outbreak began in Wuhan, China, in December 2019.
0.5022707453715565
Severe Acute Respiratory Syndrome Coronavirus-2 (SARS-CoV-2) is the name given to the 2019 novel coronavirus. COVID-19 is the name given to the disease associated with the virus. SARS-CoV-2 is a new strain of coronavirus that has not been previously identified in humans.
0.5021145941903251
Coronaviruses are a large family of viruses which may cause illness in animals or humans.  In humans, several coronaviruses are known to cause respiratory infections ranging from the common cold to more severe diseases such as Middle East Respiratory Syndrome (MERS) and Severe Acute Respiratory Syndrome (SARS). The most recently discovered coronavirus causes coronavirus disease COVID-19.
0.5020081152828929
Coronaviruses are a large group of viruses that are common among animals and humans. This no

In [18]:
query = "What is the symptoms of Covid?"

answer = pipeline.run(query=query)

answer

Batches: 100%|██████████| 1/1 [00:00<00:00, 23.08it/s]


Generating answers...
Model output len: 1


{'answer': ["You have a cold, it's like the flu. Your body produces antibodies to fight off the virus, but they don't last for long. In fact, the virus can survive for days on end without your body producing any antibodies at all. It's kind of like how you get pneumonia from the flu: the virus will survive for weeks on end before your body starts producing antibodies."],
 'documents': [<Document: {'content': 'Typically, human coronaviruses cause mild-to-moderate respiratory illness. Symptoms are very similar to the flu, including:\r\n\r\nFever\r\nCough\r\nShortness of breath\r\nCOVID-19 can cause more severe respiratory illness.', 'content_type': 'text', 'score': 0.5022833933638616, 'meta': {'vector_id': '141'}, 'id_hash_keys': ['content'], 'embedding': '<embedding of shape (1024,)>', 'id': 'c61b47cdc2a36a0e4c418499a2f33b8e'}>,
  <Document: {'content': "The most common symptoms of COVID-19 are fever, tiredness, and dry cough. Some patients may have aches and pains, nasal congestion, ru

In [19]:
print(answer['answer'][0])

You have a cold, it's like the flu. Your body produces antibodies to fight off the virus, but they don't last for long. In fact, the virus can survive for days on end without your body producing any antibodies at all. It's kind of like how you get pneumonia from the flu: the virus will survive for weeks on end before your body starts producing antibodies.
