<a href="https://colab.research.google.com/github/klutzydrummer/Python_Projects/blob/main/streamlined_article_aquisition_module.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#!/usr/bin/env python3
"""
Module Docstring
"""
import os
import shutil
import pip
from pathlib import Path
import warnings
warnings.filterwarnings("ignore", message="Setuptools is replacing distutils.")

pyrequirements_path = Path("requirements.txt")

if not pyrequirements_path.exists():
    with open(pyrequirements_path, "w") as project_file:
        project_file.write('''numpy
nltk
aiomysql
pyyaml >= 6.0.1
tqdm >= 4.66.0
googlenews >= 1.6.8
newspaper3k >= 0.2.8
asyncio >= 3.4.3
pendulum >= 2.1.2
psycopg2-binary >= 2.9.7''')

    pip.main(["install", "-r", str(pyrequirements_path)])

import aiohttp
import asyncio
import dataclasses
import datetime
import functools
import hashlib
import itertools
import json
import logging
import math
import os
import shutil
import string
import sys
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path as path
from typing import Callable, List, Optional, Generator, Awaitable

import joblib
import nltk
from GoogleNews import GoogleNews
from newspaper import Article, Source, Config
import pandas as pd
import pendulum
from tqdm.asyncio import tqdm_asyncio
from tqdm.autonotebook import tqdm
from nltk.corpus import stopwords
from nltk.stem.wordnet import WordNetLemmatizer
from nltk.tokenize import word_tokenize
import pymysql

import yfinance
import random
import pandas_datareader as pdr


from dataclasses import asdict, fields, make_dataclass
from typing import Type

config_base = "/content/drive/MyDrive/Machine_Learning_Digestor/config" if path("/content/drive").exists() else "/content"


def generate_random_dates(start_date, end_date, num_dates):
    # Convert the date objects to datetime if they are provided as strings
    if isinstance(start_date, str):
        start_date = datetime.datetime.strptime(start_date, "%m/%d/%Y")
    if isinstance(end_date, str):
        end_date = datetime.datetime.strptime(end_date, "%m/%d/%Y")

    # Compute the time difference in days
    delta = end_date - start_date

    # If num_dates is equal to the number of available days, return all dates
    if num_dates == delta.days + 1:
        generated_dates = [start_date + datetime.timedelta(days=i) for i in range(delta.days + 1)]
        random.shuffle(generated_dates)
        return generated_dates
    elif num_dates == 0:
        generated_dates = [start_date + datetime.timedelta(days=i) for i in range(delta.days + 1)]
        random.shuffle(generated_dates)
        return generated_dates
    # If num_dates is greater than available days, raise an error
    if num_dates > delta.days + 1:
        raise ValueError("num_dates is greater than the number of available days")

    random_date = None
    generated_dates = []

    while len(generated_dates) < num_dates:
        # Generate a random number of days to add to the start date
        random_days = random.randint(0, delta.days)

        # Compute the random date
        random_date = start_date + datetime.timedelta(days=random_days)

        # Check if this date has been generated before
        if random_date not in generated_dates:
            generated_dates.append(random_date)

    # Return the list of generated dates
    return generated_dates

def verbprint(*args, verbose=False, **kwargs):
    if verbose is True:
        print(*args, **kwargs)

def ignore_first_arg(func: Callable[[str], str], _: object, arg: str) -> str:
    return func(arg)

def preprocess(text: str) -> str:
    """Preprocess the text."""
    # Tokenization
    tokens = word_tokenize(text)

    # Lowercasing
    tokens = [token.lower() for token in tokens]

    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    tokens = [token for token in tokens if token not in stop_words]

    # Lemmatization
    lemmatizer = WordNetLemmatizer()
    tokens = [lemmatizer.lemmatize(token) for token in tokens]

    return " ".join(tokens)



def custom_time_to_str(datetime: datetime.datetime) -> str:
    return f"{int(datetime.month):02}/{int(datetime.day):02}/{int(datetime.year)}"

def chunks(lst, n) -> Generator:
    for i in range(0, len(lst), n):
        yield lst[i:i + n]


@dataclasses.dataclass(init=False, repr=True, eq=True)
class PreppedArticles:
    title: str
    text: str
    summary: str
    authors: list[str]
    publish_date: datetime.datetime
    search_date: datetime.datetime
    tags: set[str]
    keywords: set[str]
    stock_symbol: str
    search_string: str
    url: str
    article_id: uuid.UUID

    @staticmethod
    def remove_non_ascii(a_str):
        ascii_chars = string.printable

        return ''.join(
            filter(lambda x: x in ascii_chars, a_str)
        )

    @classmethod
    def str_to_uuid(cls, in_string: str) -> uuid.UUID:
        hash_machine = hashlib.md5()
        filtered_string = cls.remove_non_ascii(in_string)
        hash_machine.update(bytes(filtered_string, "ascii"))
        hash_bytes = int(hash_machine.hexdigest(), 16)
        return uuid.UUID(int=hash_bytes)

    @classmethod
    def custom_uuid(cls, stock_symbol, title):
        return cls.str_to_uuid(f"{stock_symbol} {title}")

    def __init__(self, title: str, text: str, summary: str, authors: list[str], publish_date: datetime.datetime, search_date: datetime.datetime, tags: set[str], keywords: set[str], stock_symbol: str, search_string: str, url: str):
        self.title = title
        self.text = text
        self.summary = summary
        self.authors = authors
        self.publish_date = publish_date
        self.search_date = search_date
        self.tags = tags
        self.keywords = keywords
        self.stock_symbol = stock_symbol
        self.search_string = search_string
        self.article_id = self.custom_uuid(stock_symbol=self.stock_symbol,title=self.title)
        self.url = url

    @classmethod
    def import_newspaper(cls, newspaper_article: Article, stock_symbol: str, search_string: str, search_date: datetime.datetime) -> 'cls':
        if type(newspaper_article.publish_date) is not datetime.datetime:
            raise ValueError(f"Publish date is not of type datetime:\n type:\n {type(newspaper_article.publish_date)}\n value:\n {newspaper_article.publish_date}")
        result = cls(
            title=newspaper_article.title,
            text=newspaper_article.text,
            summary=newspaper_article.summary,
            authors=newspaper_article.authors,
            publish_date=datetime.datetime.fromtimestamp(newspaper_article.publish_date.timestamp()),
            search_date=search_date,
            tags=set(newspaper_article.tags),
            keywords=set(newspaper_article.keywords),
            stock_symbol=stock_symbol,
            search_string=search_string,
            url=newspaper_article.canonical_link
        )
        return result

    def astuple(self) -> tuple:
        return dataclasses.astuple(self)

    def asdict(self) -> dict:
        return dataclasses.asdict(self)


def search_google(search_string: str) -> list[str]:
    googlenews = GoogleNews(lang="en", encode="utf-8")
    googlenews.clear()
    googlenews.enableException(True)
    googlenews.get_news(search_string)
    return googlenews.get_links()



class DatabaseHandler:
    def __init__(self, conn_details):
        self.conn_details = conn_details
        self.loop = asyncio.get_event_loop()
        self.pool = ThreadPoolExecutor(max_workers=5)  # Adjust max_workers based on your needs
        self._dataclass_cache = {}

    def connect(self):
        return pymysql.connect(
            user=self.conn_details['user'],
            password=self.conn_details['password'],
            host=self.conn_details['host'],
            database=self.conn_details['database'],
            ssl={'require': True}
        )

    def upload_prepped_articles(self, articles: List[PreppedArticles]):
        connection = self.connect()
        cursor = connection.cursor()

        query = """
        INSERT IGNORE INTO articles (
            article_id, authors, keywords, publish_date, search_date,
            search_string, stock_symbol, summary, tags, text, title, url
        ) VALUES (
            %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
        );
        """

        data_to_insert = [
            (
                str(article.article_id), json.dumps(article.authors), json.dumps(list(article.keywords)),
                article.publish_date, article.search_date, article.search_string,
                article.stock_symbol, article.summary, json.dumps(list(article.tags)),
                article.text, article.title, article.url
            )
            for article in articles
        ]

        cursor.executemany(query, data_to_insert)

        connection.commit()
        cursor.close()
        connection.close()



def share_elements(set1, set2, n):
    set1_ready = set([str(item).lower() for item in set1])
    set2_ready = set([str(item).lower() for item in set2])
    common_elements = len(set(set1_ready).intersection(set(set2_ready)))
    return common_elements >= n

async def validate_article(article: Article, mandatory_keywords: list[str], total_match_keywords: int, verbose=False) -> Article | None:
    try:
        if type(article.publish_date) is not datetime.datetime:
            raise ValueError("Publish_date not of type datetime.")
        if "Are you a robot?" in article.title or "403 Client Error" in article.text:
            raise ValueError("Crawling not permitted.")
        if article.text == "":
            raise ValueError("Article text empty.")
        if share_elements(article.keywords, mandatory_keywords, total_match_keywords) is not True:
            raise ValueError("Article did not match enough mandatory keywords.")
        return article
    except Exception as err:
        if verbose is True:
            print(err)
        return None

def prereq():
    if path('/root/nltk_data').exists() is not True:
        nltk.download('stopwords')
        nltk.download('wordnet')
        nltk.download('punkt')

async def pipeline(stock_symbol: str, search_date_start: datetime.datetime, search_date_end: datetime.datetime, num_dates: int, postgres_connection_string: str, mandatory_keywords: list[str], total_match_keywords: int, logfile: path | None = None, loglevel: str = "", verbose: bool = False, dry_run: bool = False):
    """
    Asynchronous pipeline function to process stock-related data and upload the processed articles to a postgres database.

    :param stock_symbol: The stock symbol representing a specific stock, given as a string.
    :param search_date: The date for which to search stock-related data, given as a datetime object.
    :param llama_instance: An instance of the LlamaCpp class responsible for processing or analyzing data.
    :param embeddings_instance: An instance of the HuggingFaceInstructEmbeddings class for generating or managing embeddings.
    :param postgres_connection_string: The connection string for connecting to a PostgreSQL database, given as a string.
    :param logfile: Optional path to a file where log information will be stored. Defaults to None, indicating no logging to a file.
    :param loglevel: Optional string specifying the logging level, such as "INFO", "WARNING", "ERROR", etc. Defaults to an empty string, indicating default logging behavior.

    :return: None.
    """

    if logfile is None:
        logfile = path('/content/article_aquisition_log.log')
    logfile = path(logfile)
    logfile.parent.mkdir(parents=True, exist_ok=True)

    log_message = f"Loglevel set to {loglevel}"
    match loglevel:
        case 'debug':
            loglevel_type = logging.DEBUG
        case 'info':
            loglevel_type = logging.INFO
        case 'warn':
            log_message = "logging.WARN is deprecated, setting Logging level to logging.WARNING."
            loglevel_type = logging.WARNING
        case 'warning':
            loglevel_type = logging.WARNING
        case 'error':
            loglevel_type = logging.ERROR
        case 'critical':
            loglevel_type = logging.CRITICAL
        case 'fatal':
            loglevel_type = logging.FATAL
        case _:
            loglevel_type = logging.DEBUG
            log_message = f"You have typed an invalid loglevel: {loglevel}\n    Logging level set to DEBUG."
    if verbose is True:
        loglevel_type = logging.DEBUG
        log_message = "LLM verbosity set to True.\n    Logging level set to debug."

    logging.debug("logging level set")
    logging.basicConfig(filename=str(logfile), encoding='utf-8', level=loglevel_type)
    logging.getLogger().setLevel(loglevel_type)
    logging.info(log_message)

    config = Config()
    config.MAX_THREADS_PER_SOURCE = 10
    google_news_url = 'https://news.google.com'

    news_list = []
    generated_dates = generate_random_dates(search_date_start, search_date_end, num_dates)
    for search_date_obj in generated_dates:
        search_date_str = custom_time_to_str(search_date_obj)
        search_string = f"stock news {stock_symbol} {search_date_str}"
        sub_news_list = search_google(search_string)
        news_list.extend(sub_news_list)
        logging.debug(f"Search created for: {search_date_str}")
        logging.debug("All searches completed.")



        source = Source(google_news_url, config=config)
        logging.debug(f"Aquired {len(sub_news_list)} urls.")

        for url in sub_news_list:
            source.articles.append(Article(f"https://{url}"))

        source.download_articles()
        try:
            source.parse_articles()
            logging.debug("Bulk parse completed")
        except:
            for article in source.articles:
                try:
                    article.parse()
                except:
                    continue
            logging.debug("Indvidual articles parsed due to error in bulk parsing.")
        tasks = []
        for article in source.articles:
            if article.is_parsed is True:
                article.nlp()
                tasks.append(asyncio.create_task(validate_article(article=article, mandatory_keywords=mandatory_keywords, total_match_keywords=total_match_keywords)))
        logging.debug("Performed nlp on all articles.")
        valid_articles = [*filter(None, await asyncio.gather(*tasks))]
        logging.debug(f"Validated {len(valid_articles)} urls.")
        prepped_articles = [PreppedArticles.import_newspaper(newspaper_article=article, stock_symbol=stock_symbol, search_string=search_string, search_date=search_date_obj) for article in valid_articles]
        logging.debug("All valid articles prepared for upload to db.")


        db_handler = DatabaseHandler(ps_conn_details)


        try:
            db_handler.upload_prepped_articles(prepped_articles)
            logging.debug("Uploaded articles.")
        except Exception as err:
            logging.debug("Encountered Error. I'm not usre these errors are being logged by logging.ERROR")
            logging.error(err)
            pass
    logging.info("Processing pipeline completed.")

if __name__ == '__main__':
    stock_symbol = "MSFT"
    # start_date = datetime.datetime(day=1, month=1, year=2022)
    end_date = datetime.datetime.now()
    start_date = end_date -  datetime.timedelta(days=5)
    end_date = '5/24/2020'
    start_date = '1/2/1962'
    # search_date = "08/13/2023"
    mandatory_keywords = [
        "Stock",
        "Stocks",
        "Stock Market",
        "Shares",
        "Dividend",
        "Portfolio",
        "Investment",
        "Trading",
        "Exchange",
        "Securities",
        "Bull Market",
        "Bear Market",
        "IPO",
        "Equity",
        "Bonds",
        "Index",
        "Financial Market"
    ]
    total_match_keywords = 2

    # search_date_obj = datetime.datetime.strptime(search_date, "%m/%d/%Y")

    logfile = path('/content/article_aquisition_log.log')
    loglevel = 'info'

    with open(f"{config_base}/connection_string.json", "r") as f:
        postgres_connection_string = json.load(f)

    with open(f"{config_base}/ps_conn_details_preprocesser.json", "r") as f:
        ps_conn_details = json.load(f)
        del ps_conn_details["ssl"]

    # loglevel = 'debug'

    #stock_symbol_list = ["MMM","AOS","ABT","ABBV","ACN","ATVI","ADM","ADBE","ADP","AAP","AES","AFL","A","APD","AKAM","ALK","ALB","ARE","ALGN","ALLE","LNT","ALL","GOOGL","GOOG","MO","AMZN","AMCR","AMD","AEE","AAL","AEP","AXP","AIG","AMT","AWK","AMP","ABC","AME","AMGN","APH","ADI","ANSS","AON","APA","AAPL","AMAT","APTV","ACGL","ANET","AJG","AIZ","T","ATO","ADSK","AZO","AVB","AVY","AXON","BKR","BALL","BAC","BBWI","BAX","BDX","WRB","BRK.B","BBY","BIO","TECH","BIIB","BLK","BK","BA","BKNG","BWA","BXP","BSX","BMY","AVGO","BR","BRO","BF.B","BG","CHRW","CDNS","CZR","CPT","CPB","COF","CAH","KMX","CCL","CARR","CTLT","CAT","CBOE","CBRE","CDW","CE","CNC","CNP","CDAY","CF","CRL","SCHW","CHTR","CVX","CMG","CB","CHD","CI","CINF","CTAS","CSCO","C","CFG","CLX","CME","CMS","KO","CTSH","CL","CMCSA","CMA","CAG","COP","ED","STZ","CEG","COO","CPRT","GLW","CTVA","CSGP","COST","CTRA","CCI","CSX","CMI","CVS","DHI","DHR","DRI","DVA","DE","DAL","XRAY","DVN","DXCM","FANG","DLR","DFS","DIS","DG","DLTR","D","DPZ","DOV","DOW","DTE","DUK","DD","DXC","EMN","ETN","EBAY","ECL","EIX","EW","EA","ELV","LLY","EMR","ENPH","ETR","EOG","EPAM","EQT","EFX","EQIX","EQR","ESS","EL","ETSY","EG","EVRG","ES","EXC","EXPE","EXPD","EXR","XOM","FFIV","FDS","FICO","FAST","FRT","FDX","FITB","FSLR","FE","FIS","FI","FLT","FMC","F","FTNT","FTV","FOXA","FOX","BEN","FCX","GRMN","IT","GEHC","GEN","GNRC","GD","GE","GIS","GM","GPC","GILD","GL","GPN","GS","HAL","HIG","HAS","HCA","PEAK","HSIC","HSY","HES","HPE","HLT","HOLX","HD","HON","HRL","HST","HWM","HPQ","HUM","HBAN","HII","IBM","IEX","IDXX","ITW","ILMN","INCY","IR","PODD","INTC","ICE","IFF","IP","IPG","INTU","ISRG","IVZ","INVH","IQV","IRM","JBHT","JKHY","J","JNJ","JCI","JPM","JNPR","K","KDP","KEY","KEYS","KMB","KIM","KMI","KLAC","KHC","KR","LHX","LH","LRCX","LW","LVS","LDOS","LEN","LNC","LIN","LYV","LKQ","LMT","L","LOW","LYB","MTB","MRO","MPC","MKTX","MAR","MMC","MLM","MAS","MA","MTCH","MKC","MCD","MCK","MDT","MRK","META","MET","MTD","MGM","MCHP","MU","MSFT","MAA","MRNA","MHK","MOH","TAP","MDLZ","MPWR","MNST","MCO","MS","MOS","MSI","MSCI","NDAQ","NTAP","NFLX","NWL","NEM","NWSA","NWS","NEE","NKE","NI","NDSN","NSC","NTRS","NOC","NCLH","NRG","NUE","NVDA","NVR","NXPI","ORLY","OXY","ODFL","OMC","ON","OKE","ORCL","OGN","OTIS","PCAR","PKG","PANW","PARA","PH","PAYX","PAYC","PYPL","PNR","PEP","PFE","PCG","PM","PSX","PNW","PXD","PNC","POOL","PPG","PPL","PFG","PG","PGR","PLD","PRU","PEG","PTC","PSA","PHM","QRVO","PWR","QCOM","DGX","RL","RJF","RTX","O","REG","REGN","RF","RSG","RMD","RVTY","RHI","ROK","ROL","ROP","ROST","RCL","SPGI","CRM","SBAC","SLB","STX","SEE","SRE","NOW","SHW","SPG","SWKS","SJM","SNA","SEDG","SO","LUV","SWK","SBUX","STT","STLD","STE","SYK","SYF","SNPS","SYY","TMUS","TROW","TTWO","TPR","TRGP","TGT","TEL","TDY","TFX","TER","TSLA","TXN","TXT","TMO","TJX","TSCO","TT","TDG","TRV","TRMB","TFC","TYL","TSN","USB","UDR","ULTA","UNP","UAL","UPS","URI","UNH","UHS","VLO","VTR","VRSN","VRSK","VZ","VRTX","VFC","VTRS","VICI","V","VMC","WAB","WBA","WMT","WBD","WM","WAT","WEC","WFC","WELL","WST","WDC","WRK","WY","WHR","WMB","WTW","GWW","WYNN","XEL","XYL","YUM","ZBRA","ZBH","ZION","ZTS"]
    temp_stock_symbol_list = list(pdr.get_nasdaq_symbols().index)
    # specific_targets = ["IBM"]
    specific_targets = ["MSFT", "AAPL", "TSLA"]
    specific_target_percentage = 10
    percentage_multiplier = int(len(temp_stock_symbol_list) // 100 * specific_target_percentage)
    specific_targets_multiplied = [item for item in specific_targets for _ in range(percentage_multiplier)]
    temp_stock_symbol_list.extend(specific_targets_multiplied)
    random.shuffle(temp_stock_symbol_list)
    stock_symbol_list = specific_targets
    stock_symbol_list.extend(temp_stock_symbol_list)
    print(stock_symbol_list)
    prereq()
    while True:
        for stock_symbol in tqdm(stock_symbol_list):
            print(f"stock_symbol: {stock_symbol}")
            await pipeline(
                stock_symbol=stock_symbol,
                search_date_start=start_date,
                search_date_end=end_date,
                num_dates=0,
                postgres_connection_string=postgres_connection_string,
                mandatory_keywords=mandatory_keywords,
                total_match_keywords=total_match_keywords,
                logfile=logfile,
                loglevel=loglevel,
                dry_run=False
            )