* `Import Libraries`

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import login
import torch
from pydantic import BaseModel, Field
import json
import re
from dotenv import load_dotenv
import os
import numpy as np
from typing import Literal

In [None]:
## Load the HF key token and log to HuggingFace
_ = load_dotenv(override=True)
hf_token = os.getenv('HF_KEY_TOKEN')
_ = login(token=hf_token)

* `Pydantic Required Response`

In [None]:
class CustomerData(BaseModel):
    CreditScore: float = Field(..., description='Credit score of the customer')
    Geography: str = Field(..., description='Geography')
    Gender: str = Field(..., description='Gender')
    Age: int = Field(..., description='Age of the customer')
    Tenure: int = Field(..., description='Number of years the customer has been with the bank')
    Balance: float = Field(..., description='Account balance')
    NumOfProducts: int = Field(..., description='Number of products the customer has')
    HasCrCard: bool = Field(..., description='Does the customer have a credit card (True for yes, False for no)')
    IsActiveMember: bool = Field(..., description='Is the customer an active member (True for yes, False for no)')
    EstimatedSalary: float = Field(..., description='Estimated salary of the customer')

* `Using Google Gemma Model 2B`

In [None]:
## Check for CPU availability (this will always be true)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('google/gemma-1.1-2b-it')
model = AutoModelForCausalLM.from_pretrained('google/gemma-1.1-2b-it')
model.to(device)

In [None]:
def generate_prompt(text: str) -> str:
    return f"""
    Extract the following fields from the text and provide them in JSON format: CreditScore, Geography, Gender, Age, Tenure, Balance, NumOfProducts, HasCrCard, IsActiveMember, EstimatedSalary.

    Example:
    Text: "Jane Smith is a 35-year-old female from Canada with a credit score of 650. She has been with the bank for 3 years, has a balance of 2000.0 USD, holds 1 product, owns a credit card, is an active member, and earns an estimated salary of 75000.0 USD."
    JSON: {{
        "CreditScore": 650,
        "Geography": "Canada",
        "Gender": "Female",
        "Age": 35,
        "Tenure": 3,
        "Balance": 2000.0,
        "NumOfProducts": 1,
        "HasCrCard": true,
        "IsActiveMember": true,
        "EstimatedSalary": 75000.0
    }}

    Text: "{text}"
    JSON:
    """

In [None]:
def extract_last_json_from_output(output: str) -> str:
    ## Use regex to find all JSON parts of the output
    json_matches = re.findall(r'\{.*?\}', output, re.DOTALL)
    if json_matches:
        ## Return the last JSON match
        return json_matches[-1]
    return None


def post_process_customer_data(data: CustomerData) -> CustomerData:
    ## Capitalize the first letter of the Gender field
    data.Gender = data.Gender.capitalize()
    ## Capitalize the first letter of the Geography field
    data.Geography = data.Geography.title()

    
    ## Convert HasCrCard and IsActiveMember to int (0 or 1)
    data.HasCrCard = int(data.HasCrCard)
    data.IsActiveMember = int(data.IsActiveMember)

    ## Some Validation
    if data.NumOfProducts not in [1, 2, 3, 4]:
        raise ValueError(f"NumOfProducts must be 1 or 2 or 3 or 4, got {data.NumOfProducts}")
    
    if data.Geography not in ['Spain', 'Germany', 'France']:
        raise ValueError(f"Geography must be Spain or Germany or France, got {data.Geography}")

    if data.Gender not in ['Male', 'Female']:
        raise ValueError(f"Gender must be Male of Female, got {data.Gender}")

    if data.Tenure not in np.arange(11).tolist():
        raise ValueError(f"Tenure must be in [0-10] range, got {data.Gender}")

    if data.HasCrCard not in [0, 1]: 
        raise ValueError(f"HasCrCard must be 0 or 1 range, got {data.HasCrCard}")

    if data.IsActiveMember not in [0, 1]: 
        raise ValueError(f"IsActiveMember must be 0 or 1 range, got {data.IsActiveMember}")


    return data


def extract_features(text: str):

    ## Conversation with model
    prompt = generate_prompt(text)
    inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
    outputs = model.generate(**inputs, max_length=512)
    result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)


    ## Call (extract_last_json_from_output) function --> extracting the last JSON from output
    json_text = extract_last_json_from_output(result_text)
    if json_text:
        try:
            result_json = json.loads(json_text)

            ## Apply Pydantic Class
            customer_data = CustomerData(**result_json)

            ## Call (post_process_customer_data) for capitalizing
            customer_data = post_process_customer_data(customer_data)
            return customer_data
            
        except (json.JSONDecodeError, TypeError, ValueError) as e:
            print(f'Failed to parse the structured data: {str(e)}')
            return None
    else:
        print('JSON format not found in the output')
        return None

In [None]:
## Sample of New data
sample_text = """
Mohammed Agoor is a 27-year-old male from the Spain with a credit score of 700. He has been with the bank for 5 years, has a balance of 5000.0 USD, holds 2 products, owns a credit card, is an active member, and earns an estimated salary of 100000.0 USD.
"""

## Call the (extract_features) function
structured_data = extract_features(text=sample_text)
structured_data

---