In [1]:
import os
import ast
import sys
from sentence_transformers import CrossEncoder
import yaml
from datetime import datetime

from pymongo import MongoClient
import logging
logger = logging.getLogger(__name__)
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import re
from json import JSONDecodeError
import torch
from torch import nn
import gradio as gr
import pandas as pd
from tqdm import trange
import numpy as np
from dotenv import load_dotenv
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from scipy.cluster.hierarchy import linkage, fcluster

# Import libraries for working with language models and Google Gemini
from langchain_core.prompts import PromptTemplate
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_groq import ChatGroq
from langchain_core.prompts import PromptTemplate

# Read configuration
config_file = '../config.yaml'
with open(config_file, 'r') as fin:
    config = yaml.safe_load(fin)
# end with

load_dotenv()
groq_api_key = os.getenv('GROQ_API_KEY')
chat_model = "llama3-8b-8192"

GEMINI_KEY   = os.environ['GEMINI_KEY']
MONGO_URI    = os.environ['MONGO_URI']
HF_KEY       = os.environ['HUGGINGFACE_API_KEY']
EMBEDDER_API = os.environ["HF_EMBEDDING_MODEL_URL"]

genai.configure(api_key=GEMINI_KEY)

# Initialise Mongodb client
mongo_client = MongoClient(MONGO_URI)

# Setup default LLM model
default_llm = genai.GenerativeModel('gemini-1.5-flash-latest')

  from tqdm.autonotebook import tqdm, trange


In [2]:
def load_database():
    # Connect to the MongoDB client
    try:
        db = mongo_client[config["database"]["name"]]
        train_documents = db[config["database"]["train_collection"]].find()
        logger.info("Train data successfully fetched from MongoDB\n")
    except Exception as error: 
        logger.error(f"Unable to fetch train data from MongoDB. Check your connection the database...\nERROR: {error}\n")
    
    try:
        test_docs = db[config["database"]["test_collection"]].find()
        logger.info("Test data successfully fetched from MongoDB\n")
    except:
        logger.error(f"Unable to fetch test data from MongoDB. Check your connection the database...\nERROR: {error}\n")
    
    df_train = pd.DataFrame.from_dict(list(train_documents))
    df_test = pd.DataFrame.from_dict(list(test_docs))
    
    # Row bind the training and test dataframes 
    df = pd.concat([df_train, df_test], axis=0).reset_index(drop=True)
    
    return df
# end def

def get_text_embeddings(df):
    print("Fetching embeddings...\n")
    #Deserializing the embeddings
    body_embeddings = np.array(df['embeddings'].apply(ast.literal_eval).tolist())
    return body_embeddings
# end def

def get_predicted_cluster(dataframe, test_id):
    test_article = dataframe[dataframe['st_id'] == test_id].reset_index(drop=True)
    predicted_cluster = test_article['Cluster_label'][0]
    return predicted_cluster
# end def

In [3]:
# utils
def clean_output(output):
    try:
        updated_timeline = json.loads(output)
        return updated_timeline
    except JSONDecodeError:
        #try 1: Ensuring that the string ends with just the open and close lists brackets
        try:
            new_output = re.search(r'\[[^\]]*\]', output).group(0)
        except AttributeError:
            new_output = re.search(r'\{.*?\}', output, re.DOTALL).group(0)  
        updated_timeline = json.loads(new_output)
        return updated_timeline

def clean_sort_timeline(timelines, df_retrieve):  
    # Expands the list of timelines
    generated_timeline = []
    for _, line in timelines.items():
        # Sieve out the timeline
        indiv_timeline = clean_output(line)
        if type(indiv_timeline) == list:
            # append each individual event in this timeline from this article
            for el in indiv_timeline:
                generated_timeline.append(el)
        else:
            generated_timeline.append(indiv_timeline)
        # end if
    # end for
    
    # Code to derive the article id from the article number 
    unsorted_timeline = []
    for event in generated_timeline:
        article_index = event["Article"] - 1
        event["Article_id"] = df_retrieve.iloc[article_index].id
        del event["Article"]
        unsorted_timeline.append(event) 
    # end for 
    
    # section to remove the parts of the date that LLM were unclear of     
    timeline = sorted(unsorted_timeline, key=lambda x:x['Date'])
    timeline = [event for event in timeline if event['Date'].lower()!= 'nan']
    
    for event in timeline:
        date = event['Date']
        if date.endswith('-XX-XX'):
            event['Date'] = date[:4]
        elif date.endswith('-XX'):
            event['Date'] = date[:7]
        # end if
    # end for
    return timeline
# end def

def format_timeline_date(date_str, formats=['%Y', '%Y-%m-%d', '%Y-%m']):
        """Formats a date string into a human-readable format.
        
        Args:
            date_str (str): The date string to be formatted.
            formats (list): A list of date formats to try.
        
        Returns:
            str: The formatted date string, or the original string if no format matches.
        """
        for fmt in formats:
            try:
                date_obj = datetime.strptime(date_str, fmt)
                if fmt == '%Y':
                    return date_obj.strftime('%Y')
                elif fmt == '%Y-%m-%d':
                    return date_obj.strftime('%d %B %Y')
                elif fmt == '%Y-%m':
                    return date_obj.strftime('%B %Y')
                # end if
            except ValueError:
                continue  # If the format doesn't match, try the next one
        # end for
        
        # If no format matches, return the original date string
        return date_str
# end def


In [4]:
# Generate the cosine similarity between texts of 2 articles
def get_text_similarity(timeline_embedding,train_article):
    
    cos_sim = nn.CosineSimilarity(dim=0)
    
    text_embedding = torch.tensor(eval(train_article['embeddings']))
    timeline_embedding = torch.tensor(timeline_embedding)
    
    similarity_score = cos_sim(timeline_embedding, text_embedding)
    return similarity_score
# end def

# Generate the cosine similarity between texts of 2 articles
def get_text_similarity(test_embedding, db_embedding):
    
    cos_sim = nn.CosineSimilarity(dim=0)
    similarity_score = cos_sim(torch.tensor(test_embedding), torch.tensor(db_embedding))
    return similarity_score
# end def

# Find the top 20 most similar articles based on their text embedding
def articles_ranked_by_text(test_article, database):
    logger.info("Computing similarities between article texts...")
    test_embedding = test_article['embeddings'].apply(ast.literal_eval)[0]  

    article_collection = []
    for i in trange(len(database)):
        article = {}
        article['id'] = database.iloc[i]['st_id']
        article['Title'] = database.iloc[i]['Title']
        article['Text'] = database.iloc[i]['Text']
        article['Date'] = database.iloc[i]['Publication_date']
        article['Article_URL'] = database.iloc[i]['article_url']
        article_embedding =  pd.DataFrame(database.iloc[i]).loc['embeddings'].apply(ast.literal_eval)[i]
        article['cosine_score'] = get_text_similarity(test_embedding, article_embedding)
        article_collection.append(article)
    # end for
    
    # Sort by cosine similarity in descending order
    article_collection.sort(key = lambda x: x['cosine_score'], reverse=True)
    
    # Returns the top 20 most similar articles
    collection = article_collection[:21]
    
    # remove the first article as it is the test article
    collection.pop(0)
    return collection
# end def

# use cross encoder to filter out the top 7 articles
# Re rank articles based on similarity derived from cross encoder
def re_rank_articles(unique_articles, test_article, cross_encoder_model="cross-encoder/ms-marco-TinyBERT-L-2-v2", top_k=6):
 
    cross_encoder = CrossEncoder(
        cross_encoder_model, max_length=512, device="cpu"
    )
    test_article_text = test_article['Text'][0]
    # Format the timeline header and article for cross encoder
    unranked_articles = [(test_article_text, doc['Text']) for doc in unique_articles]
    
    # Predict similarity_scores
    similarity_scores = cross_encoder.predict(unranked_articles).tolist()
    # Assign the list of similarity scores to the inidividual articles 
    for i in range(len(unique_articles)):
        # if similarity_scores[i]:
        unique_articles[i]['reranked_score'] = similarity_scores[i]
        # end if
    # end for
    combined_articles = [article for article in unique_articles if 'reranked_score' in article]
    best_articles = sorted(combined_articles, key=lambda x: x['reranked_score'], reverse=True)[:top_k]

    return best_articles

# Generates the clustering properties of the test article
def get_article_dict(test_article, best_articles, df):
    # Get the df ids of the best articles 
    ids = [article['id'] for article in best_articles]
    
    similar_indexes = []
    for idx, row in df.iterrows():
        if row['st_id'] in ids:
            similar_indexes.append(idx)
        # end if
        if len(similar_indexes)==len(ids):
            break
        # end if
    # end for

    similar_articles_dict = {
                    'Title': test_article['Title'][0],
                    'indexes': similar_indexes,
                    'Text': test_article['Text'][0],
                }
    if len(similar_articles_dict) < 2:
        print("There are insufficient relevant articles to construct a meaningful timeline. ... Exiting execution now\n")
        return "generate_similar_error"
    # end if
    
    return similar_articles_dict
# end def

# version 1 timeline generation
def det_generate_timeline(input_data,score_threshold=3,llm=default_llm,
    safety_settings = {
        HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
        HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
        HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
        HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE
    }):
    
    """Evaluating necessity of Timeline for this article."""

    def clean_llm_output(llm_output):
        text = llm_output.parts[0].text.replace("```", '').replace('json','')
        result = json.loads(text)
        return result
    
    # Initialise Pydantic object to force LLM return format
    class Event(BaseModel):
        score: int = Field(description="The need for this article to have a timeline")
        Reason: str = Field(description = "The main reason for your choice why a timeline is needed or why it is not needed")
    
    # Initialise Json output parser
    output_parser = JsonOutputParser(pydantic_object=Event)

    # Define the template
    template = \
    '''
    You are a highly intelligent AI tasked with analyzing articles to determine whether generating a timeline of events leading up to the key event in the article would be beneficial. 
    Consider the following factors to make your decision:
    1. **Significance of the Event**:
       - Does the event have a significant impact on a large number of people, industries, or countries?
       - Are the potential long-term consequences of the event important?

    2. **Controversy or Debate**:
       - Is the event highly controversial or has it sparked significant debate?
       - Has the event garnered significant media attention and public interest?

    3. **Complexity**:
       - Does the event involve multiple factors, stakeholders, or causes that make it complex?
       - Does the event have deep historical roots or is it the culmination of long-term developments?

    4. **Personal Relevance**:
       - Does the event directly affect the reader or their community?
       - Is the event of particular interest to the reader due to economic implications, political affiliations, or social issues?

    5. Educational Purposes:
       - Would a timeline provide valuable learning or research information?

    Here is the information for the article:
    Title:{title}
    Text: {text}

    Based on the factors above, decide whether generating a timeline of events leading up to the key event in this article would be beneficial. 
    Your answer will include the need for this article to have a timeline with a score 1 - 5, 1 means unnecessary, 5 means necessary. It will also include the main reason for your choice.
    {format_instructions}    
    ANSWER:
    '''
    
    # See the prompt template you created for formatting
    format_instructions = output_parser.get_format_instructions()

    # Create the prompt template
    prompt = PromptTemplate(
        input_variables   = ["text", "title"],
        partial_variables = {"format_instructions": format_instructions},
        template=template,
    )

    # Define the headline
    headline = input_data["Title"]
    body     = input_data["Text"]

    # Format the prompt
    final_prompt = prompt.format(title=headline, text=body)

    # Generate content using the generative model
    response = llm.generate_content(
        final_prompt,
        safety_settings=safety_settings,
    )
    final_response = clean_llm_output(response)

    score = final_response['score']
    
     # If LLM approves
    if score >= score_threshold:
        logger.info("Timeline is appropriate for this chosen article.\n")
        return {"det": True, "score": final_response['score'], "reason": None}
    # end if
    else:
        logger.info("A timeline for this article is not required. \n")
        for part in final_response['Reason'].replace(". ", ".").split(". "):
            logger.info(f"{part}\n")
        # end for
        
        logger.info("Hence I gave this a required timeline score of " + str(score))
        reason = "A timeline for this article is not required. \n" \
                    + "\n" +final_response['Reason'] + "\n"+ "\nHence this timeline received a necessity score of " \
                    + str(final_response['score'])   + "\n"
    # end else
        return {"det": False, "score": score, "reason": reason}
# end def

def generate_and_sort_timeline(similar_articles_dict, df_train, df_test, llm=default_llm):    
    
    class Event(BaseModel):
        Date: str = Field(description="The date of the event in YYYY-MM-DD format")
        Event: str = Field(description="A detailed description of the important event")
        Article: int = Field(description="The article number from which the event was extracted")

    output_parser = JsonOutputParser(pydantic_object=Event)

    # See the prompt template you created for formatting
    format_instructions = output_parser.get_format_instructions()

    template = '''
    Given an article, containing a publication date, title, and content, your task is to construct a detailed timeline of events leading up to the main event described in the article.
    Begin by thoroughly analyzing the title, content, and publication date of the article to understand the main event in the article. 
    the dates are represented in YYYY-MM-DD format. Identify events, context, and any time references such as "last week," "last month," or specific dates. 
    The article could contain more or one key events. 
    If the article does not provide a publication date or any events leading up to the main event, return NAN in the Date field, and 0 i the Article Field

    Construct the Timeline:
    Chronological Order: Organize the events chronologically, using the publication dates and time references within the articles.
    Detailed Descriptions: Provide detailed descriptions of each event, explaining how it relates to the main event of the first article.
    Contextual Links: Use information from the articles to link events together logically and coherently.
    Handle Ambiguities: If an article uses ambiguous time references, infer the date based on the publication date of the article and provide a clear rationale for your inference.

    Contextual Links:
    External Influences: Mention any external influences (e.g., global conflicts, economic trends, scientific discoveries) that might have indirectly affected the events.
    Internal Issues: Highlight any internal issues or developments (e.g., political changes, organizational restructuring, societal movements) within the entities involved that might have impacted the events.
    Efforts for Improvement: Note any indications of efforts to improve the situation (e.g., policy changes, strategic initiatives, collaborative projects) despite existing challenges.

    Be as thorough and precise as possible, ensuring the timeline accurately reflects the sequence and context of events leading to the main event.

    Article:
    {text}

    {format_instructions}
    Check and ensure again that the output follows the format instructions above very strictly. 
    '''

    prompt = PromptTemplate(
        input_variables=["text"],
        partial_variables={"format_instructions": format_instructions},
        template=template
    )
    
    def generate_individual_timeline(date_text_triples):
        article_details=  f'Article {date_text_triples[0]}: Publication date: {date_text_triples[1]} Article Text: {date_text_triples[2]}'
        final_prompt = prompt.format(text=article_details)
        response = llm.generate_content(final_prompt,
                                        safety_settings={
                                            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                                            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                                            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, 
                                            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE
                                            })
        
        # Check if Model returns correct format 
        if '[' in response.parts[0].text or '{' in response.parts[0].text:
            result = response.parts[0].text
        else:
            retry_response = llm.generate_content(final_prompt,
                                        safety_settings={
                                            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                                            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                                            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, 
                                            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE
                                            })
            try:
                result = retry_response.parts[0].text
            except ValueError:
                print("ERROR: There were issues with the generation of the timeline. The timeline could not be generated")
                return
        # end if
        return result
    # end def
    
    def process_articles(df_train, similar_articles_dict, df_test):
        # Retrieve and concatenate dataframes
        df_retrieve = pd.concat([df_train.loc[similar_articles_dict['indexes']], df_test], axis=0).iloc[::-1].reset_index(drop=True)

        # Prepare metadata for LLM
        date_text_triples = list(zip(range(1, len(df_retrieve) + 1), df_retrieve['combined'].tolist(), df_retrieve['Publication_date'].tolist()))

        # Generate timelines using ThreadPoolExecutor
        dict_of_timelines = {}
        try:
            with ThreadPoolExecutor(max_workers=len(date_text_triples)) as executor:
                futures = {executor.submit(generate_individual_timeline, triple): i for i, triple in enumerate(date_text_triples)}
                for future in as_completed(futures):
                    dict_of_timelines[futures[future]] = future.result()
        except IndexError:
            with ThreadPoolExecutor(max_workers=len(date_text_triples)) as executor:
                futures = {executor.submit(generate_individual_timeline, triple): i for i, triple in enumerate(date_text_triples)}
                for future in as_completed(futures):
                    dict_of_timelines[futures[future]] = future.result()

        return dict_of_timelines, df_retrieve
    # end def
    
    timeline_dic, df_retrieve = process_articles(df_train, similar_articles_dict, df_test)
    print("The first timeline has been generated\n")
    generated_timeline = []
    for _, line in timeline_dic.items():
        indiv_timeline = clean_output(line)
        if type(indiv_timeline) == list:
            for el in indiv_timeline:
                generated_timeline.append(el)
        else:
            generated_timeline.append(indiv_timeline)
    
    unsorted_timeline = []
    for event in generated_timeline:
        article_index = event["Article"] - 1
        event["Article_id"] = df_retrieve.iloc[article_index].st_id
    for event in generated_timeline:
        del event["Article"]
        unsorted_timeline.append(event)  
        
    timeline = sorted(unsorted_timeline, key=lambda x:x['Date'])
    finished_timeline = [event for event in timeline if event['Date'].lower()!= 'nan']
    for i in range(len(finished_timeline)):
        date = finished_timeline[i]['Date']
        if date.endswith('-XX-XX') or date.endswith('00-00'):
            finished_timeline[i]['Date'] = date[:4]
        elif date.endswith('-XX') or date.endswith('00'):
            finished_timeline[i]['Date'] = date[:7]
    return finished_timeline, df_retrieve
    # end def
# end def

def reduce_by_date(timeline_list, llm=default_llm):
    '''
    Takes in a list of events in one day, and returns a list of dicts for events in that day
    '''
    def extract_content_from_json(string):
        # Use a regular expression to find the content within the first and last square brackets
        match = re.search(r'\[.*\]', string, re.DOTALL)
        
        if match:
            json_content = match.group(0)
            try:
                # Load the extracted content into a JSON object
                json_data = json.loads(json_content)
                return json_data
            except json.JSONDecodeError as e:
                print("Failed to decode JSON:", e)
                return None
        else:
            print("No valid JSON content found.")
            return None
        # end if
    # end def
    
    timeline_string = json.dumps(timeline_list)
    
    class Event(BaseModel):
            Event: str = Field(description="A detailed description of the event")
            Article_id: list = Field(description="The article id(s) from which the events were extracted")

    parser = JsonOutputParser(pydantic_object=Event)

    template = '''You are a news article editor tasked with simplifying a section of a timeline of events on the same day. 
Given this snippet a timeline, filter out duplicated events. 
IF events convey the same overall meaning, I want you to merge these events into one event to avoid redundancy, and add the article ids to a list. 
However, the events are all different, do not combine them, I want you to return it as it is, however, follow the specified format instructions below. 
Furthermore, if the event is not considered to be an event worthy of an audience reading the timeline, do not include it.
Take your time and evaluate the timeline slowly to make your decision.

Timeline snippet:
{text}

{format_instructions}
Ensure that the format follows the example output format strictly before returning the output.'''
    
    prompt = PromptTemplate(
            input_variables=["text"],
            template=template,
            partial_variables={"format_instructions": parser.get_format_instructions()}
        )
    
    final_prompt = prompt.format(text=timeline_string)
    response = llm.generate_content(final_prompt,
            safety_settings={
                HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, 
                HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE
            }
        )
    
    data = extract_content_from_json(response.parts[0].text)
    
    # limit the number of article links to 2 links
    link_limit =  2
    
    for event_dic in data:
        article_id_list = event_dic['Article_id']
        if len(article_id_list)> link_limit:
            shortened_ids = article_id_list
            event_dic['Article_id'] = shortened_ids
        # end if
    # end for
    return data
# end def

# Creating a set of unique days. This will reduce the number of tokens to the model, and makes it easier to handle in the output
def first_timeline_enhancement(timeline):
    '''
    This function takes in a timeline in list format 
    '''
    
    # Get the unique dates seen in the timeline
    unique_dates = sorted(list(set([event['Date'] for event in timeline])))
    
    # Combine the events that have the same date
    dic = {}
    for i in range(len(unique_dates)):
        dic[unique_dates[i]] = [{'Event':event['Event'], 'Article_id': event['Article_id']} for event in timeline if event['Date'] == unique_dates[i]]
    # end for
    
    # Combine the events that have the same date and (only if they have the same event)
    new_timeline = {}
    for date, snippet in dic.items():
        if len(snippet) == 1:
            new_timeline[date] = snippet
        else:
            new_snippet = reduce_by_date(snippet)
            new_timeline[date] =new_snippet
        # end if
    # end for
    
    enhanced_timeline = []
    for date, events in new_timeline.items():
            for event in events:
                new_event = {}
                new_event['Date'] = date
                new_event['Event'] = event['Event']
                article_id = event['Article_id']
                if isinstance(article_id, str):
                    new_event['Article_id'] = [article_id]
                else:
                    new_event['Article_id'] = event['Article_id']
                # end if
                enhanced_timeline.append(new_event)
            # end for
    # end for
    return enhanced_timeline
# end def

def pair_article_urls(enhanced_timeline, df_retrieve):
    
    def edit_timeline(timeline):
        for event in timeline:
            new_date = format_timeline_date(event['Date'])
            event['Date'] = new_date
        # end for
        return timeline
    # end def

    edited_timeline = edit_timeline(enhanced_timeline)

    # Get out the article id and URL into suitable data structure for being displayed on the webpage
    for event in edited_timeline:
        id_list = event['Article_id']
        url_title_pairs = []
        
        for i in range(len(id_list)):
            id = id_list[i]
            url = df_retrieve[df_retrieve['st_id'] == id]['article_url'].values[0]
            title = df_retrieve[df_retrieve['st_id'] == id]['Title'].values[0]
            url_title_pairs.append({'url': url, 'title': title})
        # end for
        
        event['Article_URL'] = url_title_pairs
        event.pop('Article_id')  
    # end for
    return edited_timeline
# end def

# Function to get out events that need to be summarised
def get_needed_summaries(timeline):
    def get_num_words(event_str):
        ls = event_str.split()
        return len(ls)
    # end def

    need_summary_timeline = []
    for i in range(len(timeline)):
        # If the number of words in the event is more than 20
        if get_num_words(timeline[i]['Event']) > 20:
            need_summary_timeline.append((i,timeline[i]))
        # end if
    # end for
    
    # Get out events
    events = {}
    for i in range(len(need_summary_timeline)):
        events[need_summary_timeline[i][0]] = need_summary_timeline[i][1]
    # end for
    return events
# end def

# Function to summarise the events 
def groq_summariser(events_ls):
    
    # Define summarised event output
    class summarized_event(BaseModel):
        Event: str = Field(description="Event in a timeline")
        Event_Summary: str = Field(description="Short Summary of event")
    
    parser = JsonOutputParser(pydantic_object=summarized_event)

    chat = ChatGroq(temperature=0, model_name=chat_model)
    
    template = '''
You are a news article editor.
Given a list of events from a timeline, you are tasked to provide a short summary of these series of events. 
For each event, you should return the event, and the summary.

Series of events:
{text}

{format_instructions}
    '''
    
    prompt = PromptTemplate(
        template=template,
        input_variables=["text"],
        partial_variables={"format_instructions": parser.get_format_instructions()},
    )
    
    chain = prompt | chat | parser
    
    event_str = json.dumps(events_ls)
    result = chain.invoke({"text": event_str})
    if isinstance(result, list):
        return result
    else:
        cleaned_result = clean_output(result)
    # end if
    return cleaned_result
# end def

# function to combine the events that were just summarised
def merge_event_summaries(events_ls, llm_answer, timeline, need_summary_timeline):
    
    if len(llm_answer) != len(need_summary_timeline):
        print("Groq had an error where timeline summary output length not equal to input but trying to resolve")
        llm_answer = groq_summariser(events_ls)
        print(len(llm_answer) == len(events_ls))
    # end if
    
    i = 0
    for k,v in need_summary_timeline.items():
        need_summary_timeline[k]['Event_Summary'] = llm_answer[i]['Event_Summary']
        i += 1
    # end for
        
    for i in range(len(timeline)):
            if i in need_summary_timeline:
                timeline[i] = need_summary_timeline[i]
            # end if
    # end for
    return timeline
# end def

# Combining the functions for the second enhancement (event summary)
def second_timeline_enhancement(timeline):
    need_summary_timeline = get_needed_summaries(timeline)
    events_ls = [event['Event'] for _, event in need_summary_timeline.items()]
    summaries = groq_summariser(events_ls)
    final_timeline = merge_event_summaries(events_ls, summaries, timeline, need_summary_timeline)
    return final_timeline
# end def

# Function to generate and save the timeline
def generate_save_timeline(relevant_articles, df_train, df_test):
    similar_articles = get_article_dict(relevant_articles, df_train, df_test)
    if similar_articles == "generate_similar_error":
        return "Error02"
    generated_timeline, df_retrieve = generate_and_sort_timeline(similar_articles, df_train, df_test)
    print("Proceeding to Stage 1/2 of enhancement...\n")
    first_enhanced_timeline = first_timeline_enhancement(generated_timeline)
    second_enhanced_timeline = pair_article_urls(first_enhanced_timeline, df_retrieve)
    print("Proceeding to Stage 2/2 of enhancement...\n")
    final_timeline = second_timeline_enhancement(second_enhanced_timeline)
    print("Timeline enhanced.. \n")
    return final_timeline
# end def


In [5]:
test_id = "st_1155048"
df= load_database()
test_article = df[df['st_id']==test_id].reset_index(drop=True)
df_train = filtered_df = df[df["st_id"] != test_id]

timeline_necessary = det_generate_timeline(test_article)
if timeline_necessary["det"]:
    embeddings = get_text_embeddings(df)
    max_d = 0.58

    # Pre computed hierarchical clustering
    Z = linkage(embeddings, method='average', metric='cosine')
    cluster_labels = fcluster(Z, max_d, criterion='distance')
    df['Cluster_label'] = cluster_labels
    predicted_cluster = get_predicted_cluster(df, test_id)
    cluster_df = df[df['Cluster_label'] == predicted_cluster].reset_index(drop=True)
    similar_articles_by_text_embedding = articles_ranked_by_text(test_article, cluster_df)
    # Here only top 6 selected
    best_articles = re_rank_articles(similar_articles_by_text_embedding, test_article)
    similar_articles = get_article_dict(test_article, best_articles, df)

    generated_timeline, df_retrieve = generate_and_sort_timeline(similar_articles, df_train, test_article)
    print("Proceeding to Stage 1/2 of enhancement...\n")
    first_enhanced_timeline = first_timeline_enhancement(generated_timeline)
    second_enhanced_timeline = pair_article_urls(first_enhanced_timeline, df_retrieve)
    print("Proceeding to Stage 2/2 of enhancement...\n")
    final_timeline = second_timeline_enhancement(second_enhanced_timeline)
    print("Timeline enhanced.. \n")


Fetching embeddings...



100%|██████████| 180/180 [00:00<00:00, 407.43it/s]


The first timeline has been generated

Proceeding to Stage 1/2 of enhancement...

Proceeding to Stage 2/2 of enhancement...



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Timeline enhanced.. 

