In [1]:
import pandas as pd
import numpy as np
import torch
from sentence_transformers import util, SentenceTransformer
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import random

device = 'mps'

In [2]:
emb_chunks_df = pd.read_csv('emb_chunks_df.csv')
emb_chunks_df.head()

Unnamed: 0,page_n,text,embedding
0,0,CORE BANKING AGREEMENT RELATIONSHIP TERMS & CO...,"[-0.289400577545166, -0.6761614680290222, -0.2..."
1,1,|Information about us and our regulators|29|\n...,"[-0.322034627199173, -0.8583006262779236, -0.5..."
2,2,|Your obligations relating to the security of ...,"[-0.09293383359909058, -0.43664729595184326, -..."
3,3,"(""The Agreement"") contains terms, conditions a...","[-0.36704424023628235, -0.7181680202484131, 0...."
4,4,|RELATIONSHIP TERMS & CONDITIONS|PRODUCT & SER...,"[-0.19214218854904175, -0.26558810472488403, -..."


In [3]:
# convert embeddings back to np.array
emb_chunks_df['embedding'] = emb_chunks_df['embedding'].apply(lambda x: np.fromstring(x.strip('[]'), sep=', '))
embs = torch.tensor(np.stack(emb_chunks_df['embedding'].tolist(), axis=0), dtype=torch.float32).to(device)

pages_n_chunks = emb_chunks_df.to_dict(orient='records')

emb_chunks_df.head()

Unnamed: 0,page_n,text,embedding
0,0,CORE BANKING AGREEMENT RELATIONSHIP TERMS & CO...,"[-0.289400577545166, -0.6761614680290222, -0.2..."
1,1,|Information about us and our regulators|29|\n...,"[-0.322034627199173, -0.8583006262779236, -0.5..."
2,2,|Your obligations relating to the security of ...,"[-0.09293383359909058, -0.43664729595184326, -..."
3,3,"(""The Agreement"") contains terms, conditions a...","[-0.36704424023628235, -0.7181680202484131, 0...."
4,4,|RELATIONSHIP TERMS & CONDITIONS|PRODUCT & SER...,"[-0.19214218854904175, -0.26558810472488403, -..."


In [4]:
embs.shape, embs

(torch.Size([198, 1024]),
 tensor([[-0.2894, -0.6762, -0.2548,  ..., -0.2197,  0.0575,  0.2950],
         [-0.3220, -0.8583, -0.5748,  ..., -0.4800, -0.3515,  0.4240],
         [-0.0929, -0.4366, -0.6115,  ..., -0.1398, -0.7231,  0.4906],
         ...,
         [-0.8480, -0.5612, -0.6425,  ..., -0.5232, -0.2608,  0.8832],
         [-0.7390, -1.1355, -0.4723,  ..., -0.5159, -0.4490,  0.7281],
         [-0.2831,  0.1584, -0.2007,  ..., -0.0353,  0.0897,  0.0589]],
        device='mps:0'))

In [5]:
emb_model = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1', device=device)
# mixedbread-ai/mxbai-embed-large-v1 all-mpnet-base-v2

In [6]:
query = 'foreign currency exchange'

query_emb = emb_model.encode(query, convert_to_tensor=True).to(device)

scores = util.cos_sim(a=query_emb, b=embs)[0] # cos_sim dot_score

top_results = torch.topk(scores, k=5)

print(f'*** Query: {query} ***\n')
for score, idx in zip(top_results[0], top_results[1]):
    print(f'Score: {score:.4f}')
    print(f'Text: {pages_n_chunks[idx]["sentence_chunk"]}')
    print(f'Page number: {pages_n_chunks[idx]["page_n"]}\n')


*** Query: foreign currency exchange ***

Score: 0.7045
Text: ## Unavailable currencies
Page number: 137

Score: 0.6990
Text: |Summary of execution times| |
|---|---|
|Payments in sterling to a financial institution in the UK|End of the next Business Day (or end of the second Business Day if your instructions were initiated in paper form)|
|Payments in euro to a financial institution in the UK or EEA|End of the next Business Day (or end of the second Business Day if your instructions were initiated in paper form)|
|Payments in EEA currencies other than euro to a financial institution in the UK or EEA (but not including payments in sterling to be made to a financial institution in the UK)|End of the fourth Business Day|
|Payments to be made to a financial institution outside of the UK or EEA|Please contact your relationship team for details|
Page number: 132

Score: 0.6914
Text: From time to time we are unable to make payments in certain currencies (you can contact us to check whether p

In [7]:
def retrieve_relevant_info(query: str, embeddings: torch.tensor, model: SentenceTransformer=emb_model, n_to_retrieve: int=5) -> torch.tensor:
    query_emb = model.encode(query, convert_to_tensor=True)
    dot_scores = util.cos_sim(query_emb, embeddings)[0]
    scores, indices = torch.topk(dot_scores, n_to_retrieve)
    print(scores)
    return scores, indices

def print_topk(query: str, embeddings: torch.tensor, pages_n_chunks: list[dict]=pages_n_chunks, n_to_retrieve: int=5):
    scores, indices = retrieve_relevant_info(query, embeddings)

    print(f'--- Query: {query} ---')
    for score, idx in zip(scores, indices):
        print(f'Score: {score:.4f}')
        print(f'Text: {pages_n_chunks[idx]["sentence_chunk"]}')
        print(f'Page number: {pages_n_chunks[idx]["page_n"]}\n')

In [8]:
query = 'exchange rates abroad'
print_topk(query,embs)

tensor([0.6753, 0.6521, 0.6400, 0.6327, 0.6092], device='mps:0')
--- Query: exchange rates abroad ---
Score: 0.6753
Text: The exchange rates we use are variable exchange rates which are changing constantly throughout the day (for example, to reflect movements in foreign exchange markets). The exchange rate applied to your payments will appear on your statement. Unless otherwise agreed with you, the exchange rate we will apply to payments you make involving a currency exchange (including any future dated payments) and payments you receive which are in a different currency to the denomination of your account will be the Lloyds Bank Foreign Exchange Rate applicable at the time that your payment is processed. You can contact us to find out the rate which will apply and you can find details of how to contact us in the General Information On Payments, Charges & Contacts or by contacting your relationship team.
Page number: 136

Score: 0.6521
Text: From time to time we are unable to make paym

In [9]:
model_id = 'google/gemma-2b-it'

tokenizer = AutoTokenizer.from_pretrained(model_id)
llm_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=False, attn_implementation='sdpa').to(device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
queries = [
    'In what cases may I close my account?',
    'How do you pay and charge interest?',
    'Can the terms and conditions be changed?',
    'What is meant by business day?',
    'What is a PIN?',
    'To whom can you disclose my confidential information?',
    'How can I reach you?',
    'How do you receive my payment instructions?',
    'In which cases can you terminate my account?',
    'Who is authorised to give you instructions?',
]

In [11]:
def prompt_formatter(query: str, context_items: list[dict]) -> str:
    context = '- ' + '\n- '.join([item['sentence_chunk'] for item in context_items])
    base_prompt = """You are a helpful assisstant to customers about a bank's terms and conditions. 
Give yourself room to think by extracting relevant passages from the context before answering the query.
Don't return the thinking, only return the answer.
Make sure your answers are as clear and concise.
Use the following couple of examples as reference for the ideal answer style, but don't use the below example answers as answers to the query.
\nExample 1:
User query: I'm considering opening a new savings account with a competitive interest rate. However, I noticed a clause regarding minimum balance requirements. Could you elaborate on the potential implications of not maintaining this minimum balance?
AI answer: That's a prudent inquiry!  Many banks offer attractive interest rates on savings accounts, but they may stipulate a minimum balance requirement.  Failing to maintain this minimum can trigger various consequences, including incurring fees or forfeiting the advertised interest rate. Carefully review the minimum balance stipulation within the T&Cs to ensure it aligns with your financial situation.
\nExample 2:
User query: My bank has been sending frequent notifications regarding mobile banking security. While I appreciate the reminder, is utilizing mobile banking inherently risky?
AI answer: Mobile banking offers undeniable convenience but does necessitate vigilance. While not inherently risky, online transactions always carry a certain level of risk.  To mitigate these risks, ensure your mobile device is equipped with a strong password and avoid using public Wi-Fi networks for banking activities. Your bank's security notifications serve as a valuable reminder to prioritize online safety measures.
\nNow based on the following context items:
{context};
\n And answer the user's query:
User query: <start_of_turn>user{query}<end_of_turn>
AI answer:<start_of_turn>model"""

    base_prompt = base_prompt.format(context=context, query=query)
    
    # make sure the inputs to the model are in the same way that they have been trained
    dialogue_template = [
        {
            'role': 'user',
            'content': base_prompt
        }
    ]
    prompt = tokenizer.apply_chat_template(conversation=dialogue_template, tokenize=False, add_generation_prompt=True)

    return prompt

In [12]:
def ask(query: str, temperature: float=0.2, max_new_tokens: int=256, format_answer_text: bool=True, return_context: bool=False):
    # -------- RETRIEVAL --------
    scores, indices = retrieve_relevant_info(query, embs, n_to_retrieve=5)
    context_items = [pages_n_chunks[i] for i in indices]
    for i, item in enumerate(context_items):
        item['score'] = scores[i].cpu()

    # -------- AUGMENTATION --------
    prompt = prompt_formatter(query, context_items)

    # -------- GENERATION --------
    input_ids = tokenizer(prompt, return_tensors='pt').to(device)
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    outputs = llm_model.generate(**input_ids, streamer=streamer, temperature=temperature, do_sample=True, max_new_tokens=max_new_tokens)
    output_text = tokenizer.decode(outputs[0])

    if format_answer_text:
        output_text = output_text.replace(prompt, '').replace('<bos>', '').replace('<eos>', '')

    if not return_context:
        return output_text
    
    return output_text, context_items

In [13]:
# query = random.choice(queries) # 'What is meant by Business Day?'
for query in queries:
    print(f'Query: {query}')
    ask(query, temperature=0.7, return_context=False)
    print('\n')

Query: In what cases may I close my account?
tensor([0.6776, 0.6749, 0.6690, 0.6513, 0.6424], device='mps:0')
According to the context, the following circumstances may lead to the closure of an account:

- You do not maintain the minimum balance requirement for that particular savings account.
- You become overdrawn in relation to a product, or exceed any overdraft limit.
- There is a dispute between the holders of the account.
- You are insolvent.
- You are found to be using the account illegally, fraudulently, or outside of the terms of the agreement.


Query: How do you pay and charge interest?
tensor([0.7863, 0.6881, 0.6769, 0.6568, 0.6566], device='mps:0')
The context does not provide information about how the bank pays and charges interest, so I cannot answer this question from the provided context.


Query: Can the terms and conditions be changed?
tensor([0.7511, 0.7375, 0.7299, 0.6941, 0.6721], device='mps:0')
Yes, the terms and conditions can be changed at any time by the bank

In [14]:
torch.mps.empty_cache()