In [None]:
import os

os.chdir('../')
%pwd

In [None]:
from newsapi import NewsApiClient
import vertexai
from google.cloud import storage
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_google_vertexai import ChatVertexAI
import os
import random
import requests
from dotenv import load_dotenv

load_dotenv()


vertexai.init(
    project=os.getenv("GCP_PROJECT_ID"),
)
with open("data/SYSTEM_PROMPT.md", "r") as f:
    SYSTEM_PROMPT = f.read()


In [None]:
def fetch_news(query: str):
    newsapi = NewsApiClient(api_key=os.environ.get("NEWS_API_KEY"))
    response = newsapi.get_everything(qintitle=query, language="fr", sort_by="relevancy")["articles"][0]
    news = { k: v for k,v in response.items() if k in ["title", "content"]}
    return news

def fetch_json_from_gcs(bucket_name, file_name):
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(file_name)
    return blob.download_as_string()

def pick_random_meme(bucket_name, folder):
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=folder)
    files_list =  [blob.name for blob in blobs]
    id_list = [int(file.split("/")[-1].split(".")[0]) for file in files_list if file.endswith(".jpg")]
    return random.choice(id_list)

def create_meme(news, template_image, template_json):
    llm = ChatVertexAI(model="gemini-1.5-pro-001")
    messages = [
        SystemMessage(content=SYSTEM_PROMPT),
        HumanMessage(
            content=[
                {
                    "type": "text",
                    "text": news,
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": template_image,
                    },
                },
                {
                    "type": "text",
                    "text": template_json,
                },
            ]
        ),
    ]
    response = llm.invoke(messages)
    return response

def format_boxes(ai_response):
   formated_boxes = {}
   for i, box in enumerate(ai_response):
      formated_boxes[f"boxes[{i}][text]"] = box["text"]
   return formated_boxes

def call_imgflip_api(boxes, meme_id, username, password):
    payload = format_boxes(boxes)
    payload["username"] = username
    payload["password"] = password
    payload["template_id"] = meme_id
    url = "https://api.imgflip.com/caption_image"
    r = requests.post(url, data=payload)
    return r.json()

In [None]:
def process_query(query):
    bucket = os.environ.get("BUCKET_NAME")
    meme_id = pick_random_meme(bucket, "imgs")
    json_path = f"json/{meme_id}.json"
    image_path = f"gs://{bucket}/imgs/{meme_id}.jpg"

    news = str(fetch_news(query))
    template_json = fetch_json_from_gcs(bucket, json_path)

    response = eval(create_meme(news, image_path, template_json).content)

    username = os.environ.get("IMGFLIP_USERNAME")
    password = os.environ.get("IMGFLIP_PASSWORD")

    meme_url = call_imgflip_api(response, meme_id, username, password)

    return meme_url["data"]["url"]

process_query("covid")