#### Analyse generated sequences

In [None]:
import pickle
from testing_utils import *

# Path to your pickle file
input_file = 'word_freq_results_imagined_colab3.pkl'

# Loading the data from the pickle file
with open(input_file, 'rb') as file:
    results_dict = pickle.load(file)


In [None]:
len(results_dict)

In [None]:
pd.DataFrame(results_dict)['temp'].value_counts()

In [None]:
def analyze_sequences(results):
    # Extract necessary data
    current_env = results['model']  # The current environment index
    training_strs = results['training_strs']
    testing_strs = results['testing_strs']
    train_size = results['train_size']
    seqs = results['seqs']

    # Initialize results storage
    analysis_results = {'real': [], 'valid': [], 'neither': []}

    # Iterate through each sequence in seqs
    for seq in seqs:
        found_as_real = False
        found_as_valid = False

        # Check for real sequences in previous phases
        for i in range(current_env):
            if seq in training_strs[i][0:train_size]:
                analysis_results['real'].append(seq)
                found_as_real = True
                break  # Stop searching if found

        if not found_as_real:
            # Check for valid sequences if not found as real
            for i in range(current_env):
                if seq in training_strs[i][train_size:] or seq in testing_strs[i]:
                    analysis_results['valid'].append(seq)
                    found_as_valid = True
                    break  # Stop searching if found

        if not (found_as_real or found_as_valid):
            # If the sequence is neither real nor valid
            analysis_results['neither'].append(seq)

    return analysis_results

# Call the function with your results_dict
analysis_results = analyze_sequences(results_dict[24])
analysis_results


In [None]:
import matplotlib.pyplot as plt

# Assuming results_dict now includes a way to differentiate between temperatures
# For example, results_dict could be a list of dicts, each with a different temperature setting

# Modify the analyze_sequences function if necessary to work with your data structure

def analyze_sequences_by_temperature(results_dicts):
    temperature_results = {}

    for results in results_dicts:
        temp = results['temp']  # Assuming 'temp' is how temperature is recorded
        if temp == -1:  # Skip the analysis for temperature -1
            continue
        analysis_results = analyze_sequences(results)
        temperature_results[temp] = {
            'real': len(analysis_results['real']),
            'valid': len(analysis_results['valid']),
            'neither': len(analysis_results['neither'])
        }

    return temperature_results

# Now, analyze your sequences
temperature_results = analyze_sequences_by_temperature(results_dict)

# Visualize the results
def visualize_results(temperature_results):
    temps = sorted(temperature_results.keys())
    reals = [temperature_results[temp]['real'] for temp in temps]
    valids = [temperature_results[temp]['valid'] for temp in temps]
    neithers = [temperature_results[temp]['neither'] for temp in temps]

    plt.figure(figsize=(6, 4))
    plt.plot(temps, reals, label='Real', marker='o')
    plt.plot(temps, valids, label='Valid', marker='s')
    plt.plot(temps, neithers, label='Neither', marker='^')
    plt.xlabel('Temperature')
    plt.ylabel('Count')
    plt.title('Effect of Temperature on Sequence Distribution')
    plt.legend()
    plt.show()

# Call visualize_results with your actual temperature_results
visualize_results(temperature_results)


In [None]:
temperature_results

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

class GPT:

    def __init__(self, base_model):
        self.tokenizer = GPT2Tokenizer.from_pretrained(base_model)
        self.model = GPT2LMHeadModel.from_pretrained(base_model)

    def continue_input(self, input_sequence, max_length=200, num_return_sequences=1, no_repeat_ngram_size=0,
                       do_sample=False, temperature=0.7, num_beams=1):
        
        input_ids = self.tokenizer.encode(input_sequence, return_tensors='pt')

        # Generate text
        output = self.model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=num_return_sequences,
            num_beams=num_beams,
            no_repeat_ngram_size=no_repeat_ngram_size,
            do_sample=do_sample,
            temperature=temperature,
            # penalty_alpha=0.6, 
            top_k=9
        )

        # Decode the output
        sequence = output[0].tolist()
        text = self.tokenizer.decode(sequence)
        return text
                           

In [None]:
model = GPT('spatial_model_0')

In [None]:
train_seqs = results_dict[0]['training_strs'][0][0:50]
test_seqs = results_dict[0]['training_strs'][0][50:]
test_seqs.extend(results_dict[0]['testing_strs'][0])

In [None]:
def test_data_subset(test_data, train_data):
    train_starts = [train_seq.split('PATH:')[0] for train_seq in train_data]
    # Filter test_data where the start is not in train_starts
    subset = [test_seq for test_seq in test_data if test_seq.split('PATH:')[0] not in train_starts]
    return subset

for seq in test_data_subset(test_seqs, train_seqs):
    print(seq)
    seq = model.continue_input(seq[:seq.index('PATH:')+5], do_sample=False, num_beams=5)
    seq = seq[:seq.index('\n')]
    print(seq)
        
    if seq in train_seqs:
        print("real")
    elif seq in test_seqs:
        print("valid")
    else:
        print("invalid!")

In [None]:
for i in range(20):
    seqs = model.continue_input('FROM:', temperature=0.2, do_sample=True, max_length=200).split('\n')
    
    for seq in seqs:
        print(seq)
        if seq in train_seqs:
            print("real")
        elif seq in test_seqs:
            print("valid")
        else:
            print("invalid!")

In [None]:
for i in range(20):
    seq = model.continue_input('FROM:', temperature=0.5, do_sample=True)

    seqs = seq.split('\n')
    #seq = seq[:seq.index('\n')]
    for seq in seqs:
        print(seq)

        if seq in train_seqs:
            print("real")
        elif seq in test_seqs:
            print("valid")
        else:
            print("invalid!")

In [None]:
model.continue_input('FROM:', temperature=0.8, do_sample=True).split('\n')

In [None]:
model.continue_input('FROM:', temperature=1.5, do_sample=True).split('\n')

In [None]:
df = pd.read_csv('replay_results_imagined_colab3.csv')
df = df[df['test_type'] == 'next_node']

# Group by 'Sample_Size', 'Trained_On', and 'Tested_On', and calculate mean and SEM
grouped = df.groupby(['temp', 'trained_on', 'tested_on'])
mean_df = grouped['accuracy'].mean().reset_index()
sem_df = grouped['accuracy'].sem().reset_index()

sem_df['accuracy'] = 0

In [None]:
vals = mean_df['temp'].unique()
num_env = df['trained_on'].nunique()

# Create a figure with subplots
fig, axes = plt.subplots(len(vals), 1, figsize=(8, 16), sharex=True)

# Iterate over each sample size and create a subplot
for i, val in enumerate(vals):
    df_sample_mean = mean_df[mean_df['temp'] == val]
    df_sample_sem = sem_df[sem_df['temp'] == val]

    for tested_on in range(num_env):
        # Filter the mean and SEM dataframes for the specific 'Tested_On' value
        means = df_sample_mean[df_sample_mean['tested_on'] == tested_on]['accuracy']
        sems = df_sample_sem[df_sample_sem['tested_on'] == tested_on]['accuracy']
        trained_on_values = df_sample_mean[df_sample_mean['tested_on'] == tested_on]['trained_on']
        
        # Plot error bars
        axes[i].errorbar(trained_on_values, means, yerr=sems, label=f'Tested on Env {tested_on}', marker='o')

    letter = string.ascii_lowercase[i]
    axes[i].set_title(f'{letter}) {val} self-generated samples')
    axes[i].set_ylabel('Accuracy')
    axes[i].set_ylim((0,1))
    axes[i].legend()

# Set common labels and title
axes[-1].set_xlabel('Trained On Environment')
axes[-1].set_xticks(range(num_env))
# plt.suptitle('Mean Model Accuracy Across Trials with SEM')
plt.savefig('Number of samples effect three trials.png', dpi=500)
plt.show()
