### Analysis and visualisation of Bartlett results

In [None]:
import glob
import pickle
import matplotlib.pyplot as plt
from wordcloud import WordCloud
import string

bartlett = """One night two young men from Egulac went down to the river to hunt seals and while they were there it became foggy and calm. Then they heard war-cries, and they thought: "Maybe this is a war-party". They escaped to the shore, and hid behind a log. Now canoes came up, and they heard the noise of paddles, and saw one canoe coming up to them. There were five men in the canoe, and they said:
"What do you think? We wish to take you along. We are going up the river to make war on the people."
One of the young men said,"I have no arrows."
"Arrows are in the canoe," they said.
"I will not go along. I might be killed. My relatives do not know where I have gone. But you," he said, turning to the other, "may go with them."
So one of the young men went, but the other returned home.
And the warriors went on up the river to a town on the other side of Kalama. The people came down to the water and they began to fight, and many were killed. But presently the young man heard one of the warriors say, "Quick, let us go home: that man has been hit." Now he thought: "Oh, they are ghosts." He did not feel sick, but they said he had been shot.
So the canoes went back to Egulac and the young man went ashore to his house and made a fire. And he told everybody and said: "Behold I accompanied the ghosts, and we went to fight. Many of our fellows were killed, and many of those who attacked us were killed. They said I was hit, and I did not feel sick."
He told it all, and then he became quiet. When the sun rose he fell down. Something black came out of his mouth. His face became contorted. The people jumped up and cried.
He was dead."""

In [None]:
for pkl in glob.glob('bartlett_pkls/all_results_dict.pkl'):
    with open(pkl, 'rb') as f:
        d = pickle.load(f)
    print(d)

In [None]:
def plot_wordclouds(results_dict, models, keys=['greedy', 0.25, 0.5, 0.75], exclusion_text=bartlett, flip=False):
    # Convert the exclusion text into a set of words for faster lookup, stripping punctuation and lowercasing
    exclusion_text_processed = exclusion_text.translate(str.maketrans(string.punctuation, ' '*len(string.punctuation))).lower()
    exclusion_text_processed += ' ap s'
    exclusion_words = set(exclusion_text_processed.split()) 
    
    # Flipping the rows and columns
    num_rows = len(keys)  # Now based on the number of keys
    num_cols = len(models)  # Now based on the number of models
    
    # Create a figure for the subplots with flipped dimensions
    if flip:
        fig, axs = plt.subplots(num_cols, num_rows, figsize=(num_cols * 5, num_rows * 5))
    else:
        fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 5, num_rows * 5))
    
    # Ensure axs is always 2D array for consistent indexing
    if num_rows == 1 or num_cols == 1:
        axs = np.atleast_2d(axs)
    
    # Adjust layout
    plt.tight_layout(pad=3.0)
    
    def preprocess_text(text):
        text = text.translate(str.maketrans(string.punctuation, ' '*len(string.punctuation))).lower()
        return ' '.join([word for word in text.split() if word not in exclusion_words])
    
    # Function to determine color scheme based on model
    def get_color_func(model):
        if model == 'shakespeare':
            return 'winter'
        elif model == 'news':
            return 'winter'
        elif model == 'papers':
            return 'winter'
        else:
            return 'gray'
    
    # Iterate through each model and key to plot word clouds with flipped rows and columns
    for col, model in enumerate(models):
        for row, key in enumerate(keys):
            original_text = results_dict[model].get(key, '')[0:len(bartlett)]
            text = preprocess_text(original_text)
            if len(text.split()) > 0:
                wordcloud = WordCloud(width=400, height=400, relative_scaling=0.5, normalize_plurals=False,
                                      max_font_size=60,
                                      background_color ='white', colormap=get_color_func(model)).generate(text)

                if flip:
                    axs_index = axs[col, row] 
                else:
                    axs_index = axs[row, col]
                axs_index.imshow(wordcloud, interpolation='bilinear')
                axs_index.axis('off')
                axs_index.set_title(f'{model} - {key}')
    
    # Show plot
    plt.savefig('wordcloud.png', dpi=500)
    plt.show()

plot_wordclouds(d[2], models=d[2].keys(), keys=[0.25, 0.5, 0.75], flip=True)


In [None]:
plot_wordclouds(d[2], models=d[2].keys())

In [None]:
import matplotlib.pyplot as plt
import string

bartlett_processed = bartlett.translate(str.maketrans(string.punctuation, ' '*len(string.punctuation))).lower()
bartlett_words = set(bartlett_processed.split())

def preprocess_and_count(text, exclusion_set):
    """Preprocess the text and count new words not in the exclusion set."""
    text_processed = text.translate(str.maketrans(string.punctuation, ' '*len(string.punctuation))).lower()
    text_words = text_processed.split()
    text_words = [t for t in text_words if t not in exclusion_set]
    return len(text_words)

# Initialize a structure to hold the counts
new_words_counts = {category: [] for category in d[next(iter(d))]} # Initialize based on the first item's keys
repetitions = sorted(d.keys())  # Sort the repetition numbers to ensure correct x-axis ordering

for repetition in repetitions:
    for category in new_words_counts.keys():
        #aggregated_text = " ".join(d[repetition][category].values())  # Aggregate texts from all strategies
        aggregated_text = d[repetition][category]['greedy'][0:len(bartlett)]
        new_words_count = preprocess_and_count(aggregated_text, bartlett_words)
        new_words_counts[category].append(new_words_count)

# Plotting
plt.figure(figsize=(4, 3))
colors = {'shakespeare': 'red', 'news': 'purple', 'papers': 'blue'}

for category, counts in new_words_counts.items():
    plt.plot(repetitions, counts, label=category, color=colors[category], marker='o')

plt.xlabel('Number of replays')
plt.ylabel('Number of new words')
plt.legend()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import string

# Assuming bartlett is your text variable and d is your dictionary
bartlett_processed = bartlett.translate(str.maketrans(string.punctuation, ' '*len(string.punctuation))).lower()
bartlett_words = set(bartlett_processed.split())

def preprocess_and_count(text, exclusion_set):
    """Preprocess the text and count new words not in the exclusion set."""
    text_processed = text.translate(str.maketrans(string.punctuation, ' '*len(string.punctuation))).lower()
    text_words = text_processed.split()
    text_words = [t for t in text_words if t not in exclusion_set]
    return len(text_words)

# Initialize structures to hold the counts and for averages/SEM
new_words_counts = {category: [] for category in d[next(iter(d))]} # Initialize based on the first item's keys
averages = []
sems = []
repetitions = sorted(d.keys())  # Sort the repetition numbers to ensure correct x-axis ordering

for repetition in repetitions:
    counts_for_repetition = []
    for category in new_words_counts.keys():
        aggregated_text = d[repetition][category]['greedy'][0:len(bartlett)] 
        new_words_count = preprocess_and_count(aggregated_text, bartlett_words)
        counts_for_repetition.append(new_words_count)
    
    # Calculate average and SEM for this repetition
    avg = np.mean(counts_for_repetition)
    sem = np.std(counts_for_repetition, ddof=1) / np.sqrt(len(counts_for_repetition))  # ddof=1 for sample standard deviation
    averages.append(avg)
    sems.append(sem)

# Plotting categories
plt.figure(figsize=(4, 3))
colors = {'shakespeare': 'red', 'news': 'purple', 'papers': 'blue'}

plt.errorbar([5*r for r in repetitions], averages, yerr=sems, color='blue', marker='o', capsize=5)

plt.xlabel('Number of replays')
plt.ylabel('Number of new words')
plt.savefig('num_new_vs_num_replays.png', bbox_inches='tight', dpi=500)
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from nltk.metrics.distance import edit_distance
import string
from nltk import download

# # Ensure the 'punkt' resource is downloaded
# download('punkt')

average_distances = []
sems = []
repetitions = sorted(d.keys())

for repetition in repetitions:
    distances_for_repetition = []
    for category in d[repetition]:
        aggregated_text = d[repetition][category]['greedy'][0:len(bartlett)]
        distance = edit_distance(aggregated_text, bartlett)
        distances_for_repetition.append(distance)
    
    # Calculate average and SEM for this repetition
    avg_distance = np.mean(distances_for_repetition)
    sem = np.std(distances_for_repetition, ddof=1) / np.sqrt(len(distances_for_repetition))  # ddof=1 for sample standard deviation
    average_distances.append(avg_distance)
    sems.append(sem)

# Plotting
plt.figure(figsize=(4, 3))

# Plotting average edit distance with SEM error bars
plt.errorbar([5*r for r in repetitions], average_distances, yerr=sems, label='Average Edit Distance w/ SEM', color='blue', marker='o', capsize=5)

plt.xlabel('Number of replays')
plt.ylabel('Levenshtein distance')
plt.savefig('edit_dist_vs_num_replays.png', bbox_inches='tight', dpi=500)
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import string

# Define temperatures and their corresponding x-axis values for plotting
temperatures = ['greedy', 0.25, 0.5, 0.75]
temp_values = [0, 0.25, 0.5, 0.75]

repetitions = sorted(d.keys())  # Get sorted list of repetitions

# Initialize dictionaries to store new words count and SEMs for each repetition
new_words_counts_per_repetition = {rep: [] for rep in repetitions}
sems_per_repetition = {rep: [] for rep in repetitions}

def preprocess_text(text):
    text = text.translate(str.maketrans(string.punctuation, ' '*len(string.punctuation))).lower()
    return ' '.join([word for word in text.split()])

for repetition in repetitions:
    for temp in temperatures:
        new_words_counts = []
        for category in ['shakespeare']:#d[repetition]:
            text = d[repetition][category].get(temp)
            if text:  # Ensure text is available
                generated_words = preprocess_text(text[0:len(bartlett)])
                num_new_words = preprocess_and_count(generated_words, bartlett_words)
                new_words_counts.append(num_new_words)
        
        # Calculate average and SEM for new words if counts are available
        if new_words_counts:
            avg_new_words = np.mean(new_words_counts)
            sem = 0#np.std(new_words_counts, ddof=1) / np.sqrt(len(new_words_counts))
            new_words_counts_per_repetition[repetition].append(avg_new_words)
            sems_per_repetition[repetition].append(sem)

# Plotting
plt.figure(figsize=(4, 3.5))
colors = plt.cm.viridis(np.linspace(0, 1, len(repetitions)))  # Generate distinct colors for each line

for rep, color in zip(repetitions, colors):
    # Ensure we have data for all temperatures before plotting
    if len(new_words_counts_per_repetition[rep]) == len(temp_values):
        plt.errorbar(temp_values, new_words_counts_per_repetition[rep], yerr=sems_per_repetition[rep], label=f'{rep*5} replays', marker='o', linestyle='-', capsize=0, color=color)

plt.xticks(temp_values, labels=['no\nsampling', 0.25, 0.5, 0.75])  # Set custom x-axis tick labels
plt.xlabel('Temperature')
plt.ylabel('Number of new words')
plt.legend()
plt.savefig('shakespeare_temp.png', bbox_inches='tight', dpi=500)
plt.show()


In [None]:
print(bartlett[0:1700])
print('\n\n\n')

for t in ['greedy', 0.25, 0.5, 0.75]:
    print(d[6]['shakespeare'][t][0:800])
    print('\n\n\n')