In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# default_exp downloaderlabeler
# all_slow

In [None]:
import ipywidgets as widgets
from ipywidgets import Button, Layout
import pickle
import json
import webbrowser
from steroidsornot.submissionshandler import SubmissionsHandler
from steroidsornot.database import Session
from steroidsornot.submission import Submission
from steroidsornot.firsttry import PrawClient, get_thumbnail, natural_path, steroids_path, uncertain_path, irrelevant_path

In [None]:
from fastcore.test import *

In [None]:
class DownloaderLabeler(object):
    def __init__(self, session):
        self.session = session
        self.unlabeled_submissions = session.query(Submission).filter_by(label=None)
        left = self.unlabeled_submissions.count()
        print(f'-----------------------------------')
        print(f'{left} MORE TO GO')
        print(f'-----------------------------------')
        self.reddit = PrawClient().reddit()
        self.submissions_handler = SubmissionsHandler()
        
        self._main_loop()
        
    def _main_loop(self):
        count = 0
        quit = False
        while(quit == False):
            unlabeled_batch = self._get_unlabeled_batch()
            for pushshift_submission, reddit_submission in unlabeled_batch:
                data = json.loads(pushshift_submission.data)

                if not self.submissions_handler.is_useful(vars(reddit_submission)):
                    # Submission isn't useful, label it as so and skip it.
                    print(f'AUTOREMOVED https://reddit.com{reddit_submission.permalink}')
                    pushshift_submission.label = 'autoremoved'
                    self.session.add(pushshift_submission)
                    self.session.commit()
                    count +=1
                    continue

                print(f'GETTING {data["full_link"]}')
                webbrowser.open(data['full_link'])

                response, filename = get_thumbnail(reddit_submission)

                if response.status_code == 200:
                    quit = self._label(reddit_submission, pushshift_submission, response, filename)
                    if quit:
                        left = self.session.query(Submission).filter_by(label=None).count()
                        print(f'-----------------------------------')
                        print(f'SUCCESSFULLY LABELED {count} IMAGES')
                        print(f'{left} MORE TO GO')
                        print(f'-----------------------------------')
                        break
                    else:
                        count += 1

        self.session.close()
        
    def _get_unlabeled_batch(self):
        pushshift_submissions = self.session.query(Submission).filter_by(label=None).limit(100)

        reddit_submissions = []
        for pushshift_submission in pushshift_submissions:
            data = json.loads(pushshift_submission.data)
            reddit_submission = self.reddit.submission(url=data['full_link'])
            reddit_submissions.append(reddit_submission)

        fullnames = []

        for reddit_submission in reddit_submissions:
            fullnames.append('t3_' + reddit_submission.id)


        reddit_submissions_batch = self.reddit.info(fullnames)
        
        submissions_batch = list(zip(pushshift_submissions, reddit_submissions_batch))
        
        return submissions_batch
        
    def _label(self, reddit_submission, pushshift_submission,response, filename):
        directories = {
            'n': natural_path,
            's': steroids_path,
            'u': uncertain_path,
            'i': irrelevant_path,
        }
        labels = {
            'n': 'natural',
            's': 'steroids',
            'u': 'uncertain',
            'i': 'irrelevant',
            'd': 'deleted',
            'a': 'autoremoved'
        }
        valid_input = False
        letter = ''
        while(not valid_input):
            print('Which label? (n)atural, (s)teroids, (u)ncertain, (i)rrelevant or (d)eleted. Or (q)uit')
            letter = input()
            if letter in ['n', 's', 'u', 'i', 'd', 'q']:
                valid_input = True
        if letter == 'q':
            return True
        elif letter == 'd':
            print(f'MARKING "{reddit_submission.title}" AS DELETED')
            pushshift_submission.label = 'deleted'
            self.session.add(pushshift_submission)
            self.session.commit()
        else:
            path = directories.get(letter, "Invalid directory")
            label = labels.get(letter, "Invalid label")
            full_path = path / filename

            with open(full_path, 'wb') as f:
                f.write(response.content)
                pushshift_submission.image_path = str(full_path.absolute())
                pushshift_submission.label = label
                self.session.add(pushshift_submission)
                self.session.commit()
                print(f'LABELED AS {label}')
                print(f'SAVED IMAGE TO {full_path}')
        return False
        

In [None]:
# This is a monkey patch for something I don't want to debug
def label_first_unlabeled_deleted():
    session = Session()
    post = session.query(Submission).filter_by(label=None)[0]
    broken.label = 'deleted'
    session.add(broken)
    session.commit()

In [None]:
def label_count(label, session):
    return session.query(Submission).filter_by(label=label).count()

In [None]:
def statistics():
    session = Session()
    steroids = label_count('steroids', session)
    natural = label_count('natural', session)
    autoremoved = label_count('autoremoved', session)
    uncertain = label_count('uncertain', session)
    irrelevant = label_count('irrelevant', session)
    unlabeled = label_count(None, session)
    total = session.query(Submission).count()

    labeled = total - unlabeled
    ratio_autoremoved = autoremoved / labeled
    ratio_natural = natural / labeled
    ratio_steroids = steroids / labeled
    ratio_trainable = (natural + steroids) / labeled

    expected_total_natural_pics = ratio_natural * total
    expected_total_steroid_pics = ratio_steroids * total
    
    print(f'''
    labeled: {labeled}
    
    steroids: {steroids}
    natural: {natural}
    autoremoved: {autoremoved}
    uncertain: {uncertain}
    irrelevant: {irrelevant}
    
    unlabeled: {unlabeled}
    total: {total}
    percent done: {int((1 - unlabeled/total) * 100)}%
    
    Expected total natural pics: {int(expected_total_natural_pics)}
    Expected total steroids pics: {int(expected_total_steroid_pics)}
    ''')
    


In [None]:
statistics()

In [None]:
downloader_labeler = DownloaderLabeler(Session())