## Langchain v4

### Requirements
pip install langchain-google-genai langchain-community sqlalchemy psycopg2 pandas openpyxl plotly

In [2]:
import os
import re
import getpass
import chromadb
from chromadb.utils import embedding_functions
import pandas as pd
from typing_extensions import TypedDict, Annotated
from sqlalchemy import create_engine
from langchain.chat_models import init_chat_model
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
import matplotlib.pyplot as plt
from sqlalchemy import inspect
from langchain.chains import LLMChain
from langchain_core.output_parsers import StrOutputParser
from langchain_google_genai import ChatGoogleGenerativeAI


### Declaring the variables

In [None]:
if not os.environ.get("GOOGLE_API_KEY"):
    os.environ["GOOGLE_API_KEY"] = "AIzaSyDXcXJMv35uDfVrqWNoelcWo7-uB4LAz0Y"
from langchain.chat_models import init_chat_model
llm = init_chat_model("gemini-2.0-flash", model_provider="google_genai")
databasetype="PostgreSQL"
dbpassword = "PASSWORD"
dbschema="data"
query = "Show the flow from region codes to their respective countries (ISO3), and further break it down by GDELT action types such as protest, aid, support, etc., based on their frequency of occurrence."

### Connecting to PostgreSQL

In [4]:
db_uri = f"postgresql+psycopg2://postgres:{dbpassword}@localhost:5432/postgres"
db = SQLDatabase.from_uri(db_uri)

### Choosing the schema in which all the tables are present

In [5]:
engine = create_engine(db_uri)
db = SQLDatabase(engine, include_tables=None, schema=dbschema)  


In [6]:

prompt = ChatPromptTemplate.from_template(
    "Generate an SQL query to list all tables in a schema named '{schema}' for a {db_type} database."
)

# Setup chain
chain = LLMChain(llm=llm, prompt=prompt, output_parser=StrOutputParser())
sql_query = chain.invoke({"schema": dbschema, "db_type": databasetype})

raw_output = sql_query["text"] 
sql_clean = re.search(r"```sql\n(.*?)```", raw_output, re.DOTALL)

if sql_clean:
    final_query = sql_clean.group(1).strip()
else:
    final_query = raw_output.strip()  # fallback if no code block found

print("Clean SQL Query:\n", final_query)

  chain = LLMChain(llm=llm, prompt=prompt, output_parser=StrOutputParser())


Clean SQL Query:
 SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'data'
  AND table_type = 'BASE TABLE';


In [7]:
df_tables = pd.read_sql_query(final_query, engine)
print(df_tables)

         table_name
0             views
1  conflictforecast
2             gdelt
3               imf
4               wbg


### Making the master list

In [65]:
inspector = inspect(engine)

results = []
tables = inspector.get_table_names(schema=dbschema)
for table in tables:
    columns = inspector.get_columns(table, schema=dbschema)
    for column in columns:
        col_name = column["name"]
        col_type = str(column["type"])
        

        prompt = ChatPromptTemplate.from_template(
            "Explain the meaning of the dataset feature named '{feature}'. Keep it short and contextually accurate."
        )
        chain = LLMChain(llm=llm, prompt=prompt, output_parser=StrOutputParser())
        meaning = chain.invoke({"feature": col_name})["text"]
        
        results.append({
            "Source name": table,
            "Feature name": col_name,
            "Data Type": col_type,
            "Meaning": meaning.strip()
        })
feature_df=pd.DataFrame(results)
feature_df.to_excel("indicator_definitions.xlsx", index=False)
print("✔ Saved meanings to indicator_definitions.xlsx")

KeyboardInterrupt: 

In [8]:

feature_df = pd.read_excel("indicator_definitions.xlsx") 
feature_df = feature_df[[
    "Source name", "Feature name", "Data Type", "Meaning"
]].rename(columns={
    "Feature name": "indicator",
    "Meaning": "description",
    "Source name": "source",
    "Data Type": "data_type"
})

# Clean the 'description' column
feature_df["description"] = feature_df["description"].apply(lambda x: str(x).strip().lower())

client = chromadb.Client()
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")

if "indicator_metadata" in [c.name for c in client.list_collections()]:
    client.delete_collection("indicator_metadata")

# Create collection
collection = client.create_collection(name="indicator_metadata", embedding_function=embedding_fn)


  from .autonotebook import tqdm as notebook_tqdm


### Dropping duplicates to add to collection

In [9]:
feature_df = feature_df.drop_duplicates(subset="indicator")


### Adding the documents

In [10]:
collection.add(
    documents=feature_df["description"].tolist(),
    ids=feature_df["indicator"].tolist(),
    metadatas=[{"indicator": x} for x in feature_df["indicator"]]
)

# Confirm load
print("Documents added:", collection.count())


result = collection.query(query_texts=[query], n_results=5)

# Show results
if result["ids"][0]:
    for id_, doc in zip(result["ids"][0], result["documents"][0]):
        print(f"{id_}: {doc}")
else:
    print("No matches found.")

Documents added: 237
IMF_COMPRICES_PCOALSA_USD: 'imf_comprices_pcoalsa_usd' likely represents the **average price of coal sold by south africa, measured in us dollars, sourced from the imf's commodity price database (comprices).**
IMF_COMPRICES_PCOFFOTM: 'imf_comprices_pcoffotm' likely represents the **percentage change, compared to the previous month (pcoffotm), in the imf's composite index of primary commodity prices.**
IMF_COM_TOT_pct_change_winsorized: "imf_com_tot_pct_change_winsorized" likely represents the **percentage change in the imf's commodity total price index, where extreme values (outliers) have been adjusted (winsorized) to reduce their impact on the data.**  this aims to provide a more robust measure of commodity price fluctuations.
IMF_COMPRICES_PMAIZMT: 'imf_comprices_pmaizmt' likely represents the **average price of maize (corn) in us dollars per metric ton, sourced from the imf's commodity prices database.**
IMF_COM_TOT_pct_change: 'imf_com_tot_pct_change' likely r

### Search for relevation indicators

In [11]:
user_question = query
# Search in Chroma
result = collection.query(query_texts=[user_question], n_results=3)

if result["ids"][0]:
    context = "\n\n".join([
        f"Indicator: {id_}\nDescription: {doc}"
        for id_, doc in zip(result["ids"][0], result["documents"][0])
    ])
    print("Context from vector DB:\n", context)
else:
    print("No relevant features found in vector DB.")
    context = ""


Context from vector DB:
 Indicator: IMF_COMPRICES_PCOALSA_USD
Description: 'imf_comprices_pcoalsa_usd' likely represents the **average price of coal sold by south africa, measured in us dollars, sourced from the imf's commodity price database (comprices).**

Indicator: IMF_COMPRICES_PCOFFOTM
Description: 'imf_comprices_pcoffotm' likely represents the **percentage change, compared to the previous month (pcoffotm), in the imf's composite index of primary commodity prices.**

Indicator: IMF_COM_TOT_pct_change_winsorized
Description: "imf_com_tot_pct_change_winsorized" likely represents the **percentage change in the imf's commodity total price index, where extreme values (outliers) have been adjusted (winsorized) to reduce their impact on the data.**  this aims to provide a more robust measure of commodity price fluctuations.


### Creating the prompt

In [31]:
system_message = """
You are an expert SQL query generator. Given a natural language question, generate a syntactically correct {dialect} SQL query using only the tables and columns provided in the schema below. 
Unless the user specifies a number, always return at most {top_k} results.
Use LIMIT {top_k}.
All quoted column names in {databasetype} are case-sensitive. Always wrap column names in double quotes.

Schema:
{table_info}

Context:
{context}

Instructions:

- Use only relevant columns from the schema for the given question.
- Do NOT reference imaginary columns, inferred fields, or table aliases unless explicitly needed.
- Only exclude empty or null **rows** when required for aggregations, filters, or visualizations.
- Wrap each indicator name (e.g., for WBG or GDELT) in single quotes: `'Indicator_Name'`.

**Cleaning Numeric Strings**:
- If a column contains numeric values stored as text, with commas (e.g., '1,234.56') or empty strings (''), clean using:
  NULLIF(REPLACE("ColumnName"::TEXT, ',', ''), '')::NUMERIC
  -'pop' is numeric and use NOT NULL for this 
- This ensures commas are removed and blanks are treated as NULL before casting.
- NEVER compare numeric columns to empty strings directly (e.g., `"col" != ''`). Always use the above `NULLIF(...)` pattern.

**Dates**:
- For textual date columns like "yearmon" in 'Mon YYYY' format, convert using:
  TO_DATE("yearmon", 'Mon YYYY')
- For standard ISO strings, use `::DATE` or `TO_DATE(...)` appropriately.

**Multi-Column Aggregation (e.g., *_per_million_pop)**:
- When comparing values across columns, use a UNION ALL structure:
  Example:
    SELECT 'Column A' AS label, AVG(...) AS value FROM ...
    UNION ALL
    SELECT 'Column B', AVG(...) FROM ...
    ...
    ORDER BY value DESC

**ORDER BY in UNION ALL**:
- Use only the columns present in the **final SELECT** (e.g., `label`, `value`) for ordering.
- Do NOT order by names like "Column A" that are not globally accessible.

**General Rules**:
- Use only columns explicitly mentioned in the schema.
- Ensure all SQL is executable and avoids runtime errors.
- Use `DESC` for ranking or top-N queries unless explicitly stated otherwise.
"""

user_prompt = "Question: {input} (Please exclude missing or blank values)"

query_prompt_template = ChatPromptTemplate.from_messages([
    ("system", system_message),
    ("user", user_prompt),
])

### Defining the state and functions

In [22]:
class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

class QueryOutput(TypedDict):
    query: Annotated[str, "Syntactically valid SQL query."]

def write_query(state: State):
    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 25,
        "table_info": db.get_table_info(),
        "input": state["question"],
        "context": context,
        "databasetype": databasetype
    })
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

def execute_query(state: State):
    tool = QuerySQLDatabaseTool(db=db)
    return {"result": tool.invoke(state["query"])}

def generate_answer(state: State):
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f"Question: {state['question']}\n"
        f"SQL Query: {state['query']}\n"
        f"SQL Result: {state['result']}"
    )
    response = llm.invoke(prompt)
    return {"answer": response.content}



### Execute pipelines

In [52]:

state: State = {
    "question": query,
    "query": "",
    "result": "",
    "answer": ""
}

# Run full flow
state.update(write_query(state))
print("\n Generated SQL:\n", state["query"])

state.update(execute_query(state))
print("\n SQL Result:\n", state["result"])

state.update(generate_answer(state))
print("\n Final Answer:\n", state["answer"])


 Generated SQL:
 SELECT "region_code", iso3, 'GDELT_PROTEST' AS action_type, COUNT("GDELT_PROTEST") AS count FROM data.gdelt WHERE "GDELT_PROTEST" IS NOT NULL GROUP BY "region_code", iso3 UNION ALL SELECT "region_code", iso3, 'GDELT_AID' AS action_type, COUNT("GDELT_AID") AS count FROM data.gdelt WHERE "GDELT_AID" IS NOT NULL GROUP BY "region_code", iso3 UNION ALL SELECT "region_code", iso3, 'GDELT_SUPPORT' AS action_type, COUNT("GDELT_SUPPORT") AS count FROM data.gdelt WHERE "GDELT_SUPPORT" IS NOT NULL GROUP BY "region_code", iso3 ORDER BY "region_code", count DESC LIMIT 25

 SQL Result:
 [('EAP', 'LAO', 'GDELT_PROTEST', 435), ('EAP', 'PRK', 'GDELT_SUPPORT', 435), ('EAP', 'KHM', 'GDELT_PROTEST', 435), ('EAP', 'TLS', 'GDELT_PROTEST', 435), ('EAP', 'SLB', 'GDELT_PROTEST', 435), ('EAP', 'TUV', 'GDELT_PROTEST', 435), ('EAP', 'KIR', 'GDELT_PROTEST', 435), ('EAP', 'TON', 'GDELT_PROTEST', 435), ('EAP', 'FJI', 'GDELT_PROTEST', 435), ('EAP', 'CHN', 'GDELT_PROTEST', 435), ('EAP', 'FSM', 'GDELT

### Run the SQL query

In [53]:
response=state["query"]
cleaned_sql = response.strip().strip("```sql").strip("```")
from sqlalchemy import create_engine
import pandas as pd

engine = create_engine(f"postgresql+psycopg2://postgres:{dbpassword}@localhost:5432/postgres")

try:
    df = pd.read_sql_query(cleaned_sql, engine)
    print(df.head())
except Exception as e:
    print("Error running SQL:", e)


  region_code iso3    action_type  count
0         EAP  LAO  GDELT_PROTEST    435
1         EAP  PRK  GDELT_SUPPORT    435
2         EAP  KHM  GDELT_PROTEST    435
3         EAP  TLS  GDELT_PROTEST    435
4         EAP  SLB  GDELT_PROTEST    435


### Plotting result on graph

In [54]:
from matplotlib import pyplot as plt
import pandas as pd
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI  
import plotly.graph_objects as go

prompt = ChatPromptTemplate.from_messages([
    ("system", 
     "You are a helpful data assistant. Based on the user's query and DataFrame column names, "
     "return which column should be on the X-axis, which on the Y-axis, and what type of plot to use. "
     "Use the exact column names from the dataframe."
     "Supported plots: bar, line, scatter, sankey, pie.\n"
     "Output strictly in the format: x=column_name, y=column_name, type=plot_type"),
    ("human", f"The DataFrame columns are: {list(df.columns)}.\nQuery: '{query}'")
])

title_prompt = ChatPromptTemplate.from_messages([
    ("system", 
     "You are a helpful assistant that creates concise and clear titles for data visualizations. "
     "Given a user query, return a short and informative plot title."),
    ("human", f"User query: {query}")
])

title_response = llm.invoke(title_prompt.format())
plot_title = title_response.content.strip()
print("Generated title:", plot_title)

response = llm.invoke(prompt.format())
print("Gemini response:", response.content)

# Step 2: Parse Gemini Response
x_col_raw, y_col_raw, plot_type = None, None, None
try:
    for part in response.content.split(','):
        part = part.strip()
        if part.lower().startswith('x='):
            x_col_raw = part.split('=')[1].strip()
        elif part.lower().startswith('y='):
            y_col_raw = part.split('=')[1].strip()
        elif part.lower().startswith('type='):
            plot_type = part.split('=')[1].strip().lower()
except Exception as e:
    raise ValueError("Unable to parse Gemini response") from e

# Step 3: Match column case from DataFrame
def match_column(name, columns):
    for col in columns:
        if col.lower() == name.lower():
            return col
    return None

x_col = match_column(x_col_raw, df.columns)
y_col = match_column(y_col_raw, df.columns)

if x_col is None or y_col is None:
    raise ValueError(f"Columns '{x_col_raw}' or '{y_col_raw}' not found in DataFrame.")




import numpy as np
from scipy.stats import skew, kurtosis
# Step 3.5: Optional log transform for X-axis based on distribution
log_x = False
try:
    x_vals = df[x_col].dropna()

    # Check: Is column numeric and not datetime?
    if pd.api.types.is_numeric_dtype(x_vals) and not pd.api.types.is_datetime64_any_dtype(df[x_col]):
        min_val = x_vals.min()
        max_val = x_vals.max()
        variation_ratio = max_val / min_val if min_val > 0 else np.inf
        skewness = abs(skew(x_vals))
        excess_kurtosis = abs(kurtosis(x_vals))

        print(f"📊 Variation ratio: {variation_ratio:.2f}, Skew: {skewness:.2f}, Kurtosis: {excess_kurtosis:.2f}")

        if min_val <= 0:
            print(f"⚠️ Skipping log10: {x_col} contains zero or negative values.")
        elif np.isfinite(variation_ratio) and (variation_ratio > 1000 or skewness > 1.5 or excess_kurtosis > 3):
            df = df.copy()  # avoid overwriting original df
            df[x_col] = np.log10(x_vals)
            log_x = True
            print(f"✅ Applied log10 to X-axis: {x_col}")
        else:
            print("ℹ️ Log transformation not required based on distribution.")
    else:
        print(f"🚫 Skipped log10: {x_col} is not numeric or is a datetime column.")

except Exception as e:
    print(f"⚠️ Error while checking/applying log transform on '{x_col}': {e}")

# Step 4: Plot
def xlabel():
    return f"log10({x_col})" if log_x else x_col

# If datetime-like column, convert safely before plotting
if pd.api.types.is_datetime64_any_dtype(df[x_col]) or 'date' in x_col.lower() or 'year' in x_col.lower():
    try:
        df[x_col] = pd.to_datetime(df[x_col], errors='coerce')
        df = df.dropna(subset=[x_col])
        print(f"🕒 Converted {x_col} to datetime.")
    except Exception as e:
        print("⚠️ Date conversion failed:", e)

if plot_type == "bar":
    plt.figure(figsize=(8, 6))
    plt.bar(df[x_col], df[y_col], color="skyblue")
    plt.xlabel(xlabel())
    plt.ylabel(y_col)
    plt.title(plot_title)
    plt.xticks(rotation=15)
    plt.tight_layout()
    plt.show()

elif plot_type == "line":
    plt.figure(figsize=(8, 6))
    plt.plot(df[x_col], df[y_col], marker='o')
    plt.xlabel(xlabel())
    plt.ylabel(y_col)
    plt.title(plot_title)
    plt.xticks(rotation=15)
    plt.tight_layout()
    plt.show()

elif plot_type == "scatter":
    plt.figure(figsize=(8, 6))
    plt.scatter(df[x_col], df[y_col], color='green')
    plt.xlabel(xlabel())
    plt.ylabel(y_col)
    plt.title(plot_title)
    plt.xticks(rotation=15)
    plt.tight_layout()
    plt.show()

elif plot_type == "pie":
    pie_df = df[[x_col, y_col]].dropna()
    if pie_df[y_col].sum() <= 0:
        print("⚠️ Pie chart values invalid.")
    else:
        plt.figure(figsize=(8, 6))
        plt.pie(pie_df[y_col], labels=pie_df[x_col], autopct='%1.1f%%', startangle=140)
        plt.title(plot_title)
        plt.axis('equal')
        plt.tight_layout()
        plt.show()

elif plot_type == "sankey":
    categorical_cols = [
        col for col in df.columns
        if df[col].dtype == object or pd.api.types.is_categorical_dtype(df[col])
    ]
    numeric_cols = [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])]
    if not numeric_cols:
        print("ℹ️ No numeric column found. Using dummy values.")
        df["__value__"] = 1
        numeric_cols = ["__value__"]

    value_col = numeric_cols[0]
    while len(categorical_cols) < 2:
        df[f'Dummy_{len(categorical_cols)}'] = 'All'
        categorical_cols.append(f'Dummy_{len(categorical_cols)}')

    flow_cols = categorical_cols[:5]
    df = df.dropna(subset=flow_cols + [value_col])

    all_labels, label_map, index_counter, links = [], {}, 0, []
    for col in flow_cols:
        for val in df[col].unique():
            if val not in label_map:
                label_map[val] = index_counter
                all_labels.append(val)
                index_counter += 1

    for i in range(len(flow_cols) - 1):
        grouped = df.groupby([flow_cols[i], flow_cols[i+1]])[value_col].sum().reset_index()
        for _, row in grouped.iterrows():
            links.append({
                'source': label_map[row[flow_cols[i]]],
                'target': label_map[row[flow_cols[i+1]]],
                'value': row[value_col]
            })

    fig = go.Figure(data=[go.Sankey(
        node=dict(label=all_labels, pad=15, thickness=20),
        link=dict(
            source=[l['source'] for l in links],
            target=[l['target'] for l in links],
            value=[l['value'] for l in links]
        )
    )])
    fig.update_layout(title_text="Generalized Sankey Diagram", font_size=10)
    fig.show()

else:
    print(f"❌ Unsupported plot type: {plot_type}")


Generated title: GDELT Action Types by Region and Country
Gemini response: x=region_code, y=iso3, type=sankey
🚫 Skipped log10: region_code is not numeric or is a datetime column.



is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, pd.CategoricalDtype) instead

