**Notes**: Notebook to create embeddings for training questions in Spider dataset using 'bert-base-uncased' checkpoint from huggingface.

### **Setup**

In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [None]:
!pip install transformers



In [None]:
!pip install datasets



### **Download dataset**

In [None]:
from datasets import load_dataset

raw_datasets = load_dataset('spider')
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['db_id', 'query', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'],
        num_rows: 7000
    })
    validation: Dataset({
        features: ['db_id', 'query', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'],
        num_rows: 1034
    })
})

### **Tokenize dataset**

In [None]:
# Setup tokenizer
from transformers import AutoTokenizer

checkpoint = 'bert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
# Define a tokenization function
max_seq_length = 50

def tokenize_func(examples):
  tokenized_inputs = tokenizer(examples['question'],
                               padding='max_length',
                               truncation=True,
                               max_length=max_seq_length)

  return tokenized_inputs

In [None]:
# Tokenize datasets
tokenized_datasets = raw_datasets.map(tokenize_func,
                                      batched=True,
                                      remove_columns=raw_datasets['train'].column_names)

tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 7000
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1034
    })
})

In [None]:
# Test tokenizer
tokenizer.decode(tokenized_datasets['train'][0]['input_ids'])

'[CLS] how many heads of the departments are older than 56? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

### **Prepare dataloaders**

In [None]:
tokenized_datasets.set_format('torch')

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(tokenized_datasets['train'],
                              batch_size=32,
                              shuffle=False)

len(train_dataloader)

219

### **Get embeddings from pre-trained model**

In [None]:
from transformers import AutoModelForTextEncoding

model = AutoModelForTextEncoding.from_pretrained(checkpoint).to(device)

In [None]:
from tqdm.auto import tqdm

# Initialize list to collect embeddings
embeddings = []

# Run model in eval mode
model.eval()
with torch.inference_mode():
  for batch in tqdm(train_dataloader):
    # Send data to device
    batch = {k: v.to(device) for k, v in batch.items()}

    # Get model outputs
    outputs = model(**batch)

    # Get hidden state for CLS token
    embeddings.append(outputs.last_hidden_state[:, 0, :])

  0%|          | 0/219 [00:00<?, ?it/s]

In [None]:
embedding_list = []
for embedding_batch in embeddings:
  for embedding in embedding_batch:
    embedding_list.append(embedding)

len(embedding_list)

7000

### **Write embeddings to file**



In [None]:
torch.save(embedding_list, "training_ques_embeddings.pt")

### **Load embeddings and test**

#### Get random test question from validation dataset

In [None]:
import random
random_idx = random.randint(0, len(raw_datasets['validation']) - 1)

test_question = raw_datasets['validation'][random_idx]['question']
test_question

'How many countries have governments that are republics?'

#### Tokenize test question

In [None]:
from transformers import AutoTokenizer
max_seq_length = 50

checkpoint = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

input = tokenizer(test_question,
                  padding='max_length',
                  truncation=True,
                  max_length=max_seq_length,
                  return_tensors='pt')

tokenizer.decode(input['input_ids'][0], skip_special_tokens=True)

'how many countries have governments that are republics?'

#### Get embedding for test question

In [None]:
from transformers import AutoModelForTextEncoding

model_2 = AutoModelForTextEncoding.from_pretrained(checkpoint)

model_2.eval()
with torch.inference_mode():
  output = model_2(**input)
  # Get embedding for CLS token
  test_embedding = output.last_hidden_state[:, 0, :].to('cpu')

#### Load training question embeddings

In [None]:
loaded_embeddings = torch.load("training_ques_embeddings.pt")
loaded_embeddings = [emb.to('cpu') for emb in loaded_embeddings]

In [None]:
loaded_embeddings[0].shape, test_embedding.shape

(torch.Size([768]), torch.Size([1, 768]))

#### Get top 3 closest questions from training dataset

In [None]:
from torch.nn.functional import cosine_similarity

# Compute cosine similarities
similarities = [cosine_similarity(test_embedding, emb.unsqueeze(0)) for emb in loaded_embeddings]

# Get top 3 closest matches
top_indices = torch.argsort(torch.cat(similarities, dim=0), descending=True)[:3].tolist()

print(f"Test question: {tokenizer.decode(input['input_ids'][0], skip_special_tokens=True)}")
print(f"Test query: {raw_datasets['validation'][random_idx]['query']}")

for idx in top_indices:
  print(f"\nidx: {idx} | question: {raw_datasets['train'][idx]['question']}")
  print(f"idx: {idx} | query: {raw_datasets['train'][idx]['query']}\n")

# Get the top 3 most similar embeddings
# top_embeddings = [loaded_embeddings[i] for i in top_indices]

Test question: how many countries have governments that are republics?
Test query: SELECT count(*) FROM country WHERE GovernmentForm  =  "Republic"

idx: 1052 | question: How many countries are there in total?
idx: 1052 | query: SELECT count(*) FROM country


idx: 5049 | question: How many states have smaller colleges than average?
idx: 5049 | query: SELECT count(DISTINCT state) FROM college WHERE enr  <  (SELECT avg(enr) FROM college)


idx: 5047 | question: How many states have a college with more students than average?
idx: 5047 | query: SELECT count(DISTINCT state) FROM college WHERE enr  >  (SELECT avg(enr) FROM college)

