# Model 02: Fast evidence shortlisting by information extraction

This is actually not as good as 02a

In [3]:
# Change the working directory to project root
from pathlib import Path
import os
ROOT_DIR = Path.cwd()
while not ROOT_DIR.joinpath("src").exists():
    ROOT_DIR = ROOT_DIR.parent
os.chdir(ROOT_DIR)

In [4]:
# Dependencies
import json
from sklearn.model_selection import ParameterGrid
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
from typing import List, Dict, DefaultDict
from dataclasses import dataclass
from tqdm import tqdm
from multiprocessing.pool import ThreadPool as Pool
from math import floor

from src.data import load_as_dataframe, slice_by_claim, SetEncoder
from src.normalize import normalize_pipeline
from src.ner import \
    train_noun_relations, \
    get_evidence_by_noun, \
    retrieve_claim_evidence_by_noun, \
    view_claim_noun_phrases

In [5]:
import spacy
from spacy import displacy
nlp = spacy.load("en_core_web_sm")
nlp.pipe_names

  from .autonotebook import tqdm as notebook_tqdm


['tok2vec', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer', 'ner']

In [6]:
test = nlp("there")

In [7]:
for t in test:
    print(t.pos_)

ADV


# Load datasets

In [8]:
data_names = ["train-claims", "dev-claims", "evidence"]
train_claims, dev_claims, all_evidences \
    = load_as_dataframe(data_names, full_evidence=True)

Loaded train-claims
Loaded dev-claims
Loaded evidence


## Visualise dependencies

In [9]:
pairs = dev_claims.loc["claim-1896"].reset_index()
pair_index = 0
# pairs = train_claims.loc["claim-126"].reset_index()
# pair_index = 1
# pairs = train_claims.loc["claim-1937"].reset_index()
# pair_index = 2
# pairs = train_claims.loc["claim-2510"].reset_index()
# pair_index = 0
# pairs = dev_claims.loc["claim-2062"].reset_index()
# pair_index = 0
# pairs = dev_claims.loc["claim-139"].reset_index()
# pair_index = 0
# pairs = dev_claims.loc["claim-752"].reset_index()
# pair_index = 0
# pairs = dev_claims.loc["claim-506"].reset_index()
# pair_index = 4
# pairs = dev_claims.loc["claim-1975"].reset_index()
# pair_index = 0
pairs

Unnamed: 0,claim_text,claim_label,evidences,evidence_text
0,Greg Hunt CSIRO research shows carbon emission...,NOT_ENOUGH_INFO,evidence-798121,To further reduce U.S. carbon dioxide emission...
1,Greg Hunt CSIRO research shows carbon emission...,NOT_ENOUGH_INFO,evidence-305800,Almost 20 percent (8 GtCO2/year) of total gree...
2,Greg Hunt CSIRO research shows carbon emission...,NOT_ENOUGH_INFO,evidence-108068,It is estimated that increasing the carbon con...
3,Greg Hunt CSIRO research shows carbon emission...,NOT_ENOUGH_INFO,evidence-234076,There is a large potential for future reductio...
4,Greg Hunt CSIRO research shows carbon emission...,NOT_ENOUGH_INFO,evidence-1111240,"""Ireland to Plant 440 Million Trees in 20 Year..."


In [10]:
claim_doc = nlp(normalize_pipeline(pairs.iloc[pair_index]["claim_text"]))
evidence_doc = nlp(normalize_pipeline(pairs.iloc[pair_index]["evidence_text"]))
displacy.render(claim_doc, style="dep")
displacy.render(evidence_doc, style="dep")

## Info Tag extraction

In [11]:
@dataclass
class InfoTag:
    tag:str
    verb_pos:int = 0

In [12]:
def get_info_tags(doc, go_nouns:List[str] = []) -> List[InfoTag]:
    
    info_tags = list()
    seen = list()
    seen_tags = list()
    
    def add_tag(tag_txt:str, info_tags=info_tags, seen_tags=seen_tags):
        if tag_txt not in seen_tags:
            info_tags += [InfoTag(tag = tag_txt, verb_pos=verb_pos)]
            seen_tags += [tag_txt]
        return
    
    verb_pos = 0
    for token in (doc):
        
        # Skip if we have seen the token before
        if token.lemma_ in seen:
            continue
        
        # Increment the relative verb position
        if token.pos_ in ["VERB"]:
            verb_pos += 1
        
        # Include it if it is in the whitelist of go_nouns
        if token.lemma_ in go_nouns:
            # info_tags += [InfoTag(tag=token.lemma_, verb_pos=verb_pos)]
            tag_txt = token.lemma_
            add_tag(tag_txt)
        
        # Include it if it is a (proper)noun
        if token.pos_ in ["PROPN", "NOUN"]:
            # info_tags += [InfoTag(tag=token.lemma_, verb_pos=verb_pos)]
            tag_txt = token.lemma_
            add_tag(tag_txt)
        
        # Merge proper noun compounds
        if token.pos_ in ["PROPN"]:
            tag = [token.lemma_]
            seen += [token.lemma_]
            this_token = token
            while this_token.dep_ in ["compound"]:
                this_token = this_token.head
                tag += [this_token.lemma_]
                seen += [this_token.lemma_]
            tag_txt = " ".join(tag)
            add_tag(tag_txt)
            # continue
        
        # Merge noun compounds
        if (token.pos_ in ["NOUN"]
            and token.dep_ in ["compound"]
            and token.head.pos_ == "NOUN"
        ):
            tag = [token.lemma_, token.head.lemma_]
            seen += tag
            # info_tags += [InfoTag(tag = " ".join(tag), verb_pos=verb_pos)]
            tag_txt = " ".join(tag)
            add_tag(tag_txt)
            # continue
        
        # Add possessive or subject nouns/pronouns
        # if (
        #     token.pos_ in ["NOUN"]
        #     and token.dep_ in ["poss", "nsubj", "expl"]
        # ):
        #     tag = [token.lemma_]
        #     seen += tag
        #     info_tags += [InfoTag(tag = " ".join(tag), verb_pos=verb_pos)]
            # continue
        
        # Add nouns with determiners
        # if (
        #     token.pos_ in ["DET"]
        #     and token.dep_ in ["det"]
        #     and token.head.pos_ in ["NOUN"]
        # ):
        #     tag = [token.head.lemma_]
        #     seen += tag
        #     info_tags += [InfoTag(tag = " ".join(tag), verb_pos=verb_pos)]
            # continue
        
        # Add root nouns
        # if (
        #     token.pos_ in ["NOUN"]
        #     and token.dep_ in ["ROOT"]
        # ):
        #     tag = [token.lemma_]
        #     seen += tag
        #     info_tags += [InfoTag(tag = " ".join(tag), verb_pos=verb_pos)]
            # continue
        
        # Merge Nouns with adjective modifiers
        if (
            token.pos_ in ["ADJ"]
            and token.dep_ in ["amod"]
            and token.head.pos_ in ["NOUN"]
            and token.head.dep_ not in ["compound"]
        ):
            tag = [token.lemma_, token.head.lemma_]
            seen += tag
            # info_tags += [InfoTag(tag = " ".join(tag), verb_pos=verb_pos)]
            tag_txt = " ".join(tag)
            add_tag(tag_txt)
            # continue
        
        # Adjectives linked to verbs and nouns
        if (
            token.pos_ in ["ADJ"]
            and token.head.pos_ in ["VERB", "NOUN"]
        ):
            tag = [token.lemma_]
            seen += tag
            # info_tags += [InfoTag(tag = " ".join(tag), verb_pos=verb_pos)]
            tag_txt = " ".join(tag)
            add_tag(tag_txt)
            # continue
        
        continue

    return info_tags


In [13]:
go_nouns = [
    'warming',
    'climate',
    'temperature',
    'CO2',
    'level',
    'ice',
    'change',
    'sea',
    'carbon',
    'scientist',
    'absorb'
]

In [14]:
claim_tags = get_info_tags(claim_doc, go_nouns=go_nouns)
claim_tags

[InfoTag(tag='greg', verb_pos=0),
 InfoTag(tag='greg research', verb_pos=0),
 InfoTag(tag='hunt', verb_pos=0),
 InfoTag(tag='hunt research', verb_pos=0),
 InfoTag(tag='csiro', verb_pos=0),
 InfoTag(tag='csiro research', verb_pos=0),
 InfoTag(tag='carbon', verb_pos=1),
 InfoTag(tag='carbon emission', verb_pos=1),
 InfoTag(tag='cent', verb_pos=2),
 InfoTag(tag='year', verb_pos=2),
 InfoTag(tag='nature', verb_pos=3),
 InfoTag(tag='nature soil', verb_pos=3),
 InfoTag(tag='tree', verb_pos=3)]

In [15]:
evidence_tags = get_info_tags(evidence_doc, go_nouns=go_nouns)
evidence_tags

[InfoTag(tag='u', verb_pos=1),
 InfoTag(tag='carbon', verb_pos=1),
 InfoTag(tag='carbon dioxide', verb_pos=1),
 InfoTag(tag='emission', verb_pos=1),
 InfoTag(tag='percent', verb_pos=1),
 InfoTag(tag='kyoto', verb_pos=2),
 InfoTag(tag='kyoto protocol', verb_pos=2),
 InfoTag(tag='planting', verb_pos=3),
 InfoTag(tag='area', verb_pos=3),
 InfoTag(tag='size', verb_pos=3),
 InfoTag(tag='texas', verb_pos=3),
 InfoTag(tag='brazil', verb_pos=3),
 InfoTag(tag='year', verb_pos=3)]

## Map subject relations

In [16]:
from collections import defaultdict
def map_subject_relation(
    claim_tags:InfoTag,
    evidence_tags:InfoTag,
    verb_pos_threshold:int=0
) -> defaultdict:
    subj_relations = defaultdict(list)
    for c_tag in claim_tags:
        # if c_tag.verb_pos > verb_pos_threshold:
        #     continue
        for e_tag in evidence_tags:
            # if e_tag.verb_pos > verb_pos_threshold:
            #     continue
            subj_relations[c_tag.tag].append(e_tag.tag)
            subj_relations[e_tag.tag].append(c_tag.tag)
    return subj_relations
    

In [17]:
map_subject_relation(claim_tags, evidence_tags)

defaultdict(list,
            {'greg': ['u',
              'carbon',
              'carbon dioxide',
              'emission',
              'percent',
              'kyoto',
              'kyoto protocol',
              'planting',
              'area',
              'size',
              'texas',
              'brazil',
              'year'],
             'u': ['greg',
              'greg research',
              'hunt',
              'hunt research',
              'csiro',
              'csiro research',
              'carbon',
              'carbon emission',
              'cent',
              'year',
              'nature',
              'nature soil',
              'tree'],
             'carbon': ['greg',
              'greg research',
              'hunt',
              'hunt research',
              'csiro',
              'csiro research',
              'u',
              'carbon',
              'carbon',
              'carbon dioxide',
              'emission',
              

## Training

In [18]:
DATA_PATH = ROOT_DIR.joinpath("./data/*")
SAVE_PATH = ROOT_DIR.joinpath("./result/pipeline/shortlisting_v3/*")

### Info tag pipeline

In [19]:
def info_tag_pipeline(
    text:str,
    go_nouns:list = [],
    return_doc:bool=False
) -> List[InfoTag]:
    text = normalize_pipeline(text)
    doc = nlp(text)
    tags = get_info_tags(doc, go_nouns=go_nouns)
    
    if return_doc:
        return tags, doc
    else:
        return tags

### Get top nouns

In [20]:
def get_top_nouns(
    claims_paths:List[Path],
    save_path:Path = None,
    k:int = 10,
    stop_nouns:list = [],
    with_counts:bool = True,
    verbose:bool = True
):
    # Load the claims file
    claims = dict()
    for claims_path in claims_paths:
        with open(claims_path, mode="r") as f:
            claims.update(json.load(f))
    
    # Cumulator
    noun_counts = Counter()
    
    claim_obj = tqdm(claims.values(), desc="claims", disable=not verbose)
    for claim in claim_obj:
        claim_text = claim["claim_text"]
        claim_doc = nlp(claim_text)
        for token in claim_doc:
            if token.pos_ in ["NOUN", "PROPN"] \
                and token.lemma_ not in stop_nouns:
                noun_counts[token.lemma_] += 1
    
    top_k_nouns_counts = noun_counts.most_common(k)
    top_k_nouns = [noun_count[0] for noun_count in top_k_nouns_counts]

    if save_path:
        with open(save_path, mode="w") as f:
            json.dump(obj=top_k_nouns, fp=f, cls=SetEncoder)
            print(f"saved to: {save_path}")
    
    if with_counts:
        return top_k_nouns_counts
    else:
        return top_k_nouns

In [21]:
top_k_nouns_path = SAVE_PATH.with_name("train_top_k_nouns.json")

if top_k_nouns_path.exists():
    print(f"existing found: {top_k_nouns_path}")
    
    with open(top_k_nouns_path, mode="r") as f:
        top_k_nouns = json.load(f)
        print(f"loaded: {top_k_nouns_path}")
        
else:
    
    top_k_nouns = get_top_nouns(
        claims_paths=[DATA_PATH.with_name("train-claims.json")],
        save_path=top_k_nouns_path,
        k=50,
        stop_nouns=["year"],
        with_counts=True,
    )
# top_k_nouns

existing found: /Users/johnsonzhou/git/comp90042-project/result/pipeline/shortlisting_v3/train_top_k_nouns.json
loaded: /Users/johnsonzhou/git/comp90042-project/result/pipeline/shortlisting_v3/train_top_k_nouns.json


### Train subject relations

In [22]:
def train_subject_relations(
    claims_paths:List[Path],
    evidence_path:Path,
    save_path:Path = None,
    verb_pos_threshold:int=0,
    verbose:bool = True
) -> DefaultDict[str, set]:
    # Load the claims file
    claims = dict()
    for claims_path in claims_paths:
        with open(claims_path, mode="r") as f:
            claims.update(json.load(f))
    
    # Load the evidence file
    with open(evidence_path, mode="r") as f:
        evidences = json.load(f)
    
    # Cumulator
    subject_relations = defaultdict(set) #todo
    
    claim_obj = tqdm(claims.values(), desc="claims", disable=not verbose)
    for claim in claim_obj:
        claim_text = claim["claim_text"]
        for evidence_id in claim["evidences"]:
            evidence_text = evidences[evidence_id]
            
            claim_tags = info_tag_pipeline(text=claim_text)
            evidence_tags = info_tag_pipeline(text=evidence_text)
            
            related_tags = map_subject_relation(
                claim_tags=claim_tags,
                evidence_tags=evidence_tags,
                verb_pos_threshold=verb_pos_threshold
            )
            
            for tag, related_tags in related_tags.items():
                claim_obj.postfix = f"n_rel: {len(related_tags)}"
                subject_relations[tag].update(related_tags) #todo
    
    if save_path:
        with open(save_path, mode="w") as f:
            json.dump(obj=subject_relations, fp=f, cls=SetEncoder)
            print(f"saved to: {save_path}")
            
    return subject_relations

In [23]:
relation_save_path = SAVE_PATH.with_name("train_subject_relations.json")

if relation_save_path.exists():
    print(f"existing found: {relation_save_path}")
    
    # with open(relation_save_path, mode="r") as f:
    #     subject_relations = json.load(f)
    #     print(f"loaded: {relation_save_path}")
        
else:
    
    subject_relations = train_subject_relations(
        claims_paths=[DATA_PATH.with_name("train-claims.json")],
        evidence_path=DATA_PATH.with_name("evidence.json"),
        save_path=relation_save_path
    )
# subject_relations

existing found: /Users/johnsonzhou/git/comp90042-project/result/pipeline/shortlisting_v3/train_subject_relations.json


### Tag evidences

In [24]:
# For multiprocessing
def get_evidence_tags_pool(
    evidence,
    info_tag_pipeline=info_tag_pipeline,
    go_nouns=top_k_nouns
):
    evidence_id, evidence_text = evidence
        
    tags = info_tag_pipeline(text=evidence_text, go_nouns=go_nouns)

    return evidence_id, tags

In [25]:
def tag_evidences(
    evidence_path:Path,
    top_k_nouns_path:Path,
    save_path:Path = None,
    processes:int = 8,
    verbose:bool = True
) -> DefaultDict[str, set]:
    
    # Load the evidence file
    with open(evidence_path, mode="r") as f:
        evidences = json.load(f)
        
    # Load the top k nouns file
    with open(top_k_nouns_path, mode="r") as f:
        top_k_nouns = json.load(f)
    
    # Cumulator
    evidence_tags = defaultdict(set)
    
    # Multiprocessing, can't get it to work
    # evidences_iter = \
    #     tqdm([evidences.items()], desc="evidence", disable=not verbose, leave=False)
    
    # with Pool(processes=processes) as pool:
    #     pooled_evidence_tags = pool.imap(get_evidence_tags_pool, evidences_iter)
        
    # evidence_tags_iter = \
    #     tqdm(pooled_evidence_tags, desc="evidence tags", disable=not verbose)
    # for tags, evidence_id in evidence_tags_iter:
    #     for tag in tags:
    #         evidence_tags[tag.tag].add(evidence_id)
    
    evidences_iter = tqdm(evidences.items(), desc="claims", disable=not verbose)
    for evidence_id, evidence_text in evidences_iter:
        
        tags = info_tag_pipeline(text=evidence_text, go_nouns=top_k_nouns)
        
        for tag in tags:
            evidence_tags[tag.tag].add(evidence_id)
        
        evidences_iter.postfix = f"n_tags: {len(tags)}"
        
        continue
    
    if save_path:
        with open(save_path, mode="w") as f:
            json.dump(obj=evidence_tags, fp=f, cls=SetEncoder)
            print(f"saved to: {save_path}")

    return evidence_tags


In [26]:
tagged_evidences_path = SAVE_PATH.with_name("tagged_evidences.json")

if tagged_evidences_path.exists():
    print(f"existing found: {tagged_evidences_path}")
    
    # with open(tagged_evidences_path, mode="r") as f:
    #     tagged_evidences = json.load(f)
    #     print(f"loaded: {tagged_evidences_path}")
        
else:
    
    tagged_evidences = tag_evidences(
        evidence_path=DATA_PATH.with_name("evidence.json"),
        top_k_nouns_path=top_k_nouns_path,
        save_path=tagged_evidences_path
    )
# tagged_evidences

existing found: /Users/johnsonzhou/git/comp90042-project/result/pipeline/shortlisting_v3/tagged_evidences.json


### Tag evidences with char n-gram

In [27]:
def get_bidirectional_n_grams(doc, n:int=4):
    fwd_ngrams = [token.lemma_[:n] for token in doc if len(token.lemma_) >= n]
    rev_ngrams = [token.lemma_[-n:] for token in doc if len(token.lemma_) >= n]
    return fwd_ngrams, rev_ngrams

In [28]:
def get_ngram_evidences(
    evidence_path:Path,
    n_list:list = [4, 5, 6],
    save_path_fwd:Path = None,
    save_path_rev:Path = None,
    verbose:bool = True
):
    # Load the evidence file
    with open(evidence_path, mode="r") as f:
        evidences = json.load(f)
    
    # Cumulator
    fwd_evidence_ngrams = defaultdict(set)
    rev_evidence_ngrams = defaultdict(set)
    
    evidences_iter = tqdm(evidences.items(), desc="claims", disable=not verbose)
    for evidence_id, evidence_text in evidences_iter:
        
        text = normalize_pipeline(evidence_text)
        doc = nlp(text)
        
        for n in n_list:
            fwd_ngrams, rev_ngrams = get_bidirectional_n_grams(doc, n=n)
            
            for ngram in fwd_ngrams:
                fwd_evidence_ngrams[ngram].add(evidence_id)
                
            for ngram in rev_ngrams:
                rev_evidence_ngrams[ngram].add(evidence_id)
            
        continue
    
    if save_path_fwd and save_path_rev:
        with open(save_path_fwd, mode="w") as f:
            json.dump(obj=fwd_evidence_ngrams, fp=f, cls=SetEncoder)
            print(f"saved to: {save_path_fwd}")
        with open(save_path_rev, mode="w") as f:
            json.dump(obj=rev_evidence_ngrams, fp=f, cls=SetEncoder)
            print(f"saved to: {save_path_rev}")
    
    return fwd_evidence_ngrams, rev_evidence_ngrams

In [29]:
fwd_ngram_evidences_path = SAVE_PATH.with_name("train_fwd_ngram_evidences.json")
rev_ngram_evidences_path = SAVE_PATH.with_name("train_rev_ngram_evidences.json")

if fwd_ngram_evidences_path.exists():
    print(f"existing found: {fwd_ngram_evidences_path}")
    
    # with open(fwd_ngram_evidences_path, mode="r") as f:
    #     fwd_ngram_evidences = json.load(f)
    #     print(f"loaded: {fwd_ngram_evidences_path}")
        
    # with open(rev_ngram_evidences_path, mode="r") as f:
    #     rev_ngram_evidences = json.load(f)
    #     print(f"loaded: {rev_ngram_evidences_path}")

else:
    
    fwd_ngram_evidences, rev_ngram_evidences = get_ngram_evidences(
        evidence_path=DATA_PATH.with_name("evidence.json"),
        n_list=[3, 4, 5, 6, 7, 8],
        save_path_fwd=fwd_ngram_evidences_path,
        save_path_rev=rev_ngram_evidences_path
    )
    
# fwd_ngram_evidences

existing found: /Users/johnsonzhou/git/comp90042-project/result/pipeline/shortlisting_v3/train_fwd_ngram_evidences.json


### Get evidence shortlist by claim

In [49]:
def get_evidence_shortlist(
    claims_paths:List[Path],
    related_tags_path:Path,
    tagged_evidence_path:Path,
    fwd_ngram_evidences_path:Path,
    rev_ngram_evidences_path:Path,
    ngram_list:list = [3, 4, 5, 6, 7, 8],
    save_path:Path = None,
    evidence_count_threshold:int=999,
    ngram_evidence_count_threshold:int=10,
    verb_pos_threshold:int=0,
    min_tag_match_threshold:int=2,
    min_ngram_match_threshold:int=2,
    min_all_match_threshold:int=2,
    min_tag_match_ratio:float=0.6,
    min_tag_match_offset:int=2,
    max_retrieved:int=1000,
    stop_tags:list=[],
    verbose:bool = True
):
    # Load the claims file
    claims = dict()
    for claims_path in claims_paths:
        with open(claims_path, mode="r") as f:
            claims.update(json.load(f))
    
    # Load the related tags file
    with open(related_tags_path, mode="r") as f:
        related_tags = json.load(f)
        
    # Load the tagged evidences file
    with open(tagged_evidence_path, mode="r") as f:
        tagged_evidences = json.load(f)
    
    # Load the evidence ngrams files
    with open(fwd_ngram_evidences_path, mode="r") as f:
        fwd_ngram_evidences = json.load(f)
        print(f"loaded: {fwd_ngram_evidences_path}")

    with open(rev_ngram_evidences_path, mode="r") as f:
        rev_ngram_evidences = json.load(f)
        print(f"loaded: {rev_ngram_evidences_path}")
    
    # Cumulator
    claim_evidences = defaultdict(set)
    missed_retrievals = defaultdict(set)
    retrieval_counts = []
    retrieval_recalls = []
    all_unique_tags = set()
    
    claim_obj = tqdm(claims.items(), desc="claims", disable=not verbose)
    iter_count = 0
    for claim_id, claim in claim_obj:
        # if iter_count >= 10:
        #     break
        
        # iter_count += 1
        
        # if claim_id not in ["claim-1160"]:
        #     continue
        
        interest_e_ids = [
            "evidence-76391",
            # "evidence-578305",
        ]
        
        #! Keywords
        claim_keywords_evidences = defaultdict(set)
        
        #! Link to below
        staged_retrievals = set()
        
        # Get claim direct tags
        tags = set()
        
        claim_text = claim["claim_text"]
        claim_tags, claim_doc = info_tag_pipeline(text=claim_text, return_doc=True)
        for tag in claim_tags:
            tags.add(tag.tag)
            all_unique_tags.add(tag.tag) #!
        
        original_tags = tags.copy() #!

        # Keep a count of how many tags each evidence relates to
        retrieved_evidence_counts = Counter()
        
        # Match tags -----------------------------------------------------
        
        # Keep a count of how many tags each evidence relates to
        retrieved_tag_evidence_counts = Counter()
        
        #! Get tags for which there is a subject relation
        # This step is to improve recall
        # for c_tag in claim_tags:
        #     # Verb position threshold for subject attention
        #     if c_tag.verb_pos > verb_pos_threshold:
        #         continue
        #     subject_tags = set(related_tags.get(c_tag.tag, []))
        #     # expanded_tags = tags.intersection(subject_tags)
        #     tags.update(subject_tags)
        
        # Go through all the tags then retrieve evidences
        for tag in tags:
            
            # Stop tags
            # if tag in stop_tags:
            #     continue
            
            evidence_ids = list(set(tagged_evidences.get(tag, [])))
            n_evidence_ids = len(evidence_ids)
            
            # Count threshold
            # if n_evidence_ids < evidence_count_threshold:
                # print(f"skipping tag {tag} with n={n_evidence_ids}")
                #! claim_evidences[claim_id].update(evidence_ids)
            for e_id in evidence_ids:
                retrieved_evidence_counts[e_id] += 1
                # retrieved_tag_evidence_counts[e_id] += 1
                
                #! Test to see if we have the pos e_ids in here
                #! We do!
                # if e_id in interest_e_ids:
                #     print(tag, e_id, len(evidence_ids))
                
                #! Keywords
                claim_keywords_evidences[tag].add(e_id)
                
                #! Immediately stage if direct tag match is less than
                #! max_retrieved
                # if len(staged_retrievals) < max_retrieved \
                #     and len(evidence_ids) < max_retrieved:
                    
                #     staged_retrievals.update(evidence_ids)
            
            continue
        
            
        
        # Go through the retrieved evidences and add it to the claim
        # if it meets the minimum tag association threshold
        # this step is to improve precision
        # for e_id, e_count in sorted(
        #     retrieved_tag_evidence_counts.items(),
        #     key=lambda x: x[1], reverse=True
        # ):
        #     n_orig_tags = len(original_tags)
        #     # count_threshold = floor(min_tag_match_ratio * n_orig_tags) \
        #     #     if n_orig_tags > 3 else 0
        #     count_threshold = min_tag_match_threshold
        #     # count_threshold = n_orig_tags - min_tag_match_offset
        #     if e_count < count_threshold:
        #         continue
        #     claim_evidences[claim_id].add(e_id)
        #     continue
        
        
        # Match ngrams ------------------------------------------------------
        
        # Keep a count of how many tags each evidence relates to
        retrieved_ngram_evidence_counts = Counter()
        
        # Ensure to match each token once, forwards and reverse, longest first
        fwd_ngram_matched_tokens = set()
        rev_ngram_matched_tokens = set()
        
        for token in claim_doc:
            token_lemma = token.lemma_
            
            for n in sorted(ngram_list, reverse=True):
                ngram_matched_evidence_ids = set()
                
                # ngram_evidence_count_cutoff = \
                #     ngram_evidence_count_threshold if n == 6 \
                #     else ngram_evidence_count_threshold * 1 if n == 5 \
                #     else ngram_evidence_count_threshold * 1
                
                # Match forward
                fwd_lemma = token_lemma[:n]
                fwd_evidences = fwd_ngram_evidences.get(fwd_lemma, [])
                # if len(fwd_evidences) <= ngram_evidence_count_cutoff \
                #     and len(fwd_evidences) > 0 \
                #     and token_lemma not in fwd_ngram_matched_tokens:
                ngram_matched_evidence_ids.update(fwd_evidences)
                fwd_ngram_matched_tokens.add(token_lemma)
                
                #! NEW Keyword
                claim_keywords_evidences[fwd_lemma].update(fwd_evidences)
                    
                # Match reverse
                rev_lemma = token_lemma[-n:0]
                rev_evidences = rev_ngram_evidences.get(rev_lemma, [])
                # if len(rev_evidences) <= ngram_evidence_count_cutoff \
                #     and len(rev_evidences) > 0 \
                #     and token_lemma not in rev_ngram_matched_tokens:
                ngram_matched_evidence_ids.update(rev_evidences)
                rev_ngram_matched_tokens.add(token_lemma)
                
                #! NEW Keyword
                claim_keywords_evidences[rev_lemma].update(rev_evidences)
                    
                for e_id in ngram_matched_evidence_ids:
                    # retrieved_ngram_evidence_counts[e_id] += 1
                    retrieved_evidence_counts[e_id] += 1
                    

        
        # Wrapping up ------------------------------------------------------
            
        # # Dynamically find the cutoff to return the maximum specified
        # # number of evidences
        # if len(retrieved_evidence_counts.keys()) > 1:
        #     retrieved_cut_cutoff = 1
        #     staged_retrievals = []
        #     searching_cutoff = True
        #     while searching_cutoff:
        #         current_staged_retrievals = [
        #             e_id
        #             for e_id, e_count in sorted(
        #                 retrieved_evidence_counts.items(),
        #                 key=lambda x: x[1], reverse=True
        #             )
        #             if e_count >= retrieved_cut_cutoff
        #         ]
                
        #         if len(current_staged_retrievals) < 1:
        #             searching_cutoff = False
                
        #         staged_retrievals = current_staged_retrievals
                
        #         #! Check at each stage whether we still have the pos e_ids
        #         # print(retrieved_cut_cutoff, len(staged_retrievals), set(staged_retrievals).intersection(interest_e_ids))
                
        #         if len(current_staged_retrievals) > max_retrieved:
        #             retrieved_cut_cutoff += 1
        #             continue
            
        #         searching_cutoff = False
            
        #     # Add staged retrievals against the claim_id
        #     for e_id in staged_retrievals:
        #         claim_evidences[claim_id].add(e_id)
                
        
        #! NEW Keywords intersection algorithm NEW
        claim_bigram_keyword_evidences = defaultdict(set)
        
        for kw, e_id in claim_keywords_evidences.items():
            for kw_, e_id_ in claim_keywords_evidences.items():
                if (kw, kw_) in claim_bigram_keyword_evidences.keys() \
                    or (kw_, kw) in claim_bigram_keyword_evidences.keys():
                        continue
                        
                e_intersect = set(e_id).intersection(set(e_id_))
                
                # if len(e_intersect.intersection(set(interest_e_ids))) > 1:
                #     print(kw, kw_, e_intersect)
                # has_interest = set(interest_e_ids).intersection(e_intersect)
                # if len(has_interest) > 0:
                #     print(kw, kw_, len(has_interest), len(e_intersect))
                if len(e_intersect) > 0 and len(e_intersect) < max_retrieved:
                    claim_bigram_keyword_evidences[(kw, kw_)] = e_intersect
                    
        #! Triplet keywords -- incomplete
        # claim_bigram_keyword_evidences = defaultdict(set)
        
        # for kw, e_id in claim_keywords_evidences.items():
        #     for kw_, e_id_ in claim_keywords_evidences.items():
        #         for kw__, e_id__ in claim_keywords_evidences.items():
        #             if (kw, kw_) in claim_bigram_keyword_evidences \
        #                 or (kw_, kw) in claim_bigram_keyword_evidences:
        #                     continue
                        
        #         e_intersect = set(e_id).intersection(set(e_id_))
        #         # has_interest = set(interest_e_ids).intersection(e_intersect)
        #         # if len(has_interest) > 0:
        #         #     print(kw, kw_, len(has_interest), len(e_intersect))
        #         if len(e_intersect) > 0 and len(e_intersect) < max_retrieved:
        #             claim_bigram_keyword_evidences[(kw, kw_)] = e_intersect
                
        # opportunity = sorted(
        #     claim_bigram_keyword_evidences.items(),
        #     key=lambda x: len(x[1]),
        # )
        # for kw, e_ids in opportunity:
        #     print(kw, len(e_ids))
        
        # staged_retrievals = set() #! Moved to top
        searching_cutoff = True
        
        for bigram_kw, e_ids in sorted(
            claim_bigram_keyword_evidences.items(),
            key=lambda x: len(x[1]),
        ):
            if searching_cutoff == False:
                break
            
            if len(staged_retrievals) < max_retrieved:
                staged_retrievals.update(e_ids)
            else:
                searching_cutoff = True
        
        # Add staged retrievals against the claim_id
        for e_id in staged_retrievals:
            claim_evidences[claim_id].add(e_id)

        
        # Count how many evidences have been retrieved for this claim
        n_retrieved = len(claim_evidences[claim_id])
        
        # if n_retrieved < 1:
        #     print(claim_id, n_retrieved)
        
        # Calculate some statistics
        recall = 1
        if "evidences" in claim.keys():
            truth_evidences = set(claim["evidences"])
            retrieved_evidences = set(claim_evidences[claim_id])
            missed = truth_evidences.difference(retrieved_evidences)
            recall = (len(truth_evidences) - len(missed)) / len(truth_evidences)
            
            if recall < 0.5:
                print(claim_id, recall, original_tags)
                
            missed_retrievals[claim_id].update(missed)
        
        retrieval_counts.append(n_retrieved)
        retrieval_recalls.append(recall)
        claim_obj.postfix = f"n_retrieved: {n_retrieved}, recall: {recall}"
        
        continue
    
    if save_path:
        with open(save_path, mode="w") as f:
            json.dump(obj=claim_evidences, fp=f, cls=SetEncoder)
            print(f"saved to: {save_path}")
    
    return claim_evidences, missed_retrievals, retrieval_counts, retrieval_recalls, all_unique_tags

In [52]:
max_retrieval = 750
retrieved_evidences_path = SAVE_PATH.with_name(f"dev_retrieved_evidences_max_{max_retrieval}_set_intersect_no_rel.json")

if retrieved_evidences_path.exists():
    
    with open(retrieved_evidences_path, mode="r") as f:
        retrieved_evidences = json.load(f)
        print(f"loaded: {retrieved_evidences_path}")
        
else:
    
    retrieved_evidences, missed_retrievals, retrievals_counts, retrieval_recalls, all_unique_tags = \
    get_evidence_shortlist(
        # claims_paths=[DATA_PATH.with_name("train-claims.json")],
        claims_paths=[DATA_PATH.with_name("dev-claims.json")],
        # claims_paths=[DATA_PATH.with_name("train-claims.json"), DATA_PATH.with_name("dev-claims.json")],
        related_tags_path=relation_save_path,
        tagged_evidence_path=tagged_evidences_path,
        fwd_ngram_evidences_path=fwd_ngram_evidences_path,
        rev_ngram_evidences_path=rev_ngram_evidences_path,
        ngram_list=[4, 5, 6, 7, 8],
        min_ngram_match_threshold=1,
        ngram_evidence_count_threshold=1000000,
        # tagged_evidence_path=Path("./result/ner/evidence_by_noun.json"),
        save_path=retrieved_evidences_path,
        evidence_count_threshold=1000000,
        verb_pos_threshold=0,
        min_tag_match_threshold=1,
        min_all_match_threshold=4,
        max_retrieved=max_retrieval,
        stop_tags=[]
    )
# retrieved_evidences

loaded: /Users/johnsonzhou/git/comp90042-project/result/pipeline/shortlisting_v3/train_fwd_ngram_evidences.json
loaded: /Users/johnsonzhou/git/comp90042-project/result/pipeline/shortlisting_v3/train_rev_ngram_evidences.json


claims:   1%|▏         | 2/154 [00:09<11:57,  4.72s/it, n_retrieved: 758, recall: 0.4]

claim-375 0.4 {'reductio\xadn', 'carbon dioxide', 'carbon', 'amount', 'human', 'climate', 'total', 'total emission', 'cent', 'effect', 'prod\xaduce', 'annual emission', 'annual', 'global', 'global emission', 'australia'}


claims:  11%|█         | 17/154 [01:25<14:46,  6.47s/it, n_retrieved: 757, recall: 0.0]

claim-161 0.0 {'pressure', 'climate', 'continent', 'change', 'extreme', 'extreme melting', 'ground'}


claims:  16%|█▌        | 24/154 [02:10<09:57,  4.60s/it, n_retrieved: 816, recall: 0.3333333333333333]

claim-104 0.3333333333333333 {'carbon dioxide', 'carbon', 'atmospheric', 'increase', 'atmospheric dioxide', 'temperature'}


claims:  19%|█▉        | 30/154 [03:05<20:24,  9.88s/it, n_retrieved: 751, recall: 0.0]               

claim-2662 0.0 {'contention', 'level', 'carbon dioxide', 'carbon', 'earth', 'jr', 'university', 'wolfgang', 'atmosphere', 'wolfgang knorr', 'year', 'ken', 'ken ward', 'department', 'science', 'england', 'bristol'}


claims:  21%|██        | 32/154 [03:18<15:41,  7.72s/it, n_retrieved: 773, recall: 0.3333333333333333]

claim-2768 0.3333333333333333 {'term', 'albedo', 'term trend', 'long'}


claims:  22%|██▏       | 34/154 [03:24<11:25,  5.71s/it, n_retrieved: 752, recall: 0.2]               

claim-785 0.2 {'other case', 'many', 'particular', 'many case', 'global', 'particular trend', 'other', 'example', 'global warming', 'hurricane', 'linkage'}


claims:  23%|██▎       | 35/154 [03:25<08:40,  4.38s/it, n_retrieved: 755, recall: 0.0]

claim-2426 0.0 {'twentieth', 'global warming', 'global', 'twentieth century'}


claims:  27%|██▋       | 42/154 [04:27<15:16,  8.18s/it, n_retrieved: 800, recall: 0.4]

claim-540 0.4 {'just 1c', 'just', 'heatwave', 'europe'}


claims:  29%|██▊       | 44/154 [04:39<12:39,  6.91s/it, n_retrieved: 750, recall: 0.0] 

claim-1407 0.0 {'weight', 'book', 'canadian researcher', 'award', 'ross mckitrick', 'canadian', 'amount', 'height', 'ross', 'christopher', 'essex', 'temperature', 'storm'}


claims:  29%|██▉       | 45/154 [04:42<10:31,  5.79s/it, n_retrieved: 750, recall: 0.2]

claim-3070 0.2 {'tough', 'million', 'tough situation', 'family', 'government', 'bad', 'household', 'australia', 'risk', 'industry'}


claims:  31%|███       | 47/154 [04:50<08:22,  4.70s/it, n_retrieved: 750, recall: 0.4]

claim-1515 0.4 {'temperature record', 'global', 'one', 'global record', 'last', 'last year', 'temperature'}


claims:  35%|███▌      | 54/154 [05:57<12:18,  7.38s/it, n_retrieved: 778, recall: 0.3333333333333333]

claim-2611 0.3333333333333333 {'final', 'atmosphere', 'carbon dioxide', 'carbon', 'final amount', 'extra dioxide', 'century', 'extra', 'time', 'time scale'}


claims:  37%|███▋      | 57/154 [06:16<11:28,  7.09s/it, n_retrieved: 751, recall: 0.4]               

claim-1087 0.4 {'record temperature', 'way', 'overpeck', 'cold', 'particular', 'stratosphere', 'record', 'last', 'particular signature', 'last year', 'warming', 'cold temperature'}


claims:  38%|███▊      | 59/154 [06:28<10:21,  6.54s/it, n_retrieved: 791, recall: 0.2]

claim-2300 0.2 {'early', 'history', 'austria', 'mountain', 'early snowfall', 'centimetre', 'today'}


claims:  40%|████      | 62/154 [07:15<19:14, 12.55s/it, n_retrieved: 751, recall: 0.0]               

claim-3051 0.0 {'termperature', 'human contribution', 'el nino', 'termperature record', 'significant change', 'human', 'climate', 'significant', 'la nina', 'natural', 'el', 'la', 'recent', 'impact', 'evidence', 'climate change', 'natural influence', 'recent record'}


claims:  43%|████▎     | 66/154 [07:37<11:20,  7.74s/it, n_retrieved: 781, recall: 0.3333333333333333]

claim-2579 0.3333333333333333 {'past year', 'human', 'atmosphere', 'carbon dioxide', 'carbon', 'past', 'part'}


claims:  45%|████▌     | 70/154 [08:11<10:22,  7.42s/it, n_retrieved: 766, recall: 0.2]               

claim-1896 0.2 {'cent', 'greg research', 'nature', 'carbon', 'tree', 'carbon emission', 'year', 'csiro research', 'greg', 'nature soil', 'csiro', 'hunt', 'hunt research'}


claims:  51%|█████     | 78/154 [08:59<07:19,  5.78s/it, n_retrieved: 772, recall: 0.0]

claim-342 0.0 {'weather extreme', 'change', 'medium', 'such claim', 'member', 'different story', 'such', 'weather', 'evidence', 'different'}


claims:  52%|█████▏    | 80/154 [09:11<07:34,  6.15s/it, n_retrieved: 752, recall: 0.0]

claim-578 0.0 {'grown', 'young son', 'grown man', 'extreme heat', 'greenhouse', 'past summer', 'norm', 'young', 'extreme', 'past', 'emission', 'greenhouse gas', 'kind'}


claims:  57%|█████▋    | 88/154 [10:26<10:41,  9.72s/it, n_retrieved: 768, recall: 0.4]

claim-2577 0.4 {'recent article', 'alarmist', 'coral reef', 'recent', 'world climate report', 'greenhouse', 'world', 'reason', 'buildup', 'greenhouse gas', 'coral', 'deep trouble', 'deep'}


claims:  62%|██████▏   | 95/154 [11:09<05:15,  5.35s/it, n_retrieved: 750, recall: 0.0]

claim-1689 0.0 {'net feedback', 'evidence', 'cloud', 'net', 'cloud feedback'}


claims:  62%|██████▏   | 96/154 [11:19<06:25,  6.64s/it, n_retrieved: 770, recall: 0.0]

claim-443 0.0 {'overall picture', 'level', 'carbon dioxide', 'carbon', 'atmosphere', 'small part', 'overall', 'climate', 'climate scientist', 'today', 'small'}


claims:  67%|██████▋   | 103/154 [12:01<05:16,  6.21s/it, n_retrieved: 764, recall: 0.0]

claim-38 0.0 {'next', 'rise', 'sea', 'antarctica ice', 'antarctica', 'sea ice', 'next age', 'accumulation', 'scientist'}


claims:  68%|██████▊   | 104/154 [12:03<04:06,  4.93s/it, n_retrieved: 829, recall: 0.25]

claim-1643 0.25 {'book', 'al gore', 'al', 'contrarian', 'contrarian book'}


claims:  69%|██████▉   | 106/154 [12:19<04:31,  5.66s/it, n_retrieved: 758, recall: 0.25]

claim-1605 0.25 {'claim', 'sea', 'sea level', 'scientist'}


claims:  71%|███████▏  | 110/154 [12:54<06:05,  8.31s/it, n_retrieved: 769, recall: 0.0] 

claim-392 0.0 {'global', 'study', 'manmade', 'whole', 'whole concept', 'global warming'}


claims:  73%|███████▎  | 113/154 [13:12<04:39,  6.82s/it, n_retrieved: 755, recall: 0.2] 

claim-2583 0.2 {'impact', 'less', 'atmosphere', 'carbon dioxide', 'carbon', 'warming', 'unit'}


claims:  75%|███████▍  | 115/154 [13:26<04:29,  6.92s/it, n_retrieved: 761, recall: 0.2]

claim-492 0.2 {'natural', 'most', 'real effect', 'most claim', 'natural pattern', 'weather', 'global', 'real', 'weather pattern', 'global warming'}


claims:  75%|███████▌  | 116/154 [13:36<04:57,  7.83s/it, n_retrieved: 775, recall: 0.0]

claim-1420 0.0 {'warmth', 'atmosphere', 'ray', 'sun', 'earth', 'space', 'gas'}


claims:  76%|███████▌  | 117/154 [14:07<09:09, 14.86s/it, n_retrieved: 751, recall: 0.25]

claim-1089 0.25 {'monthly', 'sea ice', 'coral', 'bleaching', 'monthly low', 'coast', 'arctic', 'bad', 'january', 'record low', 'sea', 'large', 'barrier reef', 'regular low', 'warmth', 'death', 'tropical', 'record', 'regular', 'last', 'bleaching event', 'planet', 'last year', 'overall temperature', 'northeastern', 'ocean water', 'bad event', 'ocean', 'great reef', 'barrier', 'overall', 'large scale', 'september', 'northeastern australia', 'great', 'warm'}


claims:  77%|███████▋  | 118/154 [14:15<07:40, 12.79s/it, n_retrieved: 754, recall: 0.4] 

claim-1467 0.4 {'upper', 'carbon dioxide', 'carbon', 'ocean', 'upper layer', 'amount', 'year', 'ton'}


claims:  79%|███████▊  | 121/154 [14:29<04:13,  7.67s/it, n_retrieved: 1209, recall: 1.0]

claim-803 0.25 {'intergovernmental', 'change', 'hiatus', 'temperature', 'ipcc', 'temperature increase', 'change model', '21st', 'climate', 'team', 'beginning', '21st century', 'ipcc model', 'climate scientist', 'intergovernmental panel'}


claims:  80%|███████▉  | 123/154 [14:44<04:09,  8.04s/it, n_retrieved: 752, recall: 0.0] 

claim-846 0.0 {'gas', 'much', 'atmosphere', 'warming power', 'arctic', 'form', 'other', 'arctic permafrost', 'time', 'planet', 'warming', 'date', 'much carbon', 'other word'}


claims:  84%|████████▍ | 129/154 [15:00<01:09,  2.78s/it, n_retrieved: 873, recall: 0.4]

claim-181 0.4 {'basic food', 'basic', 'carbon dioxide', 'carbon', 'reef'}


claims:  84%|████████▍ | 130/154 [15:09<01:48,  4.52s/it, n_retrieved: 760, recall: 0.0]

claim-281 0.0 {'dr', 'bbc', 'scandal', 'global', 'dr jones', 'significant', 'significant warming', 'interview', 'global warming'}


claims:  85%|████████▌ | 131/154 [15:22<02:42,  7.07s/it, n_retrieved: 756, recall: 0.0]

claim-2809 0.0 {'massive', 'motion', 'heat', 'massive ocean', 'surface', 'century', 'year', 'deep layer', 'time', 'variability', 'time scale', 'deep'}


claims:  88%|████████▊ | 135/154 [15:36<01:28,  4.64s/it, n_retrieved: 760, recall: 0.2]

claim-988 0.2 {'unlikely', 'unlikely scenario', 'probable', 'say', 'sweet'}


claims:  89%|████████▉ | 137/154 [15:46<01:19,  4.67s/it, n_retrieved: 751, recall: 0.0]

claim-2282 0.0 {'overwhelming trend', 'overwhelming', 'retreat', 'isolated case', 'glacier', 'isolated'}


claims:  92%|█████████▏| 142/154 [16:07<00:44,  3.72s/it, n_retrieved: 686, recall: 0.4]

claim-897 0.0 {'world', 'united states', 'pollutionfree', 'report', 'pollutionfree nation', 'united'}
claim-3063 0.4 {'global', 'warming', 'warming myth'}


claims:  98%|█████████▊| 151/154 [17:29<00:18,  6.07s/it, n_retrieved: 756, recall: 0.4]               

claim-204 0.4 {'natural', 'atmospheric content', 'carbon dioxide', 'carbon', 'atmospheric', 'natural content', 'year'}


claims:  99%|█████████▊| 152/154 [17:33<00:10,  5.47s/it, n_retrieved: 796, recall: 0.4]

claim-1426 0.4 {'world', 'coral', 'constant decline', 'coral reef', 'constant', 'state'}


claims:  99%|█████████▉| 153/154 [17:48<00:08,  8.40s/it, n_retrieved: 755, recall: 0.0]

claim-698 0.0 {'hot', 'overestimation', 'systematic', 'ben santer', 'external', 'national laboratory climate scientist santer', 'national', 'model simulation', 'model', 'lawrence livermore scientist santer', 'post', 'systematic deficiency', 'recent', 'lawrence', 'ben', 'recent study', 'external forcing'}


claims: 100%|██████████| 154/154 [17:57<00:00,  7.00s/it, n_retrieved: 752, recall: 1.0]


saved to: /Users/johnsonzhou/git/comp90042-project/result/pipeline/shortlisting_v3/dev_retrieved_evidences_max_750_set_intersect_no_rel.json


In [53]:
print(f"avg recall: {np.mean(retrieval_recalls)}")
print(f"min recall: {np.min(retrieval_recalls)}")
print(f"avg retrieved: {np.mean(retrievals_counts)}")
print(f"max retrieved: {np.max(retrievals_counts)}")

avg recall: 0.6691558441558442
min recall: 0.0
avg retrieved: 775.077922077922
max retrieved: 1209


In [72]:
# len(all_unique_tags)

In [73]:
# len(fwd_ngram_evidences["elect"])

In [74]:
# len(tagged_evidences["south australia"])

In [75]:
# len(tagged_evidences.keys())

In [76]:
# [(tag, len(e)) for tag, e in sorted(tagged_evidences.items(), key=lambda x: len(x[1]), reverse=True)]