In [1]:
import json
import argparse
import os
import sys
from tqdm import tqdm
from collections import defaultdict
import pandas as pd
import numpy as np
from nltk.stem import PorterStemmer

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, BertModel
from transformers import GPT2Tokenizer, GPT2Model

from swisscom_ai.research_keyphrase.preprocessing.postagging import PosTaggingCoreNLP
from swisscom_ai.research_keyphrase.model.input_representation import InputTextObj
from swisscom_ai.research_keyphrase.model.extractor import extract_candidates, extract_verb_candidates

from graphviz import Digraph

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
host = 'localhost'
port = 9000
pos_tagger = PosTaggingCoreNLP(host, port)

In [3]:
# load stopwords
stopwords = []
with open('UGIR_stopwords.txt', "r") as f:
    for line in f:
        if line:
            stopwords.append(line.replace('\n', ''))


In [4]:
stemmer = PorterStemmer()

In [5]:
def get_col_sum_token_level(attention_map):
    tokens_score = torch.sum(attention_map, dim=0)
    return tokens_score


def redistribute_global_attention_score(attention_map, tokens_score):
    new_attention_map = attention_map * tokens_score.unsqueeze(0)
    return new_attention_map


def normalize_attention_map(attention_map):
    attention_map_sum = attention_map.sum(dim=0, keepdim=True)
    attention_map_sum += 1e-10
    attention_map_normalized = attention_map / attention_map_sum
    return attention_map_normalized


def get_row_sum_token_level(attention_map):
    tokens_score = torch.sum(attention_map, dim=1)
    return tokens_score


def aggregate_phrase_scores(index_list, tokens_scores):
    total_score = 0.0
    
    for p_index in index_list:
        part_sum = tokens_scores[p_index[0]:p_index[1]].sum()
        total_score += part_sum

    return total_score

# Checkpoint
# Get the highest global attention score for each phrase
def get_phrase_source(candidates_indices, index_list, attention_map, start_index, end_index, eos_indices):
    total_score = torch.zeros(attention_map.shape[1], dtype=torch.float32, device="cuda" if torch.cuda.is_available() else "cpu")
    if len(index_list) == 0 or start_index == 0:
        return None, 0.0, None, None

    # Get global attention scores from all indices in the index_list
    for p_index in index_list:
        # Skip indices that are out of the range of this sentence
        if p_index[0] < start_index or p_index[1] > end_index:
            continue

        part_sum = attention_map[p_index[0]:p_index[1]].sum(axis=0)
        total_score += part_sum
    
    
    # Set scores for itself to 0
    for p_index in index_list:
        total_score[p_index[0]:p_index[1]] = 0.0
        
    last_index = index_list[-1][0] 

    best_phrase = None
    best_score = float('-inf')
    best_sentence = None
    last_source_index = None
    for phrase in candidates_indices.keys():
        try:
            phrase_indices = candidates_indices[phrase]
            if len(phrase_indices) == 0:
                continue
        except KeyError:
            continue
        phrase_score = aggregate_phrase_scores(phrase_indices, total_score)

        

        # If the phrase score is better than the best score, update the best phrase and score
        if phrase_score > best_score:

            # Find end of sentence indices that lower or equal to the last index of the phrase
            last_index2 = phrase_indices[-1][1]
            for i, idx in enumerate(phrase_indices):
                if idx[0] > last_index:
                    last_index2 = phrase_indices[i-1][1]
                    break

            # Get the location of the last index in the phrase
            for i, eos_idx in enumerate(eos_indices):
                if last_index2 <= eos_idx:
                    sentence = i+1
                    break
                
            best_score = phrase_score
            best_phrase = phrase
            best_sentence = sentence
            last_source_index = last_index2
    return best_phrase, best_score, best_sentence, last_source_index

# Get the highest global attention score for each phrase
def get_phrase_relation(candidates_indices, index_list, attention_map, start_index, end_index, source_index):

    total_score = torch.zeros(attention_map.shape[1], dtype=torch.float32, device="cuda" if torch.cuda.is_available() else "cpu")
    if len(index_list) == 0 or start_index == 0:
        return ""


    # Get global attention scores from all indices in the index_list
    for p_index in index_list:
        # Skip indices that are out of the range of this sentence
        if p_index[0] < start_index or p_index[1] > end_index:
            continue
        
        part_sum = attention_map[p_index[0]:p_index[1]].sum(axis=0)
        total_score += part_sum
    
    # Set scores for itself to 0
    for p_index in index_list:
        total_score[p_index[0]:p_index[1]] = 0.0

    
    best_phrase = ""
    best_score = float('-inf')
    total_score[:source_index] = 0.0  # Set scores before source index to 0
    for phrase in candidates_indices.keys():
        try:
            phrase_indices = candidates_indices[phrase]
            if len(phrase_indices) == 0:
                continue
        except KeyError:
            continue
        phrase_score = aggregate_phrase_scores(phrase_indices, total_score)

        
        # If the phrase score is better than the best score, update the best phrase and score
        if phrase_score > best_score:
            best_score = phrase_score
            best_phrase = phrase

    return best_phrase

def get_phrase_indices(text_tokens, phrase, prefix):
    text_tokens = [t.replace(prefix, '') for t in text_tokens]

    phrase = phrase.replace(' ', '')

    matched_indices = []
    matched_index = []
    target = phrase
    for i in range(len(text_tokens)):
        cur_token = text_tokens[i]
        sub_len = min(len(cur_token), len(phrase))
        if cur_token[:sub_len].lower() == target[:sub_len]:
            matched_index.append(i)
            target = target[sub_len:]
            if len(target) == 0:
                matched_indices.append([matched_index[0], matched_index[-1] + 1])
                target = phrase
        else:
            matched_index = []
            target = phrase
            if cur_token[:sub_len].lower() == target[:sub_len]:
                matched_index.append(i)
                target = target[sub_len:]
                if len(target) == 0:
                    matched_indices.append([matched_index[0], matched_index[-1] + 1])
                    target = phrase

    return matched_indices

def get_verb_indices(text_tokens, phrase, prefix):
    # Remove prefix from text tokens
    text_tokens = [t.replace(prefix, '') for t in text_tokens]

    # Split phrase into tokens (words)
    phrase_tokens = phrase.split(' ')

    matched_indices = []

    # Loop over text tokens to find phrase matches
    for i in range(len(text_tokens) - len(phrase_tokens) + 1):
        # Slice text tokens of phrase length
        window = text_tokens[i:i+len(phrase_tokens)]
        # Compare tokens ignoring case
        if all(w.lower() == p.lower() for w, p in zip(window, phrase_tokens)):
            matched_indices.append([i, i + len(phrase_tokens)])

    return matched_indices

def remove_repeated_sub_word(candidates_pos_dict):
    for phrase in candidates_pos_dict.keys():
        split_phrase = phrase.split()
        if len(split_phrase) > 1:
            for word in split_phrase:
                if word in candidates_pos_dict:
                    single_word_positions = candidates_pos_dict[word]
                    phrase_positions = candidates_pos_dict[phrase]
                    single_word_alone_positions = [pos for pos in single_word_positions if not any(
                        pos[0] >= phrase_pos[0] and pos[1] <= phrase_pos[1] for phrase_pos in phrase_positions)]
                    candidates_pos_dict[word] = single_word_alone_positions

    return candidates_pos_dict

In [6]:
def get_same_len_segments(total_tokens_ids, max_len):
    num_of_seg = (len(total_tokens_ids) // max_len) + 1
    seg_len = int(len(total_tokens_ids) / num_of_seg)
    segments = []
    attn_masks = []
    for _ in range(num_of_seg):
        if len(total_tokens_ids) > seg_len:
            segment = total_tokens_ids[:seg_len]
            total_tokens_ids = total_tokens_ids[seg_len:]
        else:
            segment = total_tokens_ids
        segments.append(segment)
        attn_masks.append([1] * len(segments[-1]))

    return segments, attn_masks

def read_jsonl(path):
    data = []
    with open(path, 'r') as f:
        for line in f:
            item = json.loads(line.strip())
            data.append(item)
    return data


def get_candidates(core_nlp, text):
    tagged = core_nlp.pos_tag_raw_text(text)
    text_obj = InputTextObj(tagged, 'en')
    candidates = extract_candidates(text_obj)
    return candidates

# ADDED: Function to get verb candidates from text using CoreNLP
def get_verb_candidates(core_nlp, text):
    tagged = core_nlp.pos_tag_raw_text(text)
    text_obj = InputTextObj(tagged, 'en')
    print(text_obj.pos_tagged)
    candidates = extract_verb_candidates(text_obj)
    return candidates

def get_score_full(candidates, references, maxDepth=15):
    precision = []
    recall = []
    reference_set = set(references)
    referencelen = len(reference_set)
    true_positive = 0
    for i in range(maxDepth):
        if len(candidates) > i:
            kp_pred = candidates[i]
            if kp_pred in reference_set:
                true_positive += 1
            precision.append(true_positive / float(i + 1))
            recall.append(true_positive / float(referencelen))
        else:
            precision.append(true_positive / float(len(candidates)))
            recall.append(true_positive / float(referencelen))
    return precision, recall


def evaluate(candidates, references):
    results = {}
    precision_scores, recall_scores, f1_scores = {5: [], 10: [], 15: []}, \
                                                 {5: [], 10: [], 15: []}, \
                                                 {5: [], 10: [], 15: []}
    for candidate, reference in zip(candidates, references):
        p, r = get_score_full(candidate, reference)
        for i in [5, 10, 15]:
            precision = p[i - 1]
            recall = r[i - 1]
            if precision + recall > 0:
                f1_scores[i].append((2 * (precision * recall)) / (precision + recall))
            else:
                f1_scores[i].append(0)
            precision_scores[i].append(precision)
            recall_scores[i].append(recall)

    print("########################\nMetrics")
    for i in precision_scores:
        print("@{}".format(i))
        print("F1:{}".format(np.mean(f1_scores[i])))
        print("P:{}".format(np.mean(precision_scores[i])))
        print("R:{}".format(np.mean(recall_scores[i])))

        top_n_p = 'precision@' + str(i)
        top_n_r = 'recall@' + str(i)
        top_n_f1 = 'f1@' + str(i)
        results[top_n_p] = np.mean(precision_scores[i])
        results[top_n_r] = np.mean(recall_scores[i])
        results[top_n_f1] = np.mean(f1_scores[i])
    print("#########################")

    return results



In [7]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2', output_hidden_states=True, output_attentions=True)

In [8]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str)
parser.add_argument('--plm', type=str)
parser.add_argument('--mode', type=str)
args = parser.parse_args([
    '--dataset', 'SemEval2017',
    '--plm', 'GPT2',
    '--mode', 'Both'
])

In [9]:
if args.dataset == 'Inspec' or args.dataset == 'inpsec':
        data_path = 'data/Inspec.jsonl'
        doc_type = 'short'
elif args.dataset == 'SemEval2010' or args.dataset == 'semeval2010':
    data_path = 'data/SemEval2010.jsonl'
    doc_type = 'long'
elif args.dataset == 'SemEval2017' or args.dataset == 'semeval2017':
    data_path = 'data/SemEval2017.jsonl'
    doc_type = 'short'
elif args.dataset == 'Krapivin' or args.dataset == 'krapivin':
    data_path = 'data/krapivin.jsonl'
    doc_type = 'long'
dataset = read_jsonl(data_path)
# rank_short_documents(args, dataset, model,tokenizer)

In [10]:
import stanza
stanza.download('en')  # only the first time
nlp = stanza.Pipeline(lang='en', processors='tokenize')



Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.10.0.json: 433kB [00:00, 117MB/s]                     
2025-07-27 22:05:58 INFO: Downloaded file to C:\Users\asia_\stanza_resources\resources.json
2025-07-27 22:05:58 INFO: Downloading default packages for language: en (English) ...
2025-07-27 22:05:59 INFO: File exists: C:\Users\asia_\stanza_resources\en\default.zip
2025-07-27 22:06:02 INFO: Finished downloading models and saved to C:\Users\asia_\stanza_resources
2025-07-27 22:06:02 INFO: Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES
Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.10.0.json: 433kB [00:00, 65.8MB/s]                    
2025-07-27 22:06:02 INFO: Downloaded file to C:\Users\asia_\stanza_resources\resources.json
2025-07-27 22:06:02 INFO: Loading the

In [11]:
def get_eos_indices(text):
    doc = nlp(text)
    sentences = [sentence.text for sentence in doc.sentences]
    sent_token_lens = [len(tokenizer.tokenize(sent)) for sent in sentences]
    
    # Calculate cumulative sum of tokens to get EOS indices
    eos_token_indices = []
    cum_sum = 0
    for length in sent_token_lens:
        cum_sum += length
        eos_token_indices.append(cum_sum - 1)
    return eos_token_indices

In [12]:
ignore = ['.', '(', ')', '[', ']', '{', '}', '"', "'", '?', '!', ':', ';', '_', '*', '%', '$', '#', '@']

In [None]:
def get_flow(args, data, model, tokenizer):
    if args.plm == 'BERT':
        prefix = '##'
    elif args.plm == 'GPT2':
        prefix = 'Ġ'

    # layer_head_predicted_top15 = defaultdict(list)
    sentence_predicted_top = defaultdict(list)
    sentence_predicted_source = defaultdict(list)
    sentence_predicted_relation = defaultdict(list)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"device: {device}")

    model.to(device)
    model.eval()

    # for data in tqdm(dataset):
    with torch.no_grad():
        new_text = '.'+data['text']
        tokenized_text = tokenizer(new_text, return_tensors='pt')

        outputs = model(**tokenized_text.to(device))

        attentions = outputs.attentions

        candidates = get_candidates(pos_tagger, new_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        
        
        # print(candidates)
        text_tokens = tokenizer.convert_ids_to_tokens(tokenized_text['input_ids'].squeeze(0))

        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(text_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices

        candidates_indices = remove_repeated_sub_word(candidates_indices)

        """   Get Verb Candidates   """
        verb_candidates = get_verb_candidates(pos_tagger, new_text)
        filtered_phrases = []
        for phrase in verb_candidates:
            words = phrase.split(' ')
            # Remove all words in ignore
            filtered_words = [w for w in words if w not in ignore]
            if filtered_words:  # Only add if not empty
                filtered_phrases.append(' '.join(filtered_words))
        print(verb_candidates)
        verb_candidates_indices = {}
        for phrase in filtered_phrases:
            matched_indices = get_verb_indices(text_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            verb_candidates_indices[phrase] = matched_indices

        verb_candidates_indices = remove_repeated_sub_word(verb_candidates_indices)
        # print(verb_candidates_indices)
        """   END   """   

        layer = 10
        head = 0
        head2 = 0

        n_layer_attentions = attentions[layer].squeeze(0)
        attention_map = n_layer_attentions[head]
        relation_map = n_layer_attentions[head2]
        eos_indices = get_eos_indices(new_text)
        start_index = 0
        for i, index in enumerate(eos_indices):
            
            # global_attention_scores = get_col_sum_token_level(attention_map)
            end_index = index + 1
            interested_map = attention_map[:end_index,:end_index]
            
            global_attention_scores = get_col_sum_token_level(interested_map)

            if args.plm == "BERT":
                global_attention_scores[-1] = 0
            elif args.plm == "GPT2":
                global_attention_scores[0] = 0

            # redistributed_attention_map = redistribute_global_attention_score(attention_map,
            #                                                                     global_attention_scores)

            redistributed_attention_map = normalize_attention_map(interested_map)

            proportional_attention_scores = get_row_sum_token_level(redistributed_attention_map)

            # if args.mode == 'Both':
            #     final_tokens_score = global_attention_scores + proportional_attention_scores
            # elif args.mode == 'Global':
            #     final_tokens_score = global_attention_scores
            # elif args.mode == 'Proportional':
            #     final_tokens_score = proportional_attention_scores
            final_tokens_score = global_attention_scores + proportional_attention_scores
            final_tokens_score[:start_index] = 0.0
            # padded_final_tokens_score = F.pad(final_tokens_score, (start_index, 0), mode='constant', value=0)
            phrase_score_dict = {}
            for phrase in candidates_indices.keys():
                try:
                    phrase_indices = candidates_indices[phrase]
                    if len(phrase_indices) == 0:
                        continue
                except KeyError:
                    continue

                final_phrase_score = aggregate_phrase_scores(phrase_indices, final_tokens_score)
                
                if len(phrase.split()) == 1:
                    final_phrase_score = final_phrase_score / len(phrase_indices)
                phrase_score_dict[phrase] = final_phrase_score

            # print(phrase_score_dict)
            sorted_scores = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
            # stemmed_sorted_scores = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
            #                             phrase, score in sorted_scores]
            # print(stemmed_sorted_scores)
            set_scores_list = []
            for phrase, score in sorted_scores:
                # print(phrase)
                if phrase not in set_scores_list:
                    set_scores_list.append((phrase,score))

            pred_stemmed_phrases = set_scores_list[:5]
            
            source_score_list = []
            relation_list = []
            """   Get Source   """

            for phrase in pred_stemmed_phrases:
                phrase_indices = candidates_indices[phrase[0]]
                if len(phrase_indices) == 0:
                    continue
                
                phrase_source, source_score, source_sentence, source_index = get_phrase_source(candidates_indices, phrase_indices, attention_map, start_index, end_index, eos_indices)
                source_score_list.append((phrase_source, source_score, source_sentence))

                """   Get Relation   """
                
                phrase_relation = get_phrase_relation(verb_candidates_indices, phrase_indices, relation_map, start_index, end_index, source_index)
                
                relation_list.append(phrase_relation)

                """   END   """
                
            """   END   """
        
            # layer_head_predicted_top15[(layer, head)].append(pred_stemmed_phrases)
            sentence_predicted_top[i].append(pred_stemmed_phrases)
            sentence_predicted_source[i].append(source_score_list)
            sentence_predicted_relation[i].append(relation_list)
            # for i in pred_stemmed_phrases:
            #     if len(sentence_predicted_top[i]) == 0:
            #         sentence_predicted_top[i] = []
            
            start_index = end_index
        return sentence_predicted_top, sentence_predicted_source, sentence_predicted_relation    
        

In [35]:
def get_flow2(args, data, model, tokenizer):
    if args.plm == 'BERT':
        prefix = '##'
    elif args.plm == 'GPT2':
        prefix = 'Ġ'

    # layer_head_predicted_top15 = defaultdict(list)
    sentence_predicted_top = defaultdict(list)
    sentence_predicted_source = defaultdict(list)
    sentence_predicted_relation = defaultdict(list)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"device: {device}")

    model.to(device)
    model.eval()

    # for data in tqdm(dataset):
    with torch.no_grad():
        new_text = '.'+data['text']
        tokenized_text = tokenizer(new_text, return_tensors='pt')

        outputs = model(**tokenized_text.to(device))

        attentions = outputs.attentions

        candidates = get_candidates(pos_tagger, new_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        
        
        # print(candidates)
        text_tokens = tokenizer.convert_ids_to_tokens(tokenized_text['input_ids'].squeeze(0))

        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(text_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices

        candidates_indices = remove_repeated_sub_word(candidates_indices)

        """   Get Verb Candidates   """
        verb_candidates = get_verb_candidates(pos_tagger, new_text)
        filtered_phrases = []
        for phrase in verb_candidates:
            words = phrase.split(' ')
            # Remove all words in ignore
            filtered_words = [w for w in words if w not in ignore]
            if filtered_words:  # Only add if not empty
                filtered_phrases.append(' '.join(filtered_words))
        print(verb_candidates)
        verb_candidates_indices = {}
        for phrase in filtered_phrases:
            matched_indices = get_verb_indices(text_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            verb_candidates_indices[phrase] = matched_indices

        verb_candidates_indices = remove_repeated_sub_word(verb_candidates_indices)
        # print(verb_candidates_indices)
        """   END   """   

        layer = 10
        head = 0
        head2 = 0

        n_layer_attentions = attentions[layer].squeeze(0)
        attention_map = n_layer_attentions[head]
        relation_map = n_layer_attentions[head2]
        eos_indices = get_eos_indices(new_text)
        start_index = 0
        print(f"eos_indices: {eos_indices}")
        for i, index in enumerate(eos_indices):
            
            # global_attention_scores = get_col_sum_token_level(attention_map)
            end_index = index + 1
            interested_map = attention_map[:end_index+1,:end_index+1]
            
            global_attention_scores = get_col_sum_token_level(interested_map)

            if args.plm == "BERT":
                global_attention_scores[-1] = 0
            elif args.plm == "GPT2":
                global_attention_scores[0] = 0

            # redistributed_attention_map = redistribute_global_attention_score(attention_map,
            #                                                                     global_attention_scores)

            redistributed_attention_map = normalize_attention_map(interested_map)

            proportional_attention_scores = get_row_sum_token_level(redistributed_attention_map)

            # if args.mode == 'Both':
            #     final_tokens_score = global_attention_scores + proportional_attention_scores
            # elif args.mode == 'Global':
            #     final_tokens_score = global_attention_scores
            # elif args.mode == 'Proportional':
            #     final_tokens_score = proportional_attention_scores
            final_tokens_score = global_attention_scores + proportional_attention_scores
            final_tokens_score[:start_index] = 0.0
            # padded_final_tokens_score = F.pad(final_tokens_score, (start_index, 0), mode='constant', value=0)
            phrase_score_dict = {}
            for phrase in candidates_indices.keys():
                try:
                    phrase_indices = candidates_indices[phrase]
                    if len(phrase_indices) == 0:
                        continue
                except KeyError:
                    continue

                final_phrase_score = aggregate_phrase_scores(phrase_indices, final_tokens_score)
                
                if len(phrase.split()) == 1:
                    final_phrase_score = final_phrase_score / len(phrase_indices)
                phrase_score_dict[phrase] = final_phrase_score

            # print(phrase_score_dict)
            sorted_scores = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
            # stemmed_sorted_scores = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
            #                             phrase, score in sorted_scores]
            # print(stemmed_sorted_scores)
            set_scores_list = []
            for phrase, score in sorted_scores:
                # print(phrase)
                if phrase not in set_scores_list:
                    set_scores_list.append((phrase,score))

            pred_stemmed_phrases = set_scores_list[:5]
            
            source_score_list = []
            relation_list = []
            """   Get Source   """

            for phrase in pred_stemmed_phrases:
                phrase_indices = candidates_indices[phrase[0]]
                if len(phrase_indices) == 0:
                    continue
                
                phrase_source, source_score, source_sentence, source_index = get_phrase_source(candidates_indices, phrase_indices, attention_map, start_index, end_index, eos_indices)
                source_score_list.append((phrase_source, source_score, source_sentence))

                """   Get Relation   """
                
                phrase_relation = get_phrase_relation(verb_candidates_indices, phrase_indices, relation_map, start_index, end_index, source_index)
                
                relation_list.append(phrase_relation)

                """   END   """
                
            """   END   """
        
            # layer_head_predicted_top15[(layer, head)].append(pred_stemmed_phrases)
            sentence_predicted_top[i].append(pred_stemmed_phrases)
            sentence_predicted_source[i].append(source_score_list)
            sentence_predicted_relation[i].append(relation_list)
            # for i in pred_stemmed_phrases:
            #     if len(sentence_predicted_top[i]) == 0:
            #         sentence_predicted_top[i] = []
            
            start_index = end_index
        return sentence_predicted_top, sentence_predicted_source, sentence_predicted_relation    
        

In [14]:
practicadata_path = 'data/SemEval2017.jsonl'

dataset = read_jsonl(data_path)

In [15]:
a = {
    'text': (
        "Great Wall of China, extensive bulwark erected in ancient China, one of the largest building-construction projects ever undertaken. ",
        "The Great Wall actually consists of numerous walls—many of them parallel to each other—built over some two millennia across northern China and southern Mongolia. ",
        "The most extensive and best-preserved version of the wall dates from the Ming dynasty (1368–1644) and runs for some 5,500 miles (8,850 km) east to west from Mount Hu near Dandong, southeastern Liaoning province, to Jiayu Pass west of Jiuquan, northwestern Gansu province. ",
        "This wall often traces the crestlines of hills and mountains as it snakes across the Chinese countryside, and about one-fourth of its length consists solely of natural barriers such as rivers and mountain ridges. ",
        "Nearly all of the rest (about 70 percent of the total length) is actual constructed wall, with the small remaining stretches constituting ditches or moats. ",
        "Although lengthy sections of the wall are now in ruins or have disappeared completely, it is still one of the more remarkable structures on Earth. ",
        "The Great Wall was designated a UNESCO World Heritage site in 1987."
    )
}

# Join into a single string
a['text'] = ''.join(a['text'])

print(a['text'])


Great Wall of China, extensive bulwark erected in ancient China, one of the largest building-construction projects ever undertaken. The Great Wall actually consists of numerous walls—many of them parallel to each other—built over some two millennia across northern China and southern Mongolia. The most extensive and best-preserved version of the wall dates from the Ming dynasty (1368–1644) and runs for some 5,500 miles (8,850 km) east to west from Mount Hu near Dandong, southeastern Liaoning province, to Jiayu Pass west of Jiuquan, northwestern Gansu province. This wall often traces the crestlines of hills and mountains as it snakes across the Chinese countryside, and about one-fourth of its length consists solely of natural barriers such as rivers and mountain ridges. Nearly all of the rest (about 70 percent of the total length) is actual constructed wall, with the small remaining stretches constituting ditches or moats. Although lengthy sections of the wall are now in ruins or have di

In [16]:
def trim_by_ratio(data, min_ratio=0.5):
    nums = data[0]  # Expecting input like [[(index, value), ...]]
    result = [nums[0]]

    for i in range(1, len(nums)):
        current_val = nums[i][1]
        prev_val = nums[i - 1][1]

        if current_val == 0:
            break
        if current_val < min_ratio * prev_val:
            break

        result.append(nums[i])

    return [result]

In [17]:
def make_table_label(sentence_label, phrase, attention_str):
    return f'''<
        <TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">
            <TR><TD><FONT COLOR="red" POINT-SIZE="8">{sentence_label}:</FONT></TD></TR>
            <TR><TD><FONT COLOR="blue" POINT-SIZE="16">{phrase[0]}</FONT></TD></TR>
            <TR><TD><FONT COLOR="red" POINT-SIZE="6">{attention_str}</FONT></TD></TR>
        </TABLE>
    >'''

In [36]:
predicts, sources, relations = get_flow(args, a, model,tokenizer)

device: cuda
[[('.', 'LESS')], [('great', 'NNP'), ('wall', 'NNP'), ('of', 'LESS'), ('china', 'NNP'), (',', 'LESS'), ('extensive', 'JJ'), ('bulwark', 'NN'), ('erected', 'VBN'), ('in', 'LESS'), ('ancient', 'JJ'), ('china', 'NNP'), (',', 'LESS'), ('one', 'CD'), ('of', 'LESS'), ('the', 'DT'), ('largest', 'JJS'), ('building-construction', 'JJ'), ('projects', 'NNS'), ('ever', 'RB'), ('undertaken', 'VBN'), ('.', 'LESS')], [('the', 'DT'), ('great', 'NNP'), ('wall', 'NNP'), ('actually', 'RB'), ('consists', 'VBZ'), ('of', 'LESS'), ('numerous', 'JJ'), ('walls', 'NNS'), ('--', 'LESS'), ('many', 'JJ'), ('of', 'LESS'), ('them', 'PRP'), ('parallel', 'VBP'), ('to', 'LESS'), ('each', 'DT'), ('other', 'JJ'), ('--', 'LESS'), ('built', 'VBN'), ('over', 'IN'), ('some', 'DT'), ('two', 'CD'), ('millennia', 'NNS'), ('across', 'IN'), ('northern', 'JJ'), ('china', 'NNP'), ('and', 'CC'), ('southern', 'JJ'), ('mongolia', 'NNP'), ('.', 'LESS')], [('the', 'DT'), ('most', 'RBS'), ('extensive', 'JJ'), ('and', 'CC'), 

In [None]:
# Checkpoint
dot = Digraph()

existing_nodes = set()
phrase_sentences = defaultdict(set)
phrase_attention = defaultdict(set)
dot.attr(rankdir='TB', size='8,5')

threshold = 0.5
print(len(predicts))
# Create nodes and edges
for i in range(len(predicts)):
    p = predicts[i]
    s = sources[i]
    r = relations[i]

    if len(p) == 0:
        continue

    p = trim_by_ratio(p,0.25) 
    s = s[:len(p[0])]
    r = r[:len(p[0])]

    with dot.subgraph() as sgraph:
    #     sgraph.attr(rank='same')
        for phrase, source, relation in zip(p[0], s[0], r[0]):
            
            if phrase[1] > threshold:

                phrase_sentences[phrase[0]].add(i + 1)
                sentence_str = ', '.join(map(str, sorted(phrase_sentences[phrase[0]])))
                phrase_attention[phrase[0]].add(round(float(phrase[1]),3))
                attention_str = ', '.join(map(str, sorted(phrase_attention[phrase[0]])))
                sentence_label = f'sentence {sentence_str}'

                if phrase[0] not in existing_nodes:
                    existing_nodes.add(phrase[0])
                    dot.node(phrase[0], make_table_label(sentence_label, phrase, attention_str), shape='box')

                    if source[0]:
                        phrase_sentences[source[0]].add(source[2])
                        sentence_str = ', '.join(map(str, sorted(phrase_sentences[source[0]])))
                        phrase_attention[source[0]].add(round(float(source[1]),3))
                        attention_str = ', '.join(map(str, sorted(phrase_attention[source[0]])))
                        sentence_label = f'sentence {sentence_str}'
                        
                        if source[0] not in existing_nodes:
                            existing_nodes.add(source[0])
                        
                        dot.node(source[0], make_table_label(sentence_label, source, attention_str), shape='box')
                        dot.edge(source[0], phrase[0], label=relation if source[2]== i else '')
                # else:
                #     dot.node(phrase[0], make_table_label(sentence_label, phrase, attention_str), shape='box')

    # with dot.subgraph() as sgraph:
    #     sgraph.attr(rank='same')
    #     for phrase, _ in p[0]:
    #         if phrase not in existing_nodes:
    #             existing_nodes.add(phrase)
    #             dot.node(phrase, phrase)


dot.render(view=True)  # Or just display in notebook: dot

7


'Digraph.gv.pdf'

In [33]:
import spacy

# Load English model
nlp2 = spacy.load("en_core_web_sm")

text = a['text']

doc = nlp2(text)

# # Print dependency parse info for tokens
# for token in doc:
#     print(token.text, token.dep_, token.head.text)

# Function to check if two tokens are connected syntactically
def are_connected(token1, token2):
    # Check if token1 is ancestor of token2 or vice versa
    return token1 in token2.ancestors or token2 in token1.ancestors

# Find tokens for 'Ming' and 'ruins'
token_ming = [t for t in doc if t.text == "wall"][0]
token_ruins = [t for t in doc if t.text == "length"][0]
token_structures = [t for t in doc if t.text == "structures"][0]

print("Are 'Dynasty' and 'ruins' connected?", are_connected(token_ming, token_ruins))  # likely False
print("Are 'Ming' and 'structures' connected?", are_connected(token_ming, token_structures))  # likely True


Are 'Dynasty' and 'ruins' connected? False
Are 'Ming' and 'structures' connected? False


# MAKING TREE