In [None]:
import pandas as pd 
import json 

In [2]:
questions_df = pd.read_csv('../data/questions_df.csv')
questions_df = questions_df.query('corpus_id == "state_of_the_union"')[['question', 'references']]
questions_df.head()

Unnamed: 0,question,references
0,What significant regulatory changes and propos...,"[{""content"": ""My administration announced we\u..."
1,What reasons did President Biden give for the ...,"[{""content"": ""But unfortunately, politics have..."
2,How many people are no longer denied health in...,"[{""content"": ""Over 100 million of you can no l..."
3,"Which country is Putin invading, causing chaos...","[{""content"": ""Overseas, Putin of Russia is on ..."
4,When did the murder rate experience the sharpe...,"[{""content"": ""Last year, the murder rate saw t..."


In [3]:
with open('../data/state_of_the_union.md') as f:
    text = ''.join(f.readlines())

In [None]:
from fixed_token_chunker import FixedTokenChunker


In [5]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

In [6]:
chunker = FixedTokenChunker(
    tokenizer=tokenizer,
    chunk_size=100,
    chunk_overlap=20,
)
chunks = chunker.split_text(text)

In [None]:
from sentence_transformers import SentenceTransformer

In [8]:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
text_embeddings = model.encode(chunks)

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

In [9]:
question_embeddings = model.encode(questions_df['question'].tolist())

In [10]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

In [32]:
similarities = cosine_similarity(question_embeddings, text_embeddings)
k = 1  # Number of closest texts to retrieve
closest_texts_indices = np.argsort(similarities, axis=1)[:, -k:][:, ::-1]  # Top k closest indices

In [33]:
closest_texts_indices

array([[ 75],
       [ 80],
       [ 46],
       [  2],
       [ 97],
       [119],
       [115],
       [ 27],
       [125],
       [ 10],
       [ 97],
       [ 24],
       [ 25],
       [ 89],
       [ 83],
       [ 76],
       [ 77],
       [  3],
       [ 69],
       [  0],
       [ 77],
       [ 22],
       [ 64],
       [ 24],
       [  8],
       [ 77],
       [106],
       [ 20],
       [ 10],
       [121],
       [ 14],
       [ 33],
       [ 36],
       [ 51],
       [ 58],
       [ 43],
       [ 47],
       [ 70],
       [ 27],
       [ 17],
       [ 13],
       [ 67],
       [ 62],
       [102],
       [114],
       [  0],
       [116],
       [ 41],
       [ 18],
       [119],
       [ 65],
       [ 74],
       [ 51],
       [ 92],
       [ 76],
       [ 56],
       [ 17],
       [  3],
       [ 34],
       [ 46],
       [111],
       [119],
       [102],
       [ 46],
       [108],
       [ 88],
       [  0],
       [  7],
       [  0],
       [ 17],
       [ 91],
      

In [34]:
def get_ranges(text, chunks, chunk_indices):
    ranges = []
    
    for index in chunk_indices:
        start = text.find(chunks[index])

        if start == -1:
            continue

        end = start + len(chunks[index])
        ranges.append((start, end))

    return ranges

def sum_of_ranges(ranges):
    return sum(end - start for start, end in ranges)

def merge_intervals(intervals):
    if not intervals:
        return []
    # Sort intervals based on the start time.
    intervals.sort(key=lambda x: x[0])
    merged = [intervals[0]]
    for current in intervals[1:]:
        last_start, last_end = merged[-1]
        curr_start, curr_end = current
        # Check for overlap (assuming intervals are inclusive)
        if curr_start <= last_end:
            # Merge by extending the end time if necessary.
            merged[-1] = (last_start, max(last_end, curr_end))
        else:
            merged.append(current)
    return merged

def intersect_intervals(retrieved, targets):
    # Sort both lists by start times
    retrieved.sort(key=lambda x: x[0])
    targets.sort(key=lambda x: x[0])
    
    i, j = 0, 0
    intersections = []
    
    while i < len(retrieved) and j < len(targets):
        r_start, r_end = retrieved[i]
        t_start, t_end = targets[j]
        
        # Find overlap boundaries
        start = max(r_start, t_start)
        end = min(r_end, t_end)
        
        if start <= end:  # They overlap
            intersections.append((start, end))
        
        # Move the pointer that ends first
        if r_end < t_end:
            i += 1
        else:
            j += 1
            
    return intersections


In [35]:
retrieved_intervals = [
    get_ranges(text, chunks, indicies) 
    for indicies in closest_texts_indices
]

target_intervals = [
    [(ref['start_index'], ref['end_index']) for ref in json.loads(references)]
    for references in questions_df['references']
]

In [36]:
retrieved_intervals = [
    merge_intervals(intervals)
    for intervals in retrieved_intervals
]

target_intervals = [
    merge_intervals(intervals)
    for intervals in target_intervals
]

In [37]:
for retrieved, target in zip(retrieved_intervals, target_intervals):
    intersections = intersect_intervals(retrieved.copy(), target.copy())
    total_intersection = sum_of_ranges(intersections)
    total_retrieved = sum_of_ranges(retrieved)
    total_target = sum_of_ranges(target)

    IoU = total_intersection / (total_retrieved + total_target - total_intersection)
    precision = total_intersection / total_retrieved
    recall = total_intersection / total_target
    f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
    
    print(f'IoU: {IoU:.2f} Precision: {precision:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}')

IoU: 0.29 Precision: 0.34, Recall: 0.67, F1: 0.45
IoU: 0.59 Precision: 0.62, Recall: 0.94, F1: 0.74
IoU: 0.05 Precision: 0.06, Recall: 0.25, F1: 0.09
IoU: 0.22 Precision: 0.22, Recall: 1.00, F1: 0.35
IoU: 0.27 Precision: 0.27, Recall: 1.00, F1: 0.43
IoU: 0.38 Precision: 0.39, Recall: 0.93, F1: 0.55
IoU: 0.11 Precision: 0.11, Recall: 1.00, F1: 0.20
IoU: 0.00 Precision: 0.00, Recall: 0.00, F1: 0.00
IoU: 0.00 Precision: 0.00, Recall: 0.00, F1: 0.00
IoU: 0.22 Precision: 0.22, Recall: 1.00, F1: 0.36
IoU: 0.22 Precision: 0.22, Recall: 1.00, F1: 0.36
IoU: 0.00 Precision: 0.00, Recall: 0.00, F1: 0.00
IoU: 0.13 Precision: 0.13, Recall: 1.00, F1: 0.23
IoU: 0.09 Precision: 0.09, Recall: 1.00, F1: 0.17
IoU: 0.23 Precision: 0.23, Recall: 1.00, F1: 0.38
IoU: 0.00 Precision: 0.00, Recall: 0.00, F1: 0.00
IoU: 0.54 Precision: 0.54, Recall: 1.00, F1: 0.70
IoU: 0.30 Precision: 0.30, Recall: 1.00, F1: 0.46
IoU: 0.21 Precision: 0.21, Recall: 1.00, F1: 0.34
IoU: 0.00 Precision: 0.00, Recall: 0.00, F1: 0.00
