# Week 6 - NLP and Deep Learning

---

# Lecture 11: Generative LM's and RAG

Today we will complement a generative language model with Retrieval Augmented Generation (RAG). More precisely, we will improve the language models ability to answer Star Wars trivia questions. 

For these assignment we will assume you have the Transformers and NLTK package (with 'punkt') installed, you can do this by running:
```
pip3 install transformers
pip3 install nltk

python3
nltk.download('punkt')
```

## 1. QA with Flan-T5

We have provided:

- gold data: The questions and gold answers can be found in `questions.txt` and `answers.txt`. The data is not tokenized
- raw data: a (subset of) a scrape of wookipedia in `starwarsfandomcom-20200223.txt.filtered.tokked.gz`
- list of English words from Aspell dictionary in `en-aspell-dict.txt`
- code: code that loads the questions, supplements them with a prompt, queries the language model, and evaluates performance.

The code uses flan-t5-base by default, which can be ran on a 8gb of memory (GPU/RAM), you can also experiment with other models. (use google/flan-t5-small if you have less memory available).

It should be noted that evaluation is done with a non-standard metric (which mostly follows Rob's intuition of what should be counted correct):

In [None]:
def eval_metric(gold, pred):
    """
    An answer is considered correct if at least half of the gold
    tokens are in the prediction. Note that this is a shortcut, 
    and will favor long answers.
    """
    gold = set(gold.strip().lower().replace('.', '').split(' '))
    pred = set(pred.strip().lower().replace('.', '').split(' '))
    return len(gold.intersection(pred)) >= len(gold)/2


The code for querying the model and evaluating its performance is shown below:

In [None]:
from transformers import T5ForConditionalGeneration
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import torch

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

#lms = ['google/flan-t5-base', 'google/flan-t5-large', 'google/flan-t5-xl', 'google/flan-t5-xxl', 'mistralai/Mistral-7B-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-7b-chat-hf' , 'google/gemma-2b', 'google/gemma-2b-it', 'EleutherAI/pythia-6.9b', 'tiiuae/falcon-7b', 'falcon-7b-instruct', 'microsoft/phi-2']
lms = ['google/flan-t5-base']

questions = open('questions.txt').readlines()
answers = open('answers.txt').readlines()
#questions = ['What is the capital of Denmark ?']
#answers = ['Copenhagen']

# we will consider the prompt to be given for this project
prefix = 'Q: '
suffix = 'A: '


def eval_lm(lm, contexts = []):
    # for some of the language models, a token from huggingface needs to be used. 
    # This can be saved in a file called token. If a t5-based model is used this 
    # is not necessary.
    if 't5' in lm:
        lang_model = T5ForConditionalGeneration.from_pretrained(lm)
        tokenizer = AutoTokenizer.from_pretrained(lm, legacy=False)
    else:
        lang_model = AutoModelForCausalLM.from_pretrained(lm, token=open('token').readline().strip())
        tokenizer = AutoTokenizer.from_pretrained(lm, legacy=False, token=open('token').readline().strip())

    lang_model.to(DEVICE)

    if contexts == []:
        contexts = [''] * len(questions)
    correct = 0
    for question, answer, context in zip(questions, answers, contexts):
        # Prepare input
        question = context + prefix + ' ' + question.strip() + ' ' + suffix
        subword_ids = tokenizer(question.strip(), return_tensors='pt')['input_ids']
        subword_ids = subword_ids.to(DEVICE)
        
        # Generate output from model
        generated_ids = lang_model.generate(subword_ids, max_new_tokens=20)
        subwords_out = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        #print()
        #print(question)
        #print(' '.join(subwords_out))
        
        # Evaluate
        correct += int(eval_metric(answer, ' '.join(subwords_out)))

    print(str(correct) + ' out of ' + str(len(answers)) + ' correct', flush=True)
    
for lm in lms:
    eval_lm(lm)

## 2. RAG (Retrieval Augmented generation)

We are going to use Retrieval Augmented Generation (RAG) to improve the model with relevant information and hopefully increase performance. We are going to do this based on the  `starwarsfandomcom-20200223.txt.filtered.tokked.gz` file, and provide a ranking of sentence importances, from which we will use the 5 highest ranked sentences as additional context.

You are encouraged to evaluate a variety of ranking approaches, in which you can make use of any technologies described in the course (TF-IDF, n-grams), and also external resources (POS tagger, NER tagger, lemmatizer, sentence embeddings, etc.). Of course heuristics based approaches can also be included, like the example below. 

Below we provide a simple example that extracts all words that are not found in an English word list, and then finds all sentences in the data that contain all these words. You can use this as a starting point (e.g. find more, filter better based on verbs/nouns, improve ranking), but you can also implement your own strategies from scratch.

In [None]:
import gzip
from nltk.tokenize import word_tokenize

data = [str(x).strip() for x in gzip.open('starwarsfandomcom-20200223.txt.filtered.tokked.gz').readlines()]


def eval_plus_contexts(lm, data, context_indices):
    contexts = []
    # first collect the context
    for question_context in context_indices:
        context = ''
        for sent_id in question_context[:5]:
            context += data[sent_id] + '\n'
        contexts.append(context)
    eval_lm(lm, contexts)

def hasAlpha(word):
    for char in word:
        if char.isalpha():
            return True
    return False


# simple approach that extracts all words not in a dictionary, and then returns the 
# indices of the sentences containing all these "raw" words.
en_vocab = set([x.strip() for x in open('en-aspell-dict.txt').readlines()])
en_vocab.add("'s")
en_vocab.add("n't")
en_vocab.add("'re")
context_ids = []

for question in questions:
    # collect all rare words
    rares = []
    for word in word_tokenize(question):
        if word.lower() not in en_vocab and hasAlpha(word):
            rares.append(word)
            
    # Find indices of sentences that contain all rare words
    found_all_indices = []
    if len(rares) > 0:
        for lineIdx, line in enumerate(data):
            allIn = True
            for rare in rares:
                if rare not in line:
                    allIn = False
                    break
            if allIn:
                found_all_indices.append(lineIdx)
    context_ids.append(found_all_indices)
    
# Prepare output file for online submission
outFile = open('robv-rare.tsv', 'w')
for line in context_ids:
    outFile.write('\t'.join([str(indice) for indice in line]) + '\n')
outFile.close()

eval_plus_contexts(lms[0], data, context_ids)

## 3. Participate

We are also making an ensemble of all our approaches during the lab hours. The way this works is that you upload your top-N rankings for the sentences to the onedrive link that was sent to your ITU e-mail (https://ituniversity-my.sharepoint.com/personal/alai_itu_dk/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Falai%5Fitu%5Fdk%2FDocuments%2FWeek6&ga=1). You can use your ITU username or another anonymous name followed by a dash followed by the name of the approach, e.g. robv-tfidf robv-bert robv-rarewords etc. THe format is that we have one line for each sentence, and within one line the indices of the relevant sentences (start counting with 0) separated by a tab. In the code snippet above the data is written in the right format (to `robv-rare.tsv`). 

We will then evaluate all the individual submissions, and take the average of the 5 best rankings as an ensemble model. Note that you can upload multiple rankings, just make sure that each upload has a unique name. If your file does not show up on the leaderboard (http://itu.dk/~robv/alai/website.html) after 5 minutes the format of the file is incorrect, please compare the format to `robv-rare.tsv`.