### Bartlett embedding analysis

* Embed 'background data'
* Embed recalled stories for each model
* Project into 2D
* Do the recalled stories get closer to the background distribution?

#### Imports:

In [None]:
import glob
import pickle
import matplotlib.pyplot as plt
from wordcloud import WordCloud
import string
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import torch
from sklearn.decomposition import PCA
from umap import UMAP
from scipy.spatial.distance import cosine, euclidean
import os
import random
import numpy as np
import pandas as pd
from scipy.stats import sem

# Path to the directory containing pickle files
directory_path = '.' #'bartlett_data'

# Load embedding model for analysing text
model = SentenceTransformer('all-MiniLM-L6-v2')

#### Analyse embeddings:

In [None]:
# Bartlett story
bartlett = """One night two young men from Egulac went down to the river to hunt seals and while they were there it became foggy and calm. Then they heard war-cries, and they thought: "Maybe this is a war-party". They escaped to the shore, and hid behind a log. Now canoes came up, and they heard the noise of paddles, and saw one canoe coming up to them. There were five men in the canoe, and they said:
"What do you think? We wish to take you along. We are going up the river to make war on the people."
One of the young men said,"I have no arrows."
"Arrows are in the canoe," they said.
"I will not go along. I might be killed. My relatives do not know where I have gone. But you," he said, turning to the other, "may go with them."
So one of the young men went, but the other returned home.
And the warriors went on up the river to a town on the other side of Kalama. The people came down to the water and they began to fight, and many were killed. But presently the young man heard one of the warriors say, "Quick, let us go home: that man has been hit." Now he thought: "Oh, they are ghosts." He did not feel sick, but they said he had been shot.
So the canoes went back to Egulac and the young man went ashore to his house and made a fire. And he told everybody and said: "Behold I accompanied the ghosts, and we went to fight. Many of our fellows were killed, and many of those who attacked us were killed. They said I was hit, and I did not feel sick."
He told it all, and then he became quiet. When the sun rose he fell down. Something black came out of his mouth. His face became contorted. The people jumped up and cried.
He was dead."""

In [None]:
records = []

# Function to load data from a pickle file
def load_pickle_data(filepath):
    with open(filepath, 'rb') as file:
        data = pickle.load(file)
    return data

# Read and combine data from all pickle files in the directory
for filename in os.listdir(directory_path):
    if filename.endswith('.pkl'):  # Ensures that we are reading only pickle files
        file_path = os.path.join(directory_path, filename)
        data = load_pickle_data(file_path)
        print(filename)
        print(data.keys())
        print([k for k, v in data.items() if len(v)])

        for category in ['Universe', 'Politics', 'Health', 'Sport', 'Technology', 'Nature']:
            ckpts = sorted(data[category], key=lambda name: int(name.split('-')[-1]))
            epoch_map = {ck: i+1 for i, ck in enumerate(ckpts)}
            for ckpt in data[category]:
                for temp in [0, 0.5, 1, 1.5]:
                    # Extend the list of strings for this category and temperature
                    if type(data[category][ckpt][temp]) == str:
                        records.append({
                            'topic': category,
                            'epoch': epoch_map[ckpt],
                            'temp': temp,
                            'text': data[category][ckpt][temp]
                        })
                    else:
                        for story in data[category][ckpt][temp]:
                            records.append({
                                'topic': category,
                                'epoch': epoch_map[ckpt],
                                'temp': temp,
                                'text': story
                            })

df = pd.DataFrame(records)

In [None]:
#df[(df['topic'] == 'Politics') & (df['epoch'] == 5) & (df['temp'] == 0)]['text'].values

In [None]:
#df[df['topic'] == 'Universe'][df['temp'] == 0.5][df['epoch'] == 5]['text'].tolist()

In [None]:
dataset = load_dataset('tarekziade/wikipedia-topics')
wiki_df = dataset['train'].to_pandas()

def get_texts_by_category(category, dataframe):
    # Filter the DataFrame for rows where the category list contains the specified category
    # Remove articles about people (these tend to have many categories applied that reflect the content less)
    filtered_df = dataframe[~dataframe['categories'].apply(lambda x: 'People' in x)]
    filtered_df = dataframe[dataframe['categories'].apply(lambda x: category in x)]
    return filtered_df['text'].sample(frac=1).tolist()

universe_txts = [i[:len(bartlett)] for i in get_texts_by_category('Universe', wiki_df)][0:1000]
politics_txts = [i[:len(bartlett)] for i in get_texts_by_category('Politics', wiki_df)][0:1000]
health_txts = [i[:len(bartlett)] for i in get_texts_by_category('Health', wiki_df)][0:1000]
sport_txts = [i[:len(bartlett)] for i in get_texts_by_category('Sports', wiki_df)][0:1000]
tech_txts = [i[:len(bartlett)] for i in get_texts_by_category('Technology', wiki_df)][0:1000]
nature_txts = [i[:len(bartlett)] for i in get_texts_by_category('Nature', wiki_df)][0:1000]
    
temp = 0.5
universe_stories = df[(df['topic'] == 'Universe') & (df['temp'] == temp) & (df['epoch'] == 5)]['text'].tolist()
politics_stories = df[(df['topic'] == 'Politics') & (df['temp'] == temp) & (df['epoch'] == 5)]['text'].tolist()
health_stories = df[(df['topic'] == 'Health') & (df['temp'] == temp) & (df['epoch'] == 5)]['text'].tolist()
sport_stories =  df[(df['topic'] == 'Sport') & (df['temp'] == temp) & (df['epoch'] == 5)]['text'].tolist()
tech_stories = df[(df['topic'] == 'Technology') & (df['temp'] == temp) & (df['epoch'] == 5)]['text'].tolist()
nature_stories = df[(df['topic'] == 'Nature') & (df['temp'] == temp) & (df['epoch'] == 5)]['text'].tolist()

In [None]:
def embed_texts(texts):
    texts = [t[0:800] for t in texts]
    return model.encode(texts)

def calculate_mean_embeddings(*embedding_lists):
    means = [np.mean(embeddings, axis=0) for embeddings in embedding_lists]
    return np.array(means)
   
universe_embeddings = np.array([embed_texts([txt]) for txt in universe_txts])
politics_embeddings = np.array([embed_texts([txt]) for txt in politics_txts])
health_embeddings = np.array([embed_texts([txt]) for txt in health_txts])
sport_embeddings = np.array([embed_texts([txt]) for txt in sport_txts])
tech_embeddings = np.array([embed_texts([txt]) for txt in tech_txts])
nature_embeddings = np.array([embed_texts([txt]) for txt in nature_txts])

universe_story_embeddings = np.array([embed_texts([txt]) for txt in universe_stories])
politics_story_embeddings = np.array([embed_texts([txt]) for txt in politics_stories])
health_story_embeddings = np.array([embed_texts([txt]) for txt in health_stories])
sport_story_embeddings = np.array([embed_texts([txt]) for txt in sport_stories])
tech_story_embeddings = np.array([embed_texts([txt]) for txt in tech_stories])
nature_story_embeddings = np.array([embed_texts([txt]) for txt in nature_stories])

In [None]:
bartlett_text = bartlett
bartlett_embedding = embed_texts([bartlett])
bartlett_embedding = bartlett_embedding.reshape(1, -1)  # Ensure it has the right shape

# Apply PCA to all embeddings
all_embeddings = np.concatenate([universe_embeddings, politics_embeddings, sport_embeddings, tech_embeddings,  health_embeddings, nature_embeddings, 
                                 universe_story_embeddings, politics_story_embeddings, sport_story_embeddings, tech_story_embeddings, health_story_embeddings,  nature_story_embeddings,
                                ])
pca = PCA(n_components=2, random_state=1)
reduced_embeddings = pca.fit_transform(all_embeddings.reshape(all_embeddings.shape[0], all_embeddings.shape[-1]))

# Apply PCA to the Bartlett embedding using the already fitted PCA model
reduced_bartlett_embedding = pca.transform(bartlett_embedding)

# Calculate the reduced means after PCA for consistency in transformation
mean_embeddings = calculate_mean_embeddings(universe_embeddings, politics_embeddings, sport_embeddings, tech_embeddings , health_embeddings, nature_embeddings,
                                            universe_story_embeddings, politics_story_embeddings, sport_story_embeddings, tech_story_embeddings, health_story_embeddings, nature_story_embeddings
                                           )
reduced_means = pca.transform(mean_embeddings.reshape(mean_embeddings.shape[0], mean_embeddings.shape[-1]))

# Labels and colors for each group
labels = ['Universe data', 'Politics data', 'Sport data', 'Technology data', 'Health data', 'Nature data',
          'Recalled (Universe)', 'Recalled (Politics)', 'Recalled (Sport)', 'Recalled (Technology)', 'Recalled (Health)', 'Recalled (Nature)'
         ]

colors = ['blue', 'green', 'purple', 'orange', 'cyan', 'red',
          'blue', 'green', 'purple', 'orange', 'cyan', 'red'
         ]

base = ['#6a00a8', '#e16462', '#b12a90', '#0d0887', '#f0f921', '#fca636']
colors = base + base

In [None]:
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def plot_embeddings_with_inset(
    embeddings,
    reduced_means,
    bartlett_embedding,
    labels,
    colors,
    zoom_xlim=(-0.05, 0.05),
    zoom_ylim=(0.08, 0.2),
    inset_bbox=[0.5, 0.5, 0.47, 0.47],
):
    """
    embeddings:           (N×2) array of all points
    reduced_means:       (M×2) array of your 12 mean‑points
    bartlett_embedding:  (1×2) array for the single Bartlett point
    labels:              list of length M
    colors:              list of length M
    zoom_xlim, zoom_ylim: the x/y limits for the inset
    inset_bbox:          [x0,y0,width,height] in relative Axes coords
    """
    fig, ax = plt.subplots(figsize=(5.8, 5))
    
    # 1) Main scatter: just the “data” groups (first 6)
    group_sizes = [
        len(universe_embeddings),
        len(politics_embeddings),
        len(sport_embeddings),
        len(tech_embeddings),
        len(health_embeddings),
        len(nature_embeddings),
        len(universe_story_embeddings),
        len(politics_story_embeddings),
        len(sport_story_embeddings),
        len(tech_story_embeddings),
        len(health_story_embeddings),
        len(nature_story_embeddings),
    ]
    start = 0
    for i, sz in enumerate(group_sizes):
        end = start + sz
        if i < 6:
            ax.scatter(
                embeddings[start:end, 0],
                embeddings[start:end, 1],
                color=colors[i],
                alpha=0.35,
                s=25,
            )
        start = end

    # 2) Plot **all** the “means” >5 and arrows on the main Axes
    for i, mean in enumerate(reduced_means):
        if i > 5:
            ax.scatter(
                mean[0],
                mean[1],
                color=colors[i],
                marker="o",
                s=25,
                edgecolors="black",
                label=labels[i],
            )
            ax.arrow(
                bartlett_embedding[0, 0],
                bartlett_embedding[0, 1],
                mean[0] - bartlett_embedding[0, 0],
                mean[1] - bartlett_embedding[0, 1],
                color="black",
                lw=0.5,
                length_includes_head=True,
                head_width=0.01,
            )

    # 3) Plot Bartlett on the main Axes
    ax.scatter(
        bartlett_embedding[0, 0],
        bartlett_embedding[0, 1],
        color="black",
        marker="o",
        s=25,
        edgecolors="black",
        label="Original story",
    )
    ax.legend(fontsize=10, ncol=1, loc="upper left", markerscale=2)

    # 4) Create the inset Axes and re‑plot the “zoomed” points there
    axins = ax.inset_axes(inset_bbox, 
                          xlim=zoom_xlim, 
                          ylim=zoom_ylim, 
                          xticks=[], 
                          yticks=[])

    # 4a) re‑plot the first six groups in the inset
    start = 0
    for i, sz in enumerate(group_sizes):
        end = start + sz
        if i < 6:
            axins.scatter(
                embeddings[start:end, 0],
                embeddings[start:end, 1],
                color=colors[i],
                alpha=0.25,
                s=130,
            )
        start = end

    for i, mean in enumerate(reduced_means):
        if i > 5:
            axins.scatter(
                mean[0],
                mean[1],
                color=colors[i],
                marker="o",
                s=130,
                edgecolors="black",
            )
            axins.arrow(
                bartlett_embedding[0, 0],
                bartlett_embedding[0, 1],
                mean[0] - bartlett_embedding[0, 0],
                mean[1] - bartlett_embedding[0, 1],
                color="black",
                lw=0.5,
                length_includes_head=True,
                head_width=0.015,
            )
    axins.scatter(
        bartlett_embedding[0, 0],
        bartlett_embedding[0, 1],
        color="black",
        marker="o",
        s=130,
        edgecolors="black",
        linewidth=2,
    )

    # 5) Draw the connecting box & “zoom‐lines”
    ax.indicate_inset_zoom(axins, edgecolor="black", linewidth=1)

    # 6) Save & show
    plt.savefig('plots/Recalled 2D with inset.png', bbox_inches='tight', dpi=300)
    plt.xlim(-0.95, 0.6)
    plt.ylim(-0.63, 0.8)
    plt.show()


In [None]:
plot_embeddings_with_inset(
    reduced_embeddings,
    reduced_means,
    reduced_bartlett_embedding,
    labels,
    colors,
    zoom_xlim=(-0.2, 0.02),
    zoom_ylim=(-0.02, 0.25),
    inset_bbox=[0.02, 0.02, 0.27, 0.4],
)


In [None]:
# Function to calculate mean embedding of a list of embeddings
def mean_embedding(embeddings):
    return np.mean(embeddings, axis=0)

# Embedding the bartlett story
bartlett_embedding = embed_texts([bartlett])[0]  # Remove batch dimension here

# Dictionary to store results
category_results = {}

# Categories and their respective texts and stories
categories = {
    'Universe': (universe_txts, universe_stories),
    'Politics': (politics_txts, politics_stories),
    'Health': (health_txts, health_stories),
    'Sport': (sport_txts, sport_stories),
    'Nature': (nature_txts, nature_stories),
    'Technology': (tech_txts, tech_stories)
}

# Compute mean embeddings and distances for each category
for category, (texts, stories) in categories.items():
    # Embed category texts and stories
    category_embeddings = np.array([embed_texts([txt])[0] for txt in texts])  # Remove batch dimension
    story_embeddings = np.array([embed_texts([txt[:len(bartlett)]])[0] for txt in stories])  # Remove batch dimension
    
    # Compute mean embeddings
    category_mean = mean_embedding(category_embeddings)
    story_mean = mean_embedding(story_embeddings)
    
    # Calculate distances
    distance_bartlett_category = cosine(bartlett_embedding, category_mean)
    distance_story_category = cosine(story_mean, category_mean)
    
    # Store results in dictionary
    category_results[category] = {
        'distance_bartlett_category': distance_bartlett_category,
        'distance_story_category': distance_story_category
    }

# Output the results dictionary
print(category_results)

#### Compare distances to category means of original and recalled stories

In [None]:
bartlett_embedding = embed_texts([bartlett])[0]  # (D,)

categories = {
    'Universe':   (universe_txts, universe_stories),
    'Politics':   (politics_txts, politics_stories),
    'Health':     (health_txts, health_stories),
    'Sport':      (sport_txts, sport_stories),
    'Nature':     (nature_txts, nature_stories),
    'Technology': (tech_txts, tech_stories)
}

category_results = {}

for category, (texts, stories) in categories.items():
    # Embed category texts
    category_embeddings = np.vstack([embed_texts([t])[0] for t in texts])  # (N_texts, D)
    category_mean = category_embeddings.mean(axis=0)  # (D,)

    # Embed recalled stories
    story_embeddings = np.vstack([embed_texts([s[:len(bartlett)]])[0] for s in stories])  # (N_stories, D)

    # Distances from Bartlett → category mean (just one number per category)
    dist_b_all = [np.linalg.norm(bartlett_embedding - category_mean)]

    # Distances from each recalled story → category mean
    dist_r_all = [np.linalg.norm(story_emb - category_mean) for story_emb in story_embeddings]

    # Aggregate mean±SEM
    category_results[category] = {
        'mean_b': np.mean(dist_b_all),
        'std_b': sem(dist_b_all, ddof=1) if len(dist_b_all) > 1 else 0.0,
        'mean_r': np.mean(dist_r_all),
        'std_r': sem(dist_r_all, ddof=1) if len(dist_r_all) > 1 else 0.0
    }


In [None]:
cats = list(category_results)
x = np.arange(len(cats))
width = 0.35

means_b = [category_results[c]['mean_b'] for c in cats]
sems_b  = [category_results[c]['std_b']  for c in cats]
means_r = [category_results[c]['mean_r'] for c in cats]
sems_r  = [category_results[c]['std_r']  for c in cats]

fig, ax = plt.subplots(figsize=(5,2))
ax.bar(x - width/2, means_b, width, yerr=sems_b, capsize=5, alpha=0.6, label='Original', color=base[0])
ax.bar(x + width/2, means_r, width, yerr=sems_r, capsize=5, alpha=0.6, label='Recalled', color=base[5])

ax.set_xticks(x)
ax.set_xticklabels(cats)
ax.set_ylabel('Distance')
ax.set_ylim(0.8, 1.2)
ax.legend(loc='upper center', ncols=2)
fig.tight_layout()
plt.savefig('plots/Recalled stories.png', dpi=300)
plt.show()
