In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from src.metrics.intrinsic_dimension import IntrinsicDimension
from src.metrics.clustering import LabelClustering
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import pandas as pd
import datasets
plot_config = {
    #'font.size': 12,           
    'axes.titlesize': 30,      
    'axes.labelsize': 29,
    'xtick.labelsize': 20,
    'ytick.labelsize': 20,
    'legend.fontsize': 23,
    'figure.figsize': (10,8),
    'lines.linewidth': 2.5,
    'lines.markersize': 10,
}


  from .autonotebook import tqdm as notebook_tqdm


In [8]:
train_path = Path("/orfeo/cephfs/scratch/area/ddoimo/open/geometric_lens/repo/diego/science_qa/train")
test_path = Path("/orfeo/cephfs/scratch/area/ddoimo/open/geometric_lens/repo/diego/science_qa/test")
train = datasets.load_from_disk(train_path).to_pandas()
test = datasets.load_from_disk(test_path).to_pandas()
df = pd.concat([train, test])

In [None]:
subjects_to_keep = ["Physical and chemical change"]

In [24]:
test.category.value_counts()

category
Literary devices                       1220
Units and measurement                   805
Genes to traits                         713
Reference skills                        657
Sentences, fragments, and run-ons       637
Classification                          438
Developing and supporting arguments     399
Formatting                              367
Traits and heredity                     338
Physical and chemical change            279
Designing experiments                   255
Force and motion                        224
Verb tense                              219
Heat and thermal energy                 214
Chemical reactions                      202
Basic economic principles               191
Word usage and nuance                   189
Heredity                                167
Weather and climate                     155
Pronouns                                153
Materials                               132
Phrases and clauses                     117
Velocity, acceleration,

In [54]:
category_to_exclude = ["Creative techniques",
                       "Supply and demand",
                       "Rhyming",
                       "Short and long vowels",
                       "Poetry elements",
                       "Historical figures",
                       "Pronouns and antecedents",
                       "Cells",
                       "Author's purpose and tone",
                       "Thermal energy",
                       "States of matter",
                       "Earth events",
                       "Kinetic and potential energy",
                       "Social studies skills",
                       "Economics",
                       "Editing and revising",
                       "Categories",
                       "Government",
                       "Shades of meaning",
                       "Animals",
                       "Comprehension strategies",
                       "Research skills",
                       "Plants",
                       "Opinion writing",
                       "Mixtures"]
rows_to_keep = test.query(f"category not in {category_to_exclude}").index

In [55]:
rows_to_keep

Index([  28,   29,   30,   31,   32,   33,   34,   35,   36,   37,
       ...
       9775, 9776, 9777, 9778, 9779, 9780, 9781, 9783, 9784, 9785],
      dtype='int64', length=8521)

In [53]:
test.query(f"category == 'Literary devices'").question.iloc[20]

'Which figure of speech is used in this text?\nI would remind you that extremism in the defense of liberty is no vice. And let me remind you also that moderation in the pursuit of justice is no virtue.\n—Barry Goldwater, in his acceptance speech at the 1964 Republican National Convention'

In [3]:
def plotter(data, title, ylabel):
    # Set the style
    sns.set_style(
        "whitegrid",
        rc={"axes.edgecolor": ".15", "xtick.bottom": True, "ytick.left": True},
    )
    # Setup figure and axes for 2 plots in one row
    plt.figure(dpi = 200)
    layers = np.arange(0,data[0].shape[0])

    #Set ticks
    if layers.shape[0] < 50:
        tick_positions = np.arange(0, layers.shape[0], 4)  # Generates positions 0, 4, 8, ...
    else:
        tick_positions = np.arange(0, layers.shape[0], 8)  # Generates positions 0, 4, 8, ...

    tick_labels = tick_positions +1 # Get the corresponding labels from x

    
    names = ["0 shot pt", 
            "1 shot pt", 
            "2 shot pt",
            "5 shot pt"]
            #"0 shot ft"]
    markerstyle = ['o', 'o', 'o', 'o']#, 'x']
    
    for int_dim, label, markerstyle in zip(data, names, markerstyle):
        sns.scatterplot(x=layers, y=int_dim, marker= markerstyle)
        sns.lineplot(x=layers, y=int_dim, label=label)


    plt.xlabel("Layer")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.xticks(ticks=tick_positions, labels=tick_labels)
    # tick_positions_y = np.arange(2.5, 22, 22/10).round(3)
    # plt.yticks(tick_positions_y)
    plt.tick_params(axis='y')
    plt.legend()
    plt.tight_layout()
    plt.rcParams.update(plot_config)
    plt.show()

In [11]:
_PATH = Path("/orfeo/cephfs/scratch/area/ddoimo/open/geometric_lens"
             "/repo/results/scienceqa/evaluated_test/llama-3-8b")

shot = [0,1,2,5]
data_subjects = []
for i in shot:
    clustering = LabelClustering(path= _PATH / f'{i}shot')
    data_subjects.append(clustering.main(label="subjects",
                                z=1.6,
                                instance_per_sub=-1,
                                full_tensor=True))
    

data_subjects.append(clustering.main(label="subjects",
                            z=1.6,
                            instance_per_sub=-1,
                            full_tensor=True))
ari = [np.array(i['adjusted_rand_score']) for i in data_subjects]
plotter(ari, "Label Clustering", "ARI")


Processing layers: 100%|██████████| 32/32 [01:15<00:00,  2.35s/it]
Processing layers: 100%|██████████| 32/32 [01:09<00:00,  2.17s/it]
Processing layers: 100%|██████████| 32/32 [01:16<00:00,  2.39s/it]
Processing layers: 100%|██████████| 32/32 [01:11<00:00,  2.22s/it]


In [7]:
with open( _PATH / "0shot/statistics_target.pkl", "rb") as f:
    stats = pickle.load(f)

In [9]:
stats.keys()

dict_keys(['subjects', 'answers', 'predictions', 'contrained_predictions', 'accuracy', 'constrained_accuracy', 'few_shot_indices', 'metrics'])