# World Bank Financial Survey Q&A Model Project 

This project develops a NLP powererd question-answering system that is trained on World Bank Survey Data containing financial information gathered from various federal banks across the globe. This notebook walks the user through gathering/processing the data and training/deploying the final model. 

### Dataset Description
The World Bank survey dataset comprises of structured financial questions sent to financial instituitions worldwide. The dataset includes multi-dimensional survey responses, hierarchial question structures, and financial metrics. For this project, we will use the questions that have long-form textual answers to train an NLP model, rather than using binary response questions.

### Project Architecture

##### Phase 1 - Data Processing 

- Transform unstructured survey data into structurerd NLP training pairs
    - Parse all relevant sheets from excel file
    - Properly handle hierarchical question structures to ensure each question answer pair is standalone
- Indentify and Flag PII using a ML model in dataset

##### Phase 2 - Model Development & Fine Tuning

- Fine-tune a Google FLAN-T5-Base NLP model 
- Optimize the model's performance on this specific World Bank survey domain
- Evaluate model performance using validation and test sample sets

##### Phase 4 - Deployement

- Deploy fine-tuned model to production environment on Azure/Huggingface

##### Future Steps (if time allows):
- implement API for interacting with model
- add some sort of sentiment analysis to clasify questions/answers (financial questions, admin questions, etc)
- get feedback on model performance (answer quality/hallucinations/knowledge gaps)  
- add additional survey questions to knowledge base 

In [None]:
# download data from World Bank Database
import requests

url = "https://datacatalogfiles.worldbank.org/ddh-published/0038632/2/DR0047737/2021_04_26_brss-public-release.xlsx"
response = requests.get(url)

with open("worldbank_data.xlsx", "wb") as f:
    f.write(response.content)

Now that data is downloaded, it needs to be converted from an xlsx file with row column format to something that works for t5 training (question:answer pairs).  

In [None]:
## read and process data
import pandas as pd
import re

# Remove extra unnecessary information from question
# For example, "Select all that apply"
def simplify_question(qText):
    if pd.isna(qText):
        return ""
    
    text = str(qText).strip()
    
    # split on common instruction starters and take first part
    for splitter in [" Please ", " If ", " Include ", " Specify ", " Describe ", " List "]:
        if splitter in text:
            text = text.split(splitter)[0]
            break
    
    # if there's a question mark, take up to first one
    if "?" in text:
        text = text.split("?")[0] + "?"
    
    return text.strip()

# loads all sheets at once
allSheets = pd.ExcelFile("worldbank_data.xlsx")

# store samples
samples = []

# process all sheets except first 2 and last 1
process = allSheets.sheet_names[2:-1]

# read first sheet and extract countries
dfFirst = pd.read_excel(allSheets, sheet_name=process[0], header=None)
countries = [str(c) for c in dfFirst.iloc[0, 2:].values if not pd.isna(c)]

for sheet in process:
    # read current sheet
    df = pd.read_excel(allSheets, sheet_name=sheet, header=None)
    
    # create parent and base vars
    parent = None
    currBase = None
    
    # iterate through every row except header
    # get question index and question text
    for idx, row in df.iloc[1:].iterrows():
        qIndex = row[0]
        qText = row[1]
        
        # if the question index is null but text does exist 
        # then the question is a parent question
        # assign parent question and then clear prev base and move onto next row
        if pd.isna(qIndex) and not pd.isna(qText):
            parent = simplify_question(qText)  # ← Simplify parent too
            currBase = None
            continue
        
        # regex starts with Q and captures groups delimited by _
        # group 1 is the main question number
        # group 2 is sub-question number
        # group 3 is for multi-part questions with extra text
        # non-capturing group is for sections of index which are unnecessary
        match = re.match(r'Q(\d+)_([0-9_]+?)([a-zA-Z_]+)?(?:_[A-Z]|_\d{4}|$)', str(qIndex))
        
        # if regex matched then process row, otherwise skip
        if match:
            baseNum = f"{match.group(1)}_{match.group(2)}"
            isMulti = bool(match.group(3)) or bool(re.search(r'_\d{4}', str(qIndex)))
            part = match.group(3) if match.group(3) else ""
        else:
            continue
        
        # if new base is different to current base, update base
        if baseNum and baseNum != currBase:
            # reset parent if new question isn't multi part
            if not isMulti:
                parent = None
            currBase = baseNum
        
        # loop through each column
        for colIdx, country in enumerate(countries):
            
            # get answer for current column
            answer = row[colIdx + 2]
            
            # skip column if there's no answer
            if pd.isna(answer):
                continue
            
            # Simplify the question text
            simplifiedQ = simplify_question(qText)  # ← KEY CHANGE
            
            # if question is multi-part combine parent question and question text
            if isMulti and parent:
                completeQ = f"{parent} {simplifiedQ}"
            # otherwise just append question text
            else:
                completeQ = simplifiedQ
            
            # fill in sample entry
            sample = {
                "input": f"Answer this question about {country}: {completeQ}".strip(),
                "target": str(answer).strip()
            }
            
            # append sample to list
            samples.append(sample)

Sample questions after simplification:
0: Answer this question about Albania: 1.1 What body/agency grants banking licenses?...
1: Answer this question about Angola: 1.1 What body/agency grants banking licenses?...
2: Answer this question about Antigua and Barbuda: 1.1 What body/agency grants banking licenses?...
3: Answer this question about Argentina: 1.1 What body/agency grants banking licenses?...
4: Answer this question about Armenia: 1.1 What body/agency grants banking licenses?...


Now that the data is in proper training format, it needs to be checked for PII. We will use Microsoft's Presidio pre-trained ML library to detect PII (https://github.com/microsoft/presidio).

In [None]:
# install dependecies
# !pip install presidio_analyzer presidio_anonymizer
# !python -m spacy download en_core_web_lg

In [None]:
from presidio_analyzer import AnalyzerEngine
from tqdm import tqdm
import json

# initialize analyzer
analyzer = AnalyzerEngine()

# specific countries and years are necessary to the survey data
# do not flag these as PII
excludeWords = set(countries)
excludeWords.update(['2011', '2012', '2013', '2014', '2015', '2016'])

# only include entries that the model has 70%+ confidnece is PII
CONFIDENCE = 0.7

# only track unique PII values
seenPII = set()

# storage for PII
potentialPII = []

# iterate through every sample
for idx, sample in enumerate(tqdm(samples, desc='finding pii')):

    # get input question and target
    inputText = sample["input"]
    targetText = sample["target"]

    # analyze input and target
    inputRes = analyzer.analyze(text=inputText, language='en')
    targetRes = analyzer.analyze(text=targetText, language='en')

    # filter out exclude list from text matches
    inputRes = [r for r in inputRes 
                if r.score >= CONFIDENCE
                and not any(inputText[r.start:r.end] in word or word in inputText[r.start:r.end] for word in excludeWords)] 
    targetRes = [r for r in targetRes 
                 if r.score >= CONFIDENCE 
                 and not any(targetText[r.start:r.end] in word or word in targetText[r.start:r.end] for word in excludeWords)]

    # if pii is found
    isNewPII = False
    for r in inputRes:
        if inputText[r.start:r.end] not in seenPII:
            isNewPII = True
            seenPII.add(inputText[r.start:r.end])
    for r in targetRes:
        if targetText[r.start:r.end] not in seenPII:
            isNewPII = True
            seenPII.add(targetText[r.start:r.end])

    if isNewPII:
        res = {
            "input": inputText,
            "target": targetText,
            "inputPII": [{"type": r.entity_type, "text": inputText[r.start:r.end], "score": r.score} for r in inputRes],
            "targetPII": [{"type": r.entity_type, "text": targetText[r.start:r.end], "score": r.score} for r in targetRes]
        }
        potentialPII.append(res)

# dump all potential flagged PII into a json file
with open('potentialPII.json', 'w', encoding='utf-8') as f:
    json.dump(potentialPII, f, indent=2, ensure_ascii=False)


finding pii: 100%|██████████| 107833/107833 [35:25<00:00, 50.73it/s] 


The code dumps all potential PII matches to a seperate JSON file saved to the current directory (potentiallyPII.json). This file can now be manually checked to determine which flagged keywords are false postives and which are actually PII. Once all PII is removed from the dataset, the T5 model training can begin.

In [None]:
# install dependencies
# !pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128
# !pip install transformers datasets accelerate


In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM, 
    DataCollatorForSeq2Seq,
)
from datasets import Dataset

samplesFiltered = [s for s in samples if len(s["target"].split()) >= 3]

# convert existing data to hugging face dataset
data = Dataset.from_list(samplesFiltered)

# Setup
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

def preprocess(samples):
    modelInputs = tokenizer(
        samples["input"],
        max_length=512,
        truncation=True,
        padding=False
    )
    targets = tokenizer(
        samples["target"],
        max_length=128,
        truncation=True,
        padding=False
    )
    modelInputs["labels"] = targets["input_ids"]
    return modelInputs

trainValSplit = data.train_test_split(test_size=0.2)
valTestSplit = trainValSplit["test"].train_test_split(test_size=0.5)

splits = {
    "train": trainValSplit["train"],
    "validation": valTestSplit['train'],
    "test": valTestSplit["test"]
}

finalData = {
    "train": splits["train"].map(preprocess, batched=True, remove_columns=["input", "target"]),
    "validation": splits["validation"].map(preprocess, batched=True, remove_columns=["input", "target"]),
    "test": splits["test"].map(preprocess, batched=True, remove_columns=["input", "target"])
}

dataCollator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6, weight_decay=0.01)

# Create dataloaders
train_dataloader = DataLoader(
    finalData["train"], 
    batch_size=2, 
    shuffle=True, 
    collate_fn=dataCollator
)

val_dataloader = DataLoader(
    finalData["validation"],
    batch_size=2,
    collate_fn=dataCollator
)

num_epochs = 7
device = "cuda"

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for step, batch in enumerate(progress_bar):
        # look up more info
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        # look up more info
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        
        # look up more info
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
        
        if step % 100 == 0:
            print(f"\nStep {step}, Loss: {loss.item():.4f}")
    
    avg_train_loss = total_loss / len(train_dataloader)
    print(f"\nEpoch {epoch+1} - Avg Train Loss: {avg_train_loss:.4f}")
    
    model.eval()
    total_val_loss = 0
    
    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            total_val_loss += outputs.loss.item()
    
    avg_val_loss = total_val_loss / len(val_dataloader)
    print(f"Epoch {epoch+1} - Avg Val Loss: {avg_val_loss:.4f}")
    
model.save_pretrained("./flan-t5-bsae-CUSTOM-TRAINED")
tokenizer.save_pretrained("./flan-t5-bsae-CUSTOM-TRAINED")


Epoch 1/7:   0%|          | 1/2198 [00:01<1:13:09,  2.00s/it, loss=3.2753]


Step 0, Loss: 3.2753


Epoch 1/7:   5%|▍         | 101/2198 [01:34<31:00,  1.13it/s, loss=3.9937]


Step 100, Loss: 3.9937


Epoch 1/7:   9%|▉         | 201/2198 [03:05<30:01,  1.11it/s, loss=3.8435]


Step 200, Loss: 3.8435


Epoch 1/7:  14%|█▎        | 301/2198 [04:35<28:43,  1.10it/s, loss=3.7290]


Step 300, Loss: 3.7290


Epoch 1/7:  18%|█▊        | 401/2198 [06:05<27:17,  1.10it/s, loss=5.0440]


Step 400, Loss: 5.0440


Epoch 1/7:  23%|██▎       | 501/2198 [07:37<25:44,  1.10it/s, loss=3.7273]


Step 500, Loss: 3.7273


Epoch 1/7:  27%|██▋       | 601/2198 [09:09<24:01,  1.11it/s, loss=3.1444]


Step 600, Loss: 3.1444


Epoch 1/7:  32%|███▏      | 701/2198 [10:39<23:04,  1.08it/s, loss=2.8311]


Step 700, Loss: 2.8311


Epoch 1/7:  36%|███▋      | 801/2198 [12:10<20:46,  1.12it/s, loss=2.3301]


Step 800, Loss: 2.3301


Epoch 1/7:  41%|████      | 901/2198 [13:42<19:47,  1.09it/s, loss=2.9432]


Step 900, Loss: 2.9432


Epoch 1/7:  46%|████▌     | 1001/2198 [15:12<18:41,  1.07it/s, loss=0.5565]


Step 1000, Loss: 0.5565


Epoch 1/7:  50%|█████     | 1101/2198 [16:43<16:43,  1.09it/s, loss=3.8108]


Step 1100, Loss: 3.8108


Epoch 1/7:  55%|█████▍    | 1201/2198 [18:12<14:47,  1.12it/s, loss=3.0865]


Step 1200, Loss: 3.0865


Epoch 1/7:  59%|█████▉    | 1301/2198 [19:44<13:39,  1.09it/s, loss=0.7979]


Step 1300, Loss: 0.7979


Epoch 1/7:  64%|██████▎   | 1401/2198 [21:15<11:56,  1.11it/s, loss=3.1600]


Step 1400, Loss: 3.1600


Epoch 1/7:  68%|██████▊   | 1501/2198 [22:45<10:23,  1.12it/s, loss=2.7091]


Step 1500, Loss: 2.7091


Epoch 1/7:  73%|███████▎  | 1601/2198 [24:16<09:02,  1.10it/s, loss=3.4244]


Step 1600, Loss: 3.4244


Epoch 1/7:  77%|███████▋  | 1701/2198 [25:48<07:47,  1.06it/s, loss=3.4210]


Step 1700, Loss: 3.4210


Epoch 1/7:  82%|████████▏ | 1801/2198 [27:19<05:53,  1.12it/s, loss=3.9049]


Step 1800, Loss: 3.9049


Epoch 1/7:  86%|████████▋ | 1901/2198 [28:48<04:28,  1.11it/s, loss=2.5346]


Step 1900, Loss: 2.5346


Epoch 1/7:  91%|█████████ | 2001/2198 [30:19<02:56,  1.12it/s, loss=1.5541]


Step 2000, Loss: 1.5541


Epoch 1/7:  96%|█████████▌| 2101/2198 [31:51<01:29,  1.09it/s, loss=3.2844]


Step 2100, Loss: 3.2844


Epoch 1/7: 100%|██████████| 2198/2198 [33:20<00:00,  1.10it/s, loss=2.4665]



Epoch 1 - Avg Train Loss: 3.1506
Epoch 1 - Avg Val Loss: 2.5105


Epoch 2/7:   0%|          | 1/2198 [00:00<34:19,  1.07it/s, loss=3.9791]


Step 0, Loss: 3.9791


Epoch 2/7:   5%|▍         | 101/2198 [01:31<32:33,  1.07it/s, loss=3.0512]


Step 100, Loss: 3.0512


Epoch 2/7:   9%|▉         | 201/2198 [03:01<30:25,  1.09it/s, loss=2.7804]


Step 200, Loss: 2.7804


Epoch 2/7:  14%|█▎        | 301/2198 [04:32<28:12,  1.12it/s, loss=3.2976]


Step 300, Loss: 3.2976


Epoch 2/7:  18%|█▊        | 401/2198 [06:04<26:59,  1.11it/s, loss=2.5305]


Step 400, Loss: 2.5305


Epoch 2/7:  23%|██▎       | 501/2198 [07:36<26:46,  1.06it/s, loss=3.0669]


Step 500, Loss: 3.0669


Epoch 2/7:  27%|██▋       | 601/2198 [09:06<23:41,  1.12it/s, loss=2.6273]


Step 600, Loss: 2.6273


Epoch 2/7:  32%|███▏      | 701/2198 [10:36<22:51,  1.09it/s, loss=3.3413]


Step 700, Loss: 3.3413


Epoch 2/7:  36%|███▋      | 801/2198 [12:06<21:59,  1.06it/s, loss=3.1580]


Step 800, Loss: 3.1580


Epoch 2/7:  41%|████      | 901/2198 [13:35<19:30,  1.11it/s, loss=2.2706]


Step 900, Loss: 2.2706


Epoch 2/7:  46%|████▌     | 1001/2198 [15:04<18:17,  1.09it/s, loss=2.0621]


Step 1000, Loss: 2.0621


Epoch 2/7:  50%|█████     | 1101/2198 [16:35<16:19,  1.12it/s, loss=0.3067]


Step 1100, Loss: 0.3067


Epoch 2/7:  55%|█████▍    | 1201/2198 [18:04<15:12,  1.09it/s, loss=2.7249]


Step 1200, Loss: 2.7249


Epoch 2/7:  59%|█████▉    | 1301/2198 [19:35<14:00,  1.07it/s, loss=3.9300]


Step 1300, Loss: 3.9300


Epoch 2/7:  64%|██████▎   | 1401/2198 [21:05<11:49,  1.12it/s, loss=2.5581]


Step 1400, Loss: 2.5581


Epoch 2/7:  68%|██████▊   | 1501/2198 [22:35<10:09,  1.14it/s, loss=3.6079]


Step 1500, Loss: 3.6079


Epoch 2/7:  73%|███████▎  | 1601/2198 [24:03<08:41,  1.14it/s, loss=2.4546]


Step 1600, Loss: 2.4546


Epoch 2/7:  77%|███████▋  | 1701/2198 [25:32<07:16,  1.14it/s, loss=1.7568]


Step 1700, Loss: 1.7568


Epoch 2/7:  82%|████████▏ | 1801/2198 [27:02<05:47,  1.14it/s, loss=2.5789]


Step 1800, Loss: 2.5789


Epoch 2/7:  86%|████████▋ | 1901/2198 [28:32<04:27,  1.11it/s, loss=2.0056]


Step 1900, Loss: 2.0056


Epoch 2/7:  91%|█████████ | 2001/2198 [30:16<02:59,  1.10it/s, loss=3.1382]


Step 2000, Loss: 3.1382


Epoch 2/7:  96%|█████████▌| 2101/2198 [31:47<01:26,  1.13it/s, loss=2.6823]


Step 2100, Loss: 2.6823


Epoch 2/7: 100%|██████████| 2198/2198 [33:15<00:00,  1.10it/s, loss=1.9555]



Epoch 2 - Avg Train Loss: 2.5054
Epoch 2 - Avg Val Loss: 2.2890


Epoch 3/7:   0%|          | 1/2198 [00:00<33:52,  1.08it/s, loss=3.1608]


Step 0, Loss: 3.1608


Epoch 3/7:   5%|▍         | 101/2198 [01:29<30:26,  1.15it/s, loss=1.0666]


Step 100, Loss: 1.0666


Epoch 3/7:   9%|▉         | 201/2198 [02:58<29:26,  1.13it/s, loss=0.8652]


Step 200, Loss: 0.8652


Epoch 3/7:  14%|█▎        | 301/2198 [04:26<28:59,  1.09it/s, loss=2.5668]


Step 300, Loss: 2.5668


Epoch 3/7:  18%|█▊        | 401/2198 [05:55<26:01,  1.15it/s, loss=2.3368]


Step 400, Loss: 2.3368


Epoch 3/7:  23%|██▎       | 501/2198 [07:23<24:25,  1.16it/s, loss=2.6552]


Step 500, Loss: 2.6552


Epoch 3/7:  27%|██▋       | 601/2198 [08:51<23:32,  1.13it/s, loss=0.8995]


Step 600, Loss: 0.8995


Epoch 3/7:  32%|███▏      | 701/2198 [10:20<22:15,  1.12it/s, loss=3.4529]


Step 700, Loss: 3.4529


Epoch 3/7:  36%|███▋      | 801/2198 [11:49<20:12,  1.15it/s, loss=1.1801]


Step 800, Loss: 1.1801


Epoch 3/7:  41%|████      | 901/2198 [13:20<19:29,  1.11it/s, loss=2.9472]


Step 900, Loss: 2.9472


Epoch 3/7:  46%|████▌     | 1001/2198 [14:52<17:44,  1.12it/s, loss=2.4347]


Step 1000, Loss: 2.4347


Epoch 3/7:  50%|█████     | 1101/2198 [16:28<19:36,  1.07s/it, loss=2.9001]


Step 1100, Loss: 2.9001


Epoch 3/7:  55%|█████▍    | 1201/2198 [18:01<18:21,  1.11s/it, loss=1.0018]


Step 1200, Loss: 1.0018


Epoch 3/7:  59%|█████▉    | 1301/2198 [19:41<14:39,  1.02it/s, loss=4.5584]


Step 1300, Loss: 4.5584


Epoch 3/7:  64%|██████▎   | 1401/2198 [21:21<15:40,  1.18s/it, loss=2.4009]


Step 1400, Loss: 2.4009


Epoch 3/7:  68%|██████▊   | 1501/2198 [22:57<11:13,  1.03it/s, loss=1.6814]


Step 1500, Loss: 1.6814


Epoch 3/7:  73%|███████▎  | 1601/2198 [24:37<13:36,  1.37s/it, loss=0.3037]


Step 1600, Loss: 0.3037


Epoch 3/7:  77%|███████▋  | 1701/2198 [26:14<07:54,  1.05it/s, loss=0.9068]


Step 1700, Loss: 0.9068


Epoch 3/7:  82%|████████▏ | 1801/2198 [27:51<06:26,  1.03it/s, loss=1.5190]


Step 1800, Loss: 1.5190


Epoch 3/7:  86%|████████▋ | 1901/2198 [29:28<04:47,  1.03it/s, loss=3.5257]


Step 1900, Loss: 3.5257


Epoch 3/7:  91%|█████████ | 2001/2198 [31:04<03:05,  1.06it/s, loss=3.6049]


Step 2000, Loss: 3.6049


Epoch 3/7:  96%|█████████▌| 2101/2198 [32:39<01:31,  1.06it/s, loss=3.0537]


Step 2100, Loss: 3.0537


Epoch 3/7: 100%|██████████| 2198/2198 [34:11<00:00,  1.07it/s, loss=2.2333]



Epoch 3 - Avg Train Loss: 2.2278
Epoch 3 - Avg Val Loss: 2.1695


Epoch 4/7:   0%|          | 1/2198 [00:00<35:01,  1.05it/s, loss=4.1606]


Step 0, Loss: 4.1606


Epoch 4/7:   5%|▍         | 101/2198 [01:38<37:53,  1.08s/it, loss=2.5856]


Step 100, Loss: 2.5856


Epoch 4/7:   9%|▉         | 201/2198 [03:15<31:45,  1.05it/s, loss=2.3126]


Step 200, Loss: 2.3126


Epoch 4/7:  14%|█▎        | 301/2198 [04:57<30:19,  1.04it/s, loss=2.0419]


Step 300, Loss: 2.0419


Epoch 4/7:  18%|█▊        | 401/2198 [06:35<30:42,  1.03s/it, loss=1.2103]


Step 400, Loss: 1.2103


Epoch 4/7:  23%|██▎       | 501/2198 [08:15<27:58,  1.01it/s, loss=0.5753]


Step 500, Loss: 0.5753


Epoch 4/7:  27%|██▋       | 601/2198 [09:53<25:38,  1.04it/s, loss=1.1748]


Step 600, Loss: 1.1748


Epoch 4/7:  32%|███▏      | 701/2198 [11:30<23:53,  1.04it/s, loss=0.9562]


Step 700, Loss: 0.9562


Epoch 4/7:  36%|███▋      | 801/2198 [13:07<22:32,  1.03it/s, loss=1.2357]


Step 800, Loss: 1.2357


Epoch 4/7:  41%|████      | 901/2198 [14:46<21:00,  1.03it/s, loss=2.0108]


Step 900, Loss: 2.0108


Epoch 4/7:  46%|████▌     | 1001/2198 [16:23<19:32,  1.02it/s, loss=2.5192]


Step 1000, Loss: 2.5192


Epoch 4/7:  50%|█████     | 1101/2198 [18:01<17:52,  1.02it/s, loss=2.6912]


Step 1100, Loss: 2.6912


Epoch 4/7:  55%|█████▍    | 1201/2198 [19:37<16:09,  1.03it/s, loss=0.0038]


Step 1200, Loss: 0.0038


Epoch 4/7:  59%|█████▉    | 1301/2198 [21:17<14:25,  1.04it/s, loss=1.6484]


Step 1300, Loss: 1.6484


Epoch 4/7:  64%|██████▎   | 1401/2198 [22:57<12:41,  1.05it/s, loss=2.7655]


Step 1400, Loss: 2.7655


Epoch 4/7:  68%|██████▊   | 1501/2198 [24:44<11:20,  1.02it/s, loss=2.2885]


Step 1500, Loss: 2.2885


Epoch 4/7:  73%|███████▎  | 1601/2198 [26:24<10:42,  1.08s/it, loss=1.4723]


Step 1600, Loss: 1.4723


Epoch 4/7:  77%|███████▋  | 1701/2198 [28:04<08:59,  1.09s/it, loss=2.5627]


Step 1700, Loss: 2.5627


Epoch 4/7:  82%|████████▏ | 1801/2198 [29:52<07:04,  1.07s/it, loss=2.6765]


Step 1800, Loss: 2.6765


Epoch 4/7:  86%|████████▋ | 1901/2198 [31:44<05:17,  1.07s/it, loss=2.2328]


Step 1900, Loss: 2.2328


Epoch 4/7:  91%|█████████ | 2001/2198 [33:34<04:00,  1.22s/it, loss=3.0487]


Step 2000, Loss: 3.0487


Epoch 4/7:  96%|█████████▌| 2101/2198 [35:24<01:39,  1.03s/it, loss=2.2608]


Step 2100, Loss: 2.2608


Epoch 4/7: 100%|██████████| 2198/2198 [36:58<00:00,  1.01s/it, loss=0.0120]



Epoch 4 - Avg Train Loss: 2.0007
Epoch 4 - Avg Val Loss: 2.0973


Epoch 5/7:   0%|          | 1/2198 [00:01<38:20,  1.05s/it, loss=1.6481]


Step 0, Loss: 1.6481


Epoch 5/7:   5%|▍         | 101/2198 [01:35<33:08,  1.05it/s, loss=2.0960]


Step 100, Loss: 2.0960


Epoch 5/7:   9%|▉         | 201/2198 [03:40<50:36,  1.52s/it, loss=0.5362]  


Step 200, Loss: 0.5362


Epoch 5/7:  11%|█         | 239/2198 [04:39<55:34,  1.70s/it, loss=0.9368]

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import random

model = AutoModelForSeq2SeqLM.from_pretrained("./flan-t5-bsae-CUSTOM-TRAINED").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("./flan-t5-bsae-CUSTOM-TRAINED")

test_indices = random.sample(range(len(data)), 20)

for idx in test_indices:
    sample = data[idx]
    question = sample["input"]
    true_answer = sample["target"]
    
    inputs = tokenizer(question, return_tensors="pt", max_length=512, truncation=True).to("cuda")
    outputs = model.generate(**inputs, max_length=128, num_beams=4)
    predicted = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
    print(f"\nQ: {question[:70]}...")
    print(f"True: {true_answer[:60]}...")
    print(f"Pred: {predicted}")

# Test on custom questions
print("custom questions:")

custom_questions = [
    "Answer this question about United States: What body/agency grants banking licenses?",
    "Answer this question about France: What is the minimum capital requirement?",
    "Answer this question about Japan: Who regulates banks?"
]

for q in custom_questions:
    inputs = tokenizer(q, return_tensors="pt").to("cuda")
    outputs = model.generate(**inputs, max_length=128, num_beams=4)
    pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"\nQ: {q}")
    print(f"A: {pred}")


Q: Answer this question about Kenya: 14.10 Does any law or regulation set...
True: Yes, financial ombudsman...
Pred: Yes, financial ombudsman

Q: Answer this question about Bulgaria: 3.20.3 Are the following items de...
True: 100% of T1. If the bank has AT1 - 60% of CET1 and 40% of AT1...
Pred: Deducted from CET 1

Q: Answer this question about Togo: 3.20.4 Are the following items deduct...
True: Pris en compte dans le T1...
Pred: La juste valeur nâ€TMest pas transpos dans le corpus rglementaire

Q: Answer this question about Taiwan, China: 3.1 Which regulatory capital...
True: commercial banks, state-owned commercial banks...
Pred: Commercial banks, state-owned commercial banks, state-owned commercial banks, state-owned commercial banks, state-owned commercial banks, state-owned commercial banks, state-owned commercial banks, state-owned commercial banks, state-owned commercial banks, state-owned commercial banks, state-owned commercial banks, state-owned commercial banks, state-owne