# Script to Compare Embeddings

In this notebook, we compare embeddings of the three annotation sets.

Author: Nardiena A. Pratama

In [None]:
!pip install wandb seaborn 
# !pip install 'accelerate==0.31.0'
!pip install 'sentence-transformers==3.0.1'

In [None]:
!pip install wordsegment autocorrect 
!pip install spacy==3.8.0
!python -m spacy download en_core_web_trf

In [None]:
from sentence_transformers import SentenceTransformer

import boto3
from io import StringIO, BytesIO
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from helper_scripts.utility_functions import *
from helper_scripts.preprocess import *

## Set AWS Credentials

Do not put quotation marks around the value.

In [None]:
%env BUCKET_NAME=aws_bucket_name

## Connect to AWS

In [None]:
# Create a session using the default credentials (IAM role attached to the instance)
session = boto3.Session()

# Create an S3 client
s3 = session.client('s3')

# Specify your bucket name
bucket_name = os.getenv('BUCKET_NAME')


In [None]:
model_name = "finetuning_all-MiniLM-L12-v2_embeddings" # finetuning - minilm 12 5 epochs


key = f"/data/outputs_50/{model_name}.csv"
response = s3.get_object(Bucket=bucket_name, Key=key)
csv_content = response['Body'].read().decode('utf-8')
data = pd.read_csv(StringIO(csv_content))
data.rename(columns={"ml_captions": "ml_caption_labels", "ml_labels": "ml_object_labels"}, inplace=True)

data

In [None]:
download_from_s3 = False

model_path = "models/finetuning_all-MiniLM-L12-v2/v1"        # finetuned - minilm 12 - 5 epochs

if download_from_s3:
    # Parameters
    s3_model_path = f"/data/outputs_50/{model_path}"  # Path in S3 bucket
    
    # Download model files from S3
    download_model_from_s3(bucket_name, s3_model_path, model_path)


model = SentenceTransformer(model_path)
print(model_path)

In [None]:
data['obj_capt_sim'] = data.apply(lambda x: compare_embedding_similarity(x, 'ml_object_labels', 'ml_caption_labels', model), axis=1)
data['obj_human_sim'] = data.apply(lambda x: compare_embedding_similarity(x, 'ml_object_labels', 'human_labels', model), axis=1)
data['capt_human_sim'] = data.apply(lambda x: compare_embedding_similarity(x, 'ml_caption_labels', 'human_labels', model), axis=1)
data_embed = data.copy()

In [None]:
csv_buffer = StringIO()
data_embed.to_csv(csv_buffer, index=False)


file_path = f"/data/outputs_50/{model_name}_comparison.csv"

print(f"DataFrame saved as CSV and uploaded to {file_path} successfully.")


In [None]:
key = f"/data/outputs_50/{model_name}_comparison.csv"
print(key)
response = s3.get_object(Bucket=bucket_name, Key=key)
csv_content = response['Body'].read().decode('utf-8')
data = pd.read_csv(StringIO(csv_content))
# Read the embedding columns as arrays
data["ml_object_embed"]= data.apply(lambda x: convert_str_to_array(x["ml_object_embed"]), axis=1)
data["ml_caption_embed"]= data.apply(lambda x: convert_str_to_array(x["ml_caption_embed"]), axis=1)
data["human_embed"]= data.apply(lambda x: convert_str_to_array(x["human_embed"]), axis=1)
data

In [None]:
data["ml_object_embed"].iloc[0].shape

## Univariate Analysis

### ML Objects VS ML Captions

In [None]:
data['obj_capt_sim'].describe()

In [None]:
data['obj_capt_sim'].plot(kind="hist", title="Similarity Histogram: ML Objects and ML Captions")

# Add labels or customize as needed
plt.xlabel("Similarity")
plt.ylabel("Frequency")

# Save the figure
plt.savefig("figs/obj_capt_sim_histogram.svg", dpi=300)
plt.savefig("figs/obj_capt_sim_histogram.png", dpi=300)

plt.show()

### ML Objects VS Human Labels

In [None]:
data['obj_human_sim'].describe()

In [None]:
data['obj_human_sim'].plot(kind="hist", title="Similarity Histogram: ML Objects and Human Labels")

# Add labels or customize as needed
plt.xlabel("Similarity")
plt.ylabel("Frequency")

# Save the figure
plt.savefig("figs/obj_human_sim_histogram.svg", dpi=300)
plt.savefig("figs/obj_human_sim_histogram.png", dpi=300)

plt.show()

### ML Captions VS Human Labels

In [None]:
data['capt_human_sim'].describe()

In [None]:
data['capt_human_sim'].plot(kind="hist", title="Similarity Histogram: ML Captions and Human Labels")

# Add labels or customize as needed
plt.xlabel("Similarity")
plt.ylabel("Frequency")

# Save the figure
plt.savefig("figs/capt_human_sim_histogram.svg", dpi=300)
plt.savefig("figs/capt_human_sim_histogram.png", dpi=300)

plt.show()

## Multivariate Analysis

In [None]:
data[['obj_capt_sim', 'obj_human_sim', 'capt_human_sim', 'region']]

In [None]:
axis_limits = (0, 1)
g = sns.pairplot(data[['obj_capt_sim', 'obj_human_sim', 'capt_human_sim', 'region']], 
             hue='region', 
             kind='scatter',
            aspect=1
             )
sns.move_legend(
    g, "lower center",
    bbox_to_anchor=(.5, -0.01),  # Position the legend
    ncol=4,  # Number of columns in the legend
    title=None,  # No title
    frameon=False,  # No frame around the legend
    fontsize=18,  # Adjust the font size as needed
    markerscale=2  # Adjust the size of the legend icons (e.g., 2x larger)
)


# Define custom axis labels for each subplot
custom_xlabels = ['ML Obj & ML Captions', 'ML Obj & Human Labels', 'ML Captions & Human Labels']
custom_ylabels = ['ML Obj & ML Captions', 'ML Obj & Human Labels', 'ML Captions & Human Labels']

# # Iterate over the axes grid and set the new axis labels
for i in range(3):  # Iterate over the rows
    for j in range(3):  # Iterate over the columns
        g.axes[i, j].set_xlabel(custom_xlabels[j], fontsize=18)  # Set custom x labels
        g.axes[i, j].set_ylabel(custom_ylabels[i], fontsize=18)  # Set custom y labels



for ax in g.axes.flatten():
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_xlabel(ax.get_xlabel(), fontsize=18)  # Change the x-axis label font size
    ax.set_ylabel(ax.get_ylabel(), fontsize=18)  # Change the y-axis label font size
    ax.tick_params(axis='both', which='major', labelsize=14)  # Change the tick label font size


plt.tight_layout()

# Set figure size
plt.gcf().set_size_inches(14, 14)

plt.savefig(f"figs/{model_name}_pairwise_similarity_plot.svg", dpi=300)
plt.savefig(f"figs/{model_name}_pairwise_similarity_plot.png", dpi=300)

# Ensure layout updates
plt.show()


In [None]:
subset = data[['obj_capt_sim', 'obj_human_sim', 'capt_human_sim','region']].copy()

b = sns.boxplot(data=subset, x ='obj_capt_sim',  hue='region')
b.set_xlim(0, 1)

# plt.gcf().set_size_inches(12, 12)

plt.tight_layout()


plt.savefig(f"figs/{model_name}_obj_capt_sim_boxplot.svg", dpi=300)
plt.savefig(f"figs/{model_name}_obj_capt_sim_boxplot.png", dpi=300)

plt.show()

In [None]:
b = sns.boxplot(data=subset, x ='obj_human_sim',  hue='region')
b.set_xlim(0, 1)

# plt.gcf().set_size_inches(12, 12)

plt.tight_layout()


plt.savefig(f"figs/{model_name}_obj_human_sim_boxplot.svg", dpi=300)
plt.savefig(f"figs/{model_name}_obj_human_sim_boxplot.png", dpi=300)

plt.show()


In [None]:
b = sns.boxplot(data=subset, x ='capt_human_sim',  hue='region')
b.set_xlim(0, 1)

# plt.gcf().set_size_inches(12, 12)

plt.tight_layout()


plt.savefig(f"figs/{model_name}_capt_human_sim_boxplot.svg", dpi=300)
plt.savefig(f"figs/{model_name}_capt_human_sim_boxplot.png", dpi=300)

plt.show()

# END