In [3]:
from langchain_community.llms import Ollama

In [5]:
llm = Ollama(model="mistral",
            temperature=0, # Increasing the temperature will make the model answer more creatively. (Default: 0.8),
            # num_ctx=4096  # Default is 2048
            )

llm.invoke("Please tell me the current weather in New York city for today")

' As of my last update, here is the current weather forecast for New York City (Manhattan) on March 20, 2023:\n\n- Temperature: 51°F (10.6°C)\n- Wind speed: 8 mph (13 km/h)\n- Humidity: 49%\n- Precipitation: 0% chance of rain\n- Sky condition: Partly cloudy\n\nPlease check a reliable weather source for the most up-to-date information, as conditions can change.'

In [7]:
llm.invoke("What's your name")

" I don't have a personal name. I'm a model of the Mistral AI's family, and you can call me Mistral AI Assistant. How can I assist you today?"

In [8]:
llm.invoke("generate some sql queries to select all rows in a table having Matt as their first_name")

" Assuming you have a table named `people` with columns `first_name`, `last_name`, and other relevant columns, here are the SQL queries to select all rows where the `first_name` is 'Matt':\n\n1. For MySQL:\n\n```sql\nSELECT * FROM people WHERE first_name = 'Matt';\n```\n\n2. For PostgreSQL:\n\n```sql\nSELECT * FROM people WHERE first_name = 'Matt';\n```\n\n3. For SQL Server:\n\n```sql\nSELECT * FROM people WHERE first_name = N'Matt'; -- Use the N prefix for Unicode strings in SQL Server\n```\n\n4. For Oracle:\n\n```sql\nSELECT * FROM people WHERE UPPER(first_name) = 'MATTS'; -- Oracle is case-insensitive by default, so we use uppercase to ensure matching\n```"

In [23]:
sql_generator_prompt = """
You are a SQL generator for Postgres database. 
Given a schema and a query, return the SQL query that will return the results of the query.
Here's the schema:
{schema}
Here's the query for which you need to generate an SQL query to return its results:
{query}

You should only return the sql query as response and nothing else. 
For e.g. if the query says "What is the total count of employees?" the response should be  "SELECT COUNT(*) FROM employee;" and nothing else
There will be limitations in postgres that limit you from writing every kind of query so keeep that in mind while giving the result. For e.g. in postgres aggregate and window functions are not allowed in WHERE clause.
"""


sql_schema = """
-- Create employees table
CREATE TABLE employee (
    id SERIAL PRIMARY KEY,
    first_name VARCHAR(50) NOT NULL,
    last_name VARCHAR(50) NOT NULL,
    department VARCHAR(100),
    hire_date DATE
);

-- Create salary_details table
CREATE TABLE salary_details (
    id SERIAL PRIMARY KEY,
    employee_id INTEGER REFERENCES employee(id),
    salary_period DATE NOT NULL,
    gross_salary DECIMAL(10, 2) NOT NULL,
    deductions DECIMAL(10, 2) DEFAULT 0,
    take_home_pay DECIMAL(10, 2) NOT NULL,
    FOREIGN KEY (employee_id) REFERENCES employee(id)
);

-- Create time_sheet table
CREATE TABLE time_sheet (
    id SERIAL PRIMARY KEY,
    employee_id INTEGER REFERENCES employee(id),
    month DATE NOT NULL,
    total_hours DECIMAL(5, 2) NOT NULL,
    FOREIGN KEY (employee_id) REFERENCES employee(id)
);
"""


In [14]:
query = "How many employees do we have"
print(llm.invoke(sql_generator_prompt.format(schema=sql_schema, query=query)))

 SELECT COUNT(*) FROM employee;


In [15]:
query = "How many employees are there whose take home pay is more than 6000"
print(llm.invoke(sql_generator_prompt.format(schema=sql_schema, query=query)))

 SELECT COUNT(employee.id) FROM employee
JOIN salary_details ON employee.id = salary_details.employee_id
WHERE salary_details.take_home_pay > 6000;


In [16]:
query = "Give me the names of the employees whose take home pay is more than 6000 but has only logged in less than 140 hours of work"
print(llm.invoke(sql_generator_prompt.format(schema=sql_schema, query=query)))

 SELECT first_name, last_name
FROM employee
JOIN salary_details ON employee.id = salary_details.employee_id
WHERE take_home_pay > 6000 AND
(
    SELECT SUM(total_hours)
    FROM time_sheet
    WHERE employee_id = employee.id
    GROUP BY employee_id
) < 140;


In [17]:
query = "Give me the names of the employees whose take home pay has increased by 10% throughout the duration of their employment"
print(llm.invoke(sql_generator_prompt.format(schema=sql_schema, query=query)))

 SELECT first_name, last_name
FROM employee
JOIN salary_details sd1 ON employee.id = sd1.employee_id
WHERE NOT EXISTS (
    SELECT * FROM salary_details sd2
    WHERE sd2.employee_id = employee.id AND sd2.take_home_pay * 1.10 >= sd1.take_home_pay
    AND sd1.hire_date <= sd2.salary_period
)
ORDER BY hire_date;


In [24]:
query = "Who all have worked more than 500 hours in a year"
print(llm.invoke(sql_generator_prompt.format(schema=sql_schema, query=query)))

 SELECT e.id, e.first_name, e.last_name
FROM employee AS e
JOIN time_sheet AS ts ON e.id = ts.employee_id
WHERE (ts.month >= DATE_TRUNC('year', CURRENT_DATE) AND ts.total_hours > 500)
OR EXISTS (
    SELECT 1
    FROM time_sheet AS t2
    WHERE e.id = t2.employee_id
    AND t2.month < DATE_TRUNC('year', CURRENT_DATE)
    AND SUM(t2.total_hours) > 500
);
