In [29]:
import os
import dotenv
from groq import Groq
import pandas as pd
from pandasql import sqldf as pysql
import numpy as np

dotenv.load_dotenv()
api_key_llama = os.getenv("GROQ_API_KEY")
if not api_key_llama:
    raise ValueError("GROQ_API_KEY environment variable not set")

def get_llama_assistance(prompt: str, formatted_metadata, table_name):
    main_purpose = f"""
    As an SQL Query Expert, your primary role is to understand the given data, answer the questions based on the provided input and generate accurate SQL queries ONLY. 
    Remember, you only have to answer the Query for the given input in a single line, don't give any explanation, just the query. Don't use [\n] please. 
    Here are the column names with respect to their information: 
    Here are the column names with respect to their information: 
    {formatted_metadata}
    The table name is {table_name}
    Here is/are the Questions:"""

    client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
    completion = client.chat.completions.create(
        model="llama3-70b-8192",
        messages=[
            {
                "role": "user",
                "content": f"{main_purpose} {prompt}"
            },
            {
                "role": "assistant",
                "content": ""
            }
        ],
        temperature=1.4,
        max_tokens=8192,
        top_p=1,
        stream=True,
        stop=None,
    )

    response_text = ""
    for chunk in completion:
        response_text += chunk.choices[0].delta.content or ""
    
    # Remove backticks from the response
    cleaned_response = response_text.replace("```", "").strip()
    result = pysql(cleaned_response)
    
    return result

In [4]:
data = pd.read_csv("test_data.csv")
data.head()

Unnamed: 0,ref_no,children,age_band,status,occupation,occupation_partner,home_status,self_employed,self_employed_partner,year_last_moved,...,investment_tax_saving_bond,home_loan,online_purchase_amount,discount_offering,gender,region,investment_in_commudity,investment_in_equity,investment_in_derivative,portfolio_balance
0,1,Zero,51-55,Partner,Manual Worker,Secretarial/Admin,Own Home,No,No,1972,...,19.99,0.0,0.0,1,Female,Wales,74.67,18.66,32.32,89.43
1,2,Zero,55-60,Single/Never Married,Retired,Retired,Own Home,No,No,1998,...,0.0,0.0,0.0,2,Female,North West,20.19,0.0,4.33,22.78
2,3,Zero,26-30,Single/Never Married,Professional,Other,Own Home,Yes,No,1996,...,0.0,3.49,0.0,2,Male,North,98.06,31.07,80.96,171.78
3,5,Zero,18-21,Single/Never Married,Professional,Manual Worker,Own Home,No,No,1997,...,0.0,0.0,0.0,2,Female,West Midlands,4.1,14.15,17.57,-41.7
4,6,Zero,45-50,Partner,Business Manager,Unknown,Own Home,No,No,1995,...,0.0,45.91,25.98,2,Female,Scotland,70.16,55.86,80.44,235.02


In [18]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10155 entries, 0 to 10154
Data columns (total 31 columns):
 #   Column                           Non-Null Count  Dtype  
---  ------                           --------------  -----  
 0   ref_no                           10155 non-null  int64  
 1   children                         10155 non-null  object 
 2   age_band                         10155 non-null  object 
 3   status                           10155 non-null  object 
 4   occupation                       10155 non-null  object 
 5   occupation_partner               10155 non-null  object 
 6   home_status                      10155 non-null  object 
 7   self_employed                    10155 non-null  object 
 8   self_employed_partner            10155 non-null  object 
 9   year_last_moved                  10155 non-null  int64  
 10  tvarea                           10155 non-null  object 
 11  post_code                        10155 non-null  object 
 12  post_area         

In [19]:
pysql('''SELECT AVG(average_a_c_balance) AS avg_monthly_income
FROM data
WHERE self_employed = 'No';
''')

Unnamed: 0,avg_monthly_income
0,31.988412


In [13]:
data.columns

Index(['ref_no', 'children', 'age_band', 'status', 'occupation',
       'occupation_partner', 'home_status', 'self_employed',
       'self_employed_partner', 'year_last_moved', 'tvarea', 'post_code',
       'post_area', 'average_credit_card_transaction', 'balance_transfer',
       'term_deposit', 'life_insurance', 'medical_insurance',
       'average_a_c_balance', 'personal_loan', 'investment_in_mutual_fund',
       'investment_tax_saving_bond', 'home_loan', 'online_purchase_amount',
       'discount_offering', 'gender', 'region', 'investment_in_commudity',
       'investment_in_equity', 'investment_in_derivative',
       'portfolio_balance'],
      dtype='object')

In [25]:
df = data.copy()

In [26]:
columns_info = df.dtypes.to_dict()
categorical_columns = df.select_dtypes(include=['object', 'category']).columns
unique_values = {col: df[col].value_counts().head(10).index.tolist() for col in categorical_columns}
formatted_metadata = "\n".join([f"{col}: {dtype} (unique values: {unique_values.get(col, 'N/A')})" for col, dtype in columns_info.items()])

In [27]:
print(formatted_metadata)

ref_no: int64 (unique values: N/A)
children: object (unique values: ['Zero', '1', '2', '3', '4+'])
age_band: object (unique values: ['45-50', '36-40', '41-45', '31-35', '51-55', '55-60', '26-30', '61-65', '65-70', '22-25'])
status: object (unique values: ['Partner', 'Single/Never Married', 'Divorced/Separated', 'Widowed', 'Unknown'])
occupation: object (unique values: ['Professional', 'Retired', 'Secretarial/Admin', 'Housewife', 'Business Manager', 'Unknown', 'Manual Worker', 'Other', 'Student'])
occupation_partner: object (unique values: ['Unknown', 'Professional', 'Retired', 'Manual Worker', 'Business Manager', 'Secretarial/Admin', 'Housewife', 'Other', 'Student'])
home_status: object (unique values: ['Own Home', 'Rent from Council/HA', 'Rent Privately', 'Live in Parental Hom', 'Unclassified'])
self_employed: object (unique values: ['No', 'Yes'])
self_employed_partner: object (unique values: ['No', 'Yes'])
year_last_moved: int64 (unique values: N/A)
tvarea: object (unique values: ['C

In [30]:
get_llama_assistance("Find the average monthly income of customers who are not self-employed", formatted_metadata, "data")

Unnamed: 0,AVG(online_purchase_amount)
0,19.200767
