# Utility for creating a testing dataset by downloading memes and OCRing them with Google Vision API

In [2]:
# I made a custom API for reddit meme feed that can be used to get around 2.5k newest memes from reddit
# DO NOT use any params or query
# Use only GET request method

MEME_API_URL="https://meme-feed-api.vercel.app/api/getRedditMemes"

In [3]:
import os
import requests
import json
import time
import urllib.request
import pytesseract

# create memes dir if it doesnt exist
MEME_FOLDER = "memes"
if not os.path.exists(MEME_FOLDER):
    os.makedirs(MEME_FOLDER)

# Example API response:
# [{
# title	"Every Single Time"
# author	"Abschori"
# createdAt	1663673918
# fetchedAt	1663689431254
# contentUrl	"https://i.redd.it/qkz1jrr630p91.gif"
# id	"xj6fj8"
# likes	6026
# nsfw	false
# postLink	"https://www.reddit.com/r…j6fj8/every_single_time/"
# provider	"r"
# subreddit	"dankmemes"
# },...]

# Downloads 1000 OCRable memes to meme folder
def download_memes():

    # Get memes from API
    response = requests.get(MEME_API_URL)
    memes = json.loads(response.text)

    # Filter out NSFW memes and .gifs
    memes = [
        meme
        for meme in memes
        if not meme["nsfw"] and not meme["contentUrl"].endswith(".gif")
    ]

    # sort by likes first, so we discard the bad ones
    memes.sort(key=lambda x: x["likes"], reverse=True)

    # Download the memes and discard the ones that are not text based
    for meme in memes:
        # max out at 1000 memes
        if len(os.listdir(MEME_FOLDER)) > 1000:
            return 1

        image_url = meme["contentUrl"]
        image_extension = f".{image_url.split('.')[-1]}"
        image_path = f"{MEME_FOLDER}/{meme['id']}{image_extension}"

        # check if file already exists
        if not os.path.exists(image_path):
            try:
                # download the image
                urllib.request.urlretrieve(image_url, image_path)
                memesInFolder = len(os.listdir(MEME_FOLDER))
                print(f"Downloaded to {image_path} {memesInFolder}/{1000}")
                # check if it contains text
                if not check_text(image_path):
                    print(f"No text in {image_path} -> removing")
                    os.remove(image_path)
            except:
                print(f"Failed to download {meme['id']}")

    return 1


# Checks if an image contains text
# This is used to filter out memes that are not text based, so there is less wasted API calls for Google Vision API
# Uses Tesseract, as it seems to be faster than EasyOCR
def check_text(image_path):
    text = pytesseract.image_to_string(image_path)

    if len(text) > 5:
        return True

    return False


# Check if there is meme folder
if len(os.listdir("memes")) < 1000:
    print("Downloading memes...")
    download_memes()
else:
    print("Memes are already downloaded")

Memes are already downloaded


Now that the memes are downloaded, let's OCR them with Google Vision API to finalize the testing dataset.

In [4]:
# setup logger
# @see https://stackoverflow.com/questions/6386698/how-to-write-to-a-file-using-the-logging-python-module
import logging
logging.basicConfig(filename="datasetMaker.log",
                    filemode='a',
                    format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                    datefmt='%H:%M:%S',
                    level=logging.DEBUG)


In [5]:
import io
from google.cloud import vision

# load client
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "google_key.json"
client = vision.ImageAnnotatorClient()

# Parts of this function code are from https://cloud.google.com/vision/docs/ocr#vision_text_detection-python
def detect_text(path):
    logging.info(f"Starting detection for {path}")
    with io.open(path, 'rb') as image_file:
        content = image_file.read()

    # fetch
    image = vision.Image(content=content)
    response = client.text_detection(image=image)

    logging.info(f"response for {path}: {response}")

    if response.error.message:
        logging.error(response.error.message)
        raise Exception(f"{response.error.message} for {path}. Stopping script")
    
    return response.text_annotations[0].description

# OCR all memes
for meme in os.listdir(MEME_FOLDER):
    if meme.endswith(".txt"):
        continue    

    # each meme gets its text result saved to a .txt file in the same folder
    textPath = f"{MEME_FOLDER}/{meme.split('.')[0]}.txt"
    if not os.path.exists(textPath):
        try:
            text = detect_text(f"{MEME_FOLDER}/{meme}")
            with open(textPath, "w") as f:
                f.write(text)
            print(f"Saved text to {textPath} with length {len(text)}")
        except Exception as e:
            print(f"Failed to OCR {meme}")
            print(e)
            raise e
