### Bartlett experiment with GPT-2

* This notebook fine-tunes GPT-2 on the story from the Bartlett experiment (1932) plus contextual data, in order to explore how generative models produce distortions
* This context is taken from one of six categories of a Wikipedia dataset
* We then explore recall of Bartlett story - can substitutions and confabulations be observed in generative recall?
* How does temperature parameter for sampling explore level of distortion?

#### Installation:

In [None]:
!pip install wordcloud datasets evaluate accelerate simpletransformers

In [None]:
!wandb disabled

#### Imports:

In [None]:
import logging
from wordcloud import WordCloud
import matplotlib.pyplot as plt
import gc
from random import shuffle
from datasets import load_dataset
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import pickle
import os
import glob

In [None]:
bartlett = """One night two young men from Egulac went down to the river to hunt seals and while they were there it became foggy and calm. Then they heard war-cries, and they thought: "Maybe this is a war-party". They escaped to the shore, and hid behind a log. Now canoes came up, and they heard the noise of paddles, and saw one canoe coming up to them. There were five men in the canoe, and they said:
"What do you think? We wish to take you along. We are going up the river to make war on the people."
One of the young men said,"I have no arrows."
"Arrows are in the canoe," they said.
"I will not go along. I might be killed. My relatives do not know where I have gone. But you," he said, turning to the other, "may go with them."
So one of the young men went, but the other returned home.
And the warriors went on up the river to a town on the other side of Kalama. The people came down to the water and they began to fight, and many were killed. But presently the young man heard one of the warriors say, "Quick, let us go home: that man has been hit." Now he thought: "Oh, they are ghosts." He did not feel sick, but they said he had been shot.
So the canoes went back to Egulac and the young man went ashore to his house and made a fire. And he told everybody and said: "Behold I accompanied the ghosts, and we went to fight. Many of our fellows were killed, and many of those who attacked us were killed. They said I was hit, and I did not feel sick."
He told it all, and then he became quiet. When the sun rose he fell down. Something black came out of his mouth. His face became contorted. The people jumped up and cried.
He was dead."""

In [None]:
class GPT:

    def __init__(self, base_model):
        self.tokenizer = GPT2Tokenizer.from_pretrained(base_model)
        self.model = GPT2LMHeadModel.from_pretrained(base_model)

    def continue_input(self, input_sequence, max_length=200, num_return_sequences=1, no_repeat_ngram_size=10,
                       do_sample=False, temperature=0, num_beams=1):

        input_ids = self.tokenizer.encode(input_sequence, return_tensors='pt')

        # Generate text
        output = self.model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=num_return_sequences,
            num_beams=num_beams,
            no_repeat_ngram_size=no_repeat_ngram_size,
            do_sample=do_sample,
            temperature=temperature
        )

        # Decode the output
        sequence = output[0].tolist()
        text = self.tokenizer.decode(sequence)
        return text

In [None]:
def train_model_script(name_or_path='gpt2-medium',
                       num_epochs=50,
                       output_dir='bartlett',
                       save_steps=100000,
                       lr=5e-04,
                       seed=0):
    gc.collect()
    train_path = f'./{output_dir}/train.txt'
    ! python3 ../scripts/run_clm.py \
        --model_name_or_path {name_or_path} \
        --train_file {train_path} \
        --validation_file {train_path} \
        --per_device_train_batch_size 1 \
        --per_device_eval_batch_size 1 \
        --do_train \
        --do_eval \
        --output_dir {output_dir} \
        --overwrite_output_dir \
        --num_train_epochs {num_epochs} \
        --save_strategy 'epoch' \
        --learning_rate {lr} \
        --seed {seed}

In [None]:
dataset = load_dataset('tarekziade/wikipedia-topics')

In [None]:
df = dataset['train'].to_pandas()

In [None]:
def get_texts_by_category(category, dataframe):
    # Filter the DataFrame for rows where the category list contains the specified category
    # Remove articles about people (these tend to have many categories applied that reflect the content less)
    filtered_df = dataframe[~dataframe['categories'].apply(lambda x: 'People' in x)]
    filtered_df = dataframe[dataframe['categories'].apply(lambda x: category in x)]
    return filtered_df['text'].sample(frac=1).tolist()

In [None]:
topics = ['Universe', 'Politics', 'Health', 'Sport', 'Technology', 'Nature']

def train_models(bartlett_count):
    results_dict = {}

    for topic in topics:
        txts_subset = txts_for_topics[topic][:]
        print(len(txts_subset))
        txts_subset += [bartlett]*bartlett_count
        shuffle(txts_subset)

        !rm -rf 'bartlett_{topic}'
        !mkdir 'bartlett_{topic}'

        with open(f'bartlett_{topic}/train.txt', 'w') as fh:
            fh.write('\n'.join(txts_subset))

        train_model_script(num_epochs=5,
                          output_dir=f'bartlett_{topic}')


#### Train models and collect recall data

In [None]:
def get_results_across_epochs(topics,
                              base_models_dir,
                              temps = [0.5, 1.0, 1.5],
                              n_samples = 1):
    """
    topics:            list of topic names, e.g. ['Universe','Politics',...]
    base_models_dir:   path to the parent folder containing all bartlett_<topic> dirs
    temps:             list of sampling temperatures
    n_samples:         how many samples per temp (per checkpoint)
    """
    results = {}

    for topic in topics:
        model_dir = os.path.join(base_models_dir, f"bartlett_{topic}")
        # find all checkpoint-* subdirs
        ckpt_paths = sorted(
            glob.glob(os.path.join(model_dir, "checkpoint-*")),
            key=lambda p: int(os.path.basename(p).split("-")[-1])
        )
        print(f"Checkpoints for topic {topic}:", ckpt_paths)

        results[topic] = {}
        for ckpt in ckpt_paths:
            ckpt_name = os.path.basename(ckpt)
            print(f"\n=== {topic} @ {ckpt_name} ===")
            gpt = GPT(base_model=ckpt)

            # store all outputs under this checkpoint
            results[topic][ckpt_name] = {}

            # 1) greedy / deterministic
            out_det = gpt.continue_input(
                "One night two young men from Egulac",
                max_length=500,
                do_sample=False,
                no_repeat_ngram_size=10
            )
            print(f"{topic} {ckpt_name} [greedy]: {out_det}")
            results[topic][ckpt_name][0] = out_det

            # 2) sampled at each temperature
            for temp in temps:
                samples = []
                for i in range(n_samples):
                    out = gpt.continue_input(
                        "One night two young men from Egulac",
                        max_length=500,
                        do_sample=True,
                        temperature=temp,
                        no_repeat_ngram_size=10
                    )
                    print(f"{topic} {ckpt_name} T={temp} sample#{i}: {out}")
                    samples.append(out)
                results[topic][ckpt_name][temp] = samples

    return results


topics = ['Universe', 'Politics', 'Health', 'Sport', 'Technology', 'Nature']

base_models_dir = '.'

universe_txts = [i[:1000] for i in get_texts_by_category('Universe', df)][0:1000]
politics_txts = [i[:1000] for i in get_texts_by_category('Politics', df)][0:1000]
health_txts = [i[:1000] for i in get_texts_by_category('Health', df)][0:1000]
sport_txts = [i[:1000] for i in get_texts_by_category('Sports', df)][0:1000]
tech_txts = [i[:1000] for i in get_texts_by_category('Technology', df)][0:1000]
nature_txts = [i[:1000] for i in get_texts_by_category('Nature', df)][0:1000]

txts_for_topics = {'Universe': universe_txts, 'Politics': politics_txts,
                'Health': health_txts, 'Sport': sport_txts,
                'Technology': tech_txts, 'Nature': nature_txts}

for run_idx in range(5):
    train_models(1)
    
    results_dict = get_results_across_epochs(topics, base_models_dir)
    
    # save data
    out_fn = f'combined_results_all_epochs_run{run_idx}.pkl'
    with open(out_fn, 'wb') as f:
        pickle.dump(results_dict, f)
    print(f"Saved to {out_fn}")


#### Test new words count vs. epoch and temperature

In [None]:
import os, glob, pickle, re, string
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

orig_words = set(w.strip(string.punctuation).lower() for w in bartlett.split())

records = []
data_dir = '.'
run_files = sorted(
    glob.glob(os.path.join(data_dir, 'combined_results_all_epochs_run*.pkl')),
    key=lambda p: int(re.search(r'run(\d+)', p).group(1))
)
for run_path in run_files:
    with open(run_path, 'rb') as f:
        results = pickle.load(f)
    for topic, ckpt_dict in results.items():
        # map checkpoint names to epoch indices 1–5
        ckpts = sorted(ckpt_dict.keys(), key=lambda name: int(name.split('-')[-1]))
        print(ckpts)
        epoch_map = {ck: i+1 for i, ck in enumerate(ckpts)}
        for ck, temp_dict in ckpt_dict.items():
            epoch = epoch_map[ck]
            for temp, samples in temp_dict.items():
                for sample in (samples if isinstance(samples, list) else [samples]):
                    tokens = [t.strip(string.punctuation).lower() for t in sample.split()]
                    new_count = sum(1 for t in tokens if t not in orig_words)
                    records.append({
                        'topic': topic,
                        'epoch': epoch,
                        'temp': temp,
                        'new_word_count': new_count
                    })

df = pd.DataFrame(records)

# Filter to final epoch
final_epoch = df['epoch'].max()
df_final = df[df['epoch'] == final_epoch]


In [None]:
agg_temp = (
    df_final
    .groupby('temp')['new_word_count']
    .agg(['mean', 'sem'])
    .reset_index()
)

plt.figure(figsize=(2.2, 1.5))
plt.errorbar(
    agg_temp['temp'],
    agg_temp['mean'],
    yerr=agg_temp['sem'],
    marker='o',
    linestyle='-',
    capsize=5,
    color='purple'
)
plt.xlabel('Temperature')
plt.ylabel('New words')
plt.xticks(agg_temp['temp'])
plt.tight_layout()
plt.savefig('New words by temp.png', dpi=300)
plt.show()


chosen_temp = 0.5
df_tp       = df[df['temp'] == chosen_temp]

agg_epoch = (
    df_tp
    .groupby('epoch')['new_word_count']
    .agg(['mean', 'sem'])
    .reset_index()
)

plt.figure(figsize=(2.2, 1.5))
plt.errorbar(
    agg_epoch['epoch'],
    agg_epoch['mean'],
    yerr=agg_epoch['sem'],
    marker='o',
    linestyle='-',
    capsize=5,
    color='purple'
)
plt.xlabel('Epoch')
plt.ylabel('New words')
plt.xticks(agg_epoch['epoch'])
plt.tight_layout()
plt.savefig('New words by epoch.png', dpi=300)
plt.show()
