# Import Required Libraries

In [None]:
%pip install rdflib
%pip install pymongo
%pip install google-generativeai
%pip install timm einops

In [6]:
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
import google.generativeai as genai
from rdflib import Graph, Namespace
from typing import List
from tqdm import tqdm
from pymongo import MongoClient
from PIL import Image
import requests
from io import BytesIO

# Define Constants

In [7]:
# Constants
RDF_FILE_PATH = "metanettypes.ttl"  # RDF Turtle file path. Downloaded from https://github.com/alammehwish/AmnesticForgery/blob/master/metanettypes.ttl
API_KEY = "AIzaSyDykQxeDJ0m7t8GTbxkWEr4SXfWZA2LVCE"  # Replace with a valid API key

# Initialize the Model and Processor

We will initialize the model and processor from the `transformers` library. The model will be used to generate text based on the image, and the processor will handle inputs and outputs.

In [None]:
# Function to initialize the model and processor
def initialize_model_and_processor(device: str) -> tuple:
    """
    Initialize the model and processor based on device availability.
    
    Args:
        device (str): The device to run the model on ("cuda:0" or "cpu").
    
    Returns:
        tuple: (model, processor) - The initialized model and processor.
    """
    torch_dtype = torch.float16 if device == "cuda:0" else torch.float32
    model_name = "microsoft/Florence-2-large"
    
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype, trust_remote_code=True).to(device)
    processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
    
    return model, processor

# Initialize device and model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model, processor = initialize_model_and_processor(device)

# Generate Meme Description

We will use the model and processor to generate a description for the uploaded image, based on a prompt. This will allow us to describe the contents of the meme.

In [9]:
# Function to generate description based on image and prompt
def generate_description(model, processor, image: Image, prompt: str, device: str) -> str:
    """
    Generate a description based on the image and provided prompt.
    
    Args:
        model (AutoModelForCausalLM): The pre-trained model.
        processor (AutoProcessor): The processor for handling inputs and outputs.
        image (Image): The input image.
        prompt (str): The prompt for generating a description.
        device (str): The device to run the model on ("cuda:0" or "cpu").
    
    Returns:
        str: The generated description.
    """
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
    
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=4096,
        num_beams=3,
        do_sample=False
    )
    
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
    
    return parsed_answer[prompt]

# Configure Google Generative AI

In [10]:
# Function to configure Google Generative AI API
def configure_google_ai(api_key: str) -> None:
    """
    Configure the Google Generative AI API with the provided API key.
    
    Args:
        api_key (str): The API key for Google Generative AI.
    """
    genai.configure(api_key=api_key)

# Configure Google Generative AI
configure_google_ai(API_KEY)

# Load RDF Frames

Now, we will load the available meme frames from an RDF (Turtle) file. This data contains predefined frames for memes that we can match against the generated description.

In [11]:
# Function to load frames from an RDF file
def load_frames(rdf_file: str) -> List[str]:
    """
    Load frames from an RDF Turtle file.
    
    Args:
        rdf_file (str): Path to the RDF Turtle file.
    
    Returns:
        list: List of frame names extracted from the RDF file.
    """
    g = Graph()
    g.parse(rdf_file, format="turtle")
    
    framedata = Namespace("https://w3id.org/framester/metanet/frames/")
    metanet = Namespace("https://w3id.org/framester/metanet/schema/")
    rdf = Namespace("http://www.w3.org/1999/02/22-rdf-syntax-ns#")
    
    frames = [str(frame).split("/")[-1] for frame in g.subjects(predicate=rdf.type, object=metanet.Frame)]
    return frames

# Load frames from the RDF file
frames = load_frames(RDF_FILE_PATH)

# Query Generative AI for Detailed description, Fitted frames and its Justification

Using the description generated from the image and the available frames, we will query the generative model to identify which frames best fit the meme description.

In [12]:
def query_model_with_explanation_and_fitted_frames(meme_description: str, frames: List[str], meme_context: str = None, meme_lang: str = "EN") -> dict:
    """
    Query the generative model for an explanation of the meme, the frames that best fit the explanation,
    and why those frames were chosen.

    Args:
        meme_description (str): The meme description.
        frames (List[str]): The list of available frames.
        meme_context (str): The meme context (optional).
        meme_lang (str): The language of the meme (default is "EN").

    Returns:
        dict: A dictionary with keys 'explanation', 'fitted_frames', and 'reasoning'.
    """
    
    frames_text = ", ".join(frames)
    
    if meme_context is None:
        prompt = (
            f"Here is a description for a specific meme: '{meme_description}'\n\n"
            f"The following frames are available: {frames_text}.\n"
            "Based on the previous description, provide a detailed explanation of the meme.\n"
            "On the next line, explicitly write: 'Fitted Frames:' followed by the available frames in which this meme fits, separated by commas.\n"
            "Then, for EACH fitted frame, provide a separate explanation of WHY that specific frame was chosen, using the following format:\n"
            "- Frame Name: Explanation for this frame\n"
        )
    else:
        prompt = (
            f"Here is the context for a specific meme: '{meme_context}'\n\n"
            f"Here is the meme description: '{meme_description}'\n\n"
            f"The following frames are available: {frames_text}.\n"
            "Based on the previous context and description, provide a detailed explanation of the meme.\n"
            "On the next line, explicitly write: 'Fitted Frames:' followed by the available frames in which this meme fits, separated by commas.\n"
            "Then, for EACH fitted frame, provide a separate explanation of WHY that specific frame was chosen, using the following format:\n"
            "- Frame Name: Explanation for this frame\n"
        )
    
    if meme_lang == "ES":
        prompt += (
            "\n\n"
            "Please note that the meme was written in Spanish, so the translation provided in the meme description may not always be accurate. In that case, ignore it and translate it by your own.\n"
        )
    elif meme_lang == "FR":
        prompt += (
            "\n\n"
            "Please note that the meme was written in French, so the translation provided in the meme description may not always be accurate. In that case, ignore it and translate it by your own.\n"
        )
    
    large_language_model = genai.GenerativeModel("gemini-1.5-flash")
    response = large_language_model.generate_content(
        prompt,
        generation_config=genai.types.GenerationConfig(temperature=1.2)
    )

    # Parse the response
    response_text = response.text.strip()
    explanation = ""
    fitted_frames = {}
    
    if "Fitted Frames:" in response_text:
        parts = response_text.split("Fitted Frames:")
        explanation = parts[0].strip()
        frame_explanations = parts[1].strip().split('\n')
        for frame_explanation in frame_explanations:
            if ": " in frame_explanation and frame_explanation.startswith('- '):
                frame, reason = frame_explanation[2:].split(": ", 1)
                fitted_frames[frame.strip()] = reason.strip()
                
    else:
        # If the format is incorrect, generate a follow-up prompt
        follow_up_prompt = (
            f"The response did not follow the expected format. Please reformat it as follows:\n"
            "1. Provide a detailed explanation of the meme.\n"
            "2. Explicitly write: 'Fitted Frames:' followed by the available frames in which this meme fits, separated by commas.\n"
            "3. For EACH fitted frame, provide an explanation using this format:\n"
            "- Frame Name: Explanation for this frame\n"
            "Here is the original prompt for reference:\n"
            f"{prompt}"
        )
        
        response = large_language_model.generate_content(
            follow_up_prompt,
            generation_config=genai.types.GenerationConfig(temperature=1.2)
        )
        
        # Parse the response
        response_text = response.text.strip()
        
        if "Fitted Frames:" in response_text:
            parts = response_text.split("Fitted Frames:")
            explanation = parts[0].strip()
            frame_explanations = parts[1].strip().split('\n')
            for frame_explanation in frame_explanations:
                if ": " in frame_explanation and frame_explanation.startswith('- '):
                    frame, reason = frame_explanation[2:].split(": ", 1)
                    fitted_frames[frame.strip()] = reason.strip()
        else:
            explanation = response_text
            fitted_frames = {}
    
    # Sometimes the LLM model generates extra frames, so we filter them out
    filtered_frames = {}
    for frame, reason in fitted_frames.items():
        if frame in frames:
            filtered_frames[frame] = reason

    return {"explanation": explanation, "fitted_frames": filtered_frames}

# Generate Match Vector

In [13]:
# Function to generate a match vector between the frames and the fitted frames
def generate_match_vector(frames: List[str], fitted_frames: List[str]) -> List[int]:
    """
    Generate a match vector to compare the frames with the fitted frames.
    
    Args:
        frames (List[str]): The list of available frames.
        fitted_frames (List[str]): The list of frames that fit the meme description.
    
    Returns:
        List[int]: A match vector indicating which frames fit.
    """
    return [1 if frame in fitted_frames else 0 for frame in frames]

# Connect to MongoDB

In [14]:
# Connect to MongoDB
connection_string = 'mongodb+srv://api_user:Eni4pojp5L6d8uoy@cluster0.1zf1w.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0'
client = MongoClient(connection_string)
db = client['memes']
imgflip_collection = db['imgflip']
kym_collection = db['kym']
reddit_spanish_collection = db['reddit-spanish-memes']
reddit_french_collection = db['reddit-french-memes']

# Main code

After having run all the above code, we can run this section to do the analysis on several images.

In [None]:
# Function to fetch images from URLs
def fetch_images(origin, collection):
    for doc in collection.find():
        if origin == "imgflip" or origin == "kym":
            image_url = doc.get('image_url')
        
        elif origin == "reddit-spanish-memes" or origin == "reddit-french-memes":
            image_url = doc.get('url')
        
        if not image_url:
            continue
        
        try:
            response = requests.get(image_url)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content))
            yield doc, image
        except Exception as e:
            #print(f"Failed to fetch image from {image_url}: {e}")
            continue

# Function to process images and add attributes
def add_attributes(origin, collection, model, processor, device):
    # Wrap the fetch_images generator with tqdm to track progress
    total_docs = collection.count_documents({})  # Get the total number of documents
    for doc, image in tqdm(fetch_images(origin, collection), desc="Assigning attributes", total=total_docs):
        try:
            # Skip documents that already have the generated attributes
            if doc.get("gen_description") and doc.get("gen_explanation") and doc.get("gen_fitted_frames"):
                continue
            
            # Generate values
            task_prompt = "<MORE_DETAILED_CAPTION>"
            meme_description = generate_description(model, processor, image, task_prompt, device)
            
            if origin == "imgflip":
                result = query_model_with_explanation_and_fitted_frames(meme_description, frames)
            elif origin == "kym":
                meme_context = doc.get("description")
                result = query_model_with_explanation_and_fitted_frames(meme_description, frames, meme_context)
            elif origin == "reddit-spanish-memes":
                result = query_model_with_explanation_and_fitted_frames(meme_description, frames, meme_lang="ES")
            elif origin == "reddit-french-memes":
                result = query_model_with_explanation_and_fitted_frames(meme_description, frames, meme_lang="FR")
            
            # Extract values
            explanation = result["explanation"]
            fitted_frames = result["fitted_frames"]
            
            # Fitted frames is a dictionary, convert it to a list of dictionaries
            # Each dictionary contains the frame name and the reasoning
            fitted_frames = [{"name": frame, "reasoning": reasoning} for frame, reasoning in fitted_frames.items()]

            # Update document in MongoDB
            collection.update_one(
                {"_id": doc["_id"]},  # Match document by ID
                {"$set": {
                    "gen_description": meme_description,
                    "gen_explanation": explanation,
                    "gen_fitted_frames": fitted_frames,
                }}
            )           
        except Exception as e:
            #print(f"Failed to process and update document with ID {doc['_id']}: {e}")
            continue

# Run the pipeline
print("Starting the meme analysis pipeline...")
print("Loaded Frames:", len(frames))
add_attributes("imgflip", imgflip_collection, model, processor, device)
add_attributes("kym", kym_collection, model, processor, device)
add_attributes("reddit-spanish-memes", reddit_spanish_collection, model, processor, device)
add_attributes("reddit-french-memes", reddit_french_collection, model, processor, device)