<a href="https://colab.research.google.com/github/klutzydrummer/Python_Projects/blob/main/article_analyzer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import shutil
import pip
from pathlib import Path
import warnings
import dataclasses
import uuid
import datetime
warnings.filterwarnings("ignore", message="Setuptools is replacing distutils.")

pyrequirements_path = Path("requirements.txt")

if pyrequirements_path.exists() is not True:
    with open(pyrequirements_path, "w") as project_file:
        project_file.write('''numpy
    pymysql
    nltk
    pyarrow
    xformers
    transformers[torch]
    sentence_transformers
    setfit
    aiomysql
    pyyaml >= 6.0.1
    tqdm >= 4.66.0
    asyncio >= 3.4.3
    pendulum >= 2.1.2
    pandas >= 2.0.3
    psycopg2-binary >= 2.9.7''')

    pip.main(["install", "-r", str(pyrequirements_path)])

from pathlib import Path as path

# Use a pipeline as a high-level helper
from transformers import pipeline
import aiomysql
import pandas as pd
import asyncio
import ssl
from tqdm.notebook import tqdm
from typing import List


@dataclasses.dataclass(init=True, repr=True, eq=True)
class SentimentScore:
    article_id: uuid.UUID
    stock_symbol: str
    publish_date: datetime.datetime
    sentiment_score: float

class AsyncPlanetscaleBatchFetcher:
    def __init__(self, ps_conn_details, min_size=1, max_size=10):
        self.ps_conn_details = ps_conn_details
        self.min_size = min_size
        self.max_size = max_size
        self.pool = None

    @classmethod
    async def create(cls, ps_conn_details, min_size=1, max_size=10):
        instance = cls(ps_conn_details, min_size, max_size)
        await instance.create_pool()
        return instance

    async def create_pool(self):
        if self.pool is None:
            ca_path = '/etc/ssl/certs/ca-certificates.crt'
            ssl_context = ssl.create_default_context(cafile=ca_path)  # Adjust the path to your CA certificate
            self.pool = await aiomysql.create_pool(
                user=self.ps_conn_details['user'],
                password=self.ps_conn_details['password'],
                host=self.ps_conn_details['host'],
                db=self.ps_conn_details['database'],
                minsize=self.min_size,
                maxsize=self.max_size,
                ssl=ssl_context
            )

    async def fetch_batches(self, table_name, skip_field=None, second_table=None, batch_size=500, max_rows=-1):
        offset = 0
        rows_fetched = 0

        async with self.pool.acquire() as connection:
            async with connection.cursor() as cursor:
                # Execute the query once to get the column names
                sample_query = f"SELECT * FROM {table_name} LIMIT 1"
                await cursor.execute(sample_query)
                sample_batch = await cursor.fetchone()
                column_names = [desc[0] for desc in cursor.description] if sample_batch else []

                while True:
                    # Base query
                    batch_query = f"SELECT * FROM {table_name}"

                    # Add filtering if skip_field and second_table are provided
                    if skip_field and second_table:
                        batch_query += f" WHERE {skip_field} NOT IN (SELECT {skip_field} FROM {second_table})"

                    # Add LIMIT and OFFSET
                    batch_query += f" LIMIT {batch_size} OFFSET {offset}"

                    # Fetch the batch
                    await cursor.execute(batch_query)
                    batch = await cursor.fetchall()

                    # Break if no more rows or reached max_rows
                    if not batch or (max_rows != -1 and rows_fetched >= max_rows):
                        break

                    # Convert the batch to a pandas DataFrame with the column names
                    batch_df = pd.DataFrame(batch, columns=column_names)

                    yield batch_df

                    rows_fetched += len(batch)
                    offset += batch_size

                    # Break if reached max_rows
                    if max_rows != -1 and rows_fetched >= max_rows:
                        break

    async def upload_sentiment_scores_to_db(self, sentiment_scores: List[SentimentScore]):
        async with self.pool.acquire() as connection:
            async with connection.cursor() as cursor:
                insert_query = """INSERT INTO sentiment_scores (article_id, stock_symbol, publish_date, sentiment_score)
                                VALUES (%s, %s, %s, %s)
                                ON DUPLICATE KEY UPDATE sentiment_score = VALUES(sentiment_score);"""
                values = [(str(score.article_id), score.stock_symbol, score.publish_date, score.sentiment_score) for score in sentiment_scores]
                await cursor.executemany(insert_query, values)
                await connection.commit()


    async def average_sentiment_scores(self):
        async with self.pool.acquire() as connection:
            async with connection.cursor() as cursor:
                query = """SELECT stock_symbol, publish_date, AVG(sentiment_score) AS avg_sentiment_score
                    FROM sentiment_scores
                    GROUP BY stock_symbol, publish_date;
                    """
                await cursor.execute(query)
                await connection.commit()
        print("Averaged all sentiment scores per stock symbol per date.")

async def sentiment_to_score(sentiment_result: dict) -> float | None:
    # print(sentiment_result)
    label = sentiment_result.get("label")
    sentiment_sign: int = int(-1)
    sentiment_score = sentiment_result.get("score")
    if type(sentiment_score) is not float:
        print(f"Sentiment score is not float.\n  {sentiment_result}")
        return None
    match label:
        case 'positive':
            sentiment_sign: int = int(1)
            sentiment_score = sentiment_score * sentiment_sign
        case 'neutral':
            # Use 0 as a padding value in ML model, set to small positive to differentiate between padding and neutral
            sentiment_sign: int = int(1)
            sentiment_score = 0.0001
        case'negative':
            sentiment_sign: int = int(-1)
            sentiment_score = sentiment_score * sentiment_sign
        case _:
            raise ValueError(f"Sentiment label is not 'positive'/'neutral'/'negative', label: {label}")

    return sentiment_score

async def assemble_SentimentScore(sentiment_result: dict | None, analysis_piece: dict) -> SentimentScore | None:
    if sentiment_result is None:
        return None
    sentiment_score = await sentiment_to_score(sentiment_result)
    if sentiment_score is not None:
        analysis_piece['sentiment_score'] = sentiment_score
        return SentimentScore(**analysis_piece)
    else:
        return None

if __name__ == '__main__':

    pipe = pipeline("text-classification", model="mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis")

    # Load db credentials
    with open("/content/drive/MyDrive/Machine_Learning_Digestor/config/ps_conn_details.json", "r") as f:
        ps_conn_details = json.load(f)

    impact_score_db = await AsyncPlanetscaleBatchFetcher.create(ps_conn_details)
    impact_score_db = fetcher
    async for article_data in impact_score_db.fetch_batches(table_name='articles', batch_size=500, max_rows=-1, skip_field="article_id", second_table="sentiment_scores"):
        print(article_data)
        analysis_parts_generator = (
        {
            "article_id": row['article_id'],
            "stock_symbol": row['stock_symbol'],
            "publish_date": row['publish_date']
        }
        for _, row in article_data.iterrows())
        summaries = article_data['summary']
        tasks = []
        sentiment_results = []
        for summary in list(summaries):
            try:
                output = pipe(summary)
                sentiment_results.append(output[0])
            except:
                sentiment_results.append(None)
                continue

        upload_tasks = []
        for sentiment_result, analysis_piece in tqdm(zip(sentiment_results, analysis_parts_generator), total=len(article_data)):
            task = asyncio.create_task(assemble_SentimentScore(sentiment_result=sentiment_result, analysis_piece=analysis_piece))
            tasks.append(task)
            sentiment_scores = [*filter(None, await asyncio.gather(*tasks))]

            upload_task = asyncio.create_task(impact_score_db.upload_sentiment_scores_to_db(sentiment_scores))
            upload_tasks.append(upload_task)
        await asyncio.gather(*upload_tasks)