<a href="https://colab.research.google.com/github/klutzydrummer/Python_Projects/blob/main/article_aquisition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import json
with open("/content/drive/MyDrive/Machine_Learning_Digestor/config/connection_string.json", "r") as f:
    connection_string = json.load(f)

search_date = "07/10/2023"
search_period = {"days": 0, "weeks": 0, "months": 0, "years": 10}
stock_symbol = "MSFT"
index_name = "test-index"

embeddings_model_name = "hkunlp/instructor-large"

model_choice = "stablebeluga-7b.ggmlv3.q4_1.bin"
model_context_window = 2048

embedding_context_window = 512


enable_debugprint = False
from pathlib import Path as path

path_drive_base = path("/content/drive/MyDrive/Machine_Learning_Digestor/")
path_local_base = path("/content/Machine_Learning_Digestor/")
drive_model_path = path_drive_base.joinpath("models").joinpath(model_choice)
local_cache_file = path("/content/Machine_Learning_Digestor/lanchain-cache.db")
local_cache_file.parent.mkdir(exist_ok=True, parents=True)

import os
import shutil
from google.colab import drive
drive.mount('/content/drive/', force_remount=False)

def convert_to_drive_path(local_file_path: path) -> path:
    if str(path_local_base) not in str(local_file_path):
        raise ValueError(f"Source path not in Local path:\nSource: {local_file_path}\nLocal path: {path_local_base}")
    result = str(local_file_path).replace(str(path_local_base), str(path_drive_base))
    return path(result)

def convert_to_local_path(drive_file_path: path) -> path:
    if str(path_drive_base) not in str(drive_file_path):
        raise ValueError(f"Source path not in GDrive path:\nSource: {drive_file_path}\nGDrive path: {path_drive_base}")
    result = str(drive_file_path).replace(str(path_drive_base), str(path_local_base))
    return path(result)

async def download_from_drive(drive_file_path: path) -> path:
    local_file_path = convert_to_local_path(drive_file_path)
    local_file_path.parent.mkdir(exist_ok=True, parents=True)
    shutil.copy(drive_file_path, local_file_path)
    return local_file_path

async def upload_to_drive(local_file_path: path, overwrite=False) -> path:
    drive_file_path = convert_to_drive_path(local_file_path)
    drive_file_path.parent.mkdir(exist_ok=True, parents=True)
    if overwrite is True:
        shutil.copy(local_file_path, drive_file_path)
    else:
        if drive_file_path.exists() is True:
            raise ValueError("File already exists at destination and overwrite is not set to True.")
        else:
            shutil.copy(local_file_path, drive_file_path)
    return drive_file_path

async def path_future(file_path: path) -> path:
    return file_path

local_model_path = convert_to_local_path(drive_model_path)

pyproject_path = path("/content/pyproject.toml")
pyrequirements_path = path("/content/requirements.txt")

if local_cache_file.exists() is not True:
    drive_cache_file = convert_to_drive_path(local_cache_file)
    if drive_cache_file.exists() is True:
        cache_file = download_from_drive(drive_cache_file)
    else:
        cache_file = path_future(local_cache_file)
else:
    cache_file = path_future(local_cache_file)

if local_model_path.exists() is not True:
    local_model = download_from_drive(drive_model_path)

if pyrequirements_path.exists() is not True:
    with open(pyrequirements_path, "w") as project_file:
        project_file.write('''numpy
    asyncpg
    nltk

    pyyaml >= 6.0.1
    tqdm >= 4.66.0
    googlenews >= 1.6.8
    newspaper3k >= 0.2.8
    langchain >= 0.0.260
    sentence-transformers >= 2.2.2
    InstructorEmbedding >= 1.0.1
    asyncio >= 3.4.3
    lmql >= 0.0.6.6
    llama-cpp-python >= 0.1.77
    pendulum >= 2.1.2
    pandas >= 2.0.3
    torch >= 2.0.0
    psycopg2-binary >= 2.9.7''')

    !cd /content; pip install -r "/content/requirements.txt"
    # %pip install torch==2.0.0
    # %pip uninstall -y numpy
    # %pip install numpy

In [None]:
from GoogleNews import GoogleNews
from newspaper import Article
from pydantic import BaseModel, Field, validator
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
from typing import List, Optional
import aiohttp
import asyncio
import dataclasses
import datetime
import itertools
import json
import pandas as pd
import pendulum
from concurrent.futures import ThreadPoolExecutor
import functools
import hashlib
import asyncpg

In [None]:
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import PGEmbedding
from langchain.document_loaders import TextLoader
from langchain.cache import SQLiteCache
from langchain import PromptTemplate
from langchain.chains import LLMChain

from langchain.vectorstores.base import VectorStore

from langchain.chains.summarize import load_summarize_chain
from langchain.llms import LlamaCpp
import langchain
try:
    await cache_file
except:
    pass

langchain.llm_cache = SQLiteCache(database_path=str(local_cache_file))

In [None]:
from nltk.corpus import stopwords
from nltk.stem.wordnet import WordNetLemmatizer
from nltk.tokenize import word_tokenize
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')

def preprocess(text):
    """Preprocess the text."""
    # Tokenization
    tokens = word_tokenize(text)
    print(tokens)

    # Lowercasing
    tokens = [token.lower() for token in tokens]
    print(tokens)

    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    tokens = [token for token in tokens if token not in stop_words]
    print(tokens)

    # Lemmatization
    lemmatizer = WordNetLemmatizer()
    tokens = [lemmatizer.lemmatize(token) for token in tokens]

    return " ".join(tokens)

In [None]:
import uuid
def str_to_uuid(in_string: str) -> uuid.UUID:
    hash_machine = hashlib.md5()
    hash_machine.update(bytes(in_string, "ascii"))
    hash_bytes = int(hash_machine.hexdigest(), 16)
    return uuid.UUID(int=hash_bytes)

def custom_uuid(stock_symbol, title):
    return str(str_to_uuid(f"{stock_symbol} {title}"))

In [None]:
async def load_embeddings():
    global embeddings_instance
    try:
        return embeddings_instance
    except Exception as err:
        embeddings_instance = HuggingFaceInstructEmbeddings(
            model_name=embeddings_model_name,
            embed_instruction="Represent this financial article for clustering: ")
        return embeddings_instance

embeddings = load_embeddings
_ = embeddings()

In [None]:
async def load_llm():
    global llm_instance
    try:
        return llm_instance
    except Exception as err:
        try:
            model_path = str(await local_model)
        except Exception as err:
            model_path = str(local_model_path)
        llm_instance = LlamaCpp(
            model_path=model_path,
            n_ctx=model_context_window,
            # input={"temperature": 0.75, "max_length": 2000, "top_p": 1},
            # callback_manager=callback_manager,
            verbose=True
        )
        return llm_instance

summary_engine = load_llm
_ = summary_engine()

In [None]:
def custom_time_to_str(datetime: datetime.datetime):
    return f"{int(datetime.month):02}/{int(datetime.day):02}/{int(datetime.year)}"

In [None]:
if type(search_date) is str:
    if search_date != "":
        search_date = pendulum.from_format(search_date, 'MM/DD/YYYY')
    else:
        raise ValueError("Search date required.")

if type(search_date) is pendulum.DateTime:
    end_date = custom_time_to_str(search_date)
    start_date = custom_time_to_str(search_date.subtract(**search_period))
    print(f"Searching for articles released between {start_date} and {end_date}")
    googlenews = GoogleNews(lang="en", encode="utf-8", start=start_date, end=end_date)
elif type(search_date) is str and search_date == "":
    print("Searching for articles released anytime.")
    googlenews = GoogleNews(lang="en", encode="utf-8")
else:
    raise ValueError(f"Type of search_date not valid, type is: {type(search_date)}\nvalue is: {search_date}")
googlenews.clear()

In [None]:
googlenews.enableException(True)

search_string = f"stock news {stock_symbol}"
print(f"Creating search for: {search_string}")
googlenews = GoogleNews(lang="en")
googlenews.get_news(search_string)
news_list = googlenews.get_links()
print(news_list)


In [None]:
@dataclasses.dataclass(init=True, repr=True, eq=True)
class InputArticleRaw:
    embeddable: str
    title: str
    text: str
    authors: list[str]
    publish_date: datetime.datetime
    search_date: datetime.datetime
    tags: set[str]
    keywords: set[str]
    stock_symbol: str


    def __ini__(self, title: str, text: str, authors: list[str], publish_date: pendulum.DateTime, search_date: pendulum.DateTime, tags: set[str], keywords: set[str], stock_symbol: str, collection_id=None, embeddable=""):
        self.embeddable = ""
        self.title = title
        self.text = text
        self.authors = authors
        self.publish_date = publish_date
        self.search_date = search_date
        self.tags = tags
        self.keywords = keywords
        self.stock_symbol = stock_symbol


    @staticmethod
    def import_newspaper(newspaper_article: Article, stock_symbol: str):
        if type(newspaper_article.publish_date) is not datetime.datetime:
            raise ValueError(f"Publish date is not of type datetime:\n type:\n {type(newspaper_article.publish_date)}\n value:\n {newspaper_article.publish_date}")
        result = InputArticleRaw(
            embeddable="",
            title=newspaper_article.title,
            text=newspaper_article.text,
            authors=newspaper_article.authors,
            publish_date=datetime.datetime.fromtimestamp(newspaper_article.publish_date.timestamp()),
            search_date=datetime.datetime.fromtimestamp(search_date.timestamp()),
            tags=newspaper_article.tags,
            keywords=set(*newspaper_article.keywords),
            stock_symbol=stock_symbol
        )
        return result

    def to_json(self):
        return json.dumps(self.__dict__, default=str)

    def article_dict(self):
        return json.loads(json.dumps(self.__dict__, default=str))

    def prompt_builder(self, prompt: str = "", text=""):
        if text == "":
            text = self.text
        items = self.__dict__.items()
        out_list = [prompt] if prompt != "" else []
        for (item_1, item_2) in items:
            if item_1 == "text":
                item_2 = text
            out_list.append(f"---\n{str(item_1)}\n---\n{str(item_2)}\n")
        return "\n".join(out_list)

    def update_text(self, text: str):
      self.text = text
      return self

    def prepare_embeddable(self):
        self.embeddable = preprocess(self.text)

    @property
    def metadata(self):
        metadata = dataclasses.asdict(self)
        del metadata['keywords']
        del metadata['tags']
        del metadata['embeddable']
        metadata['publish_date'] = self.publish_date.isoformat()
        metadata['search_date'] = self.search_date.isoformat()
        return metadata

################################################################################################################################################################################################

class AsyncPostgresDB:
    def __init__(self, database_url):
        self.database_url = database_url
        self.conn = None

    async def _ensure_connection(self):
        """
        Ensure that the database connection is established.
        """
        if self.conn is None or self.conn.is_closed():
            self.conn = await asyncpg.connect(self.database_url)

    async def check_key_exists(self, table_name, column_name, key_value):
        """
        Asynchronously check if a key exists in a specific table and column.

        Args:
        - table_name (str): The name of the table.
        - column_name (str): The name of the column.
        - key_value (str/int): The value to check for.

        Returns:
        - bool: True if the key exists, False otherwise.
        """
        await self._ensure_connection()
        query = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE {column_name} = $1);"
        result = await self.conn.fetchval(query, key_value)
        return result

    async def close(self):
        """
        Close the database connection if it's open.
        """
        if self.conn:
            await self.conn.close()
            self.conn = None


################################################################################################################################################################################################

def debugprint(*args, **kwargs):
    if enable_debugprint:
        print(*args, **kwargs)

def chunk_list(input_list, n):
    avg = len(input_list) / float(n)
    chunks = []
    last = 0.0

    while last < len(input_list):
        chunks.append(input_list[int(last):int(last + avg)])
        last += avg

    return chunks

async def acheck_articles(session, url: str, i: int, total: int, postgres: AsyncPostgresDB) -> InputArticleRaw | None:
    debugprint(f"Processing {i+1}/{total}")
    try:
        async with session.get(url, timeout=10) as response:
            debugprint("Fetching article...")
            news_resolve = str(response.url)
            debugprint("Article found.")
            article = Article(news_resolve)
            article.download()
            article.parse()
            debugprint("Article parsed.")
            custom_uuid_to_check = custom_uuid(stock_symbol, article.title)
            key_exist = await postgres.check_key_exists(table_name="langchain_pg_embedding", column_name="custom_id", key_value=custom_uuid_to_check)
            if key_exist is True:
                raise ValueError("Article already exists in embeddings database.")
            if type(article.publish_date) is not datetime.datetime:
                raise ValueError("Publish_date not of type datetime.")
            if "Are you a robot?" not in article.title and "403 Client Error" not in article.text:
                if article.text == "":
                    raise ValueError("article text is empty string.")
            debugprint(f"Appending article: {article.title}")
            return InputArticleRaw.import_newspaper(article, stock_symbol=stock_symbol)
    except Exception as err:
        debugprint(f"Encountered error: {err}\nargs: {session, url, i, total}")
        return None


async def aget_valid_articles(list_gnews_articles: List[str], postgres: AsyncPostgresDB, max_articles: int = 0) -> List[InputArticleRaw]:
    max_articles = len(list_gnews_articles) if max_articles == 0 else max_articles
    tasks = []
    article_list = []

    async with aiohttp.ClientSession() as session:
        for i, news in enumerate(list_gnews_articles):
            if len(article_list) >= max_articles:
                debugprint("Max list length reached.")
                break
            news_url = f"https://{news}"
            debugprint(f"news_url: {news_url}")
            task = acheck_articles(session, news_url, i, len(list_gnews_articles), postgres=postgres)
            tasks.append(task)
        articles = await asyncio.gather(*tasks)

    for article in articles:
        if article:
            article_list.append(article)
    return article_list

In [None]:
chunked_news_list = chunk_list(news_list, 10)
tasks = []
enable_debugprint = True


db = AsyncPostgresDB(connection_string)
chunked_valid_news_lists = await tqdm_asyncio.gather(*(aget_valid_articles(list_gnews_articles=newslist, postgres=db) for newslist in chunked_news_list))
await db.close()
article_list = flattened_list = list(itertools.chain.from_iterable(chunked_valid_news_lists))

In [None]:
def generate_text_splitter(chunk_size: int=512, chunk_overlap: int=0):
    return RecursiveCharacterTextSplitter(
        # Set a really small chunk size, just to show.
        chunk_size = chunk_size,
        chunk_overlap  = chunk_overlap,
        length_function = len,
    )

chunk_size = model_context_window
text_splitter = generate_text_splitter(chunk_size=chunk_size)

my_llm = await summary_engine()

summary_chain = load_summarize_chain(llm=my_llm, chain_type="map_reduce")

text_limit_prompt = PromptTemplate.from_template(f"Take the below text and summarize it such that it still accurately conveys the same information, while falling below {chunk_size} tokens:\n{{article}}")
text_limit_chain = LLMChain(
    prompt=text_limit_prompt,
    llm=my_llm
)

def summarize_article(text: str, summary_chain=summary_chain, text_splitter=text_splitter, chunk_size=chunk_size):
    debugprint(f"Summarizing article.")
    if len(text) <= chunk_size:
        doclist = text_splitter.create_documents([text,])
        result = summary_chain(doclist, return_only_outputs=True)
        result = result.get("output_text", result)
        if len(result) > chunk_size:
            result = text_limit_chain(inputs={"article":result}, return_only_outputs=True)
            result = result.get("output_text", result)
    else:
        result = text
    if type(result) is not str:
        raise ValueError("InputArticleRaw.text cannot be assigned non str value.")
    debugprint(f"Article summarization complete.")
    return str(result)

In [None]:
pg_embedding_text_splitter = generate_text_splitter(chunk_size=embedding_context_window)

sumarry_generator = [article.update_text(summarize_article(article.text)) for article in tqdm(article_list)]
for article in tqdm(article_list):
    article.prepare_embeddable()
summary_generator_w_progress = tqdm(sumarry_generator, total=len(article_list))
summary_text_generator = [pg_embedding_text_splitter.create_documents(texts=[article.embeddable,], metadatas=[article.metadata,]) for article in sumarry_generator]
flattened_summary_text = [document for doc_list in summary_text_generator for document in doc_list]

In [None]:
print("Ensuring embeddings are loaded.")
my_embeddings = await embeddings()
print("Embeddings are loaded.")
test_embedding = my_embeddings.embed_query("test")
embedding_dimensions = len(test_embedding)
print(f"Embedding model produces vector of size {embedding_dimensions}")

db = PGEmbedding(
    embedding_function=my_embeddings,
    collection_name=stock_symbol,
    connection_string=connection_string
)

for each in tqdm(flattened_summary_text, total=len(flattened_summary_text)):
    text = each.page_content
    metadata = each.metadata
    title = metadata.get("title", ValueError("Title missing from metadata."))
    id = custom_uuid(stock_symbol, title)
    db_connect = db.connect()
    db.add_texts(
        texts=[text,],
        metadatas=[metadata,],
        ids=[id,]
    )
    db_connect.commit()
print("All documents in run are embedded.")