In [1]:
import pandas as pd
import ollama
import json
import time
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

In [2]:
# Example DataFrame
df = pd.DataFrame({
    'id': [1, 2, 3, 4],
    'tweet': [
        "I fucking hate starbucks coffee, so expensive",
        "There is no starbucks near my city",
        "I love starbucks caramel malchiato",
        "I dont like the new printed notes on the coffee..."
    ],
    'PRODUCT': [0, 0, 0, 0],
    'PLACE': [0, 0, 0, 0],
    'PRICE': [0, 0, 0, 0]
})

def classify_batch(tweets, batch_size=10):
    """
    Process multiple tweets in a single API call by constructing a batch prompt
    """
    # Instantiate the Ollama client
    client = ollama.Client()
    
    # Construct a batch prompt that asks for classification of multiple tweets
    batch_prompt = """
You are an expert NLP classifier. Classify each tweet below into categories: PRODUCT, PLACE, and PRICE.
For each tweet, determine if it belongs to each category (1) or not (0).
Return your analysis in a valid JSON array where each item corresponds to one tweet in the same order:
[
  {"PRODUCT": 0 or 1, "PLACE": 0 or 1, "PRICE": 0 or 1},
  {"PRODUCT": 0 or 1, "PLACE": 0 or 1, "PRICE": 0 or 1},
  ...
]

Tweets to classify:
"""
    
    # Add each tweet with an index
    for i, tweet in enumerate(tweets):
        batch_prompt += f"\n{i+1}. \"{tweet}\"\n"
    
    batch_prompt += "\nOnly output the JSON array with no additional text."
    
    # Make the API call
    try:
        result_str = client.chat(
            model="llama3",
            messages=[{"role": "user", "content": batch_prompt}],
            options={"temperature": 0}
        )
        
        # Extract the content from the response
        response_content = result_str['message']['content'].strip()
        
        # Find JSON in the response (in case model adds extra text)
        import re
        json_match = re.search(r'\[.*\]', response_content, re.DOTALL)
        if json_match:
            response_content = json_match.group(0)
        
        # Parse the JSON array
        results = json.loads(response_content)
        return results
    
    except Exception as e:
        print(f"Error processing batch: {e}")
        print(f"Response content: {response_content if 'response_content' in locals() else 'No response'}")
        # Return default values for the batch
        return [{"PRODUCT": 0, "PLACE": 0, "PRICE": 0} for _ in tweets]

def process_dataframe(df, batch_size=10):
    """
    Process the entire dataframe in batches of specified size
    """
    tweets = df['tweet'].tolist()
    all_results = []
    
    # Process in batches
    for i in tqdm(range(0, len(tweets), batch_size)):
        batch = tweets[i:i+batch_size]
        batch_results = classify_batch(batch)
        all_results.extend(batch_results)
        
        # Add a small delay to avoid overloading the API
        time.sleep(0.5)
    
    # Update the dataframe with results
    for i, result in enumerate(all_results):
        if i < len(df):
            df.at[i, 'PRODUCT'] = result.get('PRODUCT', 0)
            df.at[i, 'PLACE'] = result.get('PLACE', 0)
            df.at[i, 'PRICE'] = result.get('PRICE', 0)
    
    return df

# Option 2: Process using multithreading for even better performance
def classify_tweet_mt(tweet):
    """Single tweet classification function for multithreading"""
    client = ollama.Client()
    prompt = f"""
You are an expert NLP classifier. Given the tweet below, please determine which categories it belongs to among PRODUCT, PLACE, and PRICE.
Output the answer in JSON format exactly as:
{{"PRODUCT": 0 or 1, "PLACE": 0 or 1, "PRICE": 0 or 1}}

Tweet: "{tweet}"
Only output the JSON.
"""
    try:
        result_str = client.chat(
            model="llama3",
            messages=[{"role": "user", "content": prompt}],
            options={"temperature": 0}
        )
        response_content = result_str['message']['content'].strip()
        return json.loads(response_content)
    except Exception as e:
        print(f"Error classifying tweet: {e}")
        return {"PRODUCT": 0, "PLACE": 0, "PRICE": 0}

def process_with_threading(df, max_workers=5):
    """Process tweets using multithreading"""
    tweets = df['tweet'].tolist()
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        results = list(tqdm(executor.map(classify_tweet_mt, tweets), total=len(tweets)))
    
    # Update the dataframe with results
    for i, result in enumerate(results):
        df.at[i, 'PRODUCT'] = result.get('PRODUCT', 0)
        df.at[i, 'PLACE'] = result.get('PLACE', 0)
        df.at[i, 'PRICE'] = result.get('PRICE', 0)
    
    return df

In [3]:
# Choose your preferred method
# Option 1: Use batch processing (recommended for most cases)
result_df = process_dataframe(df, batch_size=2)

# Option 2: Use multithreading (if your API can handle concurrent requests)
# result_df = process_with_threading(df, max_workers=5)

print(result_df)

100%|██████████| 2/2 [00:12<00:00,  6.30s/it]

   id                                              tweet  PRODUCT  PLACE  \
0   1      I fucking hate starbucks coffee, so expensive        1      0   
1   2                 There is no starbucks near my city        0      1   
2   3                 I love starbucks caramel malchiato        1      0   
3   4  I dont like the new printed notes on the coffe...        1      1   

   PRICE  
0      1  
1      0  
2      0  
3      0  



