<a href="https://colab.research.google.com/github/edmarRod/IA025A_2022S1/blob/main/IIRC_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Library Installs

In [None]:
! pip install pytorch-lightning
! pip install transformers
! pip install sentencepiece
! pip install neptune-client
! pip install nvidia-ml-py3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.6.5-py3-none-any.whl (585 kB)
[K     |████████████████████████████████| 585 kB 5.0 MB/s 
[?25hCollecting PyYAML>=5.4
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 43.2 MB/s 
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 59.0 MB/s 
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.9.2-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 55.7 MB/s 
[?25hCollecting pyDeprecate>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.

# Imports

In [None]:
import subprocess
import random
import collections
import re
import string
import json
import numpy as np
from typing import Dict, Tuple, Union, List, Optional

import torch
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl


# Metric Functions

In [None]:
def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
  if not s: return []
  return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
  # return int(normalize_answer(a_gold) == normalize_answer(a_pred))
  gold_set = set(normalize_answer(a_gold).split())
  pred_set = set(normalize_answer(a_pred).split())
  return int((gold_set == pred_set) & (len(gold_set) == len(pred_set)))

def compute_f1(a_gold, a_pred):
  gold_toks = get_tokens(a_gold)
  pred_toks = get_tokens(a_pred)
  common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
  num_same = sum(common.values())
  if len(gold_toks) == 0 or len(pred_toks) == 0:
    # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
    return int(gold_toks == pred_toks)
  if num_same == 0:
    return 0
  precision = 1.0 * num_same / len(pred_toks)
  recall = 1.0 * num_same / len(gold_toks)
  f1 = (2 * precision * recall) / (precision + recall)
  return f1
  
def compute_f1_score(predicted_list: Union[list,str], target_list: Union[list,str]) -> float:
    """
    Given the lists of target and predicted sequences, it returns the F1-Score
    :param predicted_list: list of predicted sequence
    :param target_list: list of target sequence
    :return: f1_score
    """
    if isinstance(predicted_list, tuple):
      predicted_list = list(predicted_list)
    if isinstance(target_list, tuple):
      target_list = list(target_list)


    if not isinstance(predicted_list, list):
      predicted_list = [predicted_list]
    if not isinstance(target_list, list):
      target_list = [target_list]

    scores = []
    for predicted, target in zip(predicted_list, target_list):
      scores.append(compute_f1(target, predicted))

    return np.array(scores).mean()

def compute_em_score(predicted_list: Union[list,str], target_list: Union[list,str]) -> float:
    """
    Given the lists of target and predicted sequences, it returns the F1-Score
    :param predicted_list: list of predicted sequence
    :param target_list: list of target sequence
    :return: f1_score
    """
    if isinstance(predicted_list, tuple):
      predicted_list = list(predicted_list)
    if isinstance(target_list, tuple):
      target_list = list(target_list)

    if not isinstance(predicted_list, list):
      predicted_list = [predicted_list]
    if not isinstance(target_list, list):
      target_list = [target_list]
    scores = []
    for predicted, target in zip(predicted_list, target_list):
      scores.append(compute_exact(target, predicted))

    return np.array(scores).mean()

# Util Functions

In [None]:
def retrieve_article(title: str, context_articles: Dict) -> str:
  """
    Retrieves article from context articles using the title.
  Args:
    title: Title of the article
    context_articles: Dictionary with keys
  Returns:
    article: The article
  """
  article = None
  if title in context_articles:
      article = context_articles[title]
  elif title.lower() in context_articles:
      article = context_articles[title.lower()]
  if article is not None:
      # Remove html tags (mainly links)
      article = re.sub("<[^>]*>", "", article)
  return article

def split_article(article: str, context: List, tokenizer, window_size: int, window_stride: int, device: str = "cuda"):
  """
  Given an article, manually splits into sub-articles and labels whether the split contains a full gold span.

  Args:
      article: Main article.
      context: Gold spans from article.
      tokenizer: Transformers tokenizer.
      window_size: Sliding window size for context selection.
      window_stride: Number of tokens to progress the sliding window for each step.
      device: Torch device.

  Returns:
      splits: tokenized article splits.
      labels: label whether the split contains a complete gold span.
      lengths: lengths of the splits
  """
  gold_spans = sorted([span["indices"] for span in context])[::-1]
  article_tokens, token_spans = get_token_indices(article, gold_spans, tokenizer)

  num_splits = len(article_tokens) // window_stride + (len(article_tokens) % window_stride != 0)
  splits = torch.ones(num_splits, window_size,
                      dtype=torch.long, device=device) * tokenizer.pad_token_id
  labels = torch.zeros(num_splits, dtype=torch.float, device=device)
  split_index = 0
  lengths = []
  for window_start in range(0, len(article_tokens), window_stride):
      window_end = window_start + window_size
      split = article_tokens[window_start:window_end]
      splits[split_index, :len(split)] = torch.tensor(split, dtype=torch.long, device=device)
      lengths.append(len(split))
      for start, end in token_spans:
          if window_start <= start and window_end >= end:
              labels[split_index] = 1
      split_index += 1
  return splits, labels, lengths

def get_token_indices(passage: str, spans: List, tokenizer, overall_offset: int = 0, max_length: int = -1) -> Tuple[
  List[int], List[Tuple[int, int]]]:
  """
      Tokenizes the passage with links and returns the tokenized passage with the token indices of the links, given the offset and max length.
      To retrieve the tokens of the link, use context_token[start-overall_offset:end-overall_offset+1].
  Args:
      passage: Main context text.
      spans: List of the character locations for the links.
      tokenizer: Transformers tokenizer.
      overall_offset: General offset (usually due to the concatenation of the question).
      max_length: Max length of sequence.

  Returns:
      Tuple: context_indices: List of tokens.
          token_spans: List of tuple with start and end of tokens for each link.
  """
  if not spans:
      # This would happen anyways, but saves iterating through the tokens
      return tokenizer.encode(passage, add_special_tokens=False), spans
  # adds the separator token to the passage to indicate links e.g. "The group was occasionally diverted from strategic missions to carry out [SEP]air support[SEP] and [SEP]interdiction[SEP] missions."
  for start, end in spans:
      passage = passage[:start] + tokenizer.sep_token + passage[start:end] + tokenizer.sep_token + passage[end:]

  context_tokens = tokenizer.encode(passage, add_special_tokens=False)

  # Remove added sep tokens to get link indices
  token_spans = []
  offset = 0
  span_start = -1
  length = min(len(context_tokens), max_length) if max_length > 0 else len(context_tokens)
  for i in range(length):
      token = context_tokens[i]
      if token == tokenizer.sep_token_id:
          if span_start > -1:
              offset += 1
              token_spans.append((span_start, i - offset + overall_offset))
              span_start = -1
          else:
              span_start = i - offset + overall_offset
              offset += 1
  # removes separator tokens from context_tokens
  for start, end in token_spans:
      start -= overall_offset
      end -= overall_offset
      assert context_tokens.pop(start) == tokenizer.sep_token_id
      assert context_tokens.pop(end + 1) == tokenizer.sep_token_id

  return context_tokens, token_spans

In [None]:
def get_model(model_name: str):
  if 't5' in model_name:
    from transformers import T5ForConditionalGeneration,T5Tokenizer
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer = T5Tokenizer.from_pretrained(model_name)
  elif 'bert-' in model_name:
    from transformers import BertTokenizer, TFBertForPreTraining
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = TFBertForPreTraining.from_pretrained(model_name)
  else:
    raise NotImplementedError()

  return model, tokenizer

In [None]:
def get_trained_model(model_name: str):
  import neptune.new as neptune
  import torch
  if model_name == 'gold':
    from transformers import T5ForConditionalGeneration, T5Tokenizer, T5Config, AutoConfig
    model_name = "t5-base"
    champion_run = neptune.init(
      api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyNDBmMThjNi1kODM2LTQzYTItYTgzMi01YTczMjI3NjhjYTUifQ==",
      project="e166690/Mestrado",
      run = "MES-21", 
      mode = 'read-only'
    )
    champion_run['training/model/checkpoints/epoch=03-val/loss=0.00'].download(f'model.ckpt')
    config = AutoConfig.from_pretrained(model_name)
    model = T5ForConditionalGeneration(config)
    tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=source_max_length)
    tokenizer.sep_token = '[SEP]'
    checkpoint = torch.load('model.ckpt', map_location=torch.device('cpu'))
    aux = {k[6:]:v for k,v in checkpoint['state_dict'].items()}
    model.load_state_dict(aux)
    model.eval()
  else:
    NotImplementedError()
  
  return model, tokenizer

In [None]:
import nvidia_smi

print(f"Pytorch Lightning Version: {pl.__version__}")
nvidia_smi.nvmlInit()
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
print(f"Device name: {nvidia_smi.nvmlDeviceGetName(handle)}")

def gpu_usage():
    global handle
    return str(nvidia_smi.nvmlDeviceGetUtilizationRates(handle).gpu) + '%'

Pytorch Lightning Version: 1.6.5
Device name: b'Tesla P100-PCIE-16GB'


# Datasets

In [None]:
class GoldDataset(Dataset):
  def __init__(self, dataset, context_articles, tokenizer, max_context_size: int = 512, no_window = True, window_size_cnst=100, window_stride=25, unanswerable_noise=True):
    
    self.tokenizer = tokenizer
    self.tokenizer.sep_token = '[SEP]'

    self.context_articles = context_articles
    
    self.max_context_size = max_context_size
    self.no_window = no_window
    self.window_size_cnst = window_size_cnst
    self.window_stride = window_stride
    self.unanswerable_noise =  unanswerable_noise

    self.dataset = list(self.get_dataset(dataset).values())

  def __len__(self):
    return len(self.dataset)
  
  def __getitem__(self, idx):
    # try except because the dataset is a json and pytorch datasets expect IndexError to be raised when idx exceeds dataset
    answer = self.dataset[idx]['answer']
    passage = self.dataset[idx]['passage']
    question = self.dataset[idx]['question']

    # handle cases where there are multiple spans
    if isinstance(answer, list):
      answer = ' and '.join(answer)

    input_str = question + self.tokenizer.sep_token + passage

    #due to different tokenization of numbers, small hack to separate the numbers into digits, the extra space doesnt affect tokenization
    # final result is '123 abc' -> ' 1 2 3  abc'
    input_str = re.sub('(\d)', r' \1 ', input_str)
    answer = re.sub('(\d)', r' \1 ', answer)

    input_tokenizer_output = self.tokenizer(input_str, truncation=True, padding='max_length', max_length=self.max_context_size, return_tensors='pt')
    output_tokenizer_output = self.tokenizer(answer, truncation=True, padding='max_length', max_length=self.max_context_size, return_tensors='pt')

    input_tokens_ids = input_tokenizer_output['input_ids'].squeeze(0)
    output_tokens_ids = output_tokenizer_output['input_ids'].squeeze(0)

    input_mask = input_tokenizer_output['attention_mask'].squeeze(0)
    output_mask = output_tokenizer_output['attention_mask'].squeeze(0)

    model_input = input_str
    model_target = answer

    return (input_tokens_ids, input_mask,  output_tokens_ids, output_mask, model_input, model_target)


  def get_dataset(self, json_dataset):
    dataset = {}
    for passage in json_dataset:
      for question in passage['questions']:
        # grab all gold links
        links = set([span["passage"] for span in question["context"] if span["passage"] != "main"])

        if self.no_window:
            # Use the annotated context spans as is, putting main context first
            all_context = list(set([span["text"] for span in question["context"] if span["passage"] == "main"] +
                            [span["text"] for span in question["context"] if span["passage"] != "main"]))
        else:
            # Grab some surrounding context for each span
            all_context = []

            # Reduce the window size in the case of many links so all contexts can fit in the input
            window_size = min(self.max_context_size // (len(links) + 1), self.window_size_cnst)
            # Scale the stride according to the new size
            self.window_stride = int(self.window_stride / window_size * window_size)

            # Get main context
            gold_spans = [span for span in question["context"] if span["passage"] == "main"]
            for span in gold_spans:
                splits, labels, _ = split_article(passage["text"], [span], self.tokenizer,
                                                                    window_size, self.window_stride)
                gold_splits = [i for (i, l) in enumerate(labels) if l == 1]
                if gold_splits:
                    all_context.append(self.tokenizer.decode(splits[random.choice(gold_splits)]))

            # Get linked contexts
            for link in links:
                gold_spans = [span for span in question["context"] if span["passage"] == link]
                article = retrieve_article(link, self.context_articles)
                for span in gold_spans:
                    splits, labels, _ = split_article(article, [span], self.tokenizer,
                                                                        window_size, self.window_stride)
                    gold_splits = [i for (i, l) in enumerate(labels) if l == 1]
                    if gold_splits:
                        all_context.append(self.tokenizer.decode(splits[random.choice(gold_splits)]))
        if question["answer"]["type"] == "none" and self.unanswerable_noise:
            links = question["question_links"]
            for target in links:
                article = retrieve_article(target, self.context_articles)
                if article is None or article.strip() == "":
                    continue
                splits, _, _ = split_article(article, [], self.tokenizer,
                                                                self.window_size_cnst, self.window_stride)
                all_context.append(self.tokenizer.decode(splits[0]))
        # Format the data into drop style
        answer_info = question["answer"]
        a_type = answer_info["type"]
        if a_type == "span":
            answer_text = [a["text"] for a in answer_info["answer_spans"]]
        elif a_type == "value":
            answer_text = answer_info["answer_value"]
        elif a_type == "binary":
            question = question
            answer_text = [answer_info["answer_value"]]
        elif a_type == "none":
            answer_text = "no answer"
        elif a_type == "bad":
            continue

        dataset[question['qid']] = {'passage':self.tokenizer.sep_token.join(all_context),
        'question':question['question'],
        'answer':answer_text}

    return dataset

In [None]:
class OnlyMainContextDataset(Dataset):
  def __init__(self, dataset, context_articles, tokenizer, max_context_size: int = 512, unanswerable_noise=True):
    
    self.tokenizer = tokenizer
    self.tokenizer.sep_token = '[SEP]'

    self.context_articles = context_articles
    
    self.max_context_size = max_context_size
    self.unanswerable_noise =  unanswerable_noise

    self.dataset = list(self.get_dataset(dataset).values())

  def __len__(self):
    return len(self.dataset)
  
  def __getitem__(self, idx):
    # try except because the dataset is a json and pytorch datasets expect IndexError to be raised when idx exceeds dataset
    answer = self.dataset[idx]['answer']
    passage = self.dataset[idx]['passage']
    question = self.dataset[idx]['question']

    # handle cases where there are multiple spans
    if isinstance(answer, list):
      answer = ' and '.join(answer)

    input_str = question + self.tokenizer.sep_token + passage

    #due to different tokenization of numbers, small hack to separate the numbers into digits, the extra space doesnt affect tokenization
    # final result is '123 abc' -> ' 1 2 3  abc'
    input_str = re.sub('(\d)', r' \1 ', input_str)
    answer = re.sub('(\d)', r' \1 ', answer)

    input_tokenizer_output = self.tokenizer(input_str, truncation=True, padding='max_length', max_length=self.max_context_size, return_tensors='pt')
    output_tokenizer_output = self.tokenizer(answer, truncation=True, padding='max_length', max_length=self.max_context_size, return_tensors='pt')

    input_tokens_ids = input_tokenizer_output['input_ids'].squeeze(0)
    output_tokens_ids = output_tokenizer_output['input_ids'].squeeze(0)

    input_mask = input_tokenizer_output['attention_mask'].squeeze(0)
    output_mask = output_tokenizer_output['attention_mask'].squeeze(0)

    model_input = input_str
    model_target = answer

    return (input_tokens_ids, input_mask,  output_tokens_ids, output_mask, model_input, model_target)


  def get_dataset(self, json_dataset):
    dataset = {}
    for passage in json_dataset:
      for question in passage['questions']:
        # Use the annotated context spans as is, putting main context first
        main_context = [span["text"] for span in question["context"] if span["passage"] == "main"]

        # Format the data into drop style
        answer_info = question["answer"]
        a_type = answer_info["type"]
        if a_type == "span":
            answer_text = [a["text"] for a in answer_info["answer_spans"]]
        elif a_type == "value":
            answer_text = answer_info["answer_value"]
        elif a_type == "binary":
            question = question
            answer_text = [answer_info["answer_value"]]
        elif a_type == "none":
            answer_text = "no answer"
        elif a_type == "bad":
            continue

        dataset[question['qid']] = {'passage':self.tokenizer.sep_token.join(main_context),
        'question':question['question'],
        'answer':answer_text}

    return dataset

In [None]:
class PollutedDataset(Dataset):
  def __init__(self, dataset, pollution, context_articles, tokenizer, max_context_size: int = 512, no_window = True, window_size_cnst=100, window_stride=25, unanswerable_noise=False):
    
    self.tokenizer = tokenizer
    self.tokenizer.sep_token = '[SEP]'

    self.context_articles = context_articles
    
    self.max_context_size = max_context_size
    self.no_window = no_window
    self.window_size_cnst = window_size_cnst
    self.window_stride = window_stride
    self.unanswerable_noise =  unanswerable_noise

    self.dataset = list(self.get_dataset(dataset, pollution).values())
    self.pollution = pollution

  def __len__(self):
    return len(self.dataset)
  
  def __getitem__(self, idx):
    # try except because the dataset is a json and pytorch datasets expect IndexError to be raised when idx exceeds dataset
    answer = self.dataset[idx]['answer']
    passage = self.dataset[idx]['passage']
    question = self.dataset[idx]['question']

    # handle cases where there are multiple spans
    if isinstance(answer, list):
      answer = ' and '.join(answer)

    # input_str = question + self.tokenizer.sep_token + passage
    input_str = "Question: " + question + ' ' + passage

    #due to different tokenization of numbers, small hack to separate the numbers into digits, the extra space doesnt affect tokenization
    # final result is '123 abc' -> ' 1 2 3  abc'
    # input_str = re.sub('(\d)', r' \1 ', input_str)
    # answer = re.sub('(\d)', r' \1 ', answer)

    input_tokenizer_output = self.tokenizer(input_str, truncation=True, padding='max_length', max_length=self.max_context_size, return_tensors='pt')
    output_tokenizer_output = self.tokenizer(answer, truncation=True, padding='max_length', max_length=self.max_context_size, return_tensors='pt')

    input_tokens_ids = input_tokenizer_output['input_ids'].squeeze(0)
    output_tokens_ids = output_tokenizer_output['input_ids'].squeeze(0)

    input_mask = input_tokenizer_output['attention_mask'].squeeze(0)
    output_mask = output_tokenizer_output['attention_mask'].squeeze(0)

    model_input = input_str
    model_target = answer

    return (input_tokens_ids, input_mask,  output_tokens_ids, output_mask, model_input, model_target)


  def get_dataset(self, json_dataset, pollution):
    dataset = {}
    for passage in json_dataset:
      for question in passage['questions']:
        # grab all gold links
        links = set([span["passage"] for span in question["context"] if span["passage"] != "main"])

        if self.no_window:
            # Use the annotated context spans as is, putting main context first
            all_context = list(set([span["text"] for span in question["context"] if span["passage"] == "main"] +
                            [span["text"] for span in question["context"] if span["passage"] != "main"]))

        else:
            # Grab some surrounding context for each span
            all_context = []

            # Reduce the window size in the case of many links so all contexts can fit in the input
            window_size = min(self.max_context_size // (len(links) + 1), self.window_size_cnst)
            # Scale the stride according to the new size
            self.window_stride = int(self.window_stride / window_size * window_size)

            # Get main context
            gold_spans = [span for span in question["context"] if span["passage"] == "main"]
            for span in gold_spans:
                splits, labels, _ = split_article(passage["text"], [span], self.tokenizer,
                                                                    window_size, self.window_stride)
                gold_splits = [i for (i, l) in enumerate(labels) if l == 1]
                if gold_splits:
                    all_context.append(self.tokenizer.decode(splits[random.choice(gold_splits)]))

            # Get linked contexts
            for link in links:
                gold_spans = [span for span in question["context"] if span["passage"] == link]
                article = retrieve_article(link, self.context_articles)
                for span in gold_spans:
                    splits, labels, _ = split_article(article, [span], self.tokenizer,
                                                                        window_size, self.window_stride)
                    gold_splits = [i for (i, l) in enumerate(labels) if l == 1]
                    if gold_splits:
                        all_context.append(self.tokenizer.decode(splits[random.choice(gold_splits)]))
        if question["answer"]["type"] == "none" and self.unanswerable_noise:
            links = question["question_links"]
            for target in links:
                article = retrieve_article(target, self.context_articles)
                if article is None or article.strip() == "":
                    continue
                splits, _, _ = split_article(article, [], self.tokenizer,
                                                                self.window_size_cnst, self.window_stride)
                all_context.append(self.tokenizer.decode(splits[0]))
        # Format the data into drop style
        answer_info = question["answer"]
        a_type = answer_info["type"]
        if a_type == "span":
            answer_text = [a["text"] for a in answer_info["answer_spans"]]
        elif a_type == "value":
            answer_text = answer_info["answer_value"]
        elif a_type == "binary":
            question = question
            answer_text = [answer_info["answer_value"]]
        elif a_type == "none":
            answer_text = "no answer"
        elif a_type == "bad":
            continue

        q_pollution = pollution[question['question']]['top1']['contents']
        all_context.append(q_pollution)
        random.shuffle(all_context)
        dataset[question['qid']] = {'passage':' '.join([f'Context document {i}: '+context for i, context in enumerate(all_context)]),#self.tokenizer.sep_token.join(all_context),
        'question':question['question'],
        'answer':answer_text}

    return dataset

In [None]:
import re
class ExplanationDataset(Dataset):
  def __init__(self, dataset, explanation_dataset, context_articles, tokenizer, max_context_size: int = 512, no_window = True, window_size_cnst=100, window_stride=25, unanswerable_noise=False):
    
    self.tokenizer = tokenizer
    self.tokenizer.sep_token = '[SEP]'

    self.context_articles = context_articles
    
    self.max_context_size = max_context_size
    self.no_window = no_window
    self.window_size_cnst = window_size_cnst
    self.window_stride = window_stride
    self.unanswerable_noise =  unanswerable_noise

    self.dataset = list(self.get_dataset(dataset).values())
    if explanation_dataset is not None:
      self.training = True
      self.explanation_dataset = self.parse_explanation(explanation_dataset)
    else:
      self.training = False
      self.explanation_dataset = None

  def __len__(self):
    if self.training == True:
      return len(self.explanation_dataset)
    else:
      return len(self.dataset)
  
  def __getitem__(self, idx):
    if self.training == False:
      # try except because the dataset is a json and pytorch datasets expect IndexError to be raised when idx exceeds dataset
      answer = self.dataset[idx]['answer']
      passage = self.dataset[idx]['passage']
      question = self.dataset[idx]['question']

      # handle cases where there are multiple spans
      if isinstance(answer, list):
        answer = ' and '.join(answer)

      input_str = "Question: " + question + '\n\n' + passage
      output_str = answer
      #due to different tokenization of numbers, small hack to separate the numbers into digits, the extra space doesnt affect tokenization
      # final result is '123 abc' -> ' 1 2 3  abc'
      # input_str = re.sub('(\d)', r' \1 ', input_str)
      # answer = re.sub('(\d)', r' \1 ', answer)
    else:
      question_docs, answer, explanation = self.explanation_dataset[idx]
      input_str = question_docs
      output_str = f"Explanation: {explanation.strip()}\n\nAnswer: {answer.strip()}"

    input_tokenizer_output = self.tokenizer(input_str, truncation=True, padding='max_length', max_length=self.max_context_size, return_tensors='pt')
    output_tokenizer_output = self.tokenizer(output_str, truncation=True, padding='max_length', max_length=self.max_context_size, return_tensors='pt')

    input_tokens_ids = input_tokenizer_output['input_ids'].squeeze(0)
    output_tokens_ids = output_tokenizer_output['input_ids'].squeeze(0)
    # input_tokens_ids[input_tokens_ids == 0] = -100
    output_tokens_ids[output_tokens_ids == 0] = -100

    input_mask = input_tokenizer_output['attention_mask'].squeeze(0)
    output_mask = output_tokenizer_output['attention_mask'].squeeze(0)

    model_input = input_str
    model_target = output_str

    return (input_tokens_ids, input_mask,  output_tokens_ids, output_mask, model_input, model_target)

  def parse_explanation(self, explanation):
    prompts = explanation['prompts']
    explanations = explanation['explanations']
    explanations_list = []

    for i,prompt in enumerate(prompts):
      question_docs, answer = re.search('Example 5:\n\n(.+)Answer:(.+)\n\nExplanation:', prompt, re.S).groups()
      explanations_list.append((question_docs,answer, explanations[i]))

    return explanations_list

  def get_dataset(self, json_dataset):
    dataset = {}
    for passage in json_dataset:
      for question in passage['questions']:
        # grab all gold links
        links = set([span["passage"] for span in question["context"] if span["passage"] != "main"])

        if self.no_window:
            # Use the annotated context spans as is, putting main context first
            all_context = list(set([span["text"] for span in question["context"] if span["passage"] == "main"] +
                            [span["text"] for span in question["context"] if span["passage"] != "main"]))

        else:
            # Grab some surrounding context for each span
            all_context = []

            # Reduce the window size in the case of many links so all contexts can fit in the input
            window_size = min(self.max_context_size // (len(links) + 1), self.window_size_cnst)
            # Scale the stride according to the new size
            self.window_stride = int(self.window_stride / window_size * window_size)

            # Get main context
            gold_spans = [span for span in question["context"] if span["passage"] == "main"]
            for span in gold_spans:
                splits, labels, _ = split_article(passage["text"], [span], self.tokenizer,
                                                                    window_size, self.window_stride)
                gold_splits = [i for (i, l) in enumerate(labels) if l == 1]
                if gold_splits:
                    all_context.append(self.tokenizer.decode(splits[random.choice(gold_splits)]))

            # Get linked contexts
            for link in links:
                gold_spans = [span for span in question["context"] if span["passage"] == link]
                article = retrieve_article(link, self.context_articles)
                for span in gold_spans:
                    splits, labels, _ = split_article(article, [span], self.tokenizer,
                                                                        window_size, self.window_stride)
                    gold_splits = [i for (i, l) in enumerate(labels) if l == 1]
                    if gold_splits:
                        all_context.append(self.tokenizer.decode(splits[random.choice(gold_splits)]))
        if question["answer"]["type"] == "none" and self.unanswerable_noise:
            links = question["question_links"]
            for target in links:
                article = retrieve_article(target, self.context_articles)
                if article is None or article.strip() == "":
                    continue
                splits, _, _ = split_article(article, [], self.tokenizer,
                                                                self.window_size_cnst, self.window_stride)
                all_context.append(self.tokenizer.decode(splits[0]))
        # Format the data into drop style
        answer_info = question["answer"]
        a_type = answer_info["type"]
        if a_type == "span":
            answer_text = [a["text"] for a in answer_info["answer_spans"]]
        elif a_type == "value":
            answer_text = answer_info["answer_value"]
        elif a_type == "binary":
            question = question
            answer_text = [answer_info["answer_value"]]
        elif a_type == "none":
            answer_text = "no answer"
        elif a_type == "bad":
            continue

        random.shuffle(all_context)
        dataset[question['qid']] = {'passage':''.join([f'Document {i}: '+context + '\n\n' for i, context in enumerate(all_context)]),#self.tokenizer.sep_token.join(all_context),
        'question':question['question'],
        'answer':answer_text}

    return dataset

In [None]:
# subprocess.run("wget -nc https://iirc-dataset.s3.us-west-2.amazonaws.com/iirc_train_dev.tgz".split())
# #untar it
# subprocess.run("tar -zxvf iirc_train_dev.tgz".split())
# # subprocess.run("rm iirc_train_dev.tar.gz".split())

# # subprocess.run("cd iirc_train_dev".split())

# #get context articles
# subprocess.run("wget -nc https://iirc-dataset.s3.us-west-2.amazonaws.com/context_articles.tar.gz".split())
# subprocess.run("tar -xvzf context_articles.tar.gz".split())

# with open('iirc_train_dev/train.json', 'r') as f:
#   train_json = json.load(f)
# with open('context_articles.json', 'r') as f:
#   context_articles = json.load(f)

# from google.colab import drive
# drive.mount('/content/drive')
# with open('/content/drive/MyDrive/Mestrado/gpt_davinci_few_shot_explain_dict.json', 'r') as f:
#   train_explanation = json.load(f)

# _, tokenizer = get_model('t5-base')

# ds = ExplanationDataset(train_json[:150], train_explanation, context_articles, tokenizer, 512, True)
# ds = ExplanationDataset(train_json[:150], None, context_articles, tokenizer, 512, True)

In [None]:
# input_tokens_ids, input_mask,  output_tokens_ids, output_mask, model_input, model_target = ds[40]
# print(model_input)
# print(model_target)#.split('Document')

# Data Module

In [None]:
class IIRCDataModule(pl.LightningDataModule):
  def __init__(self, model_name, batch_size: int = 8, dataset_type: str = 'gold', max_context_size: int = 512, no_window: bool = True, window_size_cnst: int = 100, window_stride:int = 25, unanswerable_noise: bool = True):
    super().__init__()

    self.batch_size = batch_size
    self.dataset_type = dataset_type
    _, self.tokenizer = get_model(model_name)

    self.max_context_size = max_context_size
    self.no_window = no_window 
    self.window_size_cnst = window_size_cnst
    self.window_stride = window_stride
    self.unanswerable_noise = unanswerable_noise

    self.save_hyperparameters()

  def prepare_data(self):
    # download train and dev datasets
    subprocess.run("wget -nc https://iirc-dataset.s3.us-west-2.amazonaws.com/iirc_train_dev.tgz".split())
    #untar it
    subprocess.run("tar -zxvf iirc_train_dev.tgz".split())
    # subprocess.run("rm iirc_train_dev.tar.gz".split())

    # subprocess.run("cd iirc_train_dev".split())

    #get context articles
    subprocess.run("wget -nc https://iirc-dataset.s3.us-west-2.amazonaws.com/context_articles.tar.gz".split())
    subprocess.run("tar -xvzf context_articles.tar.gz".split())
    # subprocess.run("rm context_articles.tar.gz".split())

    # subprocess.run("cd ..".split())

  def setup(self, stage: Optional[str] = None):
    with open('context_articles.json', 'r') as f:
      context_articles = json.load(f)

    with open('iirc_train_dev/train.json', 'r') as f:
      train_json = json.load(f)

    with open('iirc_train_dev/dev.json', 'r') as f:
      dev_json = json.load(f)

    if self.dataset_type == "gold":
      self.train_dataset = GoldDataset(train_json[:100], context_articles, self.tokenizer, self.max_context_size, self.no_window , self.window_size_cnst, self.window_stride, self.unanswerable_noise)
      self.val_dataset = GoldDataset(dev_json, context_articles, self.tokenizer, self.max_context_size, self.no_window , self.window_size_cnst, self.window_stride, self.unanswerable_noise)
      self.test_dataset = GoldDataset(dev_json, context_articles, self.tokenizer, self.max_context_size, self.no_window , self.window_size_cnst, self.window_stride, self.unanswerable_noise)
    elif self.dataset_type == "main_context":
      self.train_dataset = OnlyMainContextDataset(train_json, context_articles, tokenizer=self.tokenizer, max_context_size=self.max_context_size, unanswerable_noise=self.unanswerable_noise)
      self.val_dataset = OnlyMainContextDataset(dev_json, context_articles, tokenizer=self.tokenizer, max_context_size=self.max_context_size, unanswerable_noise=self.unanswerable_noise)
      self.test_dataset = OnlyMainContextDataset(dev_json, context_articles, tokenizer=self.tokenizer, max_context_size=self.max_context_size, unanswerable_noise=self.unanswerable_noise)
    elif self.dataset_type == "polluted":

      from google.colab import drive
      drive.mount('/content/drive')
      with open('/content/drive/MyDrive/Mestrado/dev_bm25_top1.json', 'r') as f:
        dev_pollution = json.load(f)

      with open('/content/drive/MyDrive/Mestrado/train_bm25_top1.json', 'r') as f:
        train_pollution = json.load(f)

      self.train_dataset = PollutedDataset(dataset=train_json, pollution=train_pollution, context_articles=context_articles, tokenizer=self.tokenizer, max_context_size=self.max_context_size, no_window=self.no_window , window_size_cnst=self.window_size_cnst, window_stride=self.window_stride, unanswerable_noise=self.unanswerable_noise)
      self.val_dataset = PollutedDataset(dataset=dev_json, pollution=dev_pollution, context_articles=context_articles, tokenizer=self.tokenizer, max_context_size=self.max_context_size, no_window=self.no_window , window_size_cnst=self.window_size_cnst, window_stride=self.window_stride, unanswerable_noise=self.unanswerable_noise)
      self.test_dataset = PollutedDataset(dataset=dev_json, pollution=dev_pollution, context_articles=context_articles, tokenizer=self.tokenizer, max_context_size=self.max_context_size, no_window=self.no_window , window_size_cnst=self.window_size_cnst, window_stride=self.window_stride, unanswerable_noise=self.unanswerable_noise)
    elif self.dataset_type == "explanation":

      from google.colab import drive
      drive.mount('/content/drive')
      with open('/content/drive/MyDrive/Mestrado/gpt_davinci_few_shot_explain_dict.json', 'r') as f:
        explanation_dataset = json.load(f)

      self.train_dataset = ExplanationDataset(dataset=train_json[:50], explanation_dataset=explanation_dataset, context_articles=context_articles, tokenizer=self.tokenizer, max_context_size=self.max_context_size, no_window=self.no_window , window_size_cnst=self.window_size_cnst, window_stride=self.window_stride, unanswerable_noise=self.unanswerable_noise)
      self.val_dataset = ExplanationDataset(dataset=dev_json[:], explanation_dataset=None, context_articles=context_articles, tokenizer=self.tokenizer, max_context_size=self.max_context_size, no_window=self.no_window , window_size_cnst=self.window_size_cnst, window_stride=self.window_stride, unanswerable_noise=self.unanswerable_noise)
      self.test_dataset = ExplanationDataset(dataset=dev_json[:], explanation_dataset=None, context_articles=context_articles, tokenizer=self.tokenizer, max_context_size=self.max_context_size, no_window=self.no_window , window_size_cnst=self.window_size_cnst, window_stride=self.window_stride, unanswerable_noise=self.unanswerable_noise)
    else:
      raise NotImplementedError()
    

  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size=self.batch_size)
  def val_dataloader(self):
    return DataLoader(self.val_dataset, batch_size=self.batch_size)
  def test_dataloader(self):
    return DataLoader(self.test_dataset, batch_size=self.batch_size)

# Trainer Module

In [None]:
import re
class ModelFinetuner(pl.LightningModule):

    def __init__(self, model_name, learning_rate, source_max_length = 512, target_max_length = 512):
      super(ModelFinetuner, self).__init__()

      model, tokenizer = get_model(model_name)

      self.learning_rate = learning_rate
      self.source_max_length = source_max_length
      self.target_max_length = target_max_length
      self.model_name = model_name
      self.tokenizer = tokenizer
      self.model = model

      self.log_examples = True

      self.save_hyperparameters()

      
    def forward(self, source_token_ids, source_mask, target_token_ids=None,
                target_mask=None):

      if self.training:
          loss = self.model(input_ids=source_token_ids,
                            attention_mask=source_mask,
                            labels=target_token_ids).loss
          return loss
      else:
          generated_ids = self.model.generate(input_ids=source_token_ids,
                                              attention_mask=source_mask, 
                                              max_length=self.target_max_length,
                                              num_beams=3,
                                              early_stopping=True
                                              )
          return generated_ids

    def training_step(self, batch, batch_nb):
      source_token_ids, source_mask, target_token_ids, target_mask, _, _ = batch
        
      # fwd
      loss = self(source_token_ids, source_mask, target_token_ids, target_mask)

      # logs
      self.log('train_loss', loss.detach(), on_step=True, on_epoch=True, logger=True)

      tensorboard_logs = {'train_loss': loss.detach()}
      progress_bar = {'gpu_usage': gpu_usage()}
      return {'loss': loss, 'log': tensorboard_logs,
              'progress_bar': progress_bar}

    def validation_step(self, batch, batch_nb):
      em, f1 = self.get_scores(batch, batch_nb)
      loss = self.get_loss(batch, batch_nb)
      return {'val_em': em, 'val_loss': loss, 'val_f1':f1}

    def test_step(self, batch, batch_nb):
      em, f1 = self.get_scores(batch, batch_nb)
      loss = self.get_loss(batch, batch_nb)
      return {'test_em': em, 'test_loss': loss, 'test_f1':f1}

    def get_scores(self, batch, batch_nb):
      source_token_ids, source_mask, target_token_ids, target_mask, original_source, original_target = batch
      
      generated_ids = self(source_token_ids, source_mask, target_token_ids, target_mask)
      
      output_seq = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

      answer = output_seq
      # if self.training:
      #   answer = output_seq
      # else:
      #   try:
      #     #regex for explanations
      #     # print(output_seq)
      #     answers = []
      #     explanations = []
      #     for output in output_seq:
      #       explanation, answer = re.search('Explanation:(.+)Answer:(.+)', output).groups()
      #       answers.append(answer)
      #       explanations.append(explanation)
      #     answer = answers
      #   except:
      #     answer = output_seq
      #     explanation = ''

      em = compute_em_score(answer, original_target)
      f1 =  compute_f1_score(answer, original_target)

      # if batch_nb == 2:
      #   for src, trgt, pred in zip(original_source, original_target, output_seq):
      #     self.logger.experiment["val_samples"].log(f"Epoch {self.current_epoch}: Input: {src}, Expected Output: {trgt}, Model Output: {pred}, Explanation: {explanation}, Answer: {answer}, EM: {compute_em_score(pred, trgt)}, F1: {compute_f1_score(pred, trgt)}")
      if batch_nb == 2:
        for src, trgt, pred in zip(original_source, original_target, output_seq):
          self.logger.experiment["val_samples"].log(f"Epoch {self.current_epoch}: Input: {src}, Expected Output: {trgt}, Model Output: {pred}, EM: {compute_em_score(pred, trgt)}, F1: {compute_f1_score(pred, trgt)}")


      return em, f1

    def get_loss(self, batch, batch_nb):
      source_token_ids, source_mask, target_token_ids, target_mask, original_source, original_target = batch

      loss = self.model(input_ids=source_token_ids,
                  attention_mask=source_mask,
                  labels=target_token_ids).loss
      return loss



    def validation_epoch_end(self, outputs):
      avg_f1 = sum([x['val_f1'] for x in outputs]) / len(outputs)
      avg_em = sum([x['val_em'] for x in outputs]) / len(outputs)
      avg_loss = sum([x['val_loss'] for x in outputs]) / len(outputs)

      self.log("avg_val_f1", avg_f1, prog_bar=True)
      self.log("avg_val_em", avg_em, prog_bar=True)
      self.log("avg_val_loss", avg_loss, prog_bar=True)

    def test_epoch_end(self, outputs):
      avg_f1 = sum([x['test_f1'] for x in outputs]) / len(outputs)
      avg_em = sum([x['test_em'] for x in outputs]) / len(outputs)
      avg_loss = sum([x['test_loss'] for x in outputs]) / len(outputs)

      self.log("avg_test_f1", avg_f1, prog_bar=True)
      self.log("avg_test_em", avg_em, prog_bar=True)
      self.log("avg_test_loss", avg_loss, prog_bar=True)
    
    def configure_optimizers(self):

      optimizer = torch.optim.Adam(
          [p for p in self.parameters() if p.requires_grad],
          lr=self.learning_rate, eps=1e-08)
      
      scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2000,gamma=0.99)

      return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'avg_val_f1'}

# Get logger

In [None]:
from pytorch_lightning.loggers import NeptuneLogger

neptune_logger = NeptuneLogger(
    api_key="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyNDBmMThjNi1kODM2LTQzYTItYTgzMi01YTczMjI3NjhjYTUifQ==",
    project="e166690/Mestrado",
)

# Train

In [None]:
model_name = "t5-base"
learning_rate = 1e-3
batch_size = 4
accumulate_grad_batches = 4
source_max_length = 512
max_epochs = 10
max_context_size = 512
dataset_type = "gold"

In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

early_monitor = EarlyStopping(monitor="avg_val_f1", min_delta=0.00, patience=5, mode="max")

checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",
    filename="{epoch:02d}-{val/loss:.2f}",
    save_top_k=3,
    save_last=True,
    monitor="avg_val_f1",
    every_n_epochs=1
)

pl.seed_everything(42)

data = IIRCDataModule(model_name=model_name, max_context_size=max_context_size, batch_size=batch_size, dataset_type=dataset_type)

trainer = pl.Trainer(gpus=1,
                     max_epochs=max_epochs,
                     check_val_every_n_epoch=1,
                     accumulate_grad_batches=accumulate_grad_batches,
                     callbacks=[checkpoint_callback, early_monitor],
                     logger=neptune_logger,
                     log_every_n_steps=25,
                    #  fast_dev_run=True,
                    #  overfit_batches=0.04,
                     )

model = ModelFinetuner(model_name=model_name,
                    learning_rate=learning_rate, 
                    source_max_length=source_max_length)

trainer.fit(model,data)

Global seed set to 42


Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/850M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Token indices sequence length is longer than the specified maximum sequence length for this model (5571 > 512). Running this sequence through the model will result in indexing errors
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable par

https://app.neptune.ai/e166690/Mestrado/e/MES-40
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

File /content/checkpoints/last.ckpt changed during upload, restarting upload.


Validation: 0it [00:00, ?it/s]

Unexpected error occurred in Neptune background thread: Killing Neptune asynchronous thread. All data is safe on disk and can be later synced manually using `neptune sync` command.


Exception in thread NeptuneAsyncOpProcessor:
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/neptune/new/internal/backends/utils.py", line 90, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/neptune/new/internal/backends/hosted_file_operations.py", line 442, in upload_raw_data
    response.raise_for_status()
  File "/usr/local/lib/python3.7/dist-packages/requests/models.py", line 941, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 400 Client Error:  for url: https://app.neptune.ai/api/leaderboard/v1/attributes/storage/file/upload/part?uploadId=ABPnzm6NfR-XpRZAzeNel3Oy7fgZUE5PvVG5t5dwy_2Chu_ltLTHaZhYfgsEL51ebzxcezZo&uploadPartIdx=413&experimentIdentifier=96d1491e-01dd-42be-9e3e-76b18d712bdf&attribute=training%2Fmodel%2Fcheckpoints%2Flast

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/li

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
dl = data.val_dataloader()
batch = next(iter(dl))
source_token_ids, source_mask, target_token_ids, target_mask, original_source, original_target = batch

In [None]:
original_source

('In what country did Bain attend doctoral seminars of Wlad Godzich?[SEP]The University of Geneva (French: Université de Genève) is a public research university located in Geneva, Switzerland.[SEP]and later attended the doctoral seminars of Wlad Godzich in the University of Geneva.[SEP]He completed M. Phil at the Geneva-based IUEE (Institute for European Studies), and later attended the doctoral seminars of Wlad Godzich in the University of Geneva.',
 'In what Spanish province is the city located where Bain took up Hispanic Studies at a small private college?[SEP]In  1  9  8  2  he moved to Spain, and took up Hispanic Studies in a small private college in Salamanca[SEP]Salamanca ( , ) is a city in western Spain that is the capital of the Province of Salamanca',
 "When was the city were McCarrick left her apartment before disappearing in  1  9  9  3  founded?[SEP]Dublin celebrated its 'official' millennium in  1  9  8  8 , meaning the Irish government recognised  9  8  8  as the year in

In [None]:
print(original_target)

('Switzerland', 'Province of Salamanca', ' 9  8  8 ', 'no answer')


In [None]:
model.training = False
output_seq = model(source_token_ids, source_mask)
output_seq = model.tokenizer.batch_decode(output_seq, skip_special_tokens=True)
output_seq

['Michigan', 'yes', 'no answer', 'no answer']

In [None]:
dl = data.train_dataloader()
batch = next(iter(dl))
source_token_ids, source_mask, target_token_ids, target_mask, original_source, original_target = batch
# source_token_ids = source_token_ids.to('cuda:0')
# source_mask = source_mask.to('cuda:0')

In [None]:
model.training = False
output_seq = model(source_token_ids, source_mask)
output_seq = model.tokenizer.batch_decode(output_seq, skip_special_tokens=True)
output_seq

['no answer', 'yes', 'no answer', 'no answer']

In [None]:
dl.dataset[0]

(tensor([  366,   410,     8,  2986,   383,    84,     8,   489,     3,   632,
           314,     3,   189,  6292,  4471,    12,     3, 26655, 12673,  1084,
          2504,   354,   526,   729,  1731,    58,  6306,   134,  8569,   908,
          2092,  6411,  1575,  3611,  5072,     6,     8,  3332,    12,   142,
          1737,     3,     9,  4716,  3313,   640,     8, 11092,   630,    16,
             8, 12023,     6,     8,   489,     3,   632,   314,     3,   189,
          6292,  4471,    12,     3, 26655, 12673,  1084,  2504,   354,   526,
           729,     5,  6306,   134,  8569,   908,   667,   883,   257,  3611,
          5072,    47,     3,     9,  4567,  1150,  1602,  2466,  2716,  2986,
             3, 13973,    16,     8, 12023,    45,   209,   489,    12,   204,
           305,  1600,   209,   668,   314,   314,     3,     5,     1,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [None]:
original_source

('When did the operation during which the  7  0  4 th dropped supplies to allied troops near Nijmegen begin?[SEP]During Operation Market Garden, the attempt to seize a bridgehead across the Rhine in the Netherlands, the  7  0  4 th dropped supplies to allied troops near Nijmegen.[SEP]Operation Market Garden was a failed World War II military operation fought in the Netherlands from  1  7  to  2  5  September  1  9  4  4 .',
 'What was replaced on the property by the Washington Square Arch?[SEP]Despite being public property, and expanding the Fifth Avenue axis into Washington Square Park, the Washington Square Arch is the unofficial symbol of NYU[SEP]In  1  8  8  9 , a large plaster and wood memorial arch was erected over Fifth Avenue just north of Washington Square Park by local businessman and philanthropist William Rhinelander Stewart',
 "Which of the Ford brothers had directed more films?[SEP]Ford entered the filmmaking industry shortly after graduating from high school with the hel

In [None]:
original_target

(' from  1  7  to  2  5  September  1  9  4  4 ',
 'a large plaster and wood memorial arch ',
 'no answer',
 ' 3  2 ')