In [None]:
import praw
import csv
import re
import json   
import requests
import mysql.connector
import traceback

from datetime import datetime
from dateutil import tz

from tqdm.notebook import tqdm

reddit = praw.Reddit()

class StockPost(object):
    def __init__(self, postID, postURL, ups, downs, numComments, stock, date):
        self.postID = postID
        self.url = postURL
        self.stock = stock
        self.ups = ups
        self.downs = downs
        self.numComments = numComments
        self.date = date
    
    def jsonEnc(self):
      return {'stock': self.stock, 'postID': self.postID, 'postURL': self.url, 'ups': self.ups, 'downs': self.downs, 'numComments': self.numComments}

def jsonDefEncoder(obj):
    if hasattr(obj, 'jsonEnc'):
        return obj.jsonEnc()
    else: #some default behavior
        return obj.__dict__

 ### DATABASE FUNCTIONS ###

    # returns connection object #
def connect_to_db(db_name):
    cnx = mysql.connector.connect(
    user='root',
    password='chalkHorseMountain',
    host='localhost',
    database=db_name
    )
    return cnx
    
    # returns boolean #
def table_exists(cursor, tbl_name):
    cursor.execute("""
        SELECT COUNT(*)
        FROM information_schema.tables
        WHERE table_schema = DATABASE()
        AND table_name = \"""" + tbl_name + """\";
    """)
    
    if cursor.fetchone()[0] == 1:
        return True
    return False
    
    
class SubredditScraper:

    def __init__(self, sub, sort='new', lim=900):
        self.sub = sub
        self.sort = sort
        self.lim = lim

        #print(
            #f'SubredditScraper instance created with values '
            #f'sub = {sub}, sort = {sort}, lim = {lim}')

    def set_sort(self):
        if self.sort == 'new':
            return self.sort, reddit.subreddit(self.sub).new(limit=self.lim)
        elif self.sort == 'top':
            return self.sort, reddit.subreddit(self.sub).top(limit=self.lim)
        elif self.sort == 'hot':
            return self.sort, reddit.subreddit(self.sub).hot(limit=self.lim)
        else:
            self.sort = 'hot'
            print('Sort method was not recognized, defaulting to hot.')
            return self.sort, reddit.subreddit(self.sub).hot(limit=self.lim)

    def get_posts(self):

        stockTickers = {}
        with open('./../Tickers/tickers_crypto.csv', mode='r') as infile:
            reader = csv.reader(infile)
            for row in reader:
                stockTickers[row[0].split(',')[0]] = {}
        """Get unique posts from a specified subreddit."""

        # Attempt to specify a sorting method.
        sort, subreddit = self.set_sort()

        print(f'Collecting information from r/{self.sub}.')
        
        ## Search posts for tickers ##
        relevantPosts = []
        subreddit = list(subreddit)
        for i in tqdm(range(len(subreddit)), desc="[1/2] Scraping Posts"):
            post = subreddit[i]
            if post.link_flair_text != 'Meme':
                for stock in stockTickers.keys():
                    try:
                        if(re.search(r"\s+\$?" + stock + r"\$?\s+", post.selftext) or re.search(r"\s+\$?" + stock + r"\$?\s+",  post.title)):
                            stockTickers[stock][post.id] = StockPost(post.id, post.permalink, post.ups, post.downs, post.num_comments, stock, post.created_utc)
                    except:
                        print(f"This Ticker threw an exception: {stock}")
                        traceback.print_exc()
                    
        for stock in stockTickers.keys():
            if (len(stockTickers[stock]) > 0):
                for post in stockTickers[stock]:
                    
                    relevantPosts.append(stockTickers[stock][post]) 
        #json_object = json.dumps(relevantPosts, default=jsonDefEncoder, indent = 4)   
        #print(json_object)
        
         ## Upload data to db ##
        cnx = connect_to_db("TheSpatula")
        mycursor = cnx.cursor()
        assert mycursor
        assert table_exists(mycursor, "reddit")
        
        for x in tqdm(range(len(relevantPosts)), desc="[2/2] Updating Database"):
            post = relevantPosts[x]
            num_votes = post.ups + post.downs
            
            # get created_date and convert from utc to local time
            utc_stamp = post.date
            utc = datetime.utcfromtimestamp(utc_stamp).strftime('%Y-%m-%d %H:%M:%S')
            from_zone = tz.tzutc()
            to_zone = tz.tzlocal()
            utc = datetime.strptime(utc, '%Y-%m-%d %H:%M:%S')
            utc = utc.replace(tzinfo=from_zone)
            date_posted = utc.astimezone(to_zone).date()
            
            ## Add post, if it exists already, update post ##
            mycursor.execute(f"""
            INSERT INTO reddit (post_id, symbol, num_comments, num_votes, date_posted) 
            VALUES("{post.postID}", "{post.stock}", {post.numComments}, {num_votes}, "{date_posted}")
            ON DUPLICATE KEY UPDATE num_comments={post.numComments}, num_votes={num_votes}, date_posted="{date_posted}"
            ;""")
            
            cnx.commit()
            
            
## get_posts() every subreddit with 10000 post limit ##
def deep_scrape():
    subreddits = ["CryptoCurrency", "CryptoMoonShots", "CryptoMarkets", "Crypto_com", "wallstreetbets", "Wallstreetbetsnew", "stocks", "RobinHoodPennyStocks", "pennystocks", "weedstocks", "trakstocks", "ausstocks", "shroomstocks", "Canadapennystocks"]
    
    for x in tqdm(range(len(subreddits)), desc="DEEP SCRAPE"):
        sub = subreddits[x]
        SubredditScraper(sub, lim=10000, sort='new').get_posts()
        SubredditScraper(sub, lim=10000, sort='hot').get_posts()
        SubredditScraper(sub, lim=10000, sort='top').get_posts()

        
if __name__ == '__main__':
    #deep_scrape()
    SubredditScraper(sub, lim=20, sort='hot').get_posts()