# Algorithm

Given: Document collection $D$

Input: An event $e$

1. Find all documents that talk about $e$, we'll denote them as $D_e$
    1. Can be done with classical IR or neural methods
1. For all $d\in D_e$ do
    1. Extract a sequence of major subevents of $e$ as reported in $d$. We'll call them $e_{d,i}, i\in \{0,1,\dots,n\}$
1. Compare all pairs of subevents for equivalence. Incomplete notation: $(e_{d_k,i}, e_{d_l,j}), k\neq l$
1. Extract connected components of matched pairs: These are all referring to the same subevent!
1. Subevent canonicalization (well not really, but ideally!): find the overarching name for each found component!
    1. E.g. many subevents go like "In 2003, the US launched their attack on Iraq following [...]". These need to be all grouped under an umbrella term, e.g. "US Invasion of Iraq in 2003"

In [None]:
import os

from huggingface_hub import login
from transformers import AutoModelForCausalLM, GenerationConfig
import torch
import numpy as np
import re
import random
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer
pd.set_option('display.max_colwidth', 400)

login(token=os.environ["HUGGINGFACE_TOKEN"])

event_names = {"IraqWar":"Iraq War", "CrimeaCrisis":"Crimea Crisis", "CapitolRiot":"Capitol Riot"}


# Model config
***

In [None]:
model_name = "meta-llama/Llama-2-70b-chat-hf"
torch_dtype = torch.bfloat16
quantization = True

In [None]:
from transformers import BitsAndBytesConfig
import torch

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
def prepare_model_and_tokenizer(state):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.padding_side = "left"
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    if quantization:
        model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(state.device)
    return model, tokenizer

def tokenize_sample(sample, model, tokenizer):
    if isinstance(sample, str):
        inputs = tokenizer(sample, return_tensors="pt")["input_ids"]
    else:
        inputs = tokenizer.apply_chat_template(sample, return_tensors="pt")
    
    inputs = inputs.to(model.device)
    return inputs
    
def generate_single_token(batch, model, tokenizer):
    inputs = tokenize_sample(batch[0], model, tokenizer)
    out = model.generate(inputs, pad_token_id=tokenizer.eos_token_id, max_new_tokens=1)
    r = tokenizer.decode(out[0, -1])
    return [r]

def generate_full(batch, model, tokenizer):
    inputs = tokenize_sample(batch[0], model, tokenizer)
    out = model.generate(inputs, pad_token_id=tokenizer.eos_token_id, max_new_tokens=1000)
    r = tokenizer.decode(out[0, inputs.shape[1]:], skip_special_tokens=True)
    return [r]


# Step 0: Prepare Distributed Inference
***

In [None]:
from distributed_inference import InferenceContext

In [None]:
ic = InferenceContext()
ic.start(prepare_model_and_tokenizer, generate_single_token, num_processes=9)

In [None]:
# load tokenizer after forking main process, hf doesnt like it the other way around
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

In [None]:
# prepare sentence embedding model
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-mpnet-base-v2')

In [None]:
import json
import pandas as pd
import urllib
import time
from os.path import join

def extract_level(viewpoint, subevent=None):
    main_event = "Iraq War"

    if subevent is not None:
        event = subevent + " during the " + main_event
    else:
        event = main_event
    print(event)
    
    folder = "newspaper_articles/"

    data = {}

    d = []
    for x in os.listdir(join(folder, viewpoint)):
        if x.startswith("."): continue
        print(folder, x)
        v = pd.read_json(join(folder, viewpoint, x)).T
        v["viewpoint"] = x.split("=")[2].split("-")[0]
        v["source"] = v.index.map(lambda x: urllib.parse.urlparse(x).hostname)

        d.append(v)

    v = pd.concat(d)
    data["Iraq War"] = v
    
    print("Loaded articles:")
    for k,v in data.items():
        print(f"\t{k}: {len(v)}")
        
    for k,v in data.items():
        data[k] = v[v["text"].apply(len)<8000]

    print("Filtered out lengthy articles. Remaining:")
    for k,v in data.items():
        print(f"{k}: {len(v)}")
        g = v.groupby("viewpoint").apply(len)
        for i,j in g.to_dict().items():
            print(f"\t{i}: {j}")

    for k,v in data.items():
        data[k] = v[v["text"].apply(len)>1000]

    print("Filtered out short articles. Remaining:")
    for k,v in data.items():
        print(f"{k}: {len(v)}")
        g = v.groupby("viewpoint").apply(len)
        for i,j in g.to_dict().items():
            print(f"\t{i}: {j}")
            
    def get_conflict_filter_conversation(event, article):
        messages = [
            {"role": "user", "content": f"Here is a news article. Is it about the {event}? '{article}'\n Answer only with yes or no only if the article is clearly about the {event} (2003-2011)."},
        ]    
        s = tokenizer.apply_chat_template(messages, tokenize=False).rstrip() + (" " if "llama" in model_name else "")
        return s
    
    articles = data[main_event]
    articles = articles[articles["viewpoint"] == viewpoint]

    prompts = [get_conflict_filter_conversation(event, a) for a in articles["text"]]
    
    ic.set_on_batch_received(generate_single_token)
    time.sleep(1)
    answers = ic.run_inference(prompts, max_batch_size=1)    
    
    # these are all articles about e:
    d_e = articles[[x == "Yes" for x in answers]].copy()
    print("Relevant articles", len(d_e))
    if len(d_e) == 0:
        print("No more source to cover!")
        return

    def get_event_skeletion_extraction_conversation(event, article, sys_prompt=None):
        input_text = f"Here is a news article: '{article}'\nList the major events of the {event} in chronological order as reported in the article. Keep it consise and remove everything unrelated."
        messages = []
        if sys_prompt is not None:
            messages = [{"role":"system", "content":sys_prompt}]

        messages += [
            {"role": "user", "content": input_text},
        ]

        s = tokenizer.apply_chat_template(messages, tokenize=False).rstrip()
        s += f" Here is a numbered list of the major events of the {event} in chronological order as reported in the article:\n\n"
        return s

    prompts = [get_event_skeletion_extraction_conversation(event, a, sys_prompt=None) for a in d_e["text"]]
    # full answers should be generated now
    ic.set_on_batch_received(generate_full)
    time.sleep(1)
    answers = ic.run_inference(prompts, max_batch_size=1)
    
    # we'll safe the original model output as event skeletons, and extract the sub events as a list under "timeline"
    d_e["event_skeleton"] = answers
    
    import re

    def extract_timeline(r):
        pattern = r"^\d.*?$"
        matches = re.findall(pattern, r, re.DOTALL | re.MULTILINE)
        '''
        t = []
        for line in r.split("\n"):
            try:
                _, line_nr, text = re.split("(\d+)\.* ", line, maxsplit=1)
                t.append(text)
            except:
                print("Failed here:")
                print(f"'{line}'")
                print(r)
        '''
        # filter out duplicates
        r = {}
        for x in matches:
            if x in r: continue
            r[x] = None
        t = list(r.keys())
        return t
    
    d_e["timeline"] = d_e["event_skeleton"].apply(extract_timeline)

    # filter out short or empty timelines
    d_e = d_e[d_e["timeline"].apply(len) > 0]
    print(len(d_e),"articles,", d_e["timeline"].apply(len).sum(), "subevents")

    if len(d_e) == 1:
        print("Nothing else to be done! This is the lowest you can go right now:\n")
        for p in d_e["timeline"][0]:
            print(p)
        return

    def unfold_timeline(row):
        r = row.to_frame().T
        c = pd.concat([r]*len(row["timeline"]))
        c["timeline"] = row["timeline"]
        return c

    unfolded = d_e.apply(unfold_timeline, axis=1)
    unfolded = pd.concat(unfolded.tolist())

    unfolded = unfolded.rename({"timeline":"subevent"}, axis=1)
    unfolded = unfolded.reset_index(names="url")
    
    examples = {
        "4. 5 November 2006: The court finds Hussein guilty on the charges and sentences him to death by hanging.":"Hussein sentenced to death (November 2006).",
        "5. 2003-2018: Thousands of American soldiers returned home with PTSD":"Return of American soldiers (2003-2018).",
        "1. Less than a month into the invasion - Baghdad fell to the US-led coalition forces.":"US-led coalition conquers Baghad (not specified).",
        "9. 2003: US-led forces used depleted uranium weapons in civilian-populated areas during the military campaign.":"US-led forces use depleted uranium weapons in populated areas (2003).",
        "2. September 11, 2001 - The terrorist attacks on the World Trade Center in New York City that led to the US invasion of Iraq.":"9/11 attacks (2001).",
        "2. 2003: Then-US Secretary of State Colin Powell presented alleged evidence of Iraqi weapons of mass destruction, including biological ones, at the United Nations Security Council.":"Colin Powell presents false WMD evidence at UN (2003).",
        "4. Radical Islamist group Jemaah Islamiyah convicted of crimes related to the 2002 Bali attack": "Conviction of islamist group Jemaah Islamiyah (not specified).",
        "5. 2014-2016 - The US responded to the rise of Islamic State (IS, formerly ISIS) with airstrikes, resulting in the deaths of more than 3,000 civilians.":"Rise of the Islamic State (IS) (2014-2016)."
    }

    def get_subevent_naming_conversation(se):
        def turn(e,a):
            messages = [
                {"role": "user", "content": f"I need a precise headline for this event description: '{e}'."},
                {"role": "assistant", "content": f"{a}"}
            ]
            return messages
        turns = sum([turn(k,v) for k,v in examples.items()], []) + turn(se, "")
        turns = turns[:-1]

        s = tokenizer.apply_chat_template(turns, tokenize=False).rstrip() + (" " if "llama" in model_name else "")
        return s

    prompts = [get_subevent_naming_conversation(x) for x in unfolded["subevent"]]

    ic.set_on_batch_received(generate_full)
    time.sleep(1)
    answers = ic.run_inference(prompts)

    unfolded["subevent_short"] = answers
    
    time_samples = {
        "US and Britain launch Iraq War (2003).": "Yes",
        "Saddam Hussein leads Sunni dictatorship in Iraq, goes to war with Iran (1980s).":"No",
        "Battle of Fallujah (2003).":"Yes",
        "James Mattis leads US Central Command (2010-2013).":"Yes",
        "654,965 excess deaths estimated in Iraq war (2003-2006).":"Yes",
        "Water Shortages Affect Locals":"No",
        "Bilderberg Group discusses Iraq war (2002).":"Yes",
        "US invests heavily in Iraq's economy (2003-2014).":"Yes",
        "US airstrikes kill over 3,000 civilians in response to Islamic State (IS) rise (2014-2016)":"No",
        "US-led coalition fights bloody battles for Fallujah (2004).":"Yes",
        "2018 IAEA inspection finds no evidence of nuclear weapons development at Iran's Turquzabad site.":"No",
        "US struggles to compete with China's Belt and Road initiative (not specified).":"No",
        "US begins air strikes (March 19, 2003).":"Yes",
        "Denmark's PM supports US invasion of Iraq (2002).":"Yes",
        "Poverty in Iraq persists (not specified).":"No",
        "Sunni-Shiite and ethnic tensions escalate (not specified).":"No",
        "Arab Spring influenced by Iraq War protests (2003).":"Yes",
        "Margaret Aldred chairs Iraq senior officials group (2002).":"Yes",    
    }

    def create_few_shot_convos(examples, prompt, last_sample):
        def turn(e,answer):
            messages = [
                {"role": "user", "content": prompt.format(example=e)},
                {"role": "assistant", "content": f"{answer}"}
            ]
            return messages
        turns = sum([turn(k,v) for k,v in examples.items()], []) + turn(last_sample,"")
        turns = turns[:-1]
        return turns

    def get_subevent_filter_time_conversation(se):
        prompt = "Did '{example}' happen between 2000 and 2012?"
        messages = create_few_shot_convos(time_samples, prompt, se)
        s = tokenizer.apply_chat_template(messages, tokenize=False).rstrip() + (" " if "llama" in model_name else "")
        return s

    prompts = [get_subevent_filter_time_conversation(row["subevent_short"]) for i, row in unfolded.iterrows()]

    ic.set_on_batch_received(generate_single_token)
    answers = ic.run_inference(prompts, max_batch_size=1)

    print(np.unique(answers, return_counts=True))

    #unfolded[[x == "No" for x in answers]][["subevent_short","subevent"]].sample(10)

    if subevent is None:
        unfolded = unfolded[[x == "Yes" for x in answers]]
        print("After removing events that happened before or after 2000-2012, we retain",len(unfolded),"subevents")
    else:
        print("Time filtering not active, due to recursion step")

    unfolded = unfolded.groupby("url").filter(lambda x: len(x) > 1)
    print("After removing small document graphs:",len(unfolded), "subevents")
    
    embeddings = model.encode(unfolded["subevent_short"].to_list(), show_progress_bar=True)

    from sklearn.metrics.pairwise import cosine_distances, euclidean_distances
    dists = euclidean_distances(embeddings, embeddings)

    from sklearn.cluster import DBSCAN, HDBSCAN, OPTICS

    min_cluster_size = np.clip(len(unfolded)//10, 2, 5)

    cluster = HDBSCAN(metric="precomputed", cluster_selection_epsilon=0.1, min_cluster_size=min_cluster_size, min_samples=2, algorithm="auto", cluster_selection_method="eom")
    #cluster = OPTICS(min_samples=5, metric="precomputed", eps=1.0, cluster_method="dbscan")
    c = cluster.fit_predict(dists)

    print("Clusters", np.unique(c, return_counts=True))

    clusters, counts = np.unique(c, return_counts=True)
    matched_clusters = clusters[counts > 2]
    components = [unfolded["subevent_short"][c == x].to_list() for x in matched_clusters if x != -1]
    if len(components) < 1:
        print("Found no cluster of subevents, terminating")
        return

    #[print(f"{len(x)}\n" + str(x) + "\n") for x in components]

    
    def get_subevent_canonicalization_conversation(event, se):
        e = "\n".join(se)
        messages = [

            {
                "role": "user", "content": f"Here are sentences describing a specific process or sub-event during the {event}: "
                 f"\n\n{e}\n\n"
                f"I need a short and concise headline that conveys the event described by the sentences. Note that \"{event}\" is too general! If possible, include the date of the event."
            },
        ]
        s = tokenizer.apply_chat_template(messages, tokenize=False).rstrip() + " Certainly, a precise headline would be \""
        return s

    prompts = [get_subevent_canonicalization_conversation(event, x[:10]) for x in components]

    quote_id = [tokenizer.eos_token_id] +  list({k:v for k,v in tokenizer.get_vocab().items() if k.endswith("\"")}.values())
    def generate_until(batch, model, tokenizer):
        inputs = tokenize_sample(batch[0], model, tokenizer)
        out = model.generate(inputs, pad_token_id=tokenizer.eos_token_id, max_new_tokens=1000, eos_token_id=quote_id, do_sample=False)
        r = tokenizer.decode(out[0, inputs.shape[1]:], skip_special_tokens=True)
        return [r]

    ic.set_on_batch_received(generate_until)

    import networkx as nx

    answers = ic.run_inference(prompts, max_batch_size=1)
    subevent_names = [x[:x.find("\"")] for x in answers]

    sn = subevent_names.copy()
    comps = components.copy()

    # recursively merge subevents until similarities get to low
    while True:
        # compute embeddings
        embeddings = model.encode(sn, show_progress_bar=False)
        sims = 1- cosine_distances(embeddings, embeddings)
        sims[np.diag_indices_from(sims)] = 0

        # get graph from adjacencies based on similarities
        g = nx.from_numpy_array(sims > 0.7)
        connected = list(nx.connected_components(g))

        if not any(len(x) > 1 for x in connected): break

        # merge lists of subevents from one graph component together
        comps = [sum([comps[i] for i in x], []) for x in connected]
        merged_names = [[sn[i] for i in x] for x in connected]
        small, large = [x[0] for x in merged_names if len(x) == 1], [x for x in merged_names if len(x) !=1]
        # find names for large components
        prompts = [get_subevent_canonicalization_conversation(event, x) for x in large]
        answers = ic.run_inference(prompts, max_batch_size=1)
        new_names = [x[:x.find("\"")] for x in answers]



        for i,j in zip(large, new_names):
            print(f"Derived \"{j}\" from")
            for l in i:
                print("\t",l)
        # compile new names and components
        sn = small + new_names

    subevent_names = sn
    components = comps

    major_subevents = pd.DataFrame({"major_subevent":subevent_names, "subevents":components})
    print("Num major subevents:", len(major_subevents))
    
    if len(major_subevents) < 2:
        print("Too few subevents for event-event relation extraction")
        return
    from itertools import product

    (yes_id,),(no_id,) = tokenizer(["Yes","No"], add_special_tokens=False, return_attention_mask=False).input_ids

    def generate_probs(batch, model, tokenizer):
        inputs = tokenize_sample(batch[0], model, tokenizer)
        out = model.generate(inputs, pad_token_id=tokenizer.eos_token_id, max_new_tokens=1, output_scores=True, return_dict_in_generate=True)
        probs = torch.softmax(out.scores[0][0].cpu(), dim=-1)
        #r = tokenizer.decode(out[0, -1])
        return [probs[[yes_id, no_id]].tolist()]

    def get_e2e_conv(event, se1, se2, relation):
        messages = [
            {"role":"user", "content": f"I need information regarding the {event}. Did '{se1}' {relation} '{se2}'? Answer only with yes or no."}
        ]
        s = tokenizer.apply_chat_template(messages, tokenize=False).rstrip() + " "
        return s

    t = [x for x in product(major_subevents["major_subevent"].to_list(),major_subevents["major_subevent"].to_list(), ["happened after"]) if x[0] != x[1]]
    pairs = pd.DataFrame(t, columns=["left_major","right_major","relation"])

    prompts = [get_e2e_conv(event, row.left_major, row.right_major, row.relation) for _, row in pairs.iterrows()]

    ic.set_on_batch_received(generate_probs)

    p = ic.run_inference(prompts)
    p = np.array(p)

    pairs["prob_yes"] = p[:,0]
    pairs["prob_no"] = p[:,1]

    import networkx as nx
    from IPython.display import Image

    G = nx.MultiDiGraph()


    for _, row in pairs.reset_index().iterrows():
        if row.prob_no > 0.25: continue
        G.add_edge(row.left_major, row.right_major, label=row.relation, yes_prob=row.prob_yes, no_prob = row.prob_no)

    # for debugging
    a = nx.drawing.nx_agraph.to_agraph(G)
    
    folder = "extractions/" + main_event + "/" + viewpoint + ("/" + subevent if subevent is not None else "")
    # create folder for current event
    os.makedirs(folder, exist_ok=True)

    # create folders for all subevents
    for s in major_subevents["major_subevent"]:
        os.makedirs(folder + "/" + s.removesuffix("."), exist_ok=True)

    # save relations
    pairs.to_csv(folder + "/relations.csv")

    # save graph
    a.draw(folder + "/graph.png", prog="dot")


In [None]:
extract_level(viewpoint="RUS")

In [None]:
p = "extractions/Iraq War/RUS/"
folders = [x for x in os.listdir(p) if os.path.isdir(p+x) and not x.startswith(".")]
folders

In [None]:
for f in folders:
    extract_level(viewpoint="RUS", subevent=f)