# Multimodal Sentiment Analysis for r/BrawlStars

This notebook implements the complete, end-to-end workflow for building a multimodal (text, image, video) sentiment prediction model. We will follow the 7-step process outlined, from raw data collection to a final, usable prediction pipeline.

**Workflow Overview:**

1.  **Step 0: Preparation:** Install libraries, set API keys, create folders.
2.  **Step 1: Data Collection:** Scrape Reddit data and download all media locally.
3.  **Step 2: AI-Powered Labeling:** Use the Gemini API to create the "golden dataset."
4.  **Step 3: Data Splitting:** Create `train`, `validation`, and `test` sets.
5.  **Step 4: Phase 1 Training:** Fine-tune specialist models (Text, Image, Video).
6.  **Step 5: Phase 2 Training:** Train the fusion model on embeddings.
7.  **Step 6: Evaluation:** Test the final system on unseen data.
8.  **Step 7: Prediction:** Build the final inference function for new, raw posts.

---

## Step 0: Preparation

First, we set up our environment. This involves installing all necessary libraries, setting up our API keys, and creating the directories where we'll store our data and media.

In [None]:
# Run this cell in your terminal within your virtual environment:

# !pip install praw pandas requests google-generativeai scikit-learn transformers torch torchvision opencv-python-headless tqdm seaborn matplotlib pmaw

In [None]:
# --- Imports ---
import praw                     # For Reddit scraping
import pandas as pd             # For data manipulation
import requests                 # For downloading files
import os                       # For file/directory operations
import json                     # For handling API responses
from tqdm.auto import tqdm      # For progress bars
import google.generativeai as genai  # For Gemini API
from sklearn.model_selection import train_test_split # For splitting data
import cv2                      # For video processing (if needed)
import time
from pmaw import PushshiftAPI   # Import for Pushshift scraping
import datetime as dt           # For defining date ranges

# --- API Keys & Config ---
# !! IMPORTANT: Replace with your actual API keys
# !! Best practice: Store these in a .env file and use python-dotenv to load them
REDDIT_CLIENT_ID = "YOUR_CLIENT_ID_HERE"
REDDIT_CLIENT_SECRET = "YOUR_CLIENT_SECRET_HERE"
REDDIT_USER_AGENT = "BrawlStars Sentiment Scraper v1.0 by /u/YOUR_USERNAME"
GEMINI_API_KEY = "YOUR_GEMINI_API_KEY_HERE"

# --- Project Constants ---
SUBREDDIT_NAME = "Brawlstars"
POST_LIMIT = 1200  # Target number of posts (PMAW might return more or less)
LABEL_TARGET = 1000

# --- File & Directory Setup ---
MEDIA_DIR = "media"
IMAGE_DIR = os.path.join(MEDIA_DIR, "images")
VIDEO_DIR = os.path.join(MEDIA_DIR, "videos")
DATA_DIR = "data"

# Create directories if they don't exist
os.makedirs(IMAGE_DIR, exist_ok=True)
os.makedirs(VIDEO_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)



# --- File Paths ---
RAW_DATA_CSV = os.path.join(DATA_DIR, 'raw_data.csv')
LABELED_DATA_CSV = os.path.join(DATA_DIR, 'labeled_data_1k.csv')
TRAIN_SET_CSV = os.path.join(DATA_DIR, 'train_set.csv')
VALIDATION_SET_CSV = os.path.join(DATA_DIR, 'validation_set.csv')
TEST_SET_CSV = os.path.join(DATA_DIR, 'test_set.csv')

---

## Step 1: Data Collection (Scraping and Downloading)

Here, we connect to the Reddit API using PRAW, scrape the latest posts from r/BrawlStars, and—most importantly—download the associated image or video for each post. We save the *local path* to this media in our DataFrame.

In [None]:
# ========== CELL 1.A: Scraping with PRAW (Original) ==========

# Initialize PRAW (Reddit API client)
reddit = praw.Reddit(
    client_id=REDDIT_CLIENT_ID,
    client_secret=REDDIT_CLIENT_SECRET,
    user_agent=REDDIT_USER_AGENT,
)

print(reddit.user.me()) # Should show 'None' if using read-only (script) auth

In [None]:
def download_media_praw(post):
    """
    Downloads the media (image or video) for a PRAW post and returns the local file path.
    Relies on PRAW's post object structure.
    """
    post_hint = getattr(post, 'post_hint', None)
    media_url = None
    local_path = None
    file_ext = ".unknown"

    try:
        if post_hint == 'image':
            media_url = post.url
            file_ext = os.path.splitext(media_url)[1]
            # Basic check for common image extensions
            if not file_ext or file_ext.lower() not in ['.jpg', '.jpeg', '.png', '.gif', '.webp']:
                file_ext = ".jpg" # Default if no/unrecognized extension
            local_path = os.path.join(IMAGE_DIR, f"{post.id}{file_ext}")

        elif post_hint == 'hosted:video':
            # PRAW provides a direct fallback URL
            if post.media and 'reddit_video' in post.media and post.media['reddit_video']:
                media_url = post.media['reddit_video']['fallback_url']
                file_ext = ".mp4"
                local_path = os.path.join(VIDEO_DIR, f"{post.id}{file_ext}")
            else:
                print(f"Warning: Post {post.id} hint is hosted:video but no media found.")
                return None # Skip if media info is missing
        
        elif post_hint == 'rich:video':
            # Skip external videos (YouTube, etc.)
            pass

        # If we have a URL and a path, download the file
        if media_url and local_path:
            if os.path.exists(local_path):
                # print(f"Media already exists: {local_path}")
                return local_path # Already downloaded

            # print(f"Downloading {media_url} to {local_path}")
            response = requests.get(media_url, stream=True, timeout=30) # Added timeout
            response.raise_for_status() 
            
            with open(local_path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            return local_path

    except requests.exceptions.RequestException as e:
        print(f"Network error downloading media for post {post.id}: {e}")
        # Clean up potentially incomplete file
        if local_path and os.path.exists(local_path):
           try: os.remove(local_path) 
           except OSError: pass 
        return None
    except Exception as e:
        print(f"Error processing media for post {post.id} (URL: {getattr(post, 'url', 'N/A')} Hint: {post_hint}): {e}")
        return None
    
    return None # No downloadable media identified

In [None]:
print(f"Starting PRAW scrape of {POST_LIMIT} posts from r/{SUBREDDIT_NAME}...")

all_posts_data_praw = []
subreddit = reddit.subreddit(SUBREDDIT_NAME)

# Use tqdm for a progress bar
for post in tqdm(subreddit.hot(limit=POST_LIMIT), total=POST_LIMIT, desc="Scraping (PRAW)"):
    try:
        # 1. Download media and get local path using the PRAW-specific function
        local_media_path = download_media_praw(post)
        
        # 2. Store all relevant data
        post_data = {
            'id': post.id,
            'title': post.title,
            'text': post.selftext,
            'url': post.url, # The original URL (to image, or to post)
            'permalink': post.permalink,
            'score': post.score,
            'created_utc': post.created_utc,
            'post_hint': getattr(post, 'post_hint', 'text_only'),
            'local_media_path': local_media_path # Our new, crucial column
        }
        all_posts_data_praw.append(post_data)
        
    except Exception as e:
        print(f"Error processing post {post.id}: {e}")

# 3. Convert to DataFrame and save
df_raw_praw = pd.DataFrame(all_posts_data_praw)
df_raw_praw.to_csv(RAW_DATA_CSV, index=False) # Overwrite or use a different filename if needed

print(f"\nSuccessfully scraped and processed {len(df_raw_praw)} posts using PRAW.")
print(f"Raw data saved to: {RAW_DATA_CSV}")

# Show a summary of what we collected
print("\n--- PRAW Data Summary ---")
print(df_raw_praw.head())

print("\n--- PRAW Media Type Breakdown ---")
print(df_raw_praw['post_hint'].value_counts())

print("\n--- PRAW Downloaded Media Check ---")
print(f"{df_raw_praw['local_media_path'].notna().sum()} posts have associated local media.")

In [1]:
# ========== CELL 1.B: Scraping with PMAW (Alternative) ==========

print(f"Starting PMAW scrape of ~{POST_LIMIT} posts from r/{SUBREDDIT_NAME}...")

# --- PMAW Configuration ---
PMAW_POST_LIMIT = POST_LIMIT  # Set how many posts PMAW should aim for
# Optional: Define a date range (e.g., scrape posts from the last 30 days)
# end_epoch = int(dt.datetime.now().timestamp())
# start_epoch = int((dt.datetime.now() - dt.timedelta(days=30)).timestamp())

# --- Initialize PMAW ---
api = PushshiftAPI()

# --- Fields to retrieve from Pushshift ---
# Select fields relevant to your task to minimize data transfer
fields = [
    'id',
    'title',
    'selftext', # Corresponds to 'text' in PRAW
    'url',      # URL of the link or media
    'permalink',
    'score',
    'created_utc',
    'domain',   # Helps identify media type (e.g., i.redd.it, v.redd.it)
    'is_video', # Boolean flag for videos
    'media_metadata' # Sometimes contains image info
]

# --- Perform the search ---
print("Querying Pushshift API...")
# Note: PMAW returns a generator. We need to convert it to a list.
# This can take time depending on the limit and Pushshift's responsiveness.
submissions_generator = api.search_submissions(
    subreddit=SUBREDDIT_NAME,
    limit=PMAW_POST_LIMIT,
    # after=start_epoch, # Uncomment to use date range
    # before=end_epoch,  # Uncomment to use date range
    fields=fields,
    sort='desc',      # Get newest posts first
    sort_type='created_utc'
)

all_posts_data_pmaw = list(submissions_generator)

if not all_posts_data_pmaw:
    print("Error: No posts retrieved from Pushshift. Check parameters or API status.")
else:
    print(f"Retrieved {len(all_posts_data_pmaw)} posts from Pushshift.")

    # --- Convert to DataFrame ---
    df_raw_pmaw = pd.DataFrame(all_posts_data_pmaw)

    # --- Adapt PMAW data to match PRAW structure (where possible) ---
    df_raw_pmaw = df_raw_pmaw.rename(columns={'selftext': 'text'})

    # --- Infer 'post_hint' (simplified) ---
    # This is a basic inference; PRAW's 'post_hint' is more reliable.
    def infer_post_hint(row):
        if row.get('is_video', False):
            return 'hosted:video' # Assume v.redd.it videos are 'hosted'
        elif 'i.redd.it' in row.get('domain', '') or 'i.imgur.com' in row.get('domain', ''):
            return 'image'
        elif not pd.isna(row.get('text', None)) and row.get('text', '').strip():
             # If there's non-empty selftext, prioritize as text_only even if there's a link
            return 'text_only' 
        elif 'reddit.com' in row.get('url', ''):
            # If the URL points back to reddit and no text/image/video, likely text_only or link to another post
            return 'text_only' 
        else:
            return 'link' # Default for other links

    if not df_raw_pmaw.empty:
        df_raw_pmaw['post_hint'] = df_raw_pmaw.apply(infer_post_hint, axis=1)
    else:
        df_raw_pmaw['post_hint'] = None # Handle empty DataFrame case
        
    # --- Download Media (Adapted for PMAW data) ---
    # NOTE: Downloading v.redd.it videos reliably often requires external tools
    # like youtube-dlp as Pushshift doesn't provide the fallback URL easily.
    # This function will primarily handle images identified by URL.
    def download_media_pmaw(row):
        media_url = row.get('url', None)
        post_id = row.get('id', 'unknown')
        hint = row.get('post_hint', 'link')
        local_path = None
        file_ext = ".unknown"

        try:
            if hint == 'image' and media_url:
                file_ext = os.path.splitext(media_url)[1]
                if not file_ext or file_ext.lower() not in ['.jpg', '.jpeg', '.png', '.gif', '.webp']:
                    file_ext = ".jpg"
                local_path = os.path.join(IMAGE_DIR, f"{post_id}{file_ext}")
            
            elif hint == 'hosted:video' and media_url:
                # Basic placeholder - downloading v.redd.it needs more work
                print(f"Info: Video detected for {post_id} (URL: {media_url}). Downloading v.redd.it requires additional logic (e.g., youtube-dlp) - skipping download.")
                # You might store the URL here anyway if your video model can handle URLs, 
                # or implement youtube-dlp download logic separately.
                return None # Skip download for now
                # Example placeholder if you wanted to try direct download (often fails for v.redd.it)
                # file_ext = ".mp4"
                # local_path = os.path.join(VIDEO_DIR, f"{post_id}{file_ext}") 
            
            # Download if it's an image and path is set
            if hint == 'image' and local_path:
                if os.path.exists(local_path):
                    return local_path
                
                response = requests.get(media_url, stream=True, timeout=30)
                response.raise_for_status()
                with open(local_path, 'wb') as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                return local_path
                
        except requests.exceptions.RequestException as e:
            print(f"Network error downloading media for post {post_id}: {e}")
            if local_path and os.path.exists(local_path):
               try: os.remove(local_path) 
               except OSError: pass
            return None
        except Exception as e:
            print(f"Error processing media for post {post_id} (URL: {media_url} Hint: {hint}): {e}")
            return None
            
        return None # No downloadable media identified/handled

    print("\nAttempting to download media identified by PMAW (primarily images)...")
    tqdm.pandas(desc="Downloading Media (PMAW)") 
    if not df_raw_pmaw.empty:
       df_raw_pmaw['local_media_path'] = df_raw_pmaw.progress_apply(download_media_pmaw, axis=1)
    else:
       df_raw_pmaw['local_media_path'] = None
       
    # --- Select and order columns to match PRAW output as closely as possible ---
    final_columns = [
        'id', 'title', 'text', 'url', 'permalink', 'score', 
        'created_utc', 'post_hint', 'local_media_path'
    ]
    # Add missing columns with None/NaN
    for col in final_columns:
        if col not in df_raw_pmaw.columns:
            df_raw_pmaw[col] = None 
            
    df_raw_pmaw = df_raw_pmaw[final_columns]

    # --- Save the PMAW-sourced data ---
    # Decide whether to overwrite the PRAW data or save to a new file
    PMAW_RAW_DATA_CSV = os.path.join(DATA_DIR, 'raw_data_pmaw.csv') 
    df_raw_pmaw.to_csv(PMAW_RAW_DATA_CSV, index=False)
    # Or uncomment below to overwrite the main raw file:
    # df_raw_pmaw.to_csv(RAW_DATA_CSV, index=False) 

    print(f"\nPMAW data processing complete. Saved to: {PMAW_RAW_DATA_CSV}")

    # --- Show Summary ---
    print("\n--- PMAW Data Summary ---")
    print(df_raw_pmaw.head())

    print("\n--- PMAW Inferred Media Type Breakdown ---")
    print(df_raw_pmaw['post_hint'].value_counts())

    print("\n--- PMAW Downloaded Media Check ---")
    print(f"{df_raw_pmaw['local_media_path'].notna().sum()} posts have associated local media (mostly images).")

NameError: name 'POST_LIMIT' is not defined

---

## Step 2: AI-Powered Labeling (Creating the "Golden Dataset")

Now we take our raw data and use the Gemini API to generate sentiment labels. We will define a function that takes a post's text and its *local media file*, sends them to the API, and parses the JSON response. We'll apply this to our first 1,000 posts.

**Note:** This step will make many API calls and may take a long time and cost money. We'll start by processing just a few posts as a test.

In [None]:
# Configure the Gemini API client
try:
    genai.configure(api_key=GEMINI_API_KEY)
    # Use the multimodal-capable model
    gemini_model = genai.GenerativeModel('gemini-2.5-flash') # Or 'gemini-1.5-pro'
    print("Gemini model configured successfully.")
except Exception as e:
    print(f"Error configuring Gemini: {e}. Please check your API key.")

In [None]:
# This is the detailed few-shot prompt for the API
# NOTE: The JSON examples use {{ and }} to escape the braces.

LABELING_PROMPT_TEMPLATE = """
You are an expert sentiment analyst for the game Brawl Stars. Your task is to analyze a Reddit post (which may include text, an image, and/or a video) and provide a structured JSON output.

Analyze the user's sentiment and categorize the post. The user's post content is provided first, followed by the media.

The 5 possible 'post_sentiment' values are:
1.  **Joy**: Happiness, excitement, pride (e.g., getting a new Brawler, winning a hard match, liking a new skin).
2.  **Anger**: Frustration, rage, annoyance (e.g., losing to a specific Brawler, bad teammates, game bugs, matchmaking issues).
3.  **Sadness**: Disappointment, grief (e.g., missing a shot, losing a high-stakes game, a favorite Brawler getting nerfed).
4.  **Surprise**: Shock, disbelief (e.g., a sudden clutch play, an unexpected new feature, a rare bug).
5.  **Neutral/Other**: Objective discussion, questions, news, or art that doesn't convey a strong emotion.

The 6 possible 'post_classification' values are:
1.  **Gameplay Clip**: A video or image showing a match, a specific play, or a replay.
2.  **Meme/Humor**: A meme, joke, or funny edit.
3.  **Discussion**: A text-based post asking a question or starting a conversation.
4.  **Feedback/Rant**: A post providing feedback, suggestions, or complaining about the game.
5.  **Art/Concept**: Fan art, skin concepts, or creative edits.
6.  **Achievement/Loot**: A screenshot of a new Brawler unlock, a high rank, or a Starr Drop reward.

--- EXAMPLES ---

[EXAMPLE 1]
Post Text: "This is the 5th time I've lost to an Edgar in a row. FIX YOUR GAME SUPERCELL!!"
Post Media: 
Output:
{{
  "post_classification": "Feedback/Rant",
  "post_sentiment": "Anger",
  "sentiment_analysis": "The user is clearly angry, using all-caps ('FIX YOUR GAME') and expressing frustration at repeatedly losing to a specific Brawler (Edgar). The defeat screen image confirms the loss."
}}

[EXAMPLE 2]
Post Text: "I CAN'T BELIEVE I FINALLY GOT HIM!!"
Post Media: 
Output:
{{
  "post_classification": "Achievement/Loot",
  "post_sentiment": "Joy",
  "sentiment_analysis": "The user is excited and happy, indicated by the all-caps text and the celebratory nature of unlocking a new legendary Brawler, which is a rare event."
}}

[EXAMPLE 3]
Post Text: "Check out this insane 1v3 I pulled off with Mortis"
Post Media: [Video showing a fast-paced gameplay clip where the player (Mortis) defeats three opponents]
Output:
{{
  "post_classification": "Gameplay Clip",
  "post_sentiment": "Joy",
  "sentiment_analysis": "The user is proud and excited about their 'insane 1v3' play. This is a clear expression of joy and pride in their own skill. The video clip demonstrates the achievement."
}}

--- TASK ---

Analyze the following post and provide ONLY the JSON output. Do not include '```json' or any other text outside the JSON block.

[POST CONTENT]
Title: {post_title}
Text: {post_text}

[POST MEDIA]
"""

In [None]:
def get_gemini_label(post_row):
    """
    Takes a row from the DataFrame, sends its text and local media to Gemini,
    and returns the raw JSON string response.
    """
    post_title = post_row['title']
    post_text = post_row['text']
    media_path = post_row['local_media_path']
    post_hint = post_row['post_hint']

    # 1. Format the text part of the prompt
    # Handle potential NaN values in text/title before formatting
    safe_title = str(post_title) if pd.notna(post_title) else ""
    safe_text = str(post_text) if pd.notna(post_text) else ""
    prompt = LABELING_PROMPT_TEMPLATE.format(post_title=safe_title, post_text=safe_text)

    
    # 2. Prepare the media part
    media_payload = []
    uploaded_file_resource = None # Keep track of the file to delete later
    if pd.notna(media_path) and os.path.exists(media_path):
        try:
            # Use genai.upload_file for persistent storage and retrieval
            # print(f"Uploading {media_path}...")
            uploaded_file_resource = genai.upload_file(path=media_path)
            media_payload.append(uploaded_file_resource)
            
            # Wait for the file to be processed, especially important for videos
            # Add a timeout to prevent infinite loops
            processing_timeout = 60 # seconds
            start_time = time.time()
            while uploaded_file_resource.state.name == "PROCESSING":
                if time.time() - start_time > processing_timeout:
                     print(f"Warning: File processing timed out for {media_path}")
                     # Clean up the timed-out file
                     if uploaded_file_resource: genai.delete_file(uploaded_file_resource.name)
                     return {"error": "File processing timed out"}
                time.sleep(2)
                uploaded_file_resource = genai.get_file(uploaded_file_resource.name)
            
            if uploaded_file_resource.state.name == "FAILED":
                print(f"File upload failed: {media_path}")
                # No need to delete if it failed during upload/processing
                return {"error": "File upload failed"}

        except Exception as e:
            print(f"Error uploading/processing file {media_path}: {e}")
            # Attempt cleanup if resource exists
            if uploaded_file_resource: 
                try: genai.delete_file(uploaded_file_resource.name)
                except Exception: pass # Ignore delete error if upload failed badly
            return {"error": str(e)}
    else:
        media_payload.append("No media provided.")

    # 3. Combine prompt and media and make the API call
    try:
        full_prompt = [prompt] + media_payload
        response = gemini_model.generate_content(full_prompt)
        
        # Clean up the successfully processed file
        if uploaded_file_resource: 
            genai.delete_file(uploaded_file_resource.name)
        
        # Return the clean text response, ready for JSON parsing
        return response.text
    
    except Exception as e:
        print(f"Error during Gemini API call for post {post_row['id']}: {e}")
        # Clean up if an API call error occurred after upload
        if uploaded_file_resource: 
           try: genai.delete_file(uploaded_file_resource.name)
           except Exception: pass
        return {"error": str(e)}



In [None]:
# --- Step 2 (Revised): AI-Powered Labeling (Interruptible & Resumable) ---

# --- Choose which raw data file to use for labeling --- 
# RAW_DATA_TO_LABEL = RAW_DATA_CSV # Use PRAW data
RAW_DATA_TO_LABEL = os.path.join(DATA_DIR, 'raw_data_pmaw.csv') # Use PMAW data

# --- 1. Define How Many New Posts to Label in This Batch ---
NEW_LABEL_TARGET = 500 

# --- 2. Find Out What's Already Labeled ---
try:
    df_old_labeled = pd.read_csv(LABELED_DATA_CSV)
    already_labeled_ids = set(df_old_labeled['id'])
    print(f"Loaded {len(df_old_labeled)} previously labeled posts from {LABELED_DATA_CSV}.")
except FileNotFoundError:
    df_old_labeled = pd.DataFrame() # Create an empty one if it's the first run
    already_labeled_ids = set()
    print(f"No existing file found at {LABELED_DATA_CSV}. Starting from scratch.")

# --- 3. Find Out What's Left to Label ---
try:
    df_all_raw = pd.read_csv(RAW_DATA_TO_LABEL)
    print(f"Reading raw data from: {RAW_DATA_TO_LABEL}")
except FileNotFoundError:
    print(f"Error: Raw data file not found at {RAW_DATA_TO_LABEL}. Please run Step 1 first.")
    raise
    
# Filter out posts we've already labeled
df_to_label = df_all_raw[~df_all_raw['id'].isin(already_labeled_ids)].copy()

# Limit this run to the NEW_LABEL_TARGET
df_to_label = df_to_label.head(NEW_LABEL_TARGET)

if len(df_to_label) == 0:
    print("No new posts to label based on the selected raw data file.")
else:
    print(f"Found {len(df_to_label)} new posts to label. Starting labeling...")
    
    # --- 4. Label New Posts (Interruptible Loop) ---
    df_to_label['gemini_raw_json'] = None # Create the column to fill
    
    try:
        # Use iterrows() for a row-by-row, interruptible loop
        for index, row in tqdm(df_to_label.iterrows(), total=len(df_to_label), desc="Generating AI Labels"):
            json_response = get_gemini_label(row)
            
            # Save result immediately to the DataFrame
            # Use .loc for safer assignment
            df_to_label.loc[index, 'gemini_raw_json'] = json_response

    except KeyboardInterrupt:
        print("\n\n--- INTERRUPTED BY USER ---")
        print("Labeling process stopped. Will parse and save completed posts...")

    finally:
        # --- 5. Process and Parse Whatever Was Completed ---
        
        # Filter to only the rows that were actually processed in this run
        df_newly_processed = df_to_label.dropna(subset=['gemini_raw_json']).copy()
        
        if len(df_newly_processed) == 0:
            print("No new posts were successfully labeled in this session.")
        else:
            print(f"\nParsing {len(df_newly_processed)} newly labeled posts...")
            
            # --- JSON Parsing ---
            parsed_labels = []
            for index, row in df_newly_processed.iterrows():
                raw_json = row['gemini_raw_json']
                label_entry = {'id': row['id'], 'labeling_error': None}
                try:
                    # Check if it's an error dictionary from get_gemini_label
                    if isinstance(raw_json, dict) and 'error' in raw_json:
                        label_entry['labeling_error'] = str(raw_json['error'])
                    elif isinstance(raw_json, str):
                        # Clean up potential markdown blocks and extra whitespace
                        clean_json_str = raw_json.strip().lstrip('```json').rstrip('```').strip()
                        data = json.loads(clean_json_str)
                        label_entry['post_classification'] = data.get('post_classification')
                        label_entry['post_sentiment'] = data.get('post_sentiment')
                        label_entry['sentiment_analysis'] = data.get('sentiment_analysis')
                    else:
                        # Handle unexpected data types
                        label_entry['labeling_error'] = f"Unexpected data type in gemini_raw_json: {type(raw_json)}"

                except json.JSONDecodeError as e:
                    label_entry['labeling_error'] = f"JSON Decode Error: {e} - Raw: {raw_json[:100]}..." # Log snippet
                except Exception as e:
                    label_entry['labeling_error'] = f"General Error: {str(e)}"
                
                parsed_labels.append(label_entry)

            # Convert parsed data to a DataFrame and merge with the newly processed info
            df_labels = pd.DataFrame(parsed_labels)
            # Merge, keeping all columns from df_newly_processed
            df_new_labeled_final = pd.merge(df_newly_processed.drop(columns=['gemini_raw_json']), df_labels, on='id', how='left')

            # Filter out posts that had errors during labeling or parsing
            df_new_golden = df_new_labeled_final[df_new_labeled_final['labeling_error'].isna()].copy()
            df_errors = df_new_labeled_final[df_new_labeled_final['labeling_error'].notna()]

            print(f"Successfully parsed {len(df_new_golden)} new labels.")
            if len(df_errors) > 0:
                print(f"{len(df_errors)} posts had errors during labeling/parsing and will be skipped.")
                # Optional: print some errors for debugging
                # print("Example errors:")
                # print(df_errors[['id', 'labeling_error']].head())

            # --- 6. Combine Old and New Datasets and Save ---
            if len(df_new_golden) > 0:
                 # Ensure columns match before concatenating
                if not df_old_labeled.empty:
                    # Align columns based on the existing labeled data file
                    cols_existing = df_old_labeled.columns
                    cols_new = df_new_golden.columns
                    
                    # Add missing columns to new data (filled with NaN)
                    for col in cols_existing:
                        if col not in cols_new:
                            df_new_golden[col] = None
                            
                    # Add missing columns to old data (unlikely but possible)
                    for col in cols_new:
                         if col not in cols_existing:
                            df_old_labeled[col] = None
                    
                    # Reorder new data columns to match old data
                    df_new_golden = df_new_golden[cols_existing]
                else:
                     # If it's the first run, define the columns based on the new data
                     df_old_labeled = pd.DataFrame(columns=df_new_golden.columns)

                # Concatenate the old labeled data with the new golden data
                df_combined = pd.concat([df_old_labeled, df_new_golden], ignore_index=True)
                
                # Save the final "golden dataset"
                df_combined.to_csv(LABELED_DATA_CSV, index=False)
                print(f"\nGolden dataset updated. Total labeled posts: {len(df_combined)}")
                print(f"Saved to: {LABELED_DATA_CSV}")

                # Display results
                print("\n--- Updated Labeled Data Head (Last 5 rows) ---")
                print(df_combined[['id', 'title', 'post_sentiment', 'post_classification']].tail())

                print("\n--- Updated Sentiment Distribution ---")
                print(df_combined['post_sentiment'].value_counts())
            else:
                print("No new valid labels were generated in this session. Labeled file remains unchanged.")

---

## Step 3: Data Splitting

We split our `labeled_data_1k.csv` into three distinct sets: `train_set` (for training), `validation_set` (for tuning), and `test_set` (for final evaluation). We use **stratification** on the `post_sentiment` column to ensure all three sets have a similar distribution of emotions.

In [None]:
# Load the golden dataset
try:
    df_labeled = pd.read_csv(LABELED_DATA_CSV)
    print(f"Loaded {len(df_labeled)} labeled posts from {LABELED_DATA_CSV}")
    
    # --- Data Cleaning before split ---
    # Drop rows where post_sentiment might be missing (if any labeling errors occurred)
    df_labeled.dropna(subset=['post_sentiment'], inplace=True)
    print(f"Using {len(df_labeled)} posts with valid sentiment labels for splitting.")
    
    if len(df_labeled) < 10: # Need enough data to split reasonably
        print("Warning: Very few labeled posts. Splitting might result in tiny datasets.")
        # Handle this case - maybe stop or adjust split sizes
        # For now, we'll proceed but be aware of the issue.

    # Define split sizes
    TEST_SIZE = 0.10  # 10% for the final test set
    VALIDATION_SIZE = 0.10 # 10% for the validation set (relative to original size)

    # Check if there's enough data for stratification (at least 2 samples per class needed for sklearn)
    sentiment_counts = df_labeled['post_sentiment'].value_counts()
    if (sentiment_counts < 2).any():
        print("Warning: Some sentiment classes have fewer than 2 samples.")
        print("Stratification might fail or be unreliable. Consider labeling more data for rare classes.")
        print(sentiment_counts)
        # Optional: Proceed without stratification if necessary, though less ideal
        # stratify_param = None 
        stratify_param = df_labeled['post_sentiment']
    else:
        stratify_param = df_labeled['post_sentiment']
        
    # Adjust sizes if total data is very small
    n_total = len(df_labeled)
    n_test = max(1, int(n_total * TEST_SIZE)) # Ensure at least 1 test sample
    n_val = max(1, int(n_total * VALIDATION_SIZE)) # Ensure at least 1 val sample
    
    # Adjust relative validation size based on actual counts
    if n_total - n_test <= 0: # Handle edge case where test size >= total size
       print("Error: Not enough data to create a non-empty training/validation set after test split.")
       # You might want to raise an error here or adjust sizes further
       relative_val_size = 0.5 # Default fallback
    else:
        relative_val_size = n_val / (n_total - n_test)

    # 1. Split off the test set
    train_val_df, test_df = train_test_split(
        df_labeled,
        test_size=n_test, # Use calculated count
        random_state=42,
        stratify=stratify_param # Use the determined stratify parameter
    )

    # 2. Split the remaining into train and validation
    # Need to check stratification again for the smaller train_val_df
    if not train_val_df.empty:
        train_val_sentiment_counts = train_val_df['post_sentiment'].value_counts()
        if (train_val_sentiment_counts < 2).any():
             print("Warning: Stratifying train/validation split may be unreliable due to small class counts in the remainder.")
             stratify_train_val = None # Fallback to no stratification
        else:
             stratify_train_val = train_val_df['post_sentiment']
             
        train_df, val_df = train_test_split(
            train_val_df,
            test_size=relative_val_size, # Use calculated relative size
            random_state=42,
            stratify=stratify_train_val
        )
    else:
        # Handle case where train_val_df is empty (shouldn't happen with max(1) adjustments)
        train_df = pd.DataFrame(columns=df_labeled.columns)
        val_df = pd.DataFrame(columns=df_labeled.columns)
        print("Warning: train_val_df was empty after test split.")


    # 3. Save the splits
    train_df.to_csv(TRAIN_SET_CSV, index=False)
    val_df.to_csv(VALIDATION_SET_CSV, index=False)
    test_df.to_csv(TEST_SET_CSV, index=False)

    # 4. Report the results
    print("\nData splitting complete:")
    print(f"  Training set:   {len(train_df)} rows -> {TRAIN_SET_CSV}")
    print(f"  Validation set: {len(val_df)} rows -> {VALIDATION_SET_CSV}")
    print(f"  Test set:       {len(test_df)} rows -> {TEST_SET_CSV}")

    # Check distributions (use normalize=True for percentages)
    if not train_df.empty:
        print("\nTraining Set Sentiment Distribution:")
        print(train_df['post_sentiment'].value_counts(normalize=True))
    if not val_df.empty:
        print("\nValidation Set Sentiment Distribution:")
        print(val_df['post_sentiment'].value_counts(normalize=True))
    if not test_df.empty:
        print("\nTest Set Sentiment Distribution:")
        print(test_df['post_sentiment'].value_counts(normalize=True))

except FileNotFoundError:
    print(f"Error: Labeled data file not found at {LABELED_DATA_CSV}")
    print("Please run Step 2 successfully before running Step 3.")
except Exception as e:
     print(f"An error occurred during data splitting: {e}")

---

## Step 4: Phase 1 Training (Fine-Tuning the Specialists)

This is where we build our three specialist models. Each model (Text, Image, Video) is fine-tuned *independently* to predict the **overall post sentiment**. This teaches them what features in their modality (e.g., words, pixels, motion) correspond to emotions like "Joy" or "Anger" in the context of Brawl Stars.

**Note:** The code below provides the *structure* and *placeholders*. You will need to implement the detailed PyTorch/TensorFlow logic (Datasets, DataLoaders, model definitions, training loops, and evaluation loops) for each model.

In [None]:
# --- Common Setup for Phase 1 ---
import torch

# Check if train_df exists from Step 3, otherwise load it
if 'train_df' not in locals() or train_df.empty:
    try:
        train_df = pd.read_csv(TRAIN_SET_CSV)
        if train_df.empty:
             raise ValueError("Loaded train_set.csv is empty.")
    except FileNotFoundError:
        print(f"Error: {TRAIN_SET_CSV} not found. Please run Step 3 successfully.")
        raise
    except ValueError as e:
         print(f"Error: {e} Check {TRAIN_SET_CSV}.")
         raise

# Make sure sentiment labels are derived correctly even if some classes were dropped during split
if 'df_labeled' not in locals() or df_labeled.empty:
    try:
         df_labeled = pd.read_csv(LABELED_DATA_CSV)
         df_labeled.dropna(subset=['post_sentiment'], inplace=True)
    except FileNotFoundError:
         print(f"Error: Cannot define labels. {LABELED_DATA_CSV} not found.")
         raise
    except Exception as e:
        print(f"Error reading {LABELED_DATA_CSV}: {e}")
        raise
        
SENTIMENT_LABELS = sorted(df_labeled['post_sentiment'].unique())
if not SENTIMENT_LABELS:
    raise ValueError("No valid sentiment labels found in the labeled data.")

label_to_id = {label: i for i, label in enumerate(SENTIMENT_LABELS)}
id_to_label = {i: label for i, label in enumerate(SENTIMENT_LABELS)}
NUM_LABELS = len(SENTIMENT_LABELS)

print(f"Found {NUM_LABELS} labels based on {LABELED_DATA_CSV}: {SENTIMENT_LABELS}")
print(f"Label map: {label_to_id}")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEXT_MODEL_NAME = "distilbert-base-uncased"
IMAGE_MODEL_NAME = "openai/clip-vit-base-patch32"
VIDEO_PROCESSOR_NAME = "MCG-NJU/videomae-base" 
VIDEO_MODEL_NAME = "MCG-NJU/videomae-base-finetuned-kinetics"
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 2e-5

### 4.A: Text Model (DistilBERT)

In [None]:
# --- Step 4.A: Text Model (DistilBERT) ---

import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
    DistilBertTokenizer, 
    DistilBertForSequenceClassification,
    get_linear_schedule_with_warmup
)
from sklearn.metrics import accuracy_score
import os
from tqdm.auto import tqdm

print("--- Starting Phase 1: Text Model Training ---")

# --- 1. Configuration & Setup ---
TEXT_MODEL_NAME = "distilbert-base-uncased"
MODEL_SAVE_PATH = "./models/text_specialist"
BATCH_SIZE = 16
EPOCHS = 5  # Start with 3-5, you can increase if needed
LEARNING_RATE = 2e-5
MAX_LEN = 128 # Max token length for a post

# Ensure model save directory exists
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

# Set device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# --- 2. Load Data and Create Label Maps ---
try:
    # These should have been loaded or created in the common setup cell
    if 'train_df' not in locals() or train_df.empty:
        train_df = pd.read_csv(TRAIN_SET_CSV)
    if 'val_df' not in locals() or val_df.empty:
        val_df = pd.read_csv(VALIDATION_SET_CSV)
        
    # Ensure labels defined in common setup are used
    if 'label_to_id' not in locals() or not label_to_id:
        raise ValueError("Label map not defined in common setup cell.")

except FileNotFoundError:
    print("\n--- ERROR ---")
    print(f"Could not find {TRAIN_SET_CSV} or {VALIDATION_SET_CSV}")
    print("Please run Step 3: Data Splitting successfully.")
    raise
except ValueError as e:
     print(f"Error: {e}")
     raise

print(f"Using {NUM_LABELS} labels for text model: {label_to_id}")

# --- 3. Define Custom PyTorch Dataset ---
class TextSentimentDataset(Dataset):
    def __init__(self, dataframe, tokenizer, label_map, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        # Combine title and text for a richer input, handle NaNs
        self.texts = (dataframe['title'].fillna('') + " [SEP] " + dataframe['text'].fillna('')).tolist()
        # Map labels, handle potential missing labels gracefully (e.g., assign -1 or skip)
        self.labels = dataframe['post_sentiment'].map(label_map).fillna(-1).tolist() # Use -1 for missing/unmappable labels
        self.max_len = max_len
        
        # Filter out rows with invalid labels (-1)
        self.valid_indices = [i for i, label in enumerate(self.labels) if label != -1]

    def __len__(self):
        # Return the count of valid samples only
        return len(self.valid_indices)

    def __getitem__(self, index):
        # Map the requested index to the valid sample index
        original_index = self.valid_indices[index]
        
        text = str(self.texts[original_index])
        label = self.labels[original_index]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# --- 4. Initialize Tokenizer, Datasets, and DataLoaders ---
print("Loading tokenizer and building datasets...")
tokenizer = DistilBertTokenizer.from_pretrained(TEXT_MODEL_NAME)

train_dataset = TextSentimentDataset(train_df, tokenizer, label_to_id, MAX_LEN)
val_dataset = TextSentimentDataset(val_df, tokenizer, label_to_id, MAX_LEN)

if len(train_dataset) == 0:
    raise ValueError("Training dataset is empty after filtering invalid labels. Check data splitting and labeling.")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

print(f"Text Train dataset size: {len(train_dataset)}")
print(f"Text Validation dataset size: {len(val_dataset)}")

# --- 5. Load Model, Optimizer, and Scheduler ---
print(f"Loading pre-trained model: {TEXT_MODEL_NAME}")
text_model = DistilBertForSequenceClassification.from_pretrained(
    TEXT_MODEL_NAME, 
    num_labels=NUM_LABELS
).to(device)

# Optimizer
optimizer = AdamW(text_model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0, # Default, no warmup
    num_training_steps=total_steps
)

# --- 6. The Training & Validation Loop ---
print("--- Starting Training ---")

best_val_accuracy = 0.0

for epoch in range(EPOCHS):
    print(f"\n--- Epoch {epoch + 1} / {EPOCHS} ---")
    
    # --- Training Phase ---
    text_model.train()
    total_train_loss = 0
    train_progress_bar = tqdm(train_loader, desc="Training", leave=False)

    for batch in train_progress_bar:
        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Clear old gradients
        text_model.zero_grad()

        # Forward pass
        outputs = text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        # Get loss and logits
        loss = outputs.loss
        # Check if loss is valid (not NaN)
        if torch.isnan(loss):
             print("Warning: NaN loss detected during training. Skipping batch.")
             continue # Skip backprop for this batch
             
        total_train_loss += loss.item()

        # Backward pass
        loss.backward()
        # Clip gradients to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(text_model.parameters(), 1.0)
        
        # Update weights
        optimizer.step()
        # Update learning rate
        scheduler.step()
        
        train_progress_bar.set_postfix({'loss': loss.item()})

    if len(train_loader) > 0:
      avg_train_loss = total_train_loss / len(train_loader)
      print(f"Average Training Loss: {avg_train_loss:.4f}")
    else:
      print("No batches processed in training.")

    # --- Validation Phase ---
    if len(val_loader) == 0:
        print("Validation dataset is empty. Skipping validation.")
        # Optionally save model even without validation if needed
        # text_model.save_pretrained(MODEL_SAVE_PATH)
        # tokenizer.save_pretrained(MODEL_SAVE_PATH)
        continue # Skip to next epoch
        
    text_model.eval()
    total_val_loss = 0
    all_preds = []
    all_labels = []
    
    val_progress_bar = tqdm(val_loader, desc="Validation", leave=False)

    with torch.no_grad(): # No need to calculate gradients
        for batch in val_progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            logits = outputs.logits
            
            if loss is not None and not torch.isnan(loss):
                total_val_loss += loss.item()
            else:
                # Handle cases where loss might be None or NaN during eval
                 print("Warning: Invalid loss encountered during validation.")

            # Get predictions
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            labels_cpu = labels.cpu().numpy()
            
            all_preds.extend(preds)
            all_labels.extend(labels_cpu)

    # Ensure there are labels to calculate metrics on
    if not all_labels:
        print("No valid labels found in validation set for metric calculation.")
        continue
        
    avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
    val_accuracy = accuracy_score(all_labels, all_preds)

    print(f"Validation Loss: {avg_val_loss:.4f}")
    print(f"Validation Accuracy: {val_accuracy:.4f}")

    # --- Save the Best Model ---
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        print("New best model! Saving...")
        text_model.save_pretrained(MODEL_SAVE_PATH)
        tokenizer.save_pretrained(MODEL_SAVE_PATH)

print("\n--- Text Model Training Complete ---")
if best_val_accuracy > 0:
   print(f"Best Validation Accuracy: {best_val_accuracy:.4f}")
   print(f"Model and tokenizer saved to: {MODEL_SAVE_PATH}")
else:
   print("No best model saved (validation accuracy did not improve or validation was skipped).")

### 4.B: Image Model (ResNet-50)

In [None]:
# --- Step 4.B: Image Model (CLIP-Vision) ---

import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
    CLIPImageProcessor, 
    CLIPVisionModel,
    get_linear_schedule_with_warmup
)
from sklearn.metrics import accuracy_score
import os
from tqdm.auto import tqdm
from PIL import Image, UnidentifiedImageError

print("--- Starting Phase 1: Image Model Training (CLIP) ---")

# --- 1. Configuration & Setup ---
IMAGE_MODEL_NAME = "openai/clip-vit-base-patch32" 
MODEL_SAVE_PATH = "./models/image_specialist.pth" # Save state dict
BATCH_SIZE = 8 # Use a SMALLER batch size for images
EPOCHS = 10 # Train for a few more epochs since the dataset is tiny
LEARNING_RATE = 1e-5 # Use a smaller LR for fine-tuning vision models

# Set device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# --- 2. Load Data and Re-create Label Maps ---
try:
    if 'train_df' not in locals() or train_df.empty:
        train_df = pd.read_csv(TRAIN_SET_CSV)
    if 'val_df' not in locals() or val_df.empty:
        val_df = pd.read_csv(VALIDATION_SET_CSV)
    if 'label_to_id' not in locals() or not label_to_id:
         raise ValueError("Label map not defined in common setup cell.")
except FileNotFoundError:
    print(f"Error: Could not find {TRAIN_SET_CSV} or {VALIDATION_SET_CSV}")
    raise
except ValueError as e:
    print(f"Error: {e}")
    raise

print(f"Using {NUM_LABELS} labels for image model: {label_to_id}")

# --- 3. Define Custom Image Dataset ---
class ImageSentimentDataset(Dataset):
    def __init__(self, dataframe, processor, label_map):
        self.processor = processor
        self.label_map = label_map
        
        # Filter for valid image paths and map labels
        valid_data = []
        for _, row in dataframe.iterrows():
            path = row['local_media_path']
            hint = row['post_hint']
            sentiment = row['post_sentiment']
            
            # Check hint, path validity, and if label exists in map
            if hint == 'image' and pd.notna(path) and os.path.exists(path) and sentiment in label_map:
                 # Basic extension check (can be improved)
                 if path.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
                     label_id = label_map[sentiment]
                     valid_data.append({'path': path, 'label': label_id})
        
        self.data = pd.DataFrame(valid_data)
        self.paths = self.data['path'].tolist()
        self.labels = self.data['label'].tolist()

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img_path = self.paths[index]
        label = self.labels[index]
        
        try:
            # Open the image file
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError, OSError, Exception) as e:
            print(f"Warning: Could not load/process image {img_path}. Using a blank image. Error: {e}")
            image = Image.new("RGB", (224, 224), (0, 0, 0)) # Return blank image
            
        # Process the image (resize, normalize, etc.)
        try:
           processed_image = self.processor(
               images=image, 
               return_tensors="pt"
           )
           pixel_values = processed_image['pixel_values'].squeeze() 
        except Exception as e:
            print(f"Warning: Error processing image {img_path} with CLIP processor. Using blank tensor. Error: {e}")
            # Determine expected tensor size from processor config if possible, else default
            c, h, w = 3, 224, 224 # Default CLIP size
            try: 
                 if hasattr(self.processor, 'size'):
                     h = self.processor.size['height']
                     w = self.processor.size['width']
            except: pass
            pixel_values = torch.zeros((c, h, w))

        return {
            'pixel_values': pixel_values,
            'labels': torch.tensor(label, dtype=torch.long)
        }

# --- 4. Define Custom Model with Classification Head ---
class CustomCLIPModel(nn.Module):
    def __init__(self, model_name, num_labels):
        super(CustomCLIPModel, self).__init__()
        self.clip_vision_model = CLIPVisionModel.from_pretrained(model_name)
        embedding_dim = self.clip_vision_model.config.hidden_size
        self.classifier = nn.Linear(embedding_dim, num_labels)
        
    def forward(self, pixel_values, labels=None):
        outputs = self.clip_vision_model(
            pixel_values=pixel_values
        )
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, NUM_LABELS), labels.view(-1))
        return {"loss": loss, "logits": logits}
        

# --- 5. Initialize Processor, Datasets, and DataLoaders ---
print("Loading processor and building datasets...")
image_processor = CLIPImageProcessor.from_pretrained(IMAGE_MODEL_NAME)

train_dataset = ImageSentimentDataset(train_df, image_processor, label_to_id)
val_dataset = ImageSentimentDataset(val_df, image_processor, label_to_id)

print(f"\n--- Image Dataset Sizes ---")
print(f"Total valid training images found: {len(train_dataset)}")
print(f"Total valid validation images found: {len(val_dataset)}")

if len(train_dataset) == 0:
    print("ERROR: No valid images found in training set after filtering. Cannot train image model.")
    print("Check file paths, image formats, and labels in train_set.csv.")
    # Optionally skip image training if this happens
    # proceed_without_image = True 
    raise ValueError("No valid training data for image model.")
else:
   proceed_without_image = False

if not proceed_without_image:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    # Only create val_loader if val_dataset is not empty
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE) if len(val_dataset) > 0 else None

    # --- 6. Load Model, Optimizer, and Scheduler ---
    print(f"Loading pre-trained model: {IMAGE_MODEL_NAME}")
    image_model = CustomCLIPModel(IMAGE_MODEL_NAME, NUM_LABELS).to(device)

    # Optimizer
    optimizer = AdamW(image_model.parameters(), lr=LEARNING_RATE)

    # Learning rate scheduler
    total_steps = len(train_loader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0, 
        num_training_steps=total_steps
    )

    # --- 7. The Training & Validation Loop ---
    print("--- Starting Training --- ")
    best_val_accuracy = 0.0

    for epoch in range(EPOCHS):
        print(f"\n--- Epoch {epoch + 1} / {EPOCHS} ---")
        
        # --- Training Phase ---
        image_model.train()
        total_train_loss = 0
        train_progress_bar = tqdm(train_loader, desc="Training", leave=False)

        for batch in train_progress_bar:
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)

            image_model.zero_grad()

            outputs = image_model(
                pixel_values=pixel_values,
                labels=labels
            )

            loss = outputs['loss']
            if torch.isnan(loss):
                 print("Warning: NaN loss detected during image training. Skipping batch.")
                 continue
            total_train_loss += loss.item()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(image_model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            train_progress_bar.set_postfix({'loss': loss.item()})

        avg_train_loss = total_train_loss / len(train_loader) if len(train_loader) > 0 else 0
        print(f"Average Training Loss: {avg_train_loss:.4f}")

        # --- Validation Phase ---
        if val_loader is None or len(val_dataset) == 0:
            print("Validation set is empty. Skipping validation.")
            # Save model from the last epoch if no validation
            torch.save(image_model.state_dict(), MODEL_SAVE_PATH)
            continue
            
        image_model.eval()
        total_val_loss = 0
        all_preds = []
        all_labels = []
        
        val_progress_bar = tqdm(val_loader, desc="Validation", leave=False)

        with torch.no_grad():
            for batch in val_progress_bar:
                pixel_values = batch['pixel_values'].to(device)
                labels = batch['labels'].to(device)

                outputs = image_model(
                    pixel_values=pixel_values,
                    labels=labels
                )

                loss = outputs['loss']
                logits = outputs['logits']
                
                if loss is not None and not torch.isnan(loss):
                    total_val_loss += loss.item()

                preds = torch.argmax(logits, dim=1).cpu().numpy()
                labels_cpu = labels.cpu().numpy()
                
                all_preds.extend(preds)
                all_labels.extend(labels_cpu)
                
        if not all_labels: # Check if validation produced any results
             print("No valid samples processed during validation.")
             continue

        avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
        val_accuracy = accuracy_score(all_labels, all_preds)

        print(f"Validation Loss: {avg_val_loss:.4f}")
        print(f"Validation Accuracy: {val_accuracy:.4f}")

        # --- Save the Best Model --- 
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            print("New best model! Saving...")
            torch.save(image_model.state_dict(), MODEL_SAVE_PATH)

    print("\n--- Image Model Training Complete ---")
    if best_val_accuracy > 0:
        print(f"Best Validation Accuracy: {best_val_accuracy:.4f}")
    elif len(val_dataset) == 0:
         print("Model saved from last epoch as validation was skipped.")
    else:
        print("Validation accuracy did not improve.")
    print(f"Model weights (state_dict) saved to: {MODEL_SAVE_PATH}")

else:
    print("Skipping Image Model training as no valid training data was found.")

### 4.C: Video Model

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
# --- Step 4.C: Video Model (VideoMAE) ---

import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
    VideoMAEImageProcessor,
    VideoMAEForVideoClassification,
    get_linear_schedule_with_warmup
)
from sklearn.metrics import accuracy_score
import os
from tqdm.auto import tqdm
from PIL import Image, UnidentifiedImageError
import cv2  # OpenCV for video processing
import numpy as np

print("--- Starting Phase 1: Video Model Training (VideoMAE) ---")

# --- 1. Configuration & Setup ---

VIDEO_PROCESSOR_NAME = "MCG-NJU/videomae-base" 
VIDEO_MODEL_NAME = "MCG-NJU/videomae-base-finetuned-kinetics" 
MODEL_SAVE_PATH = "./models/video_specialist"
BATCH_SIZE = 2 # Keep small for video
EPOCHS = 20 # Train longer for small dataset
LEARNING_RATE = 1e-5
NUM_FRAMES = 16

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# --- 2. Load Data and Re-create Label Maps ---
try:
    if 'train_df' not in locals() or train_df.empty:
        train_df = pd.read_csv(TRAIN_SET_CSV)
    if 'val_df' not in locals() or val_df.empty:
        val_df = pd.read_csv(VALIDATION_SET_CSV)
    if 'label_to_id' not in locals() or not label_to_id:
         raise ValueError("Label map not defined in common setup cell.")
except FileNotFoundError:
    print(f"Error: Could not find {TRAIN_SET_CSV} or {VALIDATION_SET_CSV}")
    raise
except ValueError as e:
    print(f"Error: {e}")
    raise

print(f"Using {NUM_LABELS} labels for video model: {label_to_id}")


# --- 3. Define Custom Video Dataset ---
class VideoSentimentDataset(Dataset):
    def __init__(self, dataframe, processor, label_map, num_frames):
        self.processor = processor
        self.label_map = label_map
        self.num_frames = num_frames
        
        # Filter for valid video paths and map labels
        valid_data = []
        for _, row in dataframe.iterrows():
            path = row['local_media_path']
            hint = row['post_hint']
            sentiment = row['post_sentiment']
            
            if hint == 'hosted:video' and pd.notna(path) and os.path.exists(path) and sentiment in label_map:
                 # Basic extension check
                 if path.lower().endswith('.mp4'): # Add other video formats if needed
                     label_id = label_map[sentiment]
                     valid_data.append({'path': path, 'label': label_id})
        
        self.data = pd.DataFrame(valid_data)
        self.paths = self.data['path'].tolist()
        self.labels = self.data['label'].tolist()

    def __len__(self):
        return len(self.paths)

    def _sample_frames(self, video_path):
        frames = []
        cap = None # Initialize cap outside try block for finally clause
        try:
            cap = cv2.VideoCapture(video_path)
            if not cap.isOpened():
                raise IOError(f"Cannot open video file: {video_path}")
                
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            if total_frames <= 0:
                raise IOError(f"Video file seems empty or has 0 frames: {video_path}")
            
            indices = np.linspace(0, total_frames - 1, num=self.num_frames, dtype=int)
            
            processed_frames = 0
            for idx in indices:
                cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
                ret, frame = cap.read()
                if ret:
                    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frames.append(Image.fromarray(frame_rgb))
                    processed_frames += 1
                # else: # Optional: Log if a frame read fails
                #    print(f"Warning: Failed to read frame {idx} from {video_path}")
            
            if processed_frames < self.num_frames // 2: # Check if we got a reasonable number of frames
                print(f"Warning: Read significantly fewer frames ({processed_frames}/{self.num_frames}) than expected from {video_path}")
            if not frames: 
                 raise IOError(f"Could not read any frames from: {video_path}")
            # Ensure we return exactly num_frames, duplicating last frame if necessary
            while len(frames) < self.num_frames:
                 if frames: frames.append(frames[-1]) # Duplicate last frame
                 else: # If absolutely no frames were read, return blanks
                      raise IOError("No frames read, cannot duplicate.")

            return frames[:self.num_frames] # Ensure exactly num_frames are returned

        except Exception as e:
            print(f"Error processing video {video_path}: {e}. Returning blank frames.")
            # Determine expected size from processor if possible
            h, w = 224, 224 # Default
            try: 
                 if hasattr(self.processor, 'size'):
                     h = self.processor.size['height']
                     w = self.processor.size['width']
            except: pass
            return [Image.new("RGB", (w, h), (0, 0, 0))] * self.num_frames
        finally:
            if cap is not None: 
                cap.release()

    def __getitem__(self, index):
        video_path = self.paths[index]
        label = self.labels[index]
        
        frames = self._sample_frames(video_path)
        
        try:
            # VideoMAE expects a list of images
            processed_video = self.processor(
                images=frames, 
                return_tensors="pt"
            )
            # The processor should handle the batch dimension correctly for videos
            pixel_values = processed_video['pixel_values'].squeeze(0) # Remove potential extra batch dim if added
            
        except Exception as e:
            print(f"Warning: Error processing frames from {video_path} with VideoMAE processor. Using blank tensor. Error: {e}")
            # Determine expected tensor size from processor config
            t, c, h, w = self.num_frames, 3, 224, 224 # Default
            try: 
                 if hasattr(self.processor, 'size'):
                     h = self.processor.size['height']
                     w = self.processor.size['width']
            except: pass
            pixel_values = torch.zeros((t, c, h, w)) # Shape (T, C, H, W)

        return {
            'pixel_values': pixel_values,
            'labels': torch.tensor(label, dtype=torch.long)
        }
        

# --- 4. Initialize Processor, Datasets, and DataLoaders ---
print(f"Loading processor from: {VIDEO_PROCESSOR_NAME}")
video_image_processor = VideoMAEImageProcessor.from_pretrained(VIDEO_PROCESSOR_NAME) 

train_dataset = VideoSentimentDataset(train_df, video_image_processor, label_to_id, NUM_FRAMES)
val_dataset = VideoSentimentDataset(val_df, video_image_processor, label_to_id, NUM_FRAMES)

print(f"\n--- Video Dataset Sizes ---")
print(f"Total valid training videos found: {len(train_dataset)}")
print(f"Total valid validation videos found: {len(val_dataset)}")

if len(train_dataset) == 0:
    print("WARNING: No valid videos found in training set after filtering. Skipping video model training.")
    video_model_trained = False
else:
    video_model_trained = True
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    # Only create val_loader if val_dataset is not empty
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE) if len(val_dataset) > 0 else None


# --- 5. Load Model, Optimizer, and Scheduler ---
if video_model_trained:
    # Ensure the save directory exists before saving
    os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
    
    print(f"Loading model from: {VIDEO_MODEL_NAME}")
    video_model = VideoMAEForVideoClassification.from_pretrained(
        VIDEO_MODEL_NAME, 
        num_labels=NUM_LABELS,
        ignore_mismatched_sizes=True # Important: Drops the old classification head
    ).to(device)

    optimizer = AdamW(video_model.parameters(), lr=LEARNING_RATE)
    total_steps = len(train_loader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0, 
        num_training_steps=total_steps
    )
else:
    print("Skipping video model setup as no training data was found.")

# --- 6. The Training & Validation Loop ---
if video_model_trained:
    print("--- Starting Training --- ")
    best_val_accuracy = 0.0

    for epoch in range(EPOCHS):
        print(f"\n--- Epoch {epoch + 1} / {EPOCHS} ---")
        
        # --- Training Phase ---
        video_model.train() 
        total_train_loss = 0
        train_progress_bar = tqdm(train_loader, desc="Training", leave=False)

        for batch in train_progress_bar:
            # VideoMAE expects pixel_values shape (batch_size, num_frames, num_channels, height, width)
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)

            video_model.zero_grad()

            outputs = video_model(
                pixel_values=pixel_values,
                labels=labels
            )

            loss = outputs.loss
            if torch.isnan(loss):
                 print("Warning: NaN loss detected during video training. Skipping batch.")
                 continue
            total_train_loss += loss.item()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(video_model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            train_progress_bar.set_postfix({'loss': loss.item()})
            
        avg_train_loss = total_train_loss / len(train_loader) if len(train_loader) > 0 else 0
        print(f"Average Training Loss: {avg_train_loss:.4f}")

        # --- Validation Phase ---
        if val_loader is None:
            print("Validation set is empty. Skipping validation.")
            # Save model from this epoch if desired when no validation is possible
            # video_model.save_pretrained(MODEL_SAVE_PATH)
            # video_image_processor.save_pretrained(MODEL_SAVE_PATH)
            continue
            
        video_model.eval()
        total_val_loss = 0
        all_preds = []
        all_labels = []
        
        val_progress_bar = tqdm(val_loader, desc="Validation", leave=False)

        with torch.no_grad():
            for batch in val_progress_bar:
                pixel_values = batch['pixel_values'].to(device)
                labels = batch['labels'].to(device)

                outputs = video_model(
                    pixel_values=pixel_values,
                    labels=labels
                )

                loss = outputs.loss
                logits = outputs.logits
                
                if loss is not None and not torch.isnan(loss):
                    total_val_loss += loss.item()

                preds = torch.argmax(logits, dim=1).cpu().numpy()
                labels_cpu = labels.cpu().numpy()
                
                all_preds.extend(preds)
                all_labels.extend(labels_cpu)
                
        if not all_labels: # Check if validation produced results
            print("No valid samples processed during validation.")
            continue
            
        avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
        val_accuracy = accuracy_score(all_labels, all_preds)

        print(f"Validation Loss: {avg_val_loss:.4f}")
        print(f"Validation Accuracy: {val_accuracy:.4f}")

        # --- Save the Best Model ---
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            print("New best model! Saving...")
            video_model.save_pretrained(MODEL_SAVE_PATH)
            video_image_processor.save_pretrained(MODEL_SAVE_PATH)

    print("\n--- Video Model Training Complete ---")
    if best_val_accuracy > 0:
        print(f"Best Validation Accuracy: {best_val_accuracy:.4f}")
    elif val_loader is None:
        print("Model may have been saved from the last epoch as validation was skipped.")
    else:
        print("Validation accuracy did not improve.")
    print(f"Model and processor potentially saved to: {MODEL_SAVE_PATH}")
else:
    print("Video Model training was skipped.")

---

## Step 5: Phase 2 Training (Training the Fusion Model)

Now that we have our specialists, we *discard* their temporary classification heads. We use the *output embeddings* (the feature vectors) from these models as input for our new, simple **Fusion Model**. This model's job is to learn how to combine the signals from text, image, and video to make the best final prediction.

### 5.A: Create Embedding Dataset

In [None]:
# --- Step 5: Phase 2 Training (The Fusion Model) ---

import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
    DistilBertTokenizer, 
    DistilBertModel,
    CLIPImageProcessor,
    CLIPVisionModel,
    VideoMAEImageProcessor,
    VideoMAEModel
)
from sklearn.metrics import accuracy_score
import os
from tqdm.auto import tqdm
from PIL import Image, UnidentifiedImageError
import cv2
import numpy as np

print("--- Starting Phase 2: Fusion Model Training ---")

# --- 0. Configuration & Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define model paths
TEXT_MODEL_PATH = "./models/text_specialist"
IMAGE_MODEL_PATH = "./models/image_specialist.pth"
VIDEO_MODEL_PATH = "./models/video_specialist"
FUSION_MODEL_SAVE_PATH = "./models/fusion_model.pth"

# Define embedding dimensions (based on the 'base' models we used)
TEXT_EMBED_DIM = 768  # (from DistilBERT-base)
IMAGE_EMBED_DIM = 768  # (from CLIP-ViT-base)
VIDEO_EMBED_DIM = 768  # (from VideoMAE-base)
COMBINED_EMBED_DIM = TEXT_EMBED_DIM + IMAGE_EMBED_DIM + VIDEO_EMBED_DIM

print(f"Combined embedding dimension will be: {COMBINED_EMBED_DIM}")

# Fusion model training config
BATCH_SIZE = 32
EPOCHS = 50 # We can train for more epochs, it's a very small model
LEARNING_RATE = 1e-4 # A slightly higher LR is fine for this MLP

# --- 1. Re-define CustomCLIPModel (to load weights) ---
# This MUST be the same class definition as in Step 4.B
class CustomCLIPModel(nn.Module):
    def __init__(self, model_name, num_labels):
        super(CustomCLIPModel, self).__init__()
        self.clip_vision_model = CLIPVisionModel.from_pretrained(model_name)
        embedding_dim = self.clip_vision_model.config.hidden_size
        self.classifier = nn.Linear(embedding_dim, num_labels)
        
    def forward(self, pixel_values, labels=None):
        outputs = self.clip_vision_model(pixel_values=pixel_values)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, NUM_LABELS), labels.view(-1))
            
        return {"loss": loss, "logits": logits, "embedding": pooled_output}


# --- 2. Load Data and Label Maps ---
try:
    if 'train_df' not in locals() or train_df.empty:
         train_df = pd.read_csv(TRAIN_SET_CSV)
    if 'val_df' not in locals() or val_df.empty:
         val_df = pd.read_csv(VALIDATION_SET_CSV)
    if 'label_to_id' not in locals() or not label_to_id:
        raise ValueError("Label map not defined.")
except FileNotFoundError:
    print(f"Error: Could not find {TRAIN_SET_CSV} or {VALIDATION_SET_CSV}")
    raise
except ValueError as e:
     print(f"Error: {e}")
     raise

print(f"Loaded {len(train_df)} train posts and {len(val_df)} validation posts.")
print(f"Using {NUM_LABELS} labels: {label_to_id}")


# --- 3. Load All Specialist Models and Processors (Base models for embeddings) ---

print("Loading specialist models (base versions for embeddings)...")

# --- Text Specialist ---
text_specialist_exists = os.path.exists(TEXT_MODEL_PATH)
if text_specialist_exists:
    try:
        text_tokenizer = DistilBertTokenizer.from_pretrained(TEXT_MODEL_PATH)
        text_model = DistilBertModel.from_pretrained(TEXT_MODEL_PATH).to(device)
        text_model.eval()
        print("Text specialist base model loaded.")
    except Exception as e:
        print(f"Warning: Could not load text specialist from {TEXT_MODEL_PATH}: {e}. Text embeddings will be zeros.")
        text_model = None
        text_tokenizer = None
else:
    print(f"Warning: Text specialist model not found at {TEXT_MODEL_PATH}. Text embeddings will be zeros.")
    text_model = None
    text_tokenizer = None

# --- Image Specialist ---
image_specialist_exists = os.path.exists(IMAGE_MODEL_PATH)
if image_specialist_exists:
    try:
        image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
        # Load the state dict into the custom model, then extract the base vision model
        temp_image_model = CustomCLIPModel("openai/clip-vit-base-patch32", NUM_LABELS) 
        temp_image_model.load_state_dict(torch.load(IMAGE_MODEL_PATH, map_location=device))
        image_model = temp_image_model.clip_vision_model.to(device)
        image_model.eval()
        print("Image specialist base model loaded.")
    except Exception as e:
        print(f"Warning: Could not load image specialist state_dict from {IMAGE_MODEL_PATH}: {e}. Image embeddings will be zeros.")
        image_model = None
        image_processor = None
else:
    print(f"Warning: Image specialist model not found at {IMAGE_MODEL_PATH}. Image embeddings will be zeros.")
    image_model = None
    image_processor = None

# --- Video Specialist ---
video_specialist_exists = os.path.exists(VIDEO_MODEL_PATH)
if video_specialist_exists:
    try:
        video_processor = VideoMAEImageProcessor.from_pretrained(VIDEO_MODEL_PATH)
        # Load the base model directly
        video_model = VideoMAEModel.from_pretrained(VIDEO_MODEL_PATH).to(device)
        video_model.eval()
        print("Video specialist base model loaded.")
    except Exception as e:
        print(f"Warning: Could not load video specialist from {VIDEO_MODEL_PATH}: {e}. Video embeddings will be zeros.")
        video_model = None
        video_processor = None
else:
    print(f"Warning: Video specialist model not found at {VIDEO_MODEL_PATH}. Video embeddings will be zeros.")
    video_model = None
    video_processor = None


# --- 4. Helper Functions for Embedding Extraction ---

def _sample_frames(video_path, num_frames=16):
    frames = []
    cap = None
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened(): raise IOError(f"Cannot open video: {video_path}")
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total_frames <= 0: raise IOError(f"Video empty: {video_path}")
        indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int)
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(Image.fromarray(frame_rgb))
        if not frames: raise IOError(f"Could not read frames: {video_path}")
        while len(frames) < num_frames:
             if frames: frames.append(frames[-1])
             else: raise IOError("No frames read, cannot duplicate.")
        return frames[:num_frames]
    except Exception as e:
        # print(f"Error sampling {video_path}: {e}. Using blank frames.") # Reduce verbosity
        h, w = 224, 224 
        return [Image.new("RGB", (w, h), (0, 0, 0))] * num_frames
    finally:
        if cap is not None: cap.release()

def get_embeddings(post_row):
    with torch.no_grad():
        # Text
        if text_model is not None and text_tokenizer is not None and pd.notna(post_row['title']):
            text_str = post_row.get('text', '')
            text = str(post_row['title']) + " [SEP] " + (str(text_str) if pd.notna(text_str) else '')
            try:
                encoding = text_tokenizer(text, return_tensors='pt', max_length=128, truncation=True, padding='max_length').to(device)
                outputs = text_model(**encoding)
                text_emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu()
            except Exception as e:
                 # print(f"Error getting text embedding for {post_row['id']}: {e}")
                 text_emb = torch.zeros(TEXT_EMBED_DIM)
        else:
            text_emb = torch.zeros(TEXT_EMBED_DIM)

        # Image
        if (image_model is not None and image_processor is not None and
            post_row['post_hint'] == 'image' and pd.notna(post_row['local_media_path']) and 
            os.path.exists(post_row['local_media_path'])):
            try:
                image = Image.open(post_row['local_media_path']).convert("RGB")
                processed_image = image_processor(images=image, return_tensors="pt").to(device)
                outputs = image_model(**processed_image)
                image_emb = outputs.pooler_output.squeeze().cpu()
            except (UnidentifiedImageError, FileNotFoundError, OSError, Exception) as e:
                # print(f"Error processing image {post_row['local_media_path']}: {e}")
                image_emb = torch.zeros(IMAGE_EMBED_DIM)
        else:
            image_emb = torch.zeros(IMAGE_EMBED_DIM)

        # Video
        if (video_model is not None and video_processor is not None and
            post_row['post_hint'] == 'hosted:video' and pd.notna(post_row['local_media_path']) and 
            os.path.exists(post_row['local_media_path'])):
            try:
                frames = _sample_frames(post_row['local_media_path'])
                processed_video = video_processor(images=frames, return_tensors="pt").to(device)
                outputs = video_model(**processed_video)
                video_emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu() # Avg pool frames
            except Exception as e:
                 # print(f"Error processing video {post_row['local_media_path']}: {e}")
                 video_emb = torch.zeros(VIDEO_EMBED_DIM)
        else:
            video_emb = torch.zeros(VIDEO_EMBED_DIM)
            
        # Concat
        # Ensure all are tensors before concatenating
        if not isinstance(text_emb, torch.Tensor): text_emb = torch.zeros(TEXT_EMBED_DIM)
        if not isinstance(image_emb, torch.Tensor): image_emb = torch.zeros(IMAGE_EMBED_DIM)
        if not isinstance(video_emb, torch.Tensor): video_emb = torch.zeros(VIDEO_EMBED_DIM)
        
        # Detach tensors if they were accidentally left on GPU (should be CPU from above)
        combined_emb = torch.cat((text_emb.cpu().detach(), image_emb.cpu().detach(), video_emb.cpu().detach()))
        
        # Get label ID, handle missing labels
        label_id = label_to_id.get(post_row['post_sentiment'], -1) # Use -1 if sentiment not in map
        
        return combined_emb, label_id


# --- 5.A: Create Embedding Dataset ---

print("\n--- Starting Step 5.A: Creating Embedding Datasets ---")

def create_dataset(df, filename):
    embeddings = []
    labels = []
    valid_indices = [] # Keep track of rows that yield valid data
    
    print(f"Processing {filename} with {len(df)} rows...")
    for index, row in tqdm(df.iterrows(), total=len(df), desc=f"Creating {filename}"):
        emb, label = get_embeddings(row)
        # Only include if the label is valid
        if label != -1:
            # Basic check for embedding validity (e.g., not all zeros if expected)
            # This check is optional and might be too strict depending on data
            # if emb.count_nonzero() > 0:
                embeddings.append(emb)
                labels.append(label)
                valid_indices.append(index)
            # else:
            #    print(f"Warning: Skipping row {index} due to potentially zero embedding.")
        # else:
            # print(f"Warning: Skipping row {index} due to invalid label '{row['post_sentiment']}'.")
            
    if not embeddings: # Check if any valid data was processed
        print(f"Error: No valid embeddings or labels generated for {filename}. Check source data and specialist models.")
        # Return empty tensors or raise error
        return torch.empty((0, COMBINED_EMBED_DIM)), torch.empty((0), dtype=torch.long)
        
    # Stack valid tensors
    embeddings = torch.stack(embeddings)
    labels = torch.tensor(labels, dtype=torch.long)
    
    # Save to disk
    save_path = os.path.join(DATA_DIR, filename)
    torch.save((embeddings, labels), save_path)
    print(f"Saved dataset with {len(labels)} valid entries to {save_path}")
    return embeddings, labels

train_embeddings, train_labels = create_dataset(train_df, "train_embeddings.pt")
val_embeddings, val_labels = create_dataset(val_df, "val_embeddings.pt")

if train_embeddings.nelement() == 0: # Check if tensor is empty
    raise ValueError("Training embedding dataset is empty. Cannot proceed.")

print("Embedding datasets creation process finished.")


# --- 5.B: Train Fusion Model ---

print("\n--- Starting Step 5.B: Training Fusion Model ---")

# 1. Define simple Dataset for embeddings
class EmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

train_emb_dataset = EmbeddingDataset(train_embeddings, train_labels)
train_emb_loader = DataLoader(train_emb_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Only create validation loader if embeddings exist
if val_embeddings.nelement() > 0:
    val_emb_dataset = EmbeddingDataset(val_embeddings, val_labels)
    val_emb_loader = DataLoader(val_emb_dataset, batch_size=BATCH_SIZE)
    print(f"Fusion Validation dataset size: {len(val_emb_dataset)}")
else:
    val_emb_loader = None
    print("Warning: Validation embedding dataset is empty. Skipping validation during fusion training.")

print(f"Fusion Training dataset size: {len(train_emb_dataset)}")

# 2. Define the Fusion Model architecture
class FusionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer_1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.layer_2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.layer_1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.layer_2(x)
        return x

fusion_model = FusionModel(
    input_dim=COMBINED_EMBED_DIM, 
    hidden_dim=512,  # A reasonable hidden layer size
    output_dim=NUM_LABELS
).to(device)

# 3. Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(fusion_model.parameters(), lr=LEARNING_RATE)

# 4. Training Loop
print("Training fusion model...")
best_val_accuracy = 0.0

for epoch in range(EPOCHS):
    
    # --- Training ---
    fusion_model.train()
    total_train_loss = 0
    for embs, labels in train_emb_loader:
        embs = embs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = fusion_model(embs)
        loss = criterion(outputs, labels)
        
        if torch.isnan(loss):
            print(f"Warning: NaN loss detected during fusion training epoch {epoch+1}. Skipping batch.")
            continue
            
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_emb_loader) if len(train_emb_loader) > 0 else 0

    # --- Validation ---
    if val_emb_loader is None:
        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Validation Skipped")
        # Save the model from the last epoch if validation is skipped
        torch.save(fusion_model.state_dict(), FUSION_MODEL_SAVE_PATH)
        continue # Skip validation logic

    fusion_model.eval()
    total_val_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for embs, labels in val_emb_loader:
            embs = embs.to(device)
            labels = labels.to(device)
            
            outputs = fusion_model(embs)
            loss = criterion(outputs, labels)
            
            if loss is not None and not torch.isnan(loss):
                 total_val_loss += loss.item()
            
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
    if not all_labels: # Check if validation processed anything
        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | No valid validation samples.")
        continue

    avg_val_loss = total_val_loss / len(val_emb_loader) if len(val_emb_loader) > 0 else 0
    val_accuracy = accuracy_score(all_labels, all_preds)
    
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.4f}")

    # Save best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        print(f"New best model! Saving to {FUSION_MODEL_SAVE_PATH}")
        torch.save(fusion_model.state_dict(), FUSION_MODEL_SAVE_PATH)

print("\n--- Fusion Training Complete ---")
if val_emb_loader is not None and best_val_accuracy > 0:
    print(f"Best Validation Accuracy: {best_val_accuracy:.4f}")
elif val_emb_loader is None:
    print("Model saved from last epoch as validation was skipped.")
else:
    print("Validation accuracy did not improve.")
print(f"Final fusion model potentially saved to: {FUSION_MODEL_SAVE_PATH}")

---

## Step 6: Evaluation (The Final Test)

This is the moment of truth. We now use our **unseen** `test_set.csv`. We run each post in it through the *entire pipeline* (Specialists -> Fusion Model) and compare the final prediction to the true label. This gives us our final performance metrics.

In [None]:
import torch
import torch.nn as nn
import pandas as pd
from transformers import (
    DistilBertTokenizer, 
    DistilBertModel,
    CLIPImageProcessor,
    CLIPVisionModel,
    VideoMAEImageProcessor,
    VideoMAEModel
)
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm
from PIL import Image, UnidentifiedImageError
import cv2
import numpy as np

print("--- Starting Step 6: Final Evaluation on Test Set ---")

# --- 0. Configuration & Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define model paths
TEXT_MODEL_PATH = "./models/text_specialist"
IMAGE_MODEL_PATH = "./models/image_specialist.pth"
VIDEO_MODEL_PATH = "./models/video_specialist"
FUSION_MODEL_SAVE_PATH = "./models/fusion_model.pth"

# Define embedding dimensions
TEXT_EMBED_DIM = 768
IMAGE_EMBED_DIM = 768
VIDEO_EMBED_DIM = 768
COMBINED_EMBED_DIM = TEXT_EMBED_DIM + IMAGE_EMBED_DIM + VIDEO_EMBED_DIM

# --- 1. Re-define Model Classes (from Step 4 & 5) ---

# Custom CLIP Model (from 4.B)
class CustomCLIPModel(nn.Module):
    def __init__(self, model_name, num_labels):
        super(CustomCLIPModel, self).__init__()
        self.clip_vision_model = CLIPVisionModel.from_pretrained(model_name)
        embedding_dim = self.clip_vision_model.config.hidden_size
        # Classifier not strictly needed for eval, but defining for consistency
        self.classifier = nn.Linear(embedding_dim, num_labels) 
        
    # We only need the base model's forward pass for embeddings
    def forward(self, pixel_values):
        outputs = self.clip_vision_model(pixel_values=pixel_values)
        return outputs # Return the full output sequence

# Fusion Model (from 5.B)
class FusionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer_1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.layer_2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.layer_1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.layer_2(x)
        return x

# --- 2. Re-define Helper Functions (from Step 4 & 5) ---

def _sample_frames(video_path, num_frames=16):
    frames = []
    cap = None
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened(): raise IOError(f"Cannot open video: {video_path}")
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total_frames <= 0: raise IOError(f"Video empty: {video_path}")
        indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int)
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(Image.fromarray(frame_rgb))
        if not frames: raise IOError(f"Could not read frames: {video_path}")
        while len(frames) < num_frames:
             if frames: frames.append(frames[-1])
             else: raise IOError("No frames read, cannot duplicate.")
        return frames[:num_frames]
    except Exception as e:
        h, w = 224, 224 # Default size
        return [Image.new("RGB", (w, h), (0, 0, 0))] * num_frames
    finally:
        if cap is not None: cap.release()

# Updated embedding extraction for evaluation (takes loaded models as args)
def get_embeddings_for_eval(post_row, text_model, image_model, video_model, text_tokenizer, image_processor, video_processor, label_map):
    with torch.no_grad():
        # Text
        if text_model is not None and text_tokenizer is not None and pd.notna(post_row['title']):
            text_str = post_row.get('text', '')
            text = str(post_row['title']) + " [SEP] " + (str(text_str) if pd.notna(text_str) else '')
            try:
                encoding = text_tokenizer(text, return_tensors='pt', max_length=128, truncation=True, padding='max_length').to(device)
                outputs = text_model(**encoding)
                text_emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu()
            except Exception: text_emb = torch.zeros(TEXT_EMBED_DIM)
        else: text_emb = torch.zeros(TEXT_EMBED_DIM)

        # Image
        if (image_model is not None and image_processor is not None and
            post_row['post_hint'] == 'image' and pd.notna(post_row['local_media_path']) and 
            os.path.exists(post_row['local_media_path'])):
            try:
                image = Image.open(post_row['local_media_path']).convert("RGB")
                processed_image = image_processor(images=image, return_tensors="pt").to(device)
                # Pass through the base vision model only
                outputs = image_model(processed_image['pixel_values'])
                image_emb = outputs.pooler_output.squeeze().cpu()
            except Exception: image_emb = torch.zeros(IMAGE_EMBED_DIM)
        else: image_emb = torch.zeros(IMAGE_EMBED_DIM)

        # Video
        if (video_model is not None and video_processor is not None and
            post_row['post_hint'] == 'hosted:video' and pd.notna(post_row['local_media_path']) and 
            os.path.exists(post_row['local_media_path'])):
            try:
                frames = _sample_frames(post_row['local_media_path'])
                processed_video = video_processor(images=frames, return_tensors="pt").to(device)
                outputs = video_model(**processed_video)
                video_emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu()
            except Exception: video_emb = torch.zeros(VIDEO_EMBED_DIM)
        else: video_emb = torch.zeros(VIDEO_EMBED_DIM)
            
        # Concat
        if not isinstance(text_emb, torch.Tensor): text_emb = torch.zeros(TEXT_EMBED_DIM)
        if not isinstance(image_emb, torch.Tensor): image_emb = torch.zeros(IMAGE_EMBED_DIM)
        if not isinstance(video_emb, torch.Tensor): video_emb = torch.zeros(VIDEO_EMBED_DIM)
        combined_emb = torch.cat((text_emb.cpu().detach(), image_emb.cpu().detach(), video_emb.cpu().detach()))
        
        # Get true label ID
        true_label_id = label_map.get(post_row['post_sentiment'], -1)
        
        return combined_emb, true_label_id

# --- 3. Load Data and Label Maps ---
try:
    # Load train_df only to get the label map consistently
    if 'label_to_id' not in locals() or not label_to_id: 
         train_df_labels = pd.read_csv(TRAIN_SET_CSV)
         SENTIMENT_LABELS_EVAL = sorted(train_df_labels['post_sentiment'].dropna().unique())
         label_to_id = {label: i for i, label in enumerate(SENTIMENT_LABELS_EVAL)}
         id_to_label = {i: label for i, label in enumerate(SENTIMENT_LABELS_EVAL)}
         NUM_LABELS = len(SENTIMENT_LABELS_EVAL)
         print("(Re)loaded label map for evaluation.")
    else:
         SENTIMENT_LABELS_EVAL = list(label_to_id.keys())
         print("Using existing label map for evaluation.")

    test_df = pd.read_csv(TEST_SET_CSV)
    # Drop rows in test set with missing sentiment labels
    test_df.dropna(subset=['post_sentiment'], inplace=True)
    # Filter test set for labels that are actually in our map
    test_df = test_df[test_df['post_sentiment'].isin(label_to_id.keys())]
    
except FileNotFoundError:
    print(f"Error: Could not find {TRAIN_SET_CSV} or {TEST_SET_CSV}")
    raise
except Exception as e:
     print(f"Error loading data: {e}")
     raise

if test_df.empty:
     raise ValueError("Test dataset is empty after filtering for valid labels. Cannot evaluate.")

print(f"Loaded {len(test_df)} valid test posts.")
print(f"Using {NUM_LABELS} labels for evaluation: {label_to_id}")

# --- 4. Load All Trained Models and Processors ---
print("Loading all models and processors...")
models_loaded = {'text': False, 'image': False, 'video': False, 'fusion': False}

# Text
try:
    text_tokenizer_eval = DistilBertTokenizer.from_pretrained(TEXT_MODEL_PATH)
    text_specialist_eval = DistilBertModel.from_pretrained(TEXT_MODEL_PATH).to(device)
    text_specialist_eval.eval()
    models_loaded['text'] = True
    print("Text model loaded.")
except Exception as e: 
    print(f"Warning: Could not load text model: {e}. Text embeddings will be zeros.")
    text_tokenizer_eval = None
    text_specialist_eval = None

# Image
try:
    image_processor_eval = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
    # Need to instantiate the custom class first
    temp_image_model_eval = CustomCLIPModel("openai/clip-vit-base-patch32", NUM_LABELS) 
    # Load the saved state dict
    temp_image_model_eval.load_state_dict(torch.load(IMAGE_MODEL_PATH, map_location=device))
    # Extract the base vision model
    image_specialist_eval = temp_image_model_eval.clip_vision_model.to(device)
    image_specialist_eval.eval()
    models_loaded['image'] = True
    print("Image model loaded.")
except Exception as e: 
    print(f"Warning: Could not load image model: {e}. Image embeddings will be zeros.")
    image_processor_eval = None
    image_specialist_eval = None

# Video
try:
    video_processor_eval = VideoMAEImageProcessor.from_pretrained(VIDEO_MODEL_PATH)
    video_specialist_eval = VideoMAEModel.from_pretrained(VIDEO_MODEL_PATH).to(device)
    video_specialist_eval.eval()
    models_loaded['video'] = True
    print("Video model loaded.")
except Exception as e: 
    print(f"Warning: Could not load video model: {e}. Video embeddings will be zeros.")
    video_processor_eval = None
    video_specialist_eval = None

# Fusion
try:
    fusion_model_eval = FusionModel(
        input_dim=COMBINED_EMBED_DIM, 
        hidden_dim=512, 
        output_dim=NUM_LABELS
    ).to(device)
    fusion_model_eval.load_state_dict(torch.load(FUSION_MODEL_SAVE_PATH, map_location=device))
    fusion_model_eval.eval()
    models_loaded['fusion'] = True
    print("Fusion model loaded.")
except Exception as e:
     print(f"CRITICAL ERROR: Could not load fusion model: {e}. CANNOT EVALUATE.")
     # Stop execution if fusion model fails
     raise SystemExit("Fusion model loading failed.") 

if not any(models_loaded.values()):
     raise SystemExit("CRITICAL ERROR: No models loaded successfully. Cannot evaluate.")
elif not models_loaded['fusion']:
     raise SystemExit("CRITICAL ERROR: Fusion model did not load. Cannot evaluate.")
     
print("Model loading process finished.")

# --- 5. Run Evaluation Loop ---
all_predictions = []
all_true_labels = []

print("\nStarting evaluation on test set...")
with torch.no_grad():
    for index, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Evaluating Test Set"):
        # Get embeddings using the loaded models for evaluation
        combined_emb, true_label_id = get_embeddings_for_eval(
            row, 
            text_specialist_eval, 
            image_specialist_eval, 
            video_specialist_eval, 
            text_tokenizer_eval, 
            image_processor_eval, 
            video_processor_eval,
            label_to_id # Pass the correct label map
        )
        
        # Ensure the true label is valid before adding
        if true_label_id != -1:
            all_true_labels.append(true_label_id)
            
            # Pass embedding through the fusion model
            combined_emb = combined_emb.unsqueeze(0).to(device) # Add batch dim and move to device
            logits = fusion_model_eval(combined_emb)
            prediction_id = torch.argmax(logits, dim=1).item()
            all_predictions.append(prediction_id)
        # else: # Optional: Log skipped rows due to label issues
            # print(f"Skipping row {index} due to invalid true label during evaluation.")

# --- 6. Calculate and Print Metrics ---

if not all_true_labels or not all_predictions:
    print("\nError: No valid predictions or labels were generated. Cannot calculate metrics.")
    print(f"Length of true labels: {len(all_true_labels)}")
    print(f"Length of predictions: {len(all_predictions)}")
else:
    # Convert IDs back to labels for a readable report
    # Use the evaluation label set derived from the training data
    true_labels_names = [id_to_label[idx] for idx in all_true_labels]
    predictions_names = [id_to_label[idx] for idx in all_predictions]
    
    print("\n--- Final Model Performance on Test Set ---")
    accuracy = accuracy_score(true_labels_names, predictions_names)
    print(f"Overall Accuracy: {accuracy * 100:.2f}%")
    
    print("\nClassification Report:")
    # Ensure labels parameter uses the correct set of labels present in the data
    report = classification_report(true_labels_names, predictions_names, labels=SENTIMENT_LABELS_EVAL, zero_division=0)
    print(report)
    
    print("\nConfusion Matrix:")
    try:
        cm = confusion_matrix(true_labels_names, predictions_names, labels=SENTIMENT_LABELS_EVAL)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', xticklabels=SENTIMENT_LABELS_EVAL, yticklabels=SENTIMENT_LABELS_EVAL, cmap='Blues')
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.title('Confusion Matrix on Test Set')
        plt.show()
    except Exception as e:
        print(f"Error generating confusion matrix plot: {e}")
        # Print the matrix numerically as a fallback
        print(pd.DataFrame(cm, index=SENTIMENT_LABELS_EVAL, columns=SENTIMENT_LABELS_EVAL))

---

## Step 7: Prediction on New, Unseen Data

This is the final step: creating a single function that can take a brand new, raw Reddit post URL, run it through the entire pipeline, and return the predicted sentiment.

In [None]:
# Ensure PRAW is initialized (from Step 1 or re-init here)
if 'reddit' not in locals():
    try:
        reddit = praw.Reddit(
            client_id=REDDIT_CLIENT_ID,
            client_secret=REDDIT_CLIENT_SECRET,
            user_agent=REDDIT_USER_AGENT,
        )
        print("PRAW re-initialized for prediction.")
    except Exception as e:
        print(f"Error initializing PRAW for prediction: {e}")
        reddit = None # Set to None if init fails

# Ensure models are loaded and in eval mode (should be from Step 6)
# Add checks here to ensure all necessary models (_eval versions) are loaded
required_models = {
    'text': 'text_specialist_eval' in locals() and text_specialist_eval is not None,
    'image': 'image_specialist_eval' in locals() and image_specialist_eval is not None,
    'video': 'video_specialist_eval' in locals() and video_specialist_eval is not None,
    'fusion': 'fusion_model_eval' in locals() and fusion_model_eval is not None
}

required_processors = {
    'text': 'text_tokenizer_eval' in locals() and text_tokenizer_eval is not None,
    'image': 'image_processor_eval' in locals() and image_processor_eval is not None,
    'video': 'video_processor_eval' in locals() and video_processor_eval is not None,
}

models_ready = all(required_models.values()) and all(required_processors.values())

if not models_ready:
    print("\n--- WARNING: Not all models/processors loaded successfully for prediction. --- ")
    print("Missing Text Components:", not (required_models['text'] and required_processors['text']))
    print("Missing Image Components:", not (required_models['image'] and required_processors['image']))
    print("Missing Video Components:", not (required_models['video'] and required_processors['video']))
    print("Missing Fusion Model:", not required_models['fusion'])
    print("Prediction function might fail or produce inaccurate results.")
    # Depending on requirements, you might want to raise an error here
    # raise SystemExit("Cannot proceed with prediction - models not ready.")

# Ensure label map exists
if 'id_to_label' not in locals() or not id_to_label:
     print("Error: id_to_label map not found. Cannot decode predictions.")
     # Attempt to reload from Step 6 variables if they exist
     if 'SENTIMENT_LABELS_EVAL' in locals():
          id_to_label = {i: label for i, label in enumerate(SENTIMENT_LABELS_EVAL)}
          print("Reloaded id_to_label map.")
     else:
          raise ValueError("Label map missing and cannot be reloaded.")

# Use the PRAW download function defined earlier
if 'download_media_praw' not in locals():
     raise NameError("Function 'download_media_praw' not defined. Run the Step 1 PRAW cells.")

# Use the evaluation embedding function defined in Step 6
if 'get_embeddings_for_eval' not in locals():
     raise NameError("Function 'get_embeddings_for_eval' not defined. Run the Step 6 cell.")

def predict_sentiment_for_new_post(post_url):
    if not models_ready or reddit is None:
        return "Error: Prediction environment not fully initialized (models/PRAW missing)."
        
    print(f"\nAnalyzing new post: {post_url}")
    local_media_path = None # Initialize path
    
    try:
        # 1. Scrape the single post using PRAW
        print("Fetching post data via PRAW...")
        try:
            # Extract submission ID from URL for robust fetching
            submission_id = praw.models.Submission.id_from_url(post_url)
            post = reddit.submission(id=submission_id)
            post.load() # Eagerly load attributes
        except Exception as e:
            return f"Error: Could not fetch post data from PRAW. {e}"

        # 2. Download its media (use the PRAW download function)
        print("Attempting to download media...")
        local_media_path = download_media_praw(post)
        if local_media_path:
             print(f"Media downloaded to: {local_media_path}")
        else:
             print("No downloadable media found or download failed.")

        # 3. Create a dictionary (like a DataFrame row) for the post
        post_data = {
            'id': post.id, 
            'title': post.title, 
            'text': post.selftext, 
            'post_hint': getattr(post, 'post_hint', 'text_only'), 
            'local_media_path': local_media_path, 
            'post_sentiment': 'UNKNOWN' # Dummy value, not used by get_embeddings_for_eval
        }

        # 4. Get embeddings using the evaluation function and loaded models
        print("Generating embeddings...")
        with torch.no_grad():
            # Pass the loaded models explicitly to the function
            combined_emb, _ = get_embeddings_for_eval(
                post_data, 
                text_specialist_eval, 
                image_specialist_eval, 
                video_specialist_eval, 
                text_tokenizer_eval, 
                image_processor_eval, 
                video_processor_eval,
                label_to_id # Pass label_to_id (needed internally by func)
            )
            combined_emb = combined_emb.unsqueeze(0).to(device) # Add batch dim & move

        # 5. Get final prediction from fusion model
        print("Making final prediction...")
        with torch.no_grad():
            logits = fusion_model_eval(combined_emb)
            prediction_id = torch.argmax(logits, dim=1).item()
            predicted_sentiment = id_to_label.get(prediction_id, "Unknown Label ID")
        
        print(f"Prediction complete.")
        return predicted_sentiment
        
    except Exception as e:
        # General error catching during the prediction process
        print(f"An unexpected error occurred during prediction: {e}")
        import traceback
        traceback.print_exc()
        return f"Error: Prediction failed. {e}"

    finally:
        # 6. Clean up the downloaded media file
        if local_media_path and os.path.exists(local_media_path):
            try:
                print(f"Cleaning up downloaded media: {local_media_path}")
                os.remove(local_media_path)
            except OSError as e:
                 print(f"Error removing temporary file {local_media_path}: {e}")



In [None]:
# --- Example Usage ---

# Ensure the models_ready check passed before calling
if models_ready and reddit is not None:
    # A post that is likely 'Joy' or 'Achievement'
    # Make sure these are valid, accessible Reddit post URLs
    test_url_1 = "[https://www.reddit.com/r/Brawlstars/comments/1dbv10k/after_all_this_time_i_finally_got_one/](https://www.reddit.com/r/Brawlstars/comments/1dbv10k/after_all_this_time_i_finally_got_one/)"

    # A post that is likely 'Anger' or 'Rant'
    test_url_2 = "[https://www.reddit.com/r/Brawlstars/comments/1ddc6n6/im_so_sick_of_this_game_breaking_bug/](https://www.reddit.com/r/Brawlstars/comments/1ddc6n6/im_so_sick_of_this_game_breaking_bug/)"
    
    # A text-only discussion post
    test_url_3 = "[https://www.reddit.com/r/Brawlstars/comments/1e5w8us/is_it_just_me_or_is_ranked_really_easy_rn/](https://www.reddit.com/r/Brawlstars/comments/1e5w8us/is_it_just_me_or_is_ranked_really_easy_rn/)"

    sentiment1 = predict_sentiment_for_new_post(test_url_1)
    print(f"\n>>> Prediction for post 1 ({test_url_1}): {sentiment1}\n")

    sentiment2 = predict_sentiment_for_new_post(test_url_2)
    print(f"\n>>> Prediction for post 2 ({test_url_2}): {sentiment2}\n")

    sentiment3 = predict_sentiment_for_new_post(test_url_3)
    print(f"\n>>> Prediction for post 3 ({test_url_3}): {sentiment3}\n")
    
else:
    print("\nSkipping example usage because prediction environment is not ready.")

print("\nNotebook execution complete (up to Step 7 example). Review results and potential warnings.")