## Few-Shot TLQA with FlanT5

Below, we run generative models FlanT5-large and FlanT5-XL in a few-shot setting by selecting demonstration examples from training set of TLQA, which was derived from tempLLama.
- `flan-t5-large`: 780M params, 1GB mem
- `flan-t5-xl`: 3B params, 12 GB mem

In [11]:
!pip install transformers==4.37.2 optimum==1.12.0 --quiet
!pip install sentence-transformers

Collecting sentence-transformers
  Obtaining dependency information for sentence-transformers from https://files.pythonhosted.org/packages/76/2c/bd95032aeb087b0706596af0a4518c4bfe0439a1bb149048ece18b617766/sentence_transformers-2.7.0-py3-none-any.whl.metadata
  Downloading sentence_transformers-2.7.0-py3-none-any.whl.metadata (11 kB)
Downloading sentence_transformers-2.7.0-py3-none-any.whl (171 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.5/171.5 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentence-transformers
Successfully installed sentence-transformers-2.7.0


#### Running FlanT5-Large and FlanT5-XL (zero-shot)

In [8]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large", torch_dtype=torch.float16)
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
# model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")

# tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
# model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl")

input_text = "List all Michael Jackson albums between 2000 and 2009."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
# input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


<pad> Michael Jackson: The Greatest Hits</s>


#### Running FlanT5-XL on a CPU (zero-shot)

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl")

input_text = "List all Michael Jackson albums between 1990 and 2009."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

### Few-Shot Setting

To select few-shot examples, we select examples from the training set that are close to each test example by using KNN search in a continuous vector space using any embedding model from sentence-transformers.
We reuse the KNN search from https://anonymous.4open.science/r/ICAT-47FC/code/2wikimultihopqa/get_chatgpt_skill_transfer_knn.py.

In [12]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

class KnnSearch:
    def __init__(self, data=None, num_trees=None, emb_dim=None):
        self.num_trees = num_trees
        self.emb_dim = emb_dim

    def get_embeddings_for_data(self, data_ls):
        model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
        embeddings = model.encode(data_ls)
        return embeddings

    def get_top_n_neighbours(self,sentence,data_emb,transfer_data,strategy_emb, str_qa,k):
        sent_emb = self.get_embeddings_for_data(sentence)
        #data_emb = self.get_embeddings_for_data(transfer_questions)
        top_questions = []

        print("new_emb", sent_emb.shape, data_emb.shape)
        text_sims = cosine_similarity(data_emb,[sent_emb]).tolist()
        results_sims = zip(range(len(text_sims)), text_sims)
        sorted_similarities = sorted(results_sims, key=lambda x: x[1], reverse=True)

        #print("text_sims",sorted_similarities[:2])
        for idx, item in sorted_similarities[:k]:
                #if item[0] > 0.45:
                    top_questions.append(transfer_data[idx])
   
        text_sims = cosine_similarity(strategy_emb,[sent_emb]).tolist()
        results_sims = zip(range(len(text_sims)), text_sims)
        sorted_similarities = sorted(results_sims, key=lambda x: x[1], reverse=True)
        #print("text_sims",sorted_similarities[:2]) 
        for idx, item in sorted_similarities:
                top_questions.append(str_qa[idx]) 
        return top_questions