In [1]:
!pip install replicate
!pip install langchain-community
!pip install sqlite-utils



In [2]:
from langchain_community.utilities import SQLDatabase

# Note: to run in Colab, you need to upload the TwitterDataset//social_media.db file in the repo to the Colab folder first.
db = SQLDatabase.from_uri("sqlite:///TwitterDataset//social_media.db", sample_rows_in_table_info=0)

def get_schema():
    return db.get_table_info()

In [3]:
db_schema = get_schema()
print(db_schema)


CREATE TABLE location (
	"LocationID" INTEGER, 
	"Country" TEXT, 
	"State" TEXT, 
	"StateCode" TEXT, 
	"City" TEXT
)


CREATE TABLE tweet (
	"TweetId" TEXT, 
	"Weekday" TEXT, 
	"Hour" INTEGER, 
	"Day" INTEGER, 
	"Lang" TEXT, 
	"IsReshare" INTEGER, 
	"Reach" INTEGER, 
	"RetweetCount" INTEGER, 
	"Likes" INTEGER, 
	"Klout" INTEGER, 
	"Sentiment" REAL, 
	"Text" TEXT, 
	"LocationID" INTEGER, 
	"UserID" TEXT
)


CREATE TABLE user (
	"UserID" TEXT, 
	"Gender" TEXT
)


In [4]:
import getpass, os

replicate_api_token = getpass.getpass()

os.environ["REPLICATE_API_TOKEN"] = replicate_api_token

In [5]:
import replicate

deployment = replicate.deployments.get("ibm-granite/granite-code-20b")

# Examples of Zero-Shot Prompting

In [6]:
def make_zeroshot_prompt(db_schema, question):
    prompt = f"""You are an Social Media analyst with 15 years of experience writing complex SQL queries. Consider the Twitter tables with the following schema:
    {db_schema}

    Write a SQL query that would answer the user's question; just return the SQL query and nothing else.

    Question: {question}

    SQL Query:"""
    return prompt

In [7]:
def get_answer_using_zeroshot(db_schema, question):
    prompt = make_zeroshot_prompt(db_schema, question)

    prediction = deployment.predictions.create(
        input={"prompt": prompt, "max_length":100, "temperature":0.0}
    )
    prediction.wait()
    
    sql_query = ''.join(prediction.output)
    result = db.run(sql_query)
    
    return sql_query, result

## Queries for Zero-Shot Prompting

In [8]:
question = "How many tweets are in English?"
sql_query, result = get_answer_using_zeroshot(db_schema, question)
print(f"sql_query : {sql_query}")
print(f"result : {result}")

sql_query : SELECT COUNT(*) FROM tweet WHERE Lang = 'en' 
result : [(16,)]


In [9]:
question = "Please list the texts of all tweets that are reshared."
sql_query, result = get_answer_using_zeroshot(db_schema, question)
print(f"sql_query : {sql_query}")
print(f"result : {result}")

sql_query : SELECT Text
FROM tweet
WHERE IsReshare = 1; 
result : [('"Retweeting an amazing fact."',), ('"Sharing a valuable resource!"',), ('"Retweeting some late-night thoughts."',), ('"Exploring new ideas at work."',), ('"Retweeting an interesting debate."',), ('"El tiempo está increíble hoy."',), ('"TGIF! Plans for the weekend?"',), ('"Chilling with friends."',)]


In [10]:
question = "How many reshared tweets are there in Ontario?"
sql_query, result = get_answer_using_zeroshot(db_schema, question)
print(f"sql_query : {sql_query}")
print(f"result : {result}")

sql_query : SELECT COUNT(*) FROM tweet WHERE IsReshare = 1 AND LocationID IN (SELECT LocationID FROM location WHERE State = 'Ontario'); 
result : [(1,)]


In [11]:
question = "Which city has the highest number of tweets?"
sql_query, result = get_answer_using_zeroshot(db_schema, question)
print(f"sql_query: {sql_query}")
print(f"result: {result}")

sql_query: SELECT City, COUNT(*) as NumberOfTweets
FROM tweet
JOIN location ON tweet.LocationID = location.LocationID
GROUP BY City
ORDER BY NumberOfTweets DESC
LIMIT 1; 
result: [('Los Angeles', 3)]


In [12]:
question = "What is the total number of tweets made by male users on weekdays?" #Generated SQL query is wrong
sql_query, result = get_answer_using_zeroshot(db_schema, question)
print(f"sql_query: {sql_query}")
print(f"result: {result}")

sql_query: SELECT COUNT(*) 
FROM tweet 
JOIN user ON tweet.UserID = user.UserID 
WHERE user.Gender = 'Male' 
AND tweet.Weekday NOT LIKE '%Weekend%' 
result: [(11,)]


# Examples of Few-Shot Prompting

In [13]:
def make_fewshot_prompt(db_schema, user_question):
    prompt = f"""
    You are a social media data analyst with 15 years of experience writing complex SQL queries. Generate SQL queries by understanding user input.

    You will be given the entire database which is being queired: {db_schema}
    Your task is to come with the SQL query from the plaintext provided by the user, which when queried on the above database will result in accurate output.
    You are only required generate the SQL query. Do not generate the output from that SQL query or any other explanation.
    Do not generate any explanation or comments at all. Just give the SQL query you came up with as it is.
    Few examples are given below for your reference.

    Example 1: What is the total number of tweets made by female users on weekends?
    Answer Query 1: SELECT COUNT(*) AS total_tweets FROM tweet T JOIN user U ON T.UserID = U.UserID WHERE U.Gender = 'Female' AND T.Weekday IN ('Saturday', 'Sunday');

    Example 2: What is the average sentiment score for tweets made in the United States?
    Answer Query 2: SELECT AVG(T.Sentiment) AS avg_sentiment FROM tweet T JOIN location L ON T.LocationID = L.LocationID WHERE L.Country = 'United States';

    Example 3: List the top 5 users with the highest total reach across all their tweets.
    Answer Query 3: SELECT U.UserID, SUM(T.Reach) AS total_reach FROM tweet T JOIN user U ON T.UserID = U.UserID GROUP BY U.UserID ORDER BY total_reach DESC LIMIT 5;
    User: {user_question}
    Assistant:
    """

    return prompt

In [14]:
def get_answer_using_fewshot(db_schema, question):
    prompt = make_fewshot_prompt(db_schema, question)

    prediction = deployment.predictions.create(
        input={"prompt": prompt, "max_length":100, "temperature":0.0}
    )
    prediction.wait()
    
    sql_query = ''.join(prediction.output)
    result = db.run(sql_query)
    
    return sql_query, result

## Queries with Few-Shot prompting

In [16]:
question = "What is the total number of tweets made by male users on weekdays?"
sql_query, result = get_answer_using_fewshot(db_schema, question)
print(f"sql_query: {sql_query}")
print(f"result: {result}")

sql_query: SELECT COUNT(*) AS total_tweets FROM tweet T JOIN user U ON T.UserID = U.UserID WHERE U.Gender = 'Male' AND T.Weekday IN ('Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday'); 
result: [(10,)]


In [None]:
question = "Which city has the highest number of tweets?"
sql_query, result = get_answer_using_fewshot(db_schema, question)
print(f"sql_query: {sql_query}")
print(f"result: {result}")

In [17]:
question = "Find the total reach and average likes for tweets that were reshared by male users."
sql_query, result = get_answer_using_fewshot(db_schema, question)
print(f"sql_query: {sql_query}")
print(f"result: {result}")

sql_query: SELECT SUM(T.Reach) AS total_reach, AVG(T.Likes) AS avg_likes
FROM tweet T
JOIN user U ON T.UserID = U.UserID
WHERE U.Gender = 'Male' AND T.IsReshare = 1; 
result: [(6000, 34.5)]


In [18]:
question = "Retrieve the list of users who posted tweets with a sentiment score below 0, grouped by their gender, along with the count of such tweets."
sql_query, result = get_answer_using_fewshot(db_schema, question)
print(f"sql_query: {sql_query}")
print(f"result: {result}")

sql_query: SELECT U.Gender, COUNT(*) AS count_tweets 
FROM tweet T 
JOIN user U ON T.UserID = U.UserID 
WHERE T.Sentiment < 0 
GROUP BY U.Gender; 
result: [('Female', 4), ('Male', 3)]


In [19]:
question = "Find the user(s) who posted the maximum number of tweets in a single day." #Generated SQL query is wrong
sql_query, result = get_answer_using_fewshot(db_schema, question)
print(f"sql_query: {sql_query}")
print(f"result: {result}")

sql_query: SELECT T.UserID, COUNT(*) AS total_tweets 
FROM tweet T 
JOIN user U ON T.UserID = U.UserID 
GROUP BY T.UserID 
ORDER BY total_tweets DESC 
LIMIT 1; 
result: [('U001', 3)]


# Chain of Thought Reasoning

In [22]:
question = """Find the user(s) who posted the maximum number of tweets in a single day.
              First calculate tweets posted by each user on each day.
              Second sort the result in descending order of the number of tweets.
              Third findout the maximum number of tweets posted by a user in a single day.
              Finally select the user(s) who posted the maximum number of tweets in a single day."""

sql_query, result = get_answer_using_fewshot(db_schema, question)
print(f"sql_query: {sql_query}")
print(f"result: {result}")

sql_query: SELECT T.UserID, T.Day, COUNT(*) AS total_tweets 
FROM tweet T 
GROUP BY T.UserID, T.Day 
ORDER BY total_tweets DESC 
LIMIT 1; 
result: [('U001', 8, 2)]


In [20]:
question = """Among all tweets with a positive sentiment, what is the percentage of all those posted by a male user?
              First calculate tweets with a positive sentiment those posted by a male user.
              Second calculate all tweets with a positive sentiment.
              Finally calculate the percentage of the first result of the second result."""

sql_query, result = get_answer_using_fewshot(db_schema, question)
print(f"sql_query: {sql_query}")
print(f"result: {result}")

sql_query: SELECT (COUNT(*) * 100.0 / (SELECT COUNT(*) FROM tweet WHERE Sentiment > 0)) AS percentage FROM tweet T JOIN user U ON T.UserID = U.UserID WHERE U.Gender = 'Male' AND T.Sentiment > 0; 
result: [(66.66666666666667,)]
