# General

## Imports

In [2]:
import os
import json
import random
import functools
import numpy as np
import dspy
import phoenix as px

from collections import Counter
from typing import Dict, List
from sklearn.metrics import homogeneity_score, completeness_score, v_measure_score
from dspy.evaluate import Evaluate
from dspy.functional import TypedPredictor
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
from dspy.teleprompt import (
    BootstrapFewShotWithRandomSearch,
    BootstrapFewShot,
    LabeledFewShot,
    BootstrapFewShotWithOptuna,
    COPRO,
    BootstrapFinetune,
    MIPRO,
)


## Constants

In [3]:
import os

DATA_DIR = './data/'
DATA_DIR = './data/derived_datasets/'
PROP_CLUSTERING_PATH = os.path.join(DATA_DIR, 'proposition_clustering.json')
SALIENCE_DETECTION_PATH =os.path.join(DATA_DIR, 'salience.json')

## Phoenix & DSPy init

In [None]:
px.launch_app()
from openinference.instrumentation.dspy import DSPyInstrumentor
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor

endpoint = "http://localhost:6006/v1/traces"

resource = Resource(attributes={})
tracer_provider = trace_sdk.TracerProvider(resource=resource)
span_otlp_exporter = OTLPSpanExporter(endpoint=endpoint)
tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter=span_otlp_exporter))
trace_api.set_tracer_provider(tracer_provider=tracer_provider)
DSPyInstrumentor().instrument()

## Data

### Loading

In [6]:
prop_clustering = ""
prop_clustering = ""

with open(PROP_CLUSTERING_PATH, 'r') as f:
    prop_clustering = json.load(f)

with open(SALIENCE_DETECTION_PATH, 'r') as f:
    salience = json.load(f)

In [7]:
def calculate_num_documents(dict_data):
    return len([doc for test_text in dict_data.values() for doc in test_text["documents"]])

print(f'Number of Multinews (tests/ multi documents): {len(salience)}')
print(f'Number of document: {calculate_num_documents(salience)}')

Number of Multinews (tests/ multi documents): 98
Number of document: 275


## Model

In [92]:
# ChatGPT
gpt3_turbo = dspy.OpenAI(model='gpt-3.5-turbo-instruct', max_tokens=1000)
dspy.settings.configure(lm=gpt3_turbo, trace=[])

# Experiments

### Salience

#### Tokenizer

In [11]:
from transformers import AutoTokenizer, AutoConfig

# Load the tokenizer
model_type = 'meta-llama/Meta-Llama-3-8B-Instruct'

tokenizer = AutoTokenizer.from_pretrained(model_type)
config = AutoConfig.from_pretrained(model_type)

max_length = config.max_position_embeddings
print(f"Maximum token limit: {max_length}")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Maximum token limit: 8192


In [12]:
# Counts the amount of tokens at each test
for test, data in salience.items():
    documents_tokens_amount = [len(tokenizer.tokenize(document)) for document in data["documents"]]
    spans_tokens_amount = len(tokenizer.tokenize(" ".join([span["docSpanText"] for span in data["salient_spans"]])))

    print(f"test={test}, documents_tokens:{documents_tokens_amount}, spans_tokens:{spans_tokens_amount}")

test=test49, documents_tokens:[9755, 1747, 226, 127, 288], spans_tokens:992
test=test34, documents_tokens:[7046, 1708, 46, 1662], spans_tokens:1037
test=test59, documents_tokens:[1356, 113, 2768], spans_tokens:756
test=test57, documents_tokens:[2744, 126, 45], spans_tokens:559
test=test79, documents_tokens:[652, 4641, 845, 427], spans_tokens:672
test=test95, documents_tokens:[228, 3221, 46, 58, 1049, 497], spans_tokens:702
test=test50, documents_tokens:[3031, 384, 392], spans_tokens:402
test=test63, documents_tokens:[3284, 12986], spans_tokens:578
test=test78, documents_tokens:[5639, 484, 765], spans_tokens:1182
test=test30, documents_tokens:[354, 389, 46, 432], spans_tokens:611
test=test28, documents_tokens:[645, 229, 118, 311], spans_tokens:359
test=test7, documents_tokens:[208, 248, 477], spans_tokens:550
test=test11, documents_tokens:[267, 195], spans_tokens:122
test=test46, documents_tokens:[200, 247], spans_tokens:258
test=test12, documents_tokens:[220, 34, 213], spans_tokens:318

#### Output Handeling

##### Offsets

In [13]:
import re
def extract_saliences_from_output(output:str):
    # Regular expression pattern to find saliences list
    pattern = r'Saliences:.*'

    # Find all matches
    output = re.sub(r'\s+', '', output)
    output = output.replace("Saliences:", "\nSaliences:")
    matches = re.findall(pattern, output)
    spans = matches[-1]

    pattern = r'\(\'(\d+),(\d+)-(\d+)\'\)'
    matches = re.findall(pattern, spans)

    extracted_data = []
    for match in matches:
        id_value = int(match[0])
        num1 = int(match[1])
        num2 = int(match[2])
        extracted_data.append((id_value, num1, num2))

    extracted_spans = list(set(extracted_data))

    return extracted_spans



In [14]:
def load_spans_from_documents(documents, spans):
    output_spans = []
    for span in spans:
        doc_id, start_range, end_range = span

        if doc_id < len(documents):
            output_spans.append(documents[doc_id][start_range:end_range])

    return output_spans

In [15]:
def convert_output_span_indexes_to_spans(documents, output_span_indexes):
    spans = extract_saliences_from_output(output_span_indexes)
    return load_spans_from_documents(documents, spans)

##### Full texts

In [16]:
import re
# General model
def extract_target_saliences_from_output(output):
    keyword = "Saliences:"
    last_index = output.rfind(keyword)

    if last_index != -1:
        substring = output[last_index + len(keyword):].strip()
        substring = substring.strip('"\n')

        if "[" in substring and "]" in substring:
            pattern = r'\[(.*?)\]'
            matches = re.findall(pattern, substring)
            substring = matches[0]

        return [span.strip('"\n') for span in substring.split("\", \"")]

    return None  # Handle case where keyword is not found



#### Dataset

In [17]:
PC_dataset = []

def clean_str(string :str):
    return string#.replace(r'.', r',').replace(r',,', r',').replace(r'"', r' ')

for topic in prop_clustering.keys():
    current = prop_clustering[topic]
    docSpanTexts = [item['docSpanText'] for item in  current['input_spans']]
    random.shuffle(docSpanTexts)
    spans_list = {item : index for index, item in zip(range(len(docSpanTexts)), docSpanTexts)}
    # input_spans = [clean_str(item) for item in docSpanTexts]
    input_spans = [f'<START_SPAN_{spans_list[item]}>: {clean_str(item)} <END_SPAN_{spans_list[item]}>' for item in docSpanTexts]
    clusters = []
    for cluster in prop_clustering[topic]['clusters']:
        salience_spans = {spans_list[span['docSpanText']]:int(cluster["clusterID"]) for span in cluster['spans']}
        clusters.append(salience_spans)
    flatten_cluster = str({k: v for cluster in clusters for k,v in cluster.items()})
    PC_dataset.append(dspy.Example(salience_spans=' '.join(input_spans), clusters=flatten_cluster).with_inputs('salience_spans'))

In [18]:
# Removes large documents from the dataset
threshold = 1000

new_salience_data = {}
salience_dataset = []
old_to_new_doc_index = {}
for test, data in salience.items():
    # spans = data[""]
    spans = data["salient_spans"]
    new_documents = []
    for index, document in enumerate(data["documents"]):
        if len(tokenizer.tokenize(document)) > threshold:
            new_spans = []
            for span in spans:
                if span["documentFile"].split(".")[0] != str(index):
                    new_spans.append(span)
            spans = new_spans
        else:
            curr_index = len(new_documents)
            old_to_new_doc_index[index] = curr_index
            new_documents.append(document)

    if len(new_documents) >= 2:
        new_spans = []
        for span in spans:
            record = {}
            record["documentFile"] = old_to_new_doc_index[int(span['documentFile'].split('.')[0])]
            record["docSpanOffsets"] = span['docSpanOffsets'].replace(', ', '-')
            record["docSpanText"] = span['docSpanText']
            new_spans.append(record)

        salience_dataset.append(dspy.Example(documents=new_documents, saliences=new_spans).with_inputs('documents'))
        new_salience_data[test] = {
            "documents": new_documents,
            "salient_spans": spans,
        }

new_salience_dataset = []
for example in salience_dataset:
    if len(example.documents) <=1 or len(example.saliences) <= 2:
        continue
    new_salience_dataset.append(example)

salience_dataset = new_salience_dataset

In [19]:
num_topics = len([topic for topic in new_salience_data.keys()])
train_size, dev_size = int(num_topics * 0.20), int(num_topics * 0.35)
test_size = num_topics - train_size - dev_size

def train_test_split(data_dict, start_index, end_index):
    items = list(data_dict.items())
    split_data = dict(items[start_index: end_index])

    return split_data

salience_train = salience_dataset[:train_size]
salience_dev   = salience_dataset[train_size: train_size + dev_size]
salience_test  = salience_dataset[train_size + dev_size:]

train_PC = PC_dataset[:train_size]
val_PC   = PC_dataset[train_size: train_size + dev_size]
test_PC  = PC_dataset[train_size + dev_size:]

In [20]:
print(f'Train size: {len(salience_train)}')
print(f'Dev size: {len(salience_dev)}')
print(f'Test size: {len(salience_test)}')

Train size: 15
Dev size: 26
Test size: 33


#### Metric

In [21]:
import ast

def convert_output_to_list(output):

    output = re.sub(r'\[(\w+)\]', r'<\1>', output)
    matches =  re.findall(r'\[(.*?)\]', output, re.DOTALL)

    if matches:
        output = matches[0]
    else:
      if output[-1] == '\'' or  output[-1] == '\"' or output[-1] == '\n':
        output += "]"
      else:
        output += "']"

    return eval(output)

def metric(real_spans, pred_spans, trace=None):
    # Tokenize the real and predicted spans
    real_tokens = [token for span in real_spans.saliences for token in tokenizer.tokenize(span["docSpanText"])]
    pred_tokens = [token for span in convert_output_to_list(pred_spans.saliences) for token in tokenizer.tokenize(span)]


    # Not optimal
    min_length = min(len(real_tokens), len(pred_tokens))
    real_tokens = real_tokens[:min_length]
    pred_tokens = pred_tokens[:min_length]

    # Count the occurrences of each token in real and predicted tokens
    real_counter = Counter(real_tokens)
    pred_counter = Counter(pred_tokens)

    # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN)
    TP = sum((real_counter & pred_counter).values())
    FP =  sum([(pred_counter - real_counter)[key] for key in pred_counter if key not in (pred_counter.keys() - real_counter.keys())])
    FN =  sum([pred_counter[key] for key in pred_counter.keys() - real_counter.keys()])

    # Calculate Precision, Recall, and F1 Score
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return  f1

#### Signature

In [23]:
import pydantic

class DictOfSpansClustersType(pydantic.BaseModel):
    span_cluster_dict: Dict[int, int]

class PropositionClusteringSignature(dspy.Signature):
    """Please Cluster salience spans into groups. Each group should contain spans that share the same information. all the input spans should be in the output clusters"""

    # salience_spans = dspy.InputField()
    # salience_spans = dspy.InputField(format=list)
    salience_spans : str = dspy.InputField(desc='Salience spans which each span starting with <START_SPAN_IDX> and end with <END_SPAN_IDX>, each salience span can have <SPAN_SEP> which seprate bettwen subspans inside the currnet IDX span.')
    # output_length : str = dspy.InputField(desc='Fact of the expected output length')
    clusters : Dict[int, int] = dspy.OutputField(desc='A dict in the following format {<SPAN_NUMBER_IDXi>:<CLUSTER IDXi>, <SPAN_NUMBER_IDXj>:<CLUSTER IDXj>, ...}. return ONLY a dict with spans and cluster!, return ONLY a dict with spans and cluster!, return ONLY a dict with spans and cluster!, return ONLY a dict with spans and cluster!')
    # clusters : Dict[int, int] = dspy.OutputField()
    # clusters : DictOfSpansClustersType = dspy.OutputField()

class CleanDictSignature(dspy.Signature):
    """Clean the dict to be only integers. The keys and the values must be type int, Return ONLY dict.

    ---

    Follow the examples:
    String dict: {"0": "cluster1","1": "cluster2","2": "cluster3","3": "cluster3","4": "cluster3","5": "cluster3","6": "cluster3","7": "cluster3","8": "cluster3","9": "cluster3","10": "cluster3","11": "cluster3","12": "cluster3","13": "cluster3","14": "cluster4"}
    Int dict:{0: 1,1: 2,2: 3,3: 3,4: 3,5: 3,6: 3,7: 3,8: 3,9: 3,10: 3,11: 3,12: 3,13: 3,14: 4}
    

    String dict: { '0': 0, '1': 0, '2': 0, '3': 0, '4': 0, '5': 0, '6': 0, '7': 0, '8': 0, '9': 0, '10': 0, '11': 0, '12': 1, '13': 1, '14': 2 }
    Int dict::{ 0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 1, 13: 1, 14: 2 }
    

    String dict: "{ 'SPAN_NUMBER_0': 'CLUSTER_1', 'SPAN_NUMBER_1': 'CLUSTER_2', 'SPAN_NUMBER_2': 'CLUSTER_2', 'SPAN_NUMBER_3': 'CLUSTER_2', 'SPAN_NUMBER_4': 'CLUSTER_2', 'SPAN_NUMBER_5': 'CLUSTER_2', 'SPAN_NUMBER_6': 'CLUSTER_2', 'SPAN_NUMBER_7': 'CLUSTER_2', 'SPAN_NUMBER_8': 'CLUSTER_3', 'SPAN_NUMBER_9': 'CLUSTER_4', 'SPAN_NUMBER_10': 'CLUSTER_4' }"
    Int dict:  { 0: 1, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2, 8: 3, 9: 4, 10: 4 }

    
    String dict: "{<START_SPAN_0>: 0, <START_SPAN_1>: 0, <START_SPAN_2>: 1, <START_SPAN_3>: 0, <START_SPAN_4>: 2, <START_SPAN_5>: 2, <START_SPAN_6>: 2, <START_SPAN_7>: 1, <START_SPAN_8>: 2, <START_SPAN_9>: 2, <START_SPAN_10>: 3, <START_SPAN_11>: 2}"}
    Int dict:  {0: 0, 1: 0, 2: 1, 3: 0, 4: 2, 5: 2, 6: 2, 7: 1, 8: 2, 9: 2, 10: 3, 11: 2}
    ---
    """
    

    string_dict = dspy.InputField(desc='str dict which need to be cleaned up')
    int_dict = dspy.OutputField(desc='A dict with keys as type int, and values as type int.')

In [24]:
class PropositionalClusteringCOT(dspy.Module):
    def __init__(self):
        super().__init__()

        self.generate_clusters = dspy.ChainOfThoughtWithHint(PropositionClusteringSignature, activated=False)
        self.clean_clusters = dspy.Predict(CleanDictSignature)

    def forward(self, salience_spans):
        clusters = self.generate_clusters(salience_spans=salience_spans, hint=f'Output length should be {len(salience_spans)} spans.')
        try:
            eval(clusters.clusters)
            return dspy.Prediction(clusters=clusters.clusters)
        except:
            pass
        # with dspy.
        int_clusters = self.clean_clusters(string_dict=clusters.clusters)
        pred = dspy.Prediction(clusters=int_clusters.int_dict)
        return pred

In [25]:
from sklearn.metrics import homogeneity_score, completeness_score, v_measure_score
import numpy as np
from dspy.evaluate import Evaluate

def validate_propositional_clustering(example, pred, trace=None):
    """Validate clustering results using homogeneity, completeness, and V-measure."""
    
    true_labels = eval(example.clusters)
    predicted_labels = eval(pred.clusters)
    # print(pred.clusters)
    # list of cluster by span number if key=span_idx : value=cluster_idx if {1:0, 2:0, 0:1} then [1, 0, 0]
    true_labels_list = [true_labels[key] for key in sorted(true_labels.keys())]
    predicted_labels_list = [predicted_labels[key] for key in sorted(predicted_labels.keys())]
    # print(true_labels_list)
    # print(predicted_labels_list)
    if np.shape(true_labels_list) != np.shape(predicted_labels_list):
        print(f'Shape inst match.')
        # print(f'true_labels_list: {true_labels_list}, np.shape(true_labels_list)={np.shape(true_labels_list)}, \npredicted_labels_list:{predicted_labels_list}, np.shape(predicted_labels_list)={np.shape(predicted_labels_list)}')
        return 0
    homogeneity = homogeneity_score(true_labels_list, predicted_labels_list)
    completeness = completeness_score(true_labels_list, predicted_labels_list)
    v_measure = v_measure_score(true_labels_list, predicted_labels_list)
    
    print({
        'homogeneity': homogeneity,
        'completeness': completeness,
        'v_measure': v_measure
    })
    # return homogeneity + completeness + v_measure
    return {
        'homogeneity': homogeneity,
        'completeness': completeness,
        'v_measure': v_measure
    }

In [26]:
import ast

def format_output(data):
    if isinstance(data, list):
        return [record['docSpanText'] for record in data]
    return data

class SalienceSignature(dspy.Signature):
    """Below are documents on the same topic in different user messages. Please copy exactly salient sub-sentenial spans. Do not change the copied text."""
    documents = dspy.InputField(desc="Documents containing News information",
                                format=lambda x: x if isinstance(x, str) else list(x))
    saliences = dspy.OutputField(
        desc="Extracted sentences formated as python list, sentences copied as-is without any modification!",
        format=format_output)


In [61]:
def validate_correct_list(output):
  try:
      convert_output_to_list(output)
      return True
  except:
    return False

In [84]:
class SaliencePredict(dspy.Module):
    def __init__(self):
        super().__init__()
        # self.pred = dspy.TypedPredictor(SalienceSignature)
        # self.pred =  dspy.Predict(SalienceSignature, temperature=0.05)
        self.pred =  dspy.Predict(SalienceSignature,  temperature=0.5 , max_tokens=1000)

    def forward(self, documents):
      output = self.pred(documents=documents)
      dspy.Suggest(validate_correct_list(output.saliences), "Make sure Saliences is a valid json list of strings")
      return output

In [85]:
def convert_output_to_list(output):
    print(1)
    output = re.sub(r'\[(\w+)\]', r'<\1>', output)
    print(2)
    print(output)
    matches =  re.findall(r'\[(.*?)\]', output, re.DOTALL)
    print(3)

    if matches:
      print(4)
      output = matches[0]
      print(len(matches))
    else:
      print(5)
      if output[-1] == '\'' or  output[-1] == '\"' or output[-1] == '\n':
        print(7)
        output += "]"
      else:
        print(8)
        output += "']"

    print(output)
    return eval(output)

#### Assertion & Predictor

In [86]:
class Pipeline(dspy.Module):
    def __init__(self):
        super().__init__()
        self.salience_detection = assert_transform_module(SaliencePredict(), functools.partial(backtrack_handler, max_backtracks=1))
        self.salience_detection.load('./salience_optimized.json')
        self.propositional_clustering = PropositionalClusteringCOT()
        self.propositional_clustering.load('./propositional_clustering.json')

    def forward(self, documents):
      saliences = self.salience_detection(documents=documents)
      clusters = self.propositional_clustering(salience_spans=convert_output_to_list(saliences.saliences))
      return clusters

In [87]:
with_assertions = assert_transform_module(SaliencePredict(), functools.partial(backtrack_handler, max_backtracks=1))
with_assertions.load('./salience_optimized.json')

In [88]:
with_assertions

pred = Retry(StringSignature(documents -> saliences
    instructions='Below are documents on the same topic in different user messages. Please copy exactly salient sub-sentenial spans. Do not change the copied text.'
    documents = Field(annotation=str required=True json_schema_extra={'desc': 'Documents containing News information', 'format': <function SalienceSignature.<lambda> at 0x000001B6B1653790>, '__dspy_field_type': 'input', 'prefix': 'Documents:'})
    saliences = Field(annotation=SalienceOutput required=True json_schema_extra={'desc': 'Extracted sentences formated as python list (note strings must be inside double quotation marks), sentences copied as-is without any modification!', 'format': <function format_output at 0x000001B6B16539D0>, '__dspy_field_type': 'output', 'prefix': 'Saliences:'})
))

In [None]:
model = PropositionalClusteringCOT()
model.load('./propositional_clustering.json')
model

In [None]:
pipeline = Pipeline()
pipeline(**salience_test[0].inputs())

In [50]:
prop_clustering = PropositionalClusteringCOT().load('./propositional_clustering.json')


In [None]:
salience_scores = []
with_assertions = assert_transform_module(SaliencePredict(), functools.partial(backtrack_handler, max_backtracks=1))
with_assertions.load('./salience_optimized.json')
for example in salience_test:
    pred = with_assertions(**example.inputs())
    score = metric(example, pred)
    salience_scores.append(score)

In [93]:
scores = []
prop_clustering = PropositionalClusteringCOT()
prop_clustering.load('./propositional_clustering.json')
for example in test_PC:
    pred = prop_clustering(**example.inputs())
    score = validate_propositional_clustering(example, pred)
    scores.append(score)

{'homogeneity': 0.9999999999999998, 'completeness': 0.8343604297372121, 'v_measure': 0.9097017316893838}
{'homogeneity': 1.0, 'completeness': 0.8524398460347836, 'v_measure': 0.9203428093597348}
{'homogeneity': 1.0, 'completeness': 0.8077698004717359, 'v_measure': 0.8936644480519025}
{'homogeneity': 0.5734773267433709, 'completeness': 0.722346144208289, 'v_measure': 0.6393604453849324}
{'homogeneity': 1.0, 'completeness': 0.9687500000000001, 'v_measure': 0.9841269841269843}
Shape inst match.
{'homogeneity': 0.8333333333333333, 'completeness': 1.0, 'v_measure': 0.9090909090909091}
{'homogeneity': 1.0, 'completeness': 0.9624786378518295, 'v_measure': 0.9808806264565294}
{'homogeneity': 1.0, 'completeness': 0.8755776673507596, 'v_measure': 0.933661860654917}
{'homogeneity': 1.0, 'completeness': 0.8595824360039366, 'v_measure': 0.9244897342126939}
{'homogeneity': 1.0000000000000002, 'completeness': 0.9424351665604174, 'v_measure': 0.9703646049914161}
{'homogeneity': 0.9999999999999999, 'co