## Dataset setup

In [11]:
import pandas as pd

In [12]:
NUM_DATASETS = 2

DATASET_NAMES = ['sara', 'echr']

sara_df = pd.read_csv('DATASETS/sara_annotated.csv')
echr_df = pd.read_csv('DATASETS/echr_annotated.csv')

DATASETS = [sara_df, echr_df]

## Smart 2-shot dictionary generation

In [None]:
from sentence_transformers import SentenceTransformer
import numpy as np

In [None]:
sentence_embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

In [None]:
def most_similar(curr_i, curr_embed, embedding_space):
    closest_dist = np.finfo(np.float32).max
    closest_i = -1
    
    for d_i, embed in embedding_space.items():
        if d_i == curr_i:
            continue
        
        euclid_dist = np.linalg.norm(curr_embed - embed)
        if euclid_dist < closest_dist:
            closest_dist = euclid_dist
            closest_i = d_i
        
    return closest_i

In [None]:
# returns dictionary matching each datapt to its few_shot examples 
def gen_similar_two_shot(dataset):     
    # make dicts of embeddings
    ambig_statute_fact_pattern_embeddings = {}
    non_ambig_statute_fact_pattern_embeddings = {}
    ambig_fact_pattern_embeddings = {}
    non_ambig_fact_pattern_embeddings = {}
    
    for d_i in tqdm(range(len(dataset))):
        statute_fact_pattern = '''Statute: {s}
Fact pattern: {f}'''.format(s = dataset['statute'].iloc[d_i], 
                            f = dataset['fact_pattern'].iloc[d_i])
        
        if dataset['ambiguity_exists'].iloc[d_i] == True:
            ambig_statute_fact_pattern_embeddings[d_i] = sentence_embedding_model.encode(statute_fact_pattern)
            ambig_fact_pattern_embeddings[d_i] = sentence_embedding_model.encode(dataset['fact_pattern'].iloc[d_i])
        elif dataset['ambiguity_exists'].iloc[d_i] == False:
            non_ambig_statute_fact_pattern_embeddings[d_i] = sentence_embedding_model.encode(statute_fact_pattern)
            non_ambig_fact_pattern_embeddings[d_i] = sentence_embedding_model.encode(dataset['fact_pattern'].iloc[d_i])
        else: 
            print("DATASET ERROR")
    
    # gen similar two shot
    two_shot = {}
    for d_i in tqdm(range(len(dataset))):
        potential_examples = {'True': [], 'False': []}  # stores inds

        d_statute = dataset['statute'].iloc[d_i]

        # since datasets are sorted by statute, we can just iterate before and after datapt_i
        for pre_i in range(d_i-1, -1, -1):
            if dataset['statute'].iloc[pre_i] == d_statute:
                potential_examples[str(dataset['ambiguity_exists'].iloc[pre_i])].append(pre_i)
            else:
                break

        for post_i in range(d_i+1, len(dataset)):
            if dataset['statute'].iloc[post_i] == d_statute:
                potential_examples[str(dataset['ambiguity_exists'].iloc[post_i])].append(post_i)
            else:
                break

        ambig_i = -1
        non_ambig_i = -1

        if len(potential_examples['True']) == 1:
            ambig_i = potential_examples['True'][0]
        elif len(potential_examples['True']) == 0:
            ambig_i = most_similar(d_i, 
                                   ambig_statute_fact_pattern_embeddings[d_i] if d_i in ambig_statute_fact_pattern_embeddings else non_ambig_statute_fact_pattern_embeddings[d_i], 
                                   ambig_statute_fact_pattern_embeddings)
        else: 
            potential_example_fact_pattern_embeddings = {}
            for pe_i in potential_examples['True']:
                potential_example_fact_pattern_embeddings[pe_i] = ambig_fact_pattern_embeddings[pe_i]

            ambig_i = most_similar(d_i,
                                   ambig_fact_pattern_embeddings[d_i] if d_i in ambig_fact_pattern_embeddings else non_ambig_fact_pattern_embeddings[d_i],
                                   potential_example_fact_pattern_embeddings)
            
    
        if len(potential_examples['False']) == 1:
            non_ambig_i = potential_examples['False'][0]
        elif len(potential_examples['False']) == 0:
            non_ambig_i = most_similar(d_i, 
                                       ambig_statute_fact_pattern_embeddings[d_i] if d_i in ambig_statute_fact_pattern_embeddings else non_ambig_statute_fact_pattern_embeddings[d_i], 
                                       non_ambig_statute_fact_pattern_embeddings)
        else:
            potential_example_fact_pattern_embeddings = {}
            for pe_i in potential_examples['False']:
                potential_example_fact_pattern_embeddings[pe_i] = non_ambig_fact_pattern_embeddings[pe_i]

            non_ambig_i = most_similar(d_i,
                                       ambig_fact_pattern_embeddings[d_i] if d_i in ambig_fact_pattern_embeddings else non_ambig_fact_pattern_embeddings[d_i],
                                       potential_example_fact_pattern_embeddings)
            
        if ambig_i == -1 or non_ambig_i == -1:
            print("ERROR GEN SIMILAR TWO SHOT")
            
        # present in-context examples as non-ambiguous first, then ambiguous
        two_shot[d_i] = [(dataset['statute'].iloc[non_ambig_i],
                          dataset['fact_pattern'].iloc[non_ambig_i], 
                          dataset['ambiguity_exists'].iloc[non_ambig_i],
                          dataset['reason_for_ambiguity'].iloc[non_ambig_i]), 
                         (dataset['statute'].iloc[ambig_i],
                          dataset['fact_pattern'].iloc[ambig_i], 
                          dataset['ambiguity_exists'].iloc[ambig_i],
                          dataset['reason_for_ambiguity'].iloc[ambig_i])]
    
    return two_shot

In [None]:
TWO_SHOT_DICS = [gen_similar_two_shot(dataset) for dataset in DATASETS]`