In [20]:
import pandas as pd
import numpy as np
import torch
from sentence_transformers import util, SentenceTransformer
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import random

device = 'mps'

In [21]:
emb_chunks_df = pd.read_csv('emb_chunks_df.csv')
emb_chunks_df.head()

Unnamed: 0,page_n,sentence_chunk,chunk_chars,chunk_words,chunk_tokens,embedding
0,0,RELATIONSHIP TERMS & CONDITIONS CORE BANKING A...,54,7,13.5,"[-0.28527069091796875, -0.5010668635368347, -0..."
1,2,Contents Important information 1 1 General 1. ...,1496,248,374.0,"[-0.2918225824832916, -0.7688528895378113, -0...."
2,2,"of your accounts, payments 49 and payment inst...",117,18,29.25,"[-0.12205161154270172, -0.7234632968902588, -0..."
3,3,Core Banking Agreement (“The Agreement”) conta...,243,36,60.75,"[-0.3948034644126892, -0.8225910663604736, -0...."
4,4,Core Banking Agreement 1 Important Information...,1037,156,259.25,"[-0.030174657702445984, -0.39982807636260986, ..."


In [22]:
# convert embeddings back to np.array
emb_chunks_df['embedding'] = emb_chunks_df['embedding'].apply(lambda x: np.fromstring(x.strip('[]'), sep=', '))
embs = torch.tensor(np.stack(emb_chunks_df['embedding'].tolist(), axis=0), dtype=torch.float32).to(device)

pages_n_chunks = emb_chunks_df.to_dict(orient='records')

emb_chunks_df.head()

Unnamed: 0,page_n,sentence_chunk,chunk_chars,chunk_words,chunk_tokens,embedding
0,0,RELATIONSHIP TERMS & CONDITIONS CORE BANKING A...,54,7,13.5,"[-0.28527069091796875, -0.5010668635368347, -0..."
1,2,Contents Important information 1 1 General 1. ...,1496,248,374.0,"[-0.2918225824832916, -0.7688528895378113, -0...."
2,2,"of your accounts, payments 49 and payment inst...",117,18,29.25,"[-0.12205161154270172, -0.7234632968902588, -0..."
3,3,Core Banking Agreement (“The Agreement”) conta...,243,36,60.75,"[-0.3948034644126892, -0.8225910663604736, -0...."
4,4,Core Banking Agreement 1 Important Information...,1037,156,259.25,"[-0.030174657702445984, -0.39982807636260986, ..."


In [23]:
embs.shape, embs

(torch.Size([103, 1024]),
 tensor([[-0.2853, -0.5011, -0.1342,  ..., -0.5641, -0.0289,  0.3733],
         [-0.2918, -0.7689, -0.4976,  ..., -0.5951,  0.0964,  0.5677],
         [-0.1221, -0.7235, -0.7040,  ..., -0.1265, -0.2623,  0.3828],
         ...,
         [-1.1459, -0.9477, -0.4285,  ..., -0.5960, -0.3391, -0.1244],
         [-0.4814, -0.1554, -0.4521,  ..., -0.4333,  0.0816,  0.3699],
         [ 0.1392, -0.1616,  0.5991,  ...,  0.2702,  0.2376, -0.7100]],
        device='mps:0'))

In [24]:
emb_model = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1', device=device)
# mixedbread-ai/mxbai-embed-large-v1 all-mpnet-base-v2

In [25]:
query = 'foreign currency exchange'

query_emb = emb_model.encode(query, convert_to_tensor=True).to(device)

scores = util.cos_sim(a=query_emb, b=embs)[0] # cos_sim dot_score

top_results = torch.topk(scores, k=5)

print(f'*** Query: {query} ***\n')
for score, idx in zip(top_results[0], top_results[1]):
    print(f'Score: {score:.4f}')
    print(f'Text: {pages_n_chunks[idx]["sentence_chunk"]}')
    print(f'Page number: {pages_n_chunks[idx]["page_n"]}\n')


*** Query: foreign currency exchange ***

Score: 0.7262
Text: The exchange rate applied to your payments will appear on your statement FOREIGN CURRENCY C 26. Payments involving a foreign currency exchange Foreign currency exchange rate information  The exchange rates we use are variable exchange rates which are changing constantly throughout the day (for example, to reflect movements in foreign exchange markets). The exchange rate applied to your payments will appear on your statement. Unless otherwise agreed with you, the exchange rate we will apply to payments you make involving a currency exchange (including any future dated payments) and payments you receive which are in a different currency to the denomination of your account will be the Lloyds Bank Foreign Exchange Rate applicable at the time that your payment is processed. You can contact us to find out the rate which will apply and you can find details of how to contact us in the General Information On Payments, Charges & Conta

In [26]:
def retrieve_relevant_info(query: str, embeddings: torch.tensor, model: SentenceTransformer=emb_model, n_to_retrieve: int=5) -> torch.tensor:
    query_emb = model.encode(query, convert_to_tensor=True)
    dot_scores = util.cos_sim(query_emb, embeddings)[0]
    scores, indices = torch.topk(dot_scores, n_to_retrieve)
    print(scores)
    return scores, indices

def print_topk(query: str, embeddings: torch.tensor, pages_n_chunks: list[dict]=pages_n_chunks, n_to_retrieve: int=5):
    scores, indices = retrieve_relevant_info(query, embeddings)

    print(f'--- Query: {query} ---')
    for score, idx in zip(scores, indices):
        print(f'Score: {score:.4f}')
        print(f'Text: {pages_n_chunks[idx]["sentence_chunk"]}')
        print(f'Page number: {pages_n_chunks[idx]["page_n"]}\n')

In [27]:
query = 'exchange rates abroad'
print_topk(query,embs)

tensor([0.6958, 0.6391, 0.5893, 0.5642, 0.5568], device='mps:0')
--- Query: exchange rates abroad ---
Score: 0.6958
Text: The exchange rate applied to your payments will appear on your statement FOREIGN CURRENCY C 26. Payments involving a foreign currency exchange Foreign currency exchange rate information  The exchange rates we use are variable exchange rates which are changing constantly throughout the day (for example, to reflect movements in foreign exchange markets). The exchange rate applied to your payments will appear on your statement. Unless otherwise agreed with you, the exchange rate we will apply to payments you make involving a currency exchange (including any future dated payments) and payments you receive which are in a different currency to the denomination of your account will be the Lloyds Bank Foreign Exchange Rate applicable at the time that your payment is processed. You can contact us to find out the rate which will apply and you can find details of how to contac

In [28]:
model_id = 'google/gemma-2b-it'

tokenizer = AutoTokenizer.from_pretrained(model_id)
llm_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=False, attn_implementation='sdpa').to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [29]:
queries = [
    'In what cases may I close my account?',
    'How do you pay interest?',
    'Can the terms and conditions be changed?',
    'What is meant by Business Day?',
    'How can I reach you?',
    'How do you receive my payment instructions?',
    'In which cases can you terminate my account?',
    'Who is authorised to give you instructions?'
]

In [30]:
def prompt_formatter(query: str, context_items: list[dict]) -> str:
    context = '- ' + '\n- '.join([item['sentence_chunk'] for item in context_items])
    base_prompt = """Based on the following context items, please answer the query.
Give yourself room to think by extracting relevant passages from the context before answering the query.
Don't return the thinking, only return the answer.
Make sure your answers are as explanatory as possible.
Use the following examples as reference for the ideal answer style, but don't use the below example answers as answers to the query.
\nExample 1:
Query: Who can provide instructions to the bank according to the terms and conditions?
Answer: According to the terms and conditions, only authorized individuals can give instructions to the bank.
\nExample 2:
Query: What are your rights regarding the termination of services as outlined in the terms and conditions?
Answer: The terms and conditions specify the rights granted to you in the event of termination, including any associated procedures or obligations.
\nExample 3:
Query: How does the bank handle refunds for incorrectly executed payment instructions, as per the terms and conditions?
Answer: The terms and conditions detail the process for obtaining refunds in the case of payment instructions being incorrectly executed by the bank.
\nExample 4:
Query: What measures are outlined in the terms and conditions to ensure the security of your accounts and payment instruments?
Answer: The terms and conditions lay out your obligations regarding the security of your accounts, payments, and payment instruments, along with any corresponding measures implemented by the bank.
\nNow use the following context items to answer the user query:
{context}
\nRelevant passages: <extract relevant passages from the context here>
User query: {query}
Answer:"""

    base_prompt = base_prompt.format(context=context, query=query)
    
    # make sure the inputs to the model are in the same way that they have been trained
    dialogue_template = [
        {
            'role': 'user',
            'content': base_prompt
        }
    ]
    prompt = tokenizer.apply_chat_template(conversation=dialogue_template, tokenize=False, add_generation_prompt=True)

    return prompt

In [31]:
def ask(query: str, temperature: float=0.2, max_new_tokens: int=256, format_answer_text: bool=True, return_context: bool=False):
    # -------- RETRIEVAL --------
    scores, indices = retrieve_relevant_info(query, embs, n_to_retrieve=10)
    context_items = [pages_n_chunks[i] for i in indices]
    for i, item in enumerate(context_items):
        item['score'] = scores[i].cpu()

    # -------- AUGMENTATION --------
    prompt = prompt_formatter(query, context_items)

    # -------- GENERATION --------
    input_ids = tokenizer(prompt, return_tensors='pt').to(device)
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    outputs = llm_model.generate(**input_ids, streamer=streamer, temperature=temperature, do_sample=True, max_new_tokens=max_new_tokens)
    output_text = tokenizer.decode(outputs[0])

    if format_answer_text:
        output_text = output_text.replace(prompt, '').replace('<bos>', '').replace('<eos>', '')

    # if not return_context:
        # return output_text
    
    # return output_text, context_items

In [34]:
query = random.choice(queries) # 'What is meant by Business Day?'
print(f'Query: {query}')
ask(query, temperature=0.7, return_context=False)

Query: What is meant by Business Day?
tensor([0.8094, 0.5621, 0.5579, 0.5514, 0.5400, 0.5357, 0.5356, 0.5291, 0.5265,
        0.5241], device='mps:0')
According to the context, a Business Day means 9am to 5pm every Monday to Friday other than public or bank holidays in England and Wales, unless you are transacting through one of our branches which opens for shorter hours or we notify you of different times for the processing of payments.


In [33]:
torch.mps.empty_cache()