In [11]:
import pandas as pd
import ollama
import json
import time
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

In [12]:
CATEGORIES = [
    'PRODUCT', 'PLACE', 'PRICE', 'PUBLICITY', 
    'POSTCONSUMPTION', 'PURPOSE', 'PARTNERSHIPS', 
    'PEOPLE', 'PLANET'
]

In [13]:
df = pd.read_csv("df_x_selected.csv")
for category in CATEGORIES:
    df[category] = 0
    
df.rename(columns={"text_english": "tweet"}, inplace=True)    
df = df.head(120)
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 120 entries, 0 to 119
Data columns (total 15 columns):
 #   Column             Non-Null Count  Dtype 
---  ------             --------------  ----- 
 0   id                 120 non-null    int64 
 1   text               120 non-null    object
 2   date               120 non-null    object
 3   likes              120 non-null    int64 
 4   detected_language  120 non-null    object
 5   tweet              120 non-null    object
 6   PRODUCT            120 non-null    int64 
 7   PLACE              120 non-null    int64 
 8   PRICE              120 non-null    int64 
 9   PUBLICITY          120 non-null    int64 
 10  POSTCONSUMPTION    120 non-null    int64 
 11  PURPOSE            120 non-null    int64 
 12  PARTNERSHIPS       120 non-null    int64 
 13  PEOPLE             120 non-null    int64 
 14  PLANET             120 non-null    int64 
dtypes: int64(11), object(4)
memory usage: 14.2+ KB


In [14]:
def classify_batch(tweets, batch_size=10):
    """
    Process multiple tweets in a single API call by constructing a batch prompt
    """
    # Instantiate the Ollama client
    client = ollama.Client()
    
    # Construct the categories string for the prompt
    categories_str = ", ".join(CATEGORIES)
    
    # Construct a batch prompt that asks for classification of multiple tweets
    batch_prompt = f"""
You are an expert NLP classifier. Classify each tweet below into these categories: {categories_str}.
For each tweet, determine if it belongs to each category (1) or not (0).
Return your analysis in a valid JSON array where each item corresponds to one tweet in the same order:
[
  {{"{CATEGORIES[0]}": 0 or 1, "{CATEGORIES[1]}": 0 or 1, ..., "{CATEGORIES[-1]}": 0 or 1}},
  {{"{CATEGORIES[0]}": 0 or 1, "{CATEGORIES[1]}": 0 or 1, ..., "{CATEGORIES[-1]}": 0 or 1}},
  ...
]

Tweets to classify:
"""
    
    # Add each tweet with an index
    for i, tweet in enumerate(tweets):
        batch_prompt += f"\n{i+1}. \"{tweet}\"\n"
    
    batch_prompt += "\nOnly output the JSON array with no additional text."
    
    # Make the API call
    try:
        result_str = client.chat(
            model="llama3",
            messages=[{"role": "user", "content": batch_prompt}],
            options={"temperature": 0}
        )
        
        # Extract the content from the response
        response_content = result_str['message']['content'].strip()
        
        # Find JSON in the response (in case model adds extra text)
        import re
        json_match = re.search(r'\[.*\]', response_content, re.DOTALL)
        if json_match:
            response_content = json_match.group(0)
        
        # Parse the JSON array
        results = json.loads(response_content)
        return results
    
    except Exception as e:
        print(f"Error processing batch: {e}")
        print(f"Response content: {response_content if 'response_content' in locals() else 'No response'}")
        # Return default values for the batch
        return [{category: 0 for category in CATEGORIES} for _ in tweets]

def process_dataframe(df, batch_size=10):
    """
    Process the entire dataframe in batches of specified size
    """
    tweets = df['tweet'].tolist()
    all_results = []
    
    # Process in batches
    for i in tqdm(range(0, len(tweets), batch_size)):
        batch = tweets[i:i+batch_size]
        batch_results = classify_batch(batch)
        all_results.extend(batch_results)
        
        # Add a small delay to avoid overloading the API
        time.sleep(0.5)
    
    # Update the dataframe with results
    for i, result in enumerate(all_results):
        if i < len(df):
            for category in CATEGORIES:
                df.at[i, category] = result.get(category, 0)
    
    return df

# Option 2: Process using multithreading for even better performance
def classify_tweet_mt(tweet):
    """Single tweet classification function for multithreading"""
    client = ollama.Client()
    
    # Construct the categories string and JSON template for the prompt
    categories_str = ", ".join(CATEGORIES)
    json_template = ", ".join([f'"{category}": 0 or 1' for category in CATEGORIES])
    
    prompt = f"""
You are an expert NLP classifier. Given the tweet below, please determine which categories it belongs to among {categories_str}.
Output the answer in JSON format exactly as:
{{{json_template}}}

Tweet: "{tweet}"
Only output the JSON.
"""
    try:
        result_str = client.chat(
            model="llama3",
            messages=[{"role": "user", "content": prompt}],
            options={"temperature": 0}
        )
        response_content = result_str['message']['content'].strip()
        return json.loads(response_content)
    except Exception as e:
        print(f"Error classifying tweet: {e}")
        return {category: 0 for category in CATEGORIES}

def process_with_threading(df, max_workers=5):
    """Process tweets using multithreading"""
    tweets = df['tweet'].tolist()
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        results = list(tqdm(executor.map(classify_tweet_mt, tweets), total=len(tweets)))
    
    # Update the dataframe with results
    for i, result in enumerate(results):
        if i < len(df):
            for category in CATEGORIES:
                df.at[i, category] = result.get(category, 0)
    
    return df

In [15]:
# Choose your preferred method
# Option 1: Use batch processing (recommended for most cases)
result_df = process_dataframe(df, batch_size=10)

# Option 2: Use multithreading (if your API can handle concurrent requests)
# result_df = process_with_threading(df, max_workers=5)

print(result_df)

100%|██████████| 12/12 [10:22<00:00, 51.91s/it]

                      id                                               text  \
0    1902948287058973003  PEANUTS + STARBUCKS\n\n이렇게 귀여운 마카롱이라니!\n#스누피마카...   
1    1899431047382659156  Conversamos con una trabajadora de Starbucks e...   
2    1902133288707485945  Soylatte𖠚ᐝ\n\n#starbucks \n#photo https://t.co...   
3    1901929250351145424  リピ多めだったピザトースト🍕美味しかったなぁ🤤\n#starbucks https://t....   
4    1899628351528288580  今日発売の\n春空ミルクコーヒーフラペチーノ…\n中のストロベリーボールを\nストローで割っ...   
..                   ...                                                ...   
115  1016875474868363264  Plastic straws will soon be phased out of all ...   
116  1825555363619574219  Starbucks is now trading at 25x earnings with ...   
117  1732075961963729223  Remember to remember your reusable cup when yo...   
118  1534133308891705300  1/ Emotional Connection\n\nEmosional brand yg ...   
119  1534133926855266300  2/ Beli merchandise = loyalty\n\nMerchandise j...   

                               date  likes detected




In [17]:
result_df.to_csv("df_x_classified.csv", index=False)

In [20]:
result_df.describe()

Unnamed: 0,id,likes,PRODUCT,PLACE,PRICE,PUBLICITY,POSTCONSUMPTION,PURPOSE,PARTNERSHIPS,PEOPLE,PLANET
count,120.0,120.0,120.0,120.0,120.0,120.0,120.0,120.0,120.0,120.0,120.0
mean,1.814385e+18,324.016667,0.925,0.316667,0.1,0.583333,0.183333,0.116667,0.125,0.5,0.058333
std,2.493864e+17,2131.344772,0.264496,0.467127,0.301258,0.527931,0.388562,0.322369,0.332106,0.673633,0.235355
min,1.016325e+18,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,1.900366e+18,2.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,1.902106e+18,6.5,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
75%,1.90285e+18,21.5,1.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0
max,1.902948e+18,23052.0,1.0,1.0,1.0,2.0,1.0,1.0,1.0,3.0,1.0
