# World Bank Financial Survey Q&A Model Project 

This project develops a NLP powererd question-answering system that is trained on World Bank Survey Data containing financial information gathered from various federal banks across the globe. This notebook walks the user through gathering/processing the data and training/deploying the final model. 

### Dataset Description
The World Bank survey dataset comprises of structured financial questions sent to financial instituitions worldwide. The dataset includes multi-dimensional survey responses, hierarchial question structures, and financial metrics. For this project, we will use the questions that have long-form textual answers to train an NLP model, rather than using binary response questions.

### Project Architecture

##### Phase 1 - Data Processing 

- Transform unstructured survey data into structurerd NLP training pairs
    - Parse all relevant sheets from excel file
    - Properly handle hierarchical question structures to ensure each question answer pair is standalone
- Indentify and Flag PII using a ML model in dataset

##### Phase 2 - Model Development & Fine Tuning

- Fine-tune a Google FLAN-T5-Base NLP model 
- Optimize the model's performance on this specific World Bank survey domain
- Evaluate model performance using validation and test sample sets

##### Phase 3 - Deployement

- Deploy fine-tuned model to production environment on Azure/Huggingface

##### Future Steps (if time allows):
- implement API for interacting with model
- add some sort of sentiment analysis to clasify questions/answers (financial questions, admin questions, etc)
- get feedback on model performance (answer quality/hallucinations/knowledge gaps)  
- add additional survey questions to knowledge base 

In [None]:
# download data from World Bank Database
import requests

url = "https://datacatalogfiles.worldbank.org/ddh-published/0038632/2/DR0047737/2021_04_26_brss-public-release.xlsx"
response = requests.get(url)

with open("worldbank_data.xlsx", "wb") as f:
    f.write(response.content)

Now that data is downloaded, it needs to be converted from an xlsx file with row column format to something that works for t5 training (question:answer pairs).  

In [1]:
## read and process data
import pandas as pd
import re

# Remove extra unnecessary information from question
# For example, "Select all that apply"
def simplify_question(qText):
    if pd.isna(qText):
        return ""
    
    text = str(qText).strip()
    
    # split on common instruction starters and take first part
    for splitter in [" Please ", " If ", " Include ", " Specify ", " Describe ", " List "]:
        if splitter in text:
            text = text.split(splitter)[0]
            break
    
    # if there's a question mark, take up to first one
    if "?" in text:
        text = text.split("?")[0] + "?"
    
    return text.strip()

# loads all sheets at once
allSheets = pd.ExcelFile("worldbank_data.xlsx")

# store samples
samples = []

# process all sheets except first 2 and last 1
process = allSheets.sheet_names[2:-1]

# read first sheet and extract countries
dfFirst = pd.read_excel(allSheets, sheet_name=process[0], header=None)
countries = [str(c) for c in dfFirst.iloc[0, 2:].values if not pd.isna(c)]

for sheet in process:
    # read current sheet
    df = pd.read_excel(allSheets, sheet_name=sheet, header=None)
    
    # create parent and base vars
    parent = None
    currBase = None
    
    # iterate through every row except header
    # get question index and question text
    for idx, row in df.iloc[1:].iterrows():
        qIndex = row[0]
        qText = row[1]
        
        # if the question index is null but text does exist 
        # then the question is a parent question
        # assign parent question and then clear prev base and move onto next row
        if pd.isna(qIndex) and not pd.isna(qText):
            parent = simplify_question(qText)  # ← Simplify parent too
            currBase = None
            continue
        
        # regex starts with Q and captures groups delimited by _
        # group 1 is the main question number
        # group 2 is sub-question number
        # group 3 is for multi-part questions with extra text
        # non-capturing group is for sections of index which are unnecessary
        match = re.match(r'Q(\d+)_([0-9_]+?)([a-zA-Z_]+)?(?:_[A-Z]|_\d{4}|$)', str(qIndex))
        
        # if regex matched then process row, otherwise skip
        if match:
            baseNum = f"{match.group(1)}_{match.group(2)}"
            isMulti = bool(match.group(3)) or bool(re.search(r'_\d{4}', str(qIndex)))
            part = match.group(3) if match.group(3) else ""
        else:
            continue
        
        # if new base is different to current base, update base
        if baseNum and baseNum != currBase:
            # reset parent if new question isn't multi part
            if not isMulti:
                parent = None
            currBase = baseNum
        
        # loop through each column
        for colIdx, country in enumerate(countries):
            
            # get answer for current column
            answer = row[colIdx + 2]
            
            # skip column if there's no answer
            if pd.isna(answer):
                continue
            
            # Simplify the question text
            simplifiedQ = simplify_question(qText)  # ← KEY CHANGE
            
            # if question is multi-part combine parent question and question text
            if isMulti and parent:
                completeQ = f"{parent} {simplifiedQ}"
            # otherwise just append question text
            else:
                completeQ = simplifiedQ
            
            # fill in sample entry
            sample = {
                "input": f"Answer this question about {country}: {completeQ}".strip(),
                "target": str(answer).strip()
            }
            
            # append sample to list
            samples.append(sample)

Now that the data is in proper training format, it needs to be checked for PII. We will use Microsoft's Presidio pre-trained ML library to detect PII (https://github.com/microsoft/presidio).

In [None]:
# install dependecies
# !pip install presidio_analyzer presidio_anonymizer
# !python -m spacy download en_core_web_lg

In [None]:
from presidio_analyzer import AnalyzerEngine
from tqdm import tqdm
import json

# initialize analyzer
analyzer = AnalyzerEngine()

# specific countries and years are necessary to the survey data
# do not flag these as PII
excludeWords = set(countries)
excludeWords.update(['2011', '2012', '2013', '2014', '2015', '2016'])

# only include entries that the model has 70%+ confidnece is PII
CONFIDENCE = 0.7

# only track unique PII values
seenPII = set()

# storage for PII
potentialPII = []

# iterate through every sample
for idx, sample in enumerate(tqdm(samples, desc='finding pii')):

    # get input question and target
    inputText = sample["input"]
    targetText = sample["target"]

    # analyze input and target
    inputRes = analyzer.analyze(text=inputText, language='en')
    targetRes = analyzer.analyze(text=targetText, language='en')

    # filter out exclude list from text matches
    inputRes = [r for r in inputRes 
                if r.score >= CONFIDENCE
                and not any(inputText[r.start:r.end] in word or word in inputText[r.start:r.end] for word in excludeWords)] 
    targetRes = [r for r in targetRes 
                 if r.score >= CONFIDENCE 
                 and not any(targetText[r.start:r.end] in word or word in targetText[r.start:r.end] for word in excludeWords)]

    # if pii is found
    isNewPII = False
    for r in inputRes:
        if inputText[r.start:r.end] not in seenPII:
            isNewPII = True
            seenPII.add(inputText[r.start:r.end])
    for r in targetRes:
        if targetText[r.start:r.end] not in seenPII:
            isNewPII = True
            seenPII.add(targetText[r.start:r.end])

    if isNewPII:
        res = {
            "input": inputText,
            "target": targetText,
            "inputPII": [{"type": r.entity_type, "text": inputText[r.start:r.end], "score": r.score} for r in inputRes],
            "targetPII": [{"type": r.entity_type, "text": targetText[r.start:r.end], "score": r.score} for r in targetRes]
        }
        potentialPII.append(res)

# dump all potential flagged PII into a json file
with open('potentialPII.json', 'w', encoding='utf-8') as f:
    json.dump(potentialPII, f, indent=2, ensure_ascii=False)


finding pii: 100%|██████████| 107833/107833 [35:25<00:00, 50.73it/s] 


The code dumps all potential PII matches to a seperate JSON file saved to the current directory (potentiallyPII.json). This file can now be manually checked to determine which flagged keywords are false postives and which are actually PII. Once all PII is removed from the dataset, the T5 model training can begin.

In [None]:
# install dependencies
# !pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128
# !pip install transformers datasets accelerate


In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM, 
    DataCollatorForSeq2Seq,
)
from datasets import Dataset
import random

# balances samples to significantly reduce training time for project constraints
# also helps prevent the model from learning to predict yes/no for every question
samplesSmall = [s for s in samples if len(s["target"].split()) < 3]
samplesLarge = [s for s in samples if len(s["target"].split()) >= 3]
random.seed(42)
samplesBalanced = (
    random.sample(samplesLarge, min(int(len(samples) * 0.7), len(samplesLarge))) + 
    random.sample(samplesSmall, min(int(len(samples) * 0.3), len(samplesSmall)))
)
random.shuffle(samplesBalanced)

# convert existing data to hugging face dataset
data = Dataset.from_list(samplesBalanced)

# Setup
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")

def preprocess(samples):
    modelInputs = tokenizer(
        samples["input"],
        max_length=512,
        truncation=True,
        padding=False
    )
    targets = tokenizer(
        samples["target"],
        max_length=128,
        truncation=True,
        padding=False
    )
    modelInputs["labels"] = targets["input_ids"]
    return modelInputs

trainValSplit = data.train_test_split(test_size=0.2)
valTestSplit = trainValSplit["test"].train_test_split(test_size=0.5)

splits = {
    "train": trainValSplit["train"],
    "validation": valTestSplit['train'],
    "test": valTestSplit["test"]
}

finalData = {
    "train": splits["train"].map(preprocess, batched=True, remove_columns=["input", "target"]),
    "validation": splits["validation"].map(preprocess, batched=True, remove_columns=["input", "target"]),
    "test": splits["test"].map(preprocess, batched=True, remove_columns=["input", "target"])
}

dataCollator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6, weight_decay=0.01)

# learning rate sceduling - reduces the lr over all epochs
# helps model converge better by reducing oscillation
scheduler = CosineAnnealingLR(optimizer, T_max=7)

# Create dataloaders
train_dataloader = DataLoader(
    finalData["train"], 
    batch_size=4, 
    shuffle=True, 
    collate_fn=dataCollator
)

val_dataloader = DataLoader(
    finalData["validation"],
    batch_size=4,
    collate_fn=dataCollator
)

num_epochs = 7
device = "cuda"

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for step, batch in enumerate(progress_bar):
        # look up more info
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        # look up more info
        # added label_smoothing - makes the model less confident and improves generalization (ability to perform in unseen data)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        logits = outputs.logits
        loss_fct = CrossEntropyLoss(label_smoothing=0.1, ignore_index=-100)
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))        
        
        # look up more info
        optimizer.zero_grad()
        loss.backward()

        # gradient clipping - prevents gradients from breaking if model updates by large amount
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # clear memory to fix crash at 1800 steps
        if step % 100 == 0:
            torch.cuda.empty_cache()

        total_loss += loss.item()
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
        
        if step % 100 == 0:
            print(f"\nStep {step}, Loss: {loss.item():.4f}")
    
    avg_train_loss = total_loss / len(train_dataloader)
    print(f"\nEpoch {epoch+1} - Avg Train Loss: {avg_train_loss:.4f}")

    scheduler.step()
    
    model.eval()
    total_val_loss = 0
    
    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            total_val_loss += outputs.loss.item()
    
    avg_val_loss = total_val_loss / len(val_dataloader)
    print(f"Epoch {epoch+1} - Avg Val Loss: {avg_val_loss:.4f}")
    
model.save_pretrained("./flan-t5-small-label-smooth-balanced")
tokenizer.save_pretrained("./flan-t5-small-label-smooth-balanced")


Map: 100%|██████████| 30276/30276 [00:02<00:00, 13123.64 examples/s]
Map: 100%|██████████| 3784/3784 [00:00<00:00, 16158.58 examples/s]
Map: 100%|██████████| 3785/3785 [00:00<00:00, 16448.17 examples/s]
Epoch 1/7:   0%|          | 2/7569 [00:00<27:54,  4.52it/s, loss=5.3024]


Step 0, Loss: 5.1492


Epoch 1/7:   1%|▏         | 102/7569 [00:16<20:24,  6.10it/s, loss=4.5352]


Step 100, Loss: 4.4267


Epoch 1/7:   3%|▎         | 202/7569 [00:32<21:14,  5.78it/s, loss=5.5976]


Step 200, Loss: 5.0235


Epoch 1/7:   4%|▍         | 302/7569 [00:49<19:44,  6.14it/s, loss=3.8824]


Step 300, Loss: 3.5615


Epoch 1/7:   5%|▌         | 402/7569 [01:05<19:48,  6.03it/s, loss=4.9912]


Step 400, Loss: 4.5016


Epoch 1/7:   7%|▋         | 502/7569 [01:22<19:37,  6.00it/s, loss=4.5629]


Step 500, Loss: 3.4506


Epoch 1/7:   8%|▊         | 602/7569 [01:38<19:22,  5.99it/s, loss=4.9178]


Step 600, Loss: 5.1806


Epoch 1/7:   9%|▉         | 701/7569 [01:54<18:43,  6.11it/s, loss=5.1709]


Step 700, Loss: 3.9788


Epoch 1/7:  11%|█         | 802/7569 [02:10<18:31,  6.09it/s, loss=5.7303]


Step 800, Loss: 4.7128


Epoch 1/7:  12%|█▏        | 902/7569 [02:27<18:31,  6.00it/s, loss=4.9587]


Step 900, Loss: 3.6327


Epoch 1/7:  13%|█▎        | 1002/7569 [02:43<18:37,  5.88it/s, loss=5.5470]


Step 1000, Loss: 2.9185


Epoch 1/7:  15%|█▍        | 1102/7569 [03:00<19:03,  5.65it/s, loss=4.1318]


Step 1100, Loss: 3.0141


Epoch 1/7:  16%|█▌        | 1202/7569 [03:16<18:30,  5.73it/s, loss=3.0259]


Step 1200, Loss: 5.3474


Epoch 1/7:  17%|█▋        | 1302/7569 [03:33<18:06,  5.77it/s, loss=5.1523]


Step 1300, Loss: 5.4190


Epoch 1/7:  19%|█▊        | 1402/7569 [03:49<17:25,  5.90it/s, loss=2.2053]


Step 1400, Loss: 5.1574


Epoch 1/7:  20%|█▉        | 1502/7569 [04:05<16:59,  5.95it/s, loss=3.1918]


Step 1500, Loss: 5.0585


Epoch 1/7:  21%|██        | 1602/7569 [04:22<16:11,  6.14it/s, loss=3.4143]


Step 1600, Loss: 5.1191


Epoch 1/7:  22%|██▏       | 1702/7569 [04:39<16:05,  6.08it/s, loss=2.9043]


Step 1700, Loss: 5.1671


Epoch 1/7:  24%|██▍       | 1802/7569 [04:55<15:59,  6.01it/s, loss=2.5487]


Step 1800, Loss: 4.4686


Epoch 1/7:  25%|██▌       | 1902/7569 [05:11<14:55,  6.33it/s, loss=2.3397]


Step 1900, Loss: 4.0531


Epoch 1/7:  26%|██▋       | 2001/7569 [05:27<15:04,  6.16it/s, loss=2.3024]


Step 2000, Loss: 2.3024


Epoch 1/7:  28%|██▊       | 2101/7569 [05:44<15:30,  5.88it/s, loss=3.1964]


Step 2100, Loss: 3.1964


Epoch 1/7:  29%|██▉       | 2202/7569 [06:00<14:39,  6.11it/s, loss=3.6681]


Step 2200, Loss: 5.1633


Epoch 1/7:  30%|███       | 2302/7569 [06:17<15:20,  5.72it/s, loss=4.5968]


Step 2300, Loss: 3.0276


Epoch 1/7:  32%|███▏      | 2401/7569 [06:34<14:49,  5.81it/s, loss=3.7124]


Step 2400, Loss: 3.7124


Epoch 1/7:  33%|███▎      | 2502/7569 [06:51<15:40,  5.39it/s, loss=3.3107]


Step 2500, Loss: 5.1516


Epoch 1/7:  34%|███▍      | 2602/7569 [07:08<14:14,  5.81it/s, loss=2.6441]


Step 2600, Loss: 4.2398


Epoch 1/7:  36%|███▌      | 2701/7569 [07:25<15:03,  5.39it/s, loss=3.4073]


Step 2700, Loss: 2.7204


Epoch 1/7:  37%|███▋      | 2802/7569 [07:42<13:45,  5.78it/s, loss=3.2233]


Step 2800, Loss: 3.7710


Epoch 1/7:  38%|███▊      | 2902/7569 [08:00<14:21,  5.42it/s, loss=3.1842]


Step 2900, Loss: 2.3089


Epoch 1/7:  40%|███▉      | 3002/7569 [08:17<13:41,  5.56it/s, loss=3.3410]


Step 3000, Loss: 4.7596


Epoch 1/7:  41%|████      | 3101/7569 [08:34<12:58,  5.74it/s, loss=3.2989]


Step 3100, Loss: 3.2989


Epoch 1/7:  42%|████▏     | 3202/7569 [08:51<12:25,  5.86it/s, loss=5.5080]


Step 3200, Loss: 2.2624


Epoch 1/7:  44%|████▎     | 3301/7569 [09:08<12:22,  5.75it/s, loss=3.2328]


Step 3300, Loss: 3.2328


Epoch 1/7:  45%|████▍     | 3402/7569 [09:25<11:58,  5.80it/s, loss=5.0764]


Step 3400, Loss: 5.4340


Epoch 1/7:  46%|████▋     | 3502/7569 [09:43<11:05,  6.11it/s, loss=3.5665]


Step 3500, Loss: 3.4171


Epoch 1/7:  48%|████▊     | 3602/7569 [10:00<12:12,  5.41it/s, loss=4.5741]


Step 3600, Loss: 3.4549


Epoch 1/7:  49%|████▉     | 3701/7569 [10:16<11:06,  5.80it/s, loss=2.8294]


Step 3700, Loss: 2.8294


Epoch 1/7:  50%|█████     | 3802/7569 [10:33<10:35,  5.93it/s, loss=2.6851]


Step 3800, Loss: 3.1029


Epoch 1/7:  52%|█████▏    | 3902/7569 [10:50<10:01,  6.09it/s, loss=4.9797]


Step 3900, Loss: 4.1275


Epoch 1/7:  53%|█████▎    | 4002/7569 [11:07<10:01,  5.93it/s, loss=3.0110]


Step 4000, Loss: 3.2857


Epoch 1/7:  54%|█████▍    | 4102/7569 [11:24<10:12,  5.66it/s, loss=3.1558]


Step 4100, Loss: 3.1963


Epoch 1/7:  56%|█████▌    | 4202/7569 [11:41<09:27,  5.94it/s, loss=4.0025]


Step 4200, Loss: 2.4320


Epoch 1/7:  57%|█████▋    | 4301/7569 [11:58<09:11,  5.92it/s, loss=5.1172]


Step 4300, Loss: 5.0898


Epoch 1/7:  58%|█████▊    | 4402/7569 [12:15<09:00,  5.86it/s, loss=2.2347]


Step 4400, Loss: 4.1418


Epoch 1/7:  59%|█████▉    | 4502/7569 [12:32<09:05,  5.62it/s, loss=4.1881]


Step 4500, Loss: 1.8499


Epoch 1/7:  61%|██████    | 4602/7569 [12:49<08:30,  5.82it/s, loss=3.6372]


Step 4600, Loss: 2.7161


Epoch 1/7:  62%|██████▏   | 4701/7569 [13:06<08:21,  5.72it/s, loss=1.9624]


Step 4700, Loss: 1.9624


Epoch 1/7:  63%|██████▎   | 4802/7569 [13:23<08:03,  5.72it/s, loss=2.2358]


Step 4800, Loss: 5.0599


Epoch 1/7:  65%|██████▍   | 4901/7569 [13:39<07:29,  5.94it/s, loss=4.6938]


Step 4900, Loss: 4.6938


Epoch 1/7:  66%|██████▌   | 5002/7569 [13:56<07:09,  5.97it/s, loss=3.5705]


Step 5000, Loss: 4.0006


Epoch 1/7:  67%|██████▋   | 5102/7569 [14:13<06:57,  5.90it/s, loss=4.6136]


Step 5100, Loss: 4.3951


Epoch 1/7:  69%|██████▊   | 5202/7569 [14:30<06:56,  5.68it/s, loss=4.3218]


Step 5200, Loss: 1.9400


Epoch 1/7:  70%|███████   | 5302/7569 [14:47<06:20,  5.96it/s, loss=3.8912]


Step 5300, Loss: 3.2664


Epoch 1/7:  71%|███████▏  | 5402/7569 [15:04<06:10,  5.85it/s, loss=5.7806]


Step 5400, Loss: 5.3024


Epoch 1/7:  73%|███████▎  | 5502/7569 [15:21<06:06,  5.63it/s, loss=5.1289]


Step 5500, Loss: 5.3649


Epoch 1/7:  74%|███████▍  | 5602/7569 [15:37<05:28,  6.00it/s, loss=3.5915]


Step 5600, Loss: 5.0912


Epoch 1/7:  75%|███████▌  | 5702/7569 [15:54<05:21,  5.81it/s, loss=2.9038]


Step 5700, Loss: 3.7837


Epoch 1/7:  77%|███████▋  | 5802/7569 [16:11<05:07,  5.74it/s, loss=3.0099]


Step 5800, Loss: 3.6106


Epoch 1/7:  78%|███████▊  | 5902/7569 [16:28<04:53,  5.68it/s, loss=2.3801]


Step 5900, Loss: 5.1821


Epoch 1/7:  79%|███████▉  | 6002/7569 [16:45<04:30,  5.79it/s, loss=4.5526]


Step 6000, Loss: 5.0998


Epoch 1/7:  81%|████████  | 6101/7569 [17:01<04:14,  5.78it/s, loss=4.8710]


Step 6100, Loss: 4.8710


Epoch 1/7:  82%|████████▏ | 6202/7569 [17:18<04:00,  5.69it/s, loss=3.0193]


Step 6200, Loss: 3.8896


Epoch 1/7:  83%|████████▎ | 6302/7569 [17:35<03:39,  5.78it/s, loss=4.1206]


Step 6300, Loss: 4.5176


Epoch 1/7:  85%|████████▍ | 6402/7569 [17:52<03:26,  5.65it/s, loss=4.7562]


Step 6400, Loss: 2.8918


Epoch 1/7:  86%|████████▌ | 6502/7569 [18:09<03:12,  5.54it/s, loss=2.6743]


Step 6500, Loss: 4.0855


Epoch 1/7:  87%|████████▋ | 6602/7569 [18:26<02:48,  5.75it/s, loss=2.5326]


Step 6600, Loss: 1.7848


Epoch 1/7:  89%|████████▊ | 6702/7569 [18:43<02:28,  5.82it/s, loss=3.2580]


Step 6700, Loss: 2.3623


Epoch 1/7:  90%|████████▉ | 6802/7569 [19:00<02:33,  4.99it/s, loss=3.8185]


Step 6800, Loss: 2.7231


Epoch 1/7:  91%|█████████ | 6902/7569 [19:16<01:56,  5.74it/s, loss=4.0423]


Step 6900, Loss: 2.4312


Epoch 1/7:  93%|█████████▎| 7002/7569 [19:33<01:36,  5.88it/s, loss=4.1049]


Step 7000, Loss: 3.4869


Epoch 1/7:  94%|█████████▍| 7102/7569 [19:50<01:21,  5.72it/s, loss=4.4933]


Step 7100, Loss: 3.0369


Epoch 1/7:  95%|█████████▌| 7202/7569 [20:07<01:05,  5.58it/s, loss=2.2262]


Step 7200, Loss: 2.2538


Epoch 1/7:  96%|█████████▋| 7302/7569 [20:23<00:43,  6.07it/s, loss=3.1765]


Step 7300, Loss: 4.8730


Epoch 1/7:  98%|█████████▊| 7402/7569 [20:40<00:29,  5.69it/s, loss=4.1515]


Step 7400, Loss: 5.7678


Epoch 1/7:  99%|█████████▉| 7502/7569 [20:57<00:12,  5.38it/s, loss=2.6996]


Step 7500, Loss: 5.0806


Epoch 1/7: 100%|██████████| 7569/7569 [21:09<00:00,  5.96it/s, loss=5.4632]



Epoch 1 - Avg Train Loss: 3.8623
Epoch 1 - Avg Val Loss: 2.1116


Epoch 2/7:   0%|          | 2/7569 [00:00<23:38,  5.34it/s, loss=4.9636]


Step 0, Loss: 2.8015


Epoch 2/7:   1%|▏         | 102/7569 [00:17<21:33,  5.77it/s, loss=2.4070]


Step 100, Loss: 2.8391


Epoch 2/7:   3%|▎         | 202/7569 [00:33<21:21,  5.75it/s, loss=4.3700]


Step 200, Loss: 3.6890


Epoch 2/7:   4%|▍         | 302/7569 [00:50<20:02,  6.04it/s, loss=2.5060]


Step 300, Loss: 4.0156


Epoch 2/7:   5%|▌         | 402/7569 [01:07<21:34,  5.54it/s, loss=2.9012]


Step 400, Loss: 5.4762


Epoch 2/7:   7%|▋         | 502/7569 [01:24<19:25,  6.06it/s, loss=3.2295]


Step 500, Loss: 4.8582


Epoch 2/7:   8%|▊         | 602/7569 [01:41<21:22,  5.43it/s, loss=1.7444]


Step 600, Loss: 4.5097


Epoch 2/7:   9%|▉         | 702/7569 [01:58<19:58,  5.73it/s, loss=4.5195]


Step 700, Loss: 2.4461


Epoch 2/7:  11%|█         | 802/7569 [02:15<19:53,  5.67it/s, loss=3.2528]


Step 800, Loss: 3.5975


Epoch 2/7:  12%|█▏        | 902/7569 [02:32<19:33,  5.68it/s, loss=4.5006]


Step 900, Loss: 3.8696


Epoch 2/7:  13%|█▎        | 1002/7569 [02:49<18:38,  5.87it/s, loss=2.4519]


Step 1000, Loss: 4.3021


Epoch 2/7:  15%|█▍        | 1102/7569 [03:06<18:24,  5.86it/s, loss=3.5574]


Step 1100, Loss: 2.8928


Epoch 2/7:  16%|█▌        | 1202/7569 [03:23<18:03,  5.88it/s, loss=4.5812]


Step 1200, Loss: 2.4564


Epoch 2/7:  17%|█▋        | 1301/7569 [03:40<17:18,  6.04it/s, loss=2.8302]


Step 1300, Loss: 2.8302


Epoch 2/7:  19%|█▊        | 1401/7569 [03:57<18:14,  5.63it/s, loss=5.0627]


Step 1400, Loss: 3.2608


Epoch 2/7:  20%|█▉        | 1502/7569 [04:14<17:02,  5.93it/s, loss=1.9881]


Step 1500, Loss: 4.8714


Epoch 2/7:  21%|██        | 1602/7569 [04:31<16:35,  5.99it/s, loss=5.2028]


Step 1600, Loss: 4.4396


Epoch 2/7:  22%|██▏       | 1702/7569 [04:48<17:10,  5.69it/s, loss=2.5680]


Step 1700, Loss: 3.7377


Epoch 2/7:  24%|██▍       | 1802/7569 [05:06<19:20,  4.97it/s, loss=4.2260]


Step 1800, Loss: 2.1059


Epoch 2/7:  25%|██▌       | 1901/7569 [05:44<44:08,  2.14it/s, loss=5.6566]


Step 1900, Loss: 5.6566


Epoch 2/7:  26%|██▋       | 2002/7569 [06:03<14:56,  6.21it/s, loss=4.6067]


Step 2000, Loss: 4.8224


Epoch 2/7:  28%|██▊       | 2102/7569 [06:19<14:54,  6.11it/s, loss=2.9705]


Step 2100, Loss: 2.3763


Epoch 2/7:  29%|██▉       | 2202/7569 [06:35<15:01,  5.95it/s, loss=3.5923]


Step 2200, Loss: 4.7283


Epoch 2/7:  30%|███       | 2302/7569 [06:51<13:40,  6.42it/s, loss=2.5830]


Step 2300, Loss: 2.2411


Epoch 2/7:  32%|███▏      | 2402/7569 [07:08<13:58,  6.16it/s, loss=2.3030]


Step 2400, Loss: 3.8562


Epoch 2/7:  33%|███▎      | 2502/7569 [07:24<14:16,  5.92it/s, loss=2.6088]


Step 2500, Loss: 2.7337


Epoch 2/7:  34%|███▍      | 2602/7569 [07:40<13:46,  6.01it/s, loss=2.5661]


Step 2600, Loss: 3.4622


Epoch 2/7:  36%|███▌      | 2701/7569 [07:56<13:41,  5.92it/s, loss=4.3011]


Step 2700, Loss: 4.3011


Epoch 2/7:  37%|███▋      | 2802/7569 [08:13<13:36,  5.84it/s, loss=5.2815]


Step 2800, Loss: 1.7892


Epoch 2/7:  38%|███▊      | 2902/7569 [08:30<13:12,  5.89it/s, loss=2.6962]


Step 2900, Loss: 4.9848


Epoch 2/7:  40%|███▉      | 3002/7569 [08:47<12:36,  6.04it/s, loss=1.8218]


Step 3000, Loss: 2.1211


Epoch 2/7:  41%|████      | 3102/7569 [09:04<11:59,  6.21it/s, loss=2.1681]


Step 3100, Loss: 2.7618


Epoch 2/7:  42%|████▏     | 3202/7569 [09:21<11:47,  6.18it/s, loss=3.3362]


Step 3200, Loss: 1.6525


Epoch 2/7:  44%|████▎     | 3302/7569 [09:38<11:51,  6.00it/s, loss=2.0143]


Step 3300, Loss: 3.5871


Epoch 2/7:  45%|████▍     | 3402/7569 [09:54<11:40,  5.95it/s, loss=4.3437]


Step 3400, Loss: 5.1885


Epoch 2/7:  46%|████▋     | 3502/7569 [10:11<11:34,  5.85it/s, loss=3.2431]


Step 3500, Loss: 5.1578


Epoch 2/7:  48%|████▊     | 3602/7569 [10:28<12:19,  5.37it/s, loss=4.6302]


Step 3600, Loss: 4.2285


Epoch 2/7:  49%|████▉     | 3702/7569 [10:45<10:53,  5.92it/s, loss=3.5943]


Step 3700, Loss: 3.5842


Epoch 2/7:  50%|█████     | 3802/7569 [11:02<10:48,  5.81it/s, loss=1.9590]


Step 3800, Loss: 3.8404


Epoch 2/7:  52%|█████▏    | 3902/7569 [11:19<10:25,  5.86it/s, loss=3.1378]


Step 3900, Loss: 3.3029


Epoch 2/7:  53%|█████▎    | 4002/7569 [11:35<10:25,  5.71it/s, loss=3.4667]


Step 4000, Loss: 3.7223


Epoch 2/7:  54%|█████▍    | 4102/7569 [11:52<10:12,  5.66it/s, loss=2.6966]


Step 4100, Loss: 2.8895


Epoch 2/7:  56%|█████▌    | 4202/7569 [12:09<09:29,  5.92it/s, loss=4.5286]


Step 4200, Loss: 2.1420


Epoch 2/7:  57%|█████▋    | 4301/7569 [12:26<09:52,  5.52it/s, loss=3.0882]


Step 4300, Loss: 3.0882


Epoch 2/7:  58%|█████▊    | 4402/7569 [12:43<09:05,  5.81it/s, loss=2.9796]


Step 4400, Loss: 4.0353


Epoch 2/7:  59%|█████▉    | 4502/7569 [13:00<09:18,  5.49it/s, loss=2.8663]


Step 4500, Loss: 5.4027


Epoch 2/7:  61%|██████    | 4602/7569 [13:17<08:21,  5.91it/s, loss=3.3430]


Step 4600, Loss: 2.7530


Epoch 2/7:  62%|██████▏   | 4702/7569 [13:34<08:18,  5.75it/s, loss=4.8052]


Step 4700, Loss: 3.5755


Epoch 2/7:  63%|██████▎   | 4801/7569 [13:52<07:55,  5.82it/s, loss=3.7262]


Step 4800, Loss: 2.9224


Epoch 2/7:  65%|██████▍   | 4902/7569 [14:09<07:31,  5.91it/s, loss=4.6700]


Step 4900, Loss: 4.4626


Epoch 2/7:  66%|██████▌   | 5002/7569 [14:26<07:15,  5.90it/s, loss=1.7044]


Step 5000, Loss: 3.2837


Epoch 2/7:  67%|██████▋   | 5102/7569 [14:42<07:03,  5.82it/s, loss=4.9634]


Step 5100, Loss: 2.5136


Epoch 2/7:  69%|██████▊   | 5202/7569 [14:59<06:55,  5.70it/s, loss=2.4979]


Step 5200, Loss: 5.3107


Epoch 2/7:  70%|███████   | 5302/7569 [15:16<06:24,  5.90it/s, loss=3.7836]


Step 5300, Loss: 3.1689


Epoch 2/7:  71%|███████▏  | 5402/7569 [15:33<06:30,  5.56it/s, loss=1.8054]


Step 5400, Loss: 3.4476


Epoch 2/7:  73%|███████▎  | 5502/7569 [15:50<05:54,  5.83it/s, loss=1.6916]


Step 5500, Loss: 4.0163


Epoch 2/7:  74%|███████▍  | 5602/7569 [16:07<05:23,  6.08it/s, loss=2.3464]


Step 5600, Loss: 2.7913


Epoch 2/7:  75%|███████▌  | 5702/7569 [16:24<05:19,  5.85it/s, loss=3.3423]


Step 5700, Loss: 3.9179


Epoch 2/7:  77%|███████▋  | 5802/7569 [16:41<04:56,  5.96it/s, loss=2.9544]


Step 5800, Loss: 4.5985


Epoch 2/7:  78%|███████▊  | 5902/7569 [16:57<04:57,  5.60it/s, loss=4.8090]


Step 5900, Loss: 2.7901


Epoch 2/7:  79%|███████▉  | 6002/7569 [17:14<04:30,  5.80it/s, loss=4.5890]


Step 6000, Loss: 3.2330


Epoch 2/7:  81%|████████  | 6102/7569 [17:31<04:17,  5.70it/s, loss=5.2058]


Step 6100, Loss: 3.4629


Epoch 2/7:  82%|████████▏ | 6202/7569 [17:48<04:00,  5.69it/s, loss=5.6642]


Step 6200, Loss: 2.0717


Epoch 2/7:  83%|████████▎ | 6302/7569 [18:05<03:29,  6.05it/s, loss=2.9053]


Step 6300, Loss: 2.2172


Epoch 2/7:  85%|████████▍ | 6402/7569 [18:22<03:26,  5.66it/s, loss=4.7250]


Step 6400, Loss: 5.1932


Epoch 2/7:  86%|████████▌ | 6502/7569 [18:38<02:59,  5.96it/s, loss=2.9824]


Step 6500, Loss: 1.9469


Epoch 2/7:  87%|████████▋ | 6601/7569 [18:56<02:44,  5.88it/s, loss=2.0875]


Step 6600, Loss: 3.2489


Epoch 2/7:  89%|████████▊ | 6702/7569 [19:13<02:19,  6.20it/s, loss=2.8779]


Step 6700, Loss: 4.0169


Epoch 2/7:  90%|████████▉ | 6802/7569 [19:31<02:30,  5.09it/s, loss=2.1893]


Step 6800, Loss: 4.9080


Epoch 2/7:  91%|█████████ | 6902/7569 [19:49<01:56,  5.71it/s, loss=1.9933]


Step 6900, Loss: 5.1154


Epoch 2/7:  93%|█████████▎| 7002/7569 [20:10<01:52,  5.06it/s, loss=2.2035]


Step 7000, Loss: 4.2460


Epoch 2/7:  94%|█████████▍| 7102/7569 [20:27<01:24,  5.50it/s, loss=1.6002]


Step 7100, Loss: 5.4326


Epoch 2/7:  95%|█████████▌| 7202/7569 [20:49<01:14,  4.94it/s, loss=3.0046]


Step 7200, Loss: 4.3586


Epoch 2/7:  96%|█████████▋| 7302/7569 [21:08<00:48,  5.49it/s, loss=4.1804]


Step 7300, Loss: 3.0323


Epoch 2/7:  98%|█████████▊| 7401/7569 [21:26<00:36,  4.62it/s, loss=3.8627]


Step 7400, Loss: 3.8627


Epoch 2/7:  99%|█████████▉| 7501/7569 [21:46<00:15,  4.41it/s, loss=3.9116]


Step 7500, Loss: 3.9116


Epoch 2/7: 100%|██████████| 7569/7569 [21:59<00:00,  5.73it/s, loss=2.8034]



Epoch 2 - Avg Train Loss: 3.4619


Now that the model is fine-tuned on the initial dataset, it can be locally queried to it correctly provides predictions.

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import random

model = AutoModelForSeq2SeqLM.from_pretrained("./flan-t5-bsae-CUSTOM-TRAINED").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("./flan-t5-bsae-CUSTOM-TRAINED")

test_indices = random.sample(range(len(data)), 20)

for idx in test_indices:
    sample = data[idx]
    question = sample["input"]
    true_answer = sample["target"]
    
    inputs = tokenizer(question, return_tensors="pt", max_length=512, truncation=True).to("cuda")
    outputs = model.generate(**inputs, max_length=128, num_beams=4)
    predicted = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
    print(f"\nQ: {question[:70]}...")
    print(f"True: {true_answer[:60]}...")
    print(f"Pred: {predicted}")

# Test on custom questions
print("custom questions:")

custom_questions = [
    "Answer this question about United States: What body/agency grants banking licenses?",
    "Answer this question about France: What is the minimum capital requirement?",
    "Answer this question about Japan: Who regulates banks?"
]

for q in custom_questions:
    inputs = tokenizer(q, return_tensors="pt").to("cuda")
    outputs = model.generate(**inputs, max_length=128, num_beams=4)
    pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"\nQ: {q}")
    print(f"A: {pred}")


Q: Answer this question about Portugal: 14.1 What body/agency has the res...
True: The Central Bank for the retail banking sector. ComissÃ£o do...
Pred: The Financial Supervisory Authority of Portugal (Financial Supervisory Authority) is responsible for implementing, overseeing and enforcing any aspects of financial consumer protection laws and regulations. The Financial Supervisory Authority of Portugal (Financial Supervisory Authority) is responsible for implementing, overseeing and enforcing any aspects of financial consumer protection laws and regulations.

Q: Answer this question about Cayman Islands: 3.20.1 Which of the followi...
True: Limit of 1.25% of risk weighted assets...
Pred: Tier 2 capital is not recognised by the Cayman Islands Monetary Authority

Q: Answer this question about Lebanon: 12.1.1 a. Commercial banks...
True: Banking Control Commision of Lebanon and Special investigati...
Pred: Banking Control Commision of Lebanon and Special investigation Committee

Q: Ans

Now that we have made sure the model is working, we can upload it to a server. In this project, I'm using HuggingFace as its free and allows for easy testing. For production we would use Azure/AWS/GCP.

In [None]:

# log in to huggingface - enter
# you may need to run the authentication command directly in your terminal 
!pip install huggingface_hub
!hf auth login

^C


Now that you are logged in to huggingface, you must upload the trained model.

In [None]:
# upload model to hf:
model.push_to_hub("mian21/flan-t5-bsae-CUSTOM-TRAINED")
tokenizer.push_to_hub("mian21/flan-t5-bsae-CUSTOM-TRAINED")

model.safetensors: 100%|██████████| 308M/308M [03:58<00:00, 1.29MB/s]   
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


CommitInfo(commit_url='https://huggingface.co/mian21/flan-t5-bsae-CUSTOM-TRAINED/commit/9269846cf7bba0c68b074278b603102fde5356e8', commit_message='Upload tokenizer', commit_description='', oid='9269846cf7bba0c68b074278b603102fde5356e8', pr_url=None, repo_url=RepoUrl('https://huggingface.co/mian21/flan-t5-bsae-CUSTOM-TRAINED', endpoint='https://huggingface.co', repo_type='model', repo_id='mian21/flan-t5-bsae-CUSTOM-TRAINED'), pr_revision=None, pr_num=None)

The model can be queried directly from the huggingface server using API requests or loaded directly into your code using huggingface's autotrainer.

In [None]:
import requests

API_URL = "https://router.huggingface.co/hf-inference/models/mian21/flan-t5-bsae-CUSTOM-TRAINED"
headers = {"Authorization": "Bearer "}

payload = {
  "inputs": "question: What body/agency grants banking licenses in the United States?",
  "parameters": {"max_new_tokens": 128, "temperature": 0.2}
}

resp = requests.post(API_URL, headers=headers, json=payload, timeout=120)
print(resp.status_code, resp.text)


404 Not Found
