## Causal Relations Predictions

To extract causal relations between events using similarity semantics and COMET model.

1. Loads COMET model from HuggingFace.
2. Encodes events using SBERT to compare similarity.
3. Generates possible effects of event A using COMET
4. Measures similarity between Event B and COMET-generated effects.
5. Decides if there's a causal relation based on a threshold.

#### Load Data

In [1]:
import json
import os
from tqdm import tqdm
import pandas as pd

#with open("../data/events.json", "r") as file:
#    data = json.load(file)

data = pd.read_parquet("../data/cluster_output.parquet")

In [2]:
data.head()

Unnamed: 0,event_id,event_type,trigger,event_summary,arguments,dependencies,cluster_id
0,0,Economic Warning,warn of inflation risks,Federal Reserve officials warned that the Trum...,"{'agent': 'Federal Reserve officials', 'cause'...",[],0
1,1,Trade Policy,slapped tariffs,President Trump announced new tariffs on the t...,"{'agent': 'President Trump', 'cause': 'Trump a...",[],219
2,2,Diplomatic Agreement,suspending the tariffs,Trump agreed to suspend tariffs on Mexico and ...,"{'agent': 'President Trump', 'cause': 'Agreeme...",[{'description': 'Suspensions were a response ...,2
3,3,Economic Analysis,will push up inflation and depress growth,Economists predict the new tariffs will increa...,"{'agent': 'Economists', 'cause': 'Implementati...",[{'description': 'Price increases and growth d...,3
4,4,Monetary Policy,held policy rate steady,The Federal Reserve decided to keep interest r...,"{'agent': 'Federal Reserve', 'cause': 'Uncerta...",[{'description': 'Tariff-induced economic unce...,4


In [3]:
# unique clusters
data["cluster_id"].nunique()

31996

**Update the dependencies link (link to cluster id instead of event_id)**

In [None]:
# build event_id to cluster id mapping
event_to_cluster = dict(zip(data["event_id"], data["cluster_id"]))

In [5]:
# update dependencies for a single event
def update_dependencies_to_cluster(dependencies, event_to_cluster):
    """Replace event_id in dependencies to corresponding cluster_id"""
    updated_deps = []
    for dep in dependencies:
        target_event_id = dep["event_id"]
        # lookup cluster_id
        cluster_id = event_to_cluster.get(target_event_id, None)
        if cluster_id is not None:
            updated_dep = dep.copy()
            updated_dep["event_id"] = cluster_id
            updated_deps.append(updated_dep)

    return updated_deps

# apply dependencies updating row wise
data["dependencies"] = data["dependencies"].apply(lambda deps: update_dependencies_to_cluster(deps, event_to_cluster))

In [7]:
# deduplicate dependencies
def deduplicate_dependencies(dependencies):
    """Remove duplicate dependencies pointing to the same cluster_id"""
    seen = set()
    unique_deps = []
    for dep in dependencies:
        cid = dep["event_id"]
        if cid not in seen:
            unique_deps.append(dep)
            seen.add(cid)
    return unique_deps

data["dependencies"] = data["dependencies"].apply(deduplicate_dependencies)

**Filter to get top 100 clusters**

In [27]:
# get top 100 clusters
n = 1000
top_clusters = data["cluster_id"].value_counts().index.tolist()[:n]

In [28]:
# only retain top 100 clusters
data_top = data[data["cluster_id"].isin(top_clusters)].reset_index(drop=True)

In [29]:
# for each event in the top 100 cluster, retain the event with the longest trigger
data_top["trigger_len"] = data_top["trigger"].apply(lambda x: len(x.split()))
data_top["summary_len"] = data_top["event_summary"].apply(lambda x: len(x.split()))
data_top = data_top.sort_values(["cluster_id", "trigger_len", "summary_len"], ascending=[True, False, False])

# remove duplicates
data_top_nodup = data_top.drop_duplicates(subset=["cluster_id"], keep="first", ignore_index=True)

### Extract Cluster-Level Dependencies (based on `dependencies` field)

In [57]:
def extract_cluster_dependency_edges(rep_event_df):
    """Extract dependency edges between clusters using representative events"""
    dependency_edges = []
    rel_type_list = ["INFLUENCED", "TRIGGERED", "RELATED_TO", "RESPONSE_TO"]

    # Precompute lookup for efficiency
    cluster_to_event = rep_event_df.set_index("cluster_id").to_dict('index')

    for idx, row in rep_event_df.iterrows():
        source_cluster_id = row["cluster_id"]
        source_event_id = row["event_id"]
        source_trigger = row["trigger"]
        source_summary = row["event_summary"]

        for dep in row.get("dependencies", []):  # Safe access
            target_cluster_id = dep["event_id"]
            relation_type = dep["relation_type"]
            description = dep["description"]

            # Skip if invalid or self-loop
            if target_cluster_id is None or source_cluster_id == target_cluster_id:
                continue

            target_row = cluster_to_event.get(target_cluster_id, None)
            if target_row is None:
                continue  # If target cluster not represented

            target_trigger = target_row["trigger"]
            target_summary = target_row["event_summary"]

            # Build edge if relation is recognized
            if relation_type == "RESPONSE_TO":
                edge = {
                    # Flip the source and target
                    "source_cluster_id": target_cluster_id,
                    "target_cluster_id": source_cluster_id,
                    "relation_type": "causes",
                    "description": description,
                    "source_trigger": target_trigger,
                    "source_summary": target_summary,
                    "target_trigger": source_trigger,
                    "target_summary": source_summary
                }
                dependency_edges.append(edge)
            elif relation_type in rel_type_list:
                edge = {
                    "source_cluster_id": source_cluster_id,
                    "target_cluster_id": target_cluster_id,
                    "relation_type": "causes" if relation_type in ["INFLUENCED", "TRIGGERED"] else "related_to",
                    "description": description,
                    "source_trigger": source_trigger,
                    "source_summary": source_summary,
                    "target_trigger": target_trigger,
                    "target_summary": target_summary
                }
                dependency_edges.append(edge)

    print(f"Extracted {len(dependency_edges)} cluster-level dependency edges.")
    return dependency_edges

In [58]:
# extract cluster-level dependencies
cluster_dependencies_edges = extract_cluster_dependency_edges(data_top_nodup)

Extracted 82 cluster-level dependency edges.


### Load Model

In [39]:
import torch
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load COMET model
comet_model_name = "mismayil/comet-bart-ai2"
comet_tokenizer = AutoTokenizer.from_pretrained(comet_model_name)
comet_model = AutoModelForSeq2SeqLM.from_pretrained(comet_model_name)

# Load SBERT for semantic similarity
sbert_model = SentenceTransformer("all-mpnet-base-v2")

**Build Fast Lookup for Events**


In [40]:
# Create mapping from cluster_id to representative event
cluster_to_event = data_top_nodup.set_index('cluster_id').to_dict('index')

In [61]:
existing_links = set()

for edge in cluster_dependencies_edges:
    source_cluster_id = edge['source_cluster_id']
    target_cluster_id = edge['target_cluster_id']
    # Add (source, target)
    existing_links.add((source_cluster_id, target_cluster_id))

In [62]:
def generate_comet_relations(event_text, num_return_sequences=3, max_length=50):
    relation_prompts = {'xEffect': f"{event_text} xEffect", 'isAfter': f"{event_text} isAfter"}
    relation_results = {}
    for relation, prompt in relation_prompts.items():
        inputs = comet_tokenizer([prompt], return_tensors='pt')
        outputs = comet_model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=num_return_sequences,
            do_sample=True,
            temperature=0.5,
            top_k=50,
            top_p=0.95
        )
        results = [comet_tokenizer.decode(output, skip_special_tokens=True).strip() for output in outputs]
        relation_results[relation] = results
    return relation_results

# SBERT Similarity
def sbert_similarity(text1, text2):
    emb1 = sbert_model.encode(text1, convert_to_tensor=True)
    emb2 = sbert_model.encode(text2, convert_to_tensor=True)
    return util.cos_sim(emb1, emb2).item()

def infer_cluster_links_with_comet_sbert(cluster_to_event, existing_links, similarity_threshold=0.7):
    """Link cluster representative events using COMET + SBERT, skipping existing links"""
    cluster_ids = list(cluster_to_event.keys())
    n = len(cluster_ids)
    inferred_edges = []

    for i in tqdm(range(n), desc="Linking clusters with COMET + SBERT"):
        for j in range(n):
            if i == j:
                continue  # Skip self-pairs

            source_cluster_id = cluster_ids[i]
            target_cluster_id = cluster_ids[j]

            # Skip if already linked
            if (source_cluster_id, target_cluster_id) in existing_links or (target_cluster_id, source_cluster_id) in existing_links:
                continue

            event_a = cluster_to_event[source_cluster_id]
            event_b = cluster_to_event[target_cluster_id]

            # Run COMET on source event
            comet_relations = generate_comet_relations(event_a['event_summary'])
            comet_results = []

            for rel_type in ['xEffect', 'isAfter']:
                for effect in comet_relations[rel_type]:
                    sim_score = sbert_similarity(effect, event_b['event_summary'])
                    comet_results.append((sim_score, rel_type, effect))

            # -------- Select best match -------- #
            if comet_results:
                comet_effects_score, best_comet_type, best_comet_effect = max(comet_results, key=lambda x: x[0])
            else:
                comet_effects_score, best_comet_type, best_comet_effect = 0.0, None, None

            # -------- Save link if above threshold -------- #
            if comet_effects_score >= similarity_threshold:
                relation_type = "causes" if best_comet_type == 'xEffect' else "happens_before"
                explanation = (
                    f"Event '{event_a['trigger']}' {relation_type} Event '{event_b['trigger']}' "
                    f"via COMET {best_comet_type}: '{best_comet_effect}'."
                )

                inferred_edge = {
                    "source_cluster_id": source_cluster_id,
                    "target_cluster_id": target_cluster_id,
                    "relation_type": relation_type,
                    "confidence_score": comet_effects_score,
                    "explanation": explanation,
                    "source_trigger": event_a['trigger'],
                    "source_summary": event_a['event_summary'],
                    "target_trigger": event_b['trigger'],
                    "target_summary": event_b['event_summary'],
                    "evidence_type": f"comet_{best_comet_type}"
                }
                inferred_edges.append(inferred_edge)

    print(f"Inferred {len(inferred_edges)} new relations using COMET + SBERT.")
    return inferred_edges

In [None]:
def precompute_comet_sbert_embeddings(rep_event_df):
    """Precompute COMET relations and SBERT embeddings for all representative events."""
    cache = {}

    for idx, row in tqdm(rep_event_df.iterrows(), total=len(rep_event_df), desc="Precomputing COMET+SBERT"):
        cluster_id = row['cluster_id']
        event_summary = row['event_summary']

        # Generate COMET relations 
        comet_relations = generate_comet_relations(event_summary)

        # Compute SBERT embeddings for COMET outputs
        comet_embeddings = {
            rel_type: sbert_model.encode(effects, convert_to_tensor=True).cpu().tolist()
            for rel_type, effects in comet_relations.items()
        }

        # Compute SBERT embedding for event_summary 
        summary_embedding = sbert_model.encode(event_summary, convert_to_tensor=True).cpu().tolist()

        # Cache all 
        cache[cluster_id] = {
            'comet_relations': comet_relations,
            'comet_embeddings': comet_embeddings,
            'summary_embedding': summary_embedding,
            'trigger': row['trigger'],
            'event_summary': row['event_summary']
        }

    print(f"Precomputed COMET and SBERT embeddings for {len(rep_event_df)} events.")
    return cache

def save_comet_sbert_cache(cache, save_path='../data/cluste/comet_sbert_cache.jsonl'):
    with open(save_path, 'w', encoding='utf-8') as f:
        for cluster_id, data in cache.items():
            json.dump({cluster_id: data}, f)
            f.write('\n')
    print(f"Saved COMET + SBERT cache to {save_path}")

In [48]:
# Precompute and save COMET+SBERT embeddings
cache = precompute_comet_sbert_embeddings(data_top_nodup)
save_comet_sbert_cache(cache, save_path='../data/cluster/comet_sbert_cache.jsonl')

Precomputing COMET+SBERT: 100%|██████████| 1000/1000 [28:10<00:00,  1.69s/it]


Precomputed COMET and SBERT embeddings for 1000 events.
Saved COMET + SBERT cache to ../data/cluster/comet_sbert_cache.jsonl


**Load the Cached Data**

In [49]:
def load_comet_sbert_cache(load_path='../data/cluster/comet_sbert_cache.jsonl'):
    cache = {}
    with open(load_path, 'r', encoding='utf-8') as f:
        for line in f:
            entry = json.loads(line)
            cluster_id, data = next(iter(entry.items()))
            cache[int(cluster_id)] = {
                'comet_relations': data['comet_relations'],
                'comet_embeddings': {k: torch.tensor(v) for k, v in data['comet_embeddings'].items()},
                'summary_embedding': torch.tensor(data['summary_embedding']),
                'trigger': data['trigger'],
                'event_summary': data['event_summary']
            }
    print(f"Loaded COMET + SBERT cache for {len(cache)} clusters.")
    return cache

In [64]:
# Load the cached embeddings when needed
cache_loaded = load_comet_sbert_cache('../data/cluster/comet_sbert_cache.jsonl')

Loaded COMET + SBERT cache for 1000 clusters.


**Pairwise comparison to infer relations**

In [None]:
def infer_cluster_links_with_cached_embeddings(cache, existing_links, similarity_threshold=0.7):
    """Pairwise linking using precomputed COMET + SBERT embeddings."""
    cluster_ids = list(cache.keys())
    n = len(cluster_ids)
    inferred_edges = []

    for i in tqdm(range(n), desc="Linking clusters using cached embeddings"):
        for j in range(n):
            if i == j:
                continue  # Skip self-pair

            source_cluster_id = cluster_ids[i]
            target_cluster_id = cluster_ids[j]

            # Skip existing dependency links
            if (source_cluster_id, target_cluster_id) in existing_links or (target_cluster_id, source_cluster_id) in existing_links:
                continue

            source_data = cache[source_cluster_id]
            target_data = cache[target_cluster_id]

            target_summary_embedding = target_data['summary_embedding']

            best_score, best_relation, best_effect = 0.0, None, None

            for rel_type in ['xEffect', 'isAfter']:
                effect_embeddings = source_data['comet_embeddings'][rel_type]  # Tensor
                similarities = util.cos_sim(target_summary_embedding, effect_embeddings).squeeze(0)
                max_sim = torch.max(similarities).item()

                if max_sim > best_score:
                    best_score = max_sim
                    best_relation = 'causes' if rel_type == 'xEffect' else 'happens_before'
                    best_idx = torch.argmax(similarities).item()
                    best_effect = source_data['comet_relations'][rel_type][best_idx]

            if best_score >= similarity_threshold:
                explanation = (
                    f"Event '{source_data['trigger']}' {best_relation} Event '{target_data['trigger']}' "
                    f"via COMET: '{best_effect}'"
                )
                inferred_edges.append({
                    "source_cluster_id": source_cluster_id,
                    "target_cluster_id": target_cluster_id,
                    "relation_type": best_relation,
                    "confidence_score": best_score,
                    "explanation": explanation,
                    "source_trigger": source_data['trigger'],
                    "source_summary": source_data['event_summary'],
                    "target_trigger": target_data['trigger'],
                    "target_summary": target_data['event_summary'],
                    "evidence_type": f"comet_{best_relation}"
                })

    print(f"Inferred {len(inferred_edges)} new cluster-level links.")
    return inferred_edges

In [66]:
# Run inference on cluster pairs, avoiding existing links
comet_sbert_edges = infer_cluster_links_with_cached_embeddings(
    cache_loaded,
    existing_links,
    similarity_threshold=0.7
)

Linking clusters using cached embeddings: 100%|██████████| 1000/1000 [01:08<00:00, 14.60it/s]

Inferred 359 new cluster-level links.





In [69]:
# Combine both lists of edges
combined_edges = cluster_dependencies_edges + comet_sbert_edges

print(f"Combined {len(cluster_dependencies_edges)} dependency edges and {len(comet_sbert_edges)} COMET/SBERT edges.")
print(f"Total combined edges: {len(combined_edges)}")

Combined 82 dependency edges and 359 COMET/SBERT edges.
Total combined edges: 441


In [70]:
def save_combined_edges(edges, save_path='../data/cluster/combined_cluster_edges.jsonl'):
    with open(save_path, 'w', encoding='utf-8') as f:
        for edge in edges:
            json.dump(edge, f)
            f.write('\n')
    print(f"Saved {len(edges)} combined edges to {save_path}")

In [71]:
# save the edges
save_combined_edges(combined_edges, save_path='../data/cluster/combined_cluster_edges.jsonl')

Saved 441 combined edges to ../data/cluster/combined_cluster_edges.jsonl


In [72]:
comet_sbert_edges[:10]

[{'source_cluster_id': 85,
  'target_cluster_id': 198,
  'relation_type': 'happens_before',
  'confidence_score': 0.7020261287689209,
  'explanation': "Event 'sweeping tariffs on goods' happens_before Event 'imposed tariffs' via COMET: 'US President Donald Trump orders tariffs on products from Canada and China'",
  'source_trigger': 'sweeping tariffs on goods',
  'source_summary': 'US President Donald Trump ordered sweeping tariffs on goods from Mexico, Canada, and China to control illegal immigration and fentanyl flow into the US.',
  'target_trigger': 'imposed tariffs',
  'target_summary': 'US President Donald Trump imposed tariffs on China and has considered similar measures against Taiwanese semiconductors.',
  'evidence_type': 'comet_happens_before'},
 {'source_cluster_id': 85,
  'target_cluster_id': 272,
  'relation_type': 'happens_before',
  'confidence_score': 0.7403940558433533,
  'explanation': "Event 'sweeping tariffs on goods' happens_before Event 'threatened to slap a 60 p