### Modelling distortions in marratives

An (overfitted) transformer-based model such as GPT-2 can memorise its training data. Here we explore distortions in the resulting model when trained on narratives, comparing the results to Raykov et al. (2023).

In [None]:
!pip install simpletransformers torch

In [None]:
import pandas as pd
from random import shuffle
import random
from gpt import GPT
from story_utils import *
import pickle
import matplotlib.pyplot as plt

random.seed(1)

In [None]:
def prepare_data(num_typical=100, num_char=50, num_variants=20):
    stories = get_stories()
    typical = []
    atypical_short = []
    atypical_long = []

    sents_list = []

    for story in stories[0:num_typical]:
        typical.append(story + " END ")

    # atypically short stories are n characters shorter
    for story in stories[num_typical:num_typical+num_variants]:
        story = story[0:-num_char] + " END "
        print(story)
        print("...........")
        atypical_short.append(story)

    # atypically longer stories are n characters longer
    # the additional chatacters are taken from another story
    for story in stories[num_typical+num_variants:num_typical+2*num_variants]:
        story = story + get_random_sentence(stories)[0:num_char] + " END "
        print(story)
        print("...........")
        atypical_long.append(story)

    return typical, atypical_short, atypical_long

In [None]:
def compute_length_difference(stories):
    """
    Computes the length difference between input and output for a given list of stories.
    """
    differences = []
    for story in stories:
        input_length = len(story[0])
        output_length = len(story[1])
        difference = output_length - input_length
        differences.append(difference)
    return sum(differences) / len(differences)


In [None]:
def test_model(save_name, typical, atypical_short, atypical_long):

    model = GPT(base_model='outputs_stories', base_model_name='gpt2-medium')

    results_dict = {}
    results_dict['typical'] = []
    results_dict['atypical_short'] = []
    results_dict['atypical_long'] = []

    for s in typical:
        print(s)
        start = " ".join(s.split()[0:10])
        gen = model.continue_input(start)
        if 'END' in gen:
            gen = gen[0:gen.index('END')]
            print(start)
            print(gen)
            print(s)
            results_dict['typical'].append([s, gen])

    for s in atypical_short:
        start = " ".join(s.split()[0:10])
        gen = model.continue_input(start)
        if 'END' in gen:
            gen = gen[0:gen.index('END')]
            print(start)
            print(gen)
            print(s)
            results_dict['atypical_short'].append([s, gen])

    for s in atypical_long:
        start = " ".join(s.split()[0:10])
        gen = model.continue_input(start)
        if 'END' in gen:
            gen = gen[0:gen.index('END')]
            print(start)
            print(gen)
            print(s)
            results_dict['atypical_long'].append([s, gen])


    # Calculate the average length difference for each category
    typical_difference = compute_length_difference(results_dict['typical'])
    atypical_short_difference = compute_length_difference(results_dict['atypical_short'])
    atypical_long_difference = compute_length_difference(results_dict['atypical_long'])

    # Plotting the results
    categories = ['Atypical Short', 'Typical', 'Atypical Long']
    differences = [atypical_short_difference, typical_difference, atypical_long_difference]

    plt.figure()
    plt.bar(categories, differences)
    plt.xlabel('Story Category')
    plt.ylabel('Average Length Difference (Output - Input)')
    plt.title('Length Difference by Story Category')
    plt.axhline(y=0, color='black')
    plt.savefig(save_name)
    plt.show()

In [None]:
for num_typical in [100]:
    for num_char in [20]:
        for num_variants in [10]:

            !rm -rf outputs_stories

            typical, atypical_short, atypical_long = prepare_data(num_typical=num_typical, 
                                                                  num_char=num_char, 
                                                                  num_variants=num_variants)
            sents_list = typical + atypical_short + atypical_long
            shuffle(sents_list)

            with open("train.txt", "w") as fh:
                fh.write('\n'.join(sents_list))

            with open("test.txt", "w") as fh:
                fh.write('\n'.join(sents_list))

            gpt = GPT(base_model='gpt2-medium')

            gpt.train(segmented_sequence_list=[], 
                      best_model_dir='outputs_stories', 
                      train_file="train.txt",
                      test_file="test.txt", 
                      eps=50)

            test_model(f'./plots/{num_typical}typicals_{num_char}chars_{num_variants}_variants.png', 
                       typical, 
                       atypical_short, 
                       atypical_long)

In [None]:
# !rm -rf outputs_stories

# num_typical = 50
# num_char = 50
# num_variants = 10
# typical, atypical_short, atypical_long = prepare_data(num_typical=num_typical, 
#                                                       num_char=num_char, 
#                                                       num_variants=num_variants)
# sents_list = typical + atypical_short + atypical_long
# shuffle(sents_list)

# with open("train.txt", "w") as fh:
#     fh.write('\n'.join(sents_list))

# with open("test.txt", "w") as fh:
#     fh.write('\n'.join(sents_list))

# gpt = GPT(base_model='gpt2-medium')

# gpt.train(segmented_sequence_list=[], 
#           best_model_dir='outputs_stories', 
#           train_file="train.txt",
#           test_file="test.txt", 
#           eps=50)

# test_model(f'./plots/{num_typical}typicals_{num_char}chars_{num_variants}_variants.png', 
#            typical, 
#            atypical_short, 
#            atypical_long)

In [None]:
# with open('story_outputs.pkl', 'wb') as handle:
#     pickle.dump(results_dict, handle)

# with open('story_outputs_backup.pkl', 'rb') as handle:
#     results_dict = pickle.load(handle)