In [1]:
#Imports
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2,preprocess_input
from tensorflow.keras.layers import Input,GlobalMaxPooling2D,Dense
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import img_to_array,load_img
import json
import numpy as np
import cv2
from cv2 import resize
from os import path, listdir
import praw,requests,re
import time
import psaw
import datetime as dt
import os
import sys

In [3]:
#Functions
def compare2images(original,duplicate):
    if original is None or duplicate is None:
        return True #delete emtpy pictures
    if original.shape == duplicate.shape:
        #print("The images have same size and channels")
        difference = cv2.subtract(original, duplicate)
        b, g, r = cv2.split(difference)
        if cv2.countNonZero(b) == 0 and cv2.countNonZero(g) == 0 and cv2.countNonZero(r) == 0:
            return True
        else:
            return False
    else:
        return False

def submissions_pushshift_praw(subreddit, start=None, end=None, limit=20000, extra_query=""):
    """
    A simple function that returns a list of PRAW submission objects during a particular period from a defined sub.
    This function serves as a replacement for the now deprecated PRAW `submissions()` method.

    :param subreddit: A subreddit name to fetch submissions from.
    :param start: A Unix time integer. Posts fetched will be AFTER this time. (default: None)
    :param end: A Unix time integer. Posts fetched will be BEFORE this time. (default: None)
    :param limit: There needs to be a defined limit of results (default: 100), or Pushshift will return only 25.
    :param extra_query: A query string is optional. If an extra_query string is not supplied,
                        the function will just grab everything from the defined time period. (default: empty string)

    Submissions are yielded newest first.

    For more information on PRAW, see: https://github.com/praw-dev/praw
    For more information on Pushshift, see: https://github.com/pushshift/api
    """
    matching_praw_submissions = []

    # Default time values if none are defined (credit to u/bboe's PRAW `submissions()` for this section)
    utc_offset = 28800
    now = int(time.time())
    start = max(int(start) + utc_offset if start else 0, 0)
    end = min(int(end) if end else now, now) + utc_offset

    # Format our search link properly.
    search_link = ('https://api.pushshift.io/reddit/submission/search/'
                   '?subreddit={}&after={}&before={}&sort_type=score&sort=asc&limit={}&q={}')
    search_link = search_link.format(subreddit, start, end, limit, extra_query)

    # Get the data from Pushshift as JSON.
    retrieved_data = requests.get(search_link)
    returned_submissions = retrieved_data.json()['data']

    # Iterate over the returned submissions to convert them to PRAW submission objects.
    for submission in returned_submissions:

        # Take the ID, fetch the PRAW submission object, and append to our list
        praw_submission = reddit.submission(id=submission['id'])
        matching_praw_submissions.append(praw_submission)

    # Return all PRAW submissions that were obtained.
    return matching_praw_submissions

In [4]:
with open('config.json') as config_file:
    config = json.load(config_file)['keys']

# Sign into Reddit using API Key
reddit = praw.Reddit(user_agent="Downloading images from r/art for a machine learning project",
                     client_id=config['client_id'],
                     client_secret=config['client_secret'],
                     username=config['username'],
                     password=config['password'])


In [11]:
#187mb for 200 pics, approx 18.7gb for 20000
#Relatively arbitrary start date, representative of modern times
Jan12018 = int(dt.datetime(2018,1,1).timestamp())
#Pass a PRAW instances so that scores are accurate
api = psaw.PushshiftAPI(reddit)
n = 30000
print("Looking for posts using Pushshift...")
posts = list(api.search_submissions(after = Jan12018, subreddit='art', limit = n*10))
print(f"Number of posts found:  {len(posts)}")
files=[]
counter = 0
#Some images are deleted. Load a template and don't include files that are deleted.

for post in posts:
    counter +=1
    sys.stdout.write('\r')
    sys.stdout.write("Downloading: [{:{}}] {:.1f}%".format("="*counter, n-1, (100/(n-1)*counter)))
    sys.stdout.flush()
    url = (post.url)
    #Save score for ML training, and post id for unique file names
    file_name = str(post.score) + "_" + str(post.id) + ".jpg"
    try:
        #use requests to get image
        r = requests.get(url)
        fullfilename = "pics/"+file_name
        files.append(file_name)
        #save image
        with open(fullfilename,"wb") as f:
            f.write(r.content)

    except (
        requests.ConnectionError,
        requests.exceptions.ReadTimeout,
        requests.exceptions.Timeout,
        requests.exceptions.ConnectTimeout,
    ) as e:
        print(e)

#Number of files downloaded not always the same due to connection errors
print(f'\nNumber of files downloaded: {len(files)}')

Looking for posts using Pushshift...
Number of posts found:  30000
Number of files downloaded: 29995


In [None]:
path = "pics/"
files = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]
print(files)
 #Check if image is just a blank image by comparing to template
cull = []
counter = 0
length = len(files)
print(f"Original Number of files: {len(files)}")
for file in files:
    counter+=1
    sys.stdout.write('\r')
    sys.stdout.write("Scanning: [{:{}}] {:.1f}%".format("="*counter, length-1, (100/(length-1)*counter)))
    sys.stdout.flush()
    fullfilename = "pics/" + file
    deletedtemplate = cv2.imread("exampledeleted.jpg")
    checkdeleted = cv2.imread(fullfilename)
    if compare2images(deletedtemplate,checkdeleted):
        #delete if so
        os.remove(fullfilename)
        cull.append(file)
counter = 0
length = len(cull)
for file in cull:
    files.remove(file)
    sys.stdout.write('\r')
    sys.stdout.write("Deleting: [{:{}}] {:.1f}%".format("="*counter, length-1, (100/(length-1)*counter)))
    sys.stdout.flush()
print(f"Final Number of files: {len(files)}")
