### Bartlett experiment with GPT-2

* This notebook fine-tunes GPT-2 on the story from the Bartlett experiment (1932) plus contextual data, in order to explore how generative models produce distortions
* This context is taken from the cnn_dailymail dataset of news article content (see https://www.tensorflow.org/datasets/catalog/cnn_dailymail for further details)
* We then explore recall of Bartlett story - can substitutions and confabulations be observed in generative recall?
* How does temperature parameter for sampling explore level of distortion?

#### Installation:

In [None]:
!pip install simpletransformers
!pip install wordcloud

#### Imports:

In [None]:
import logging
from wordcloud import WordCloud
import matplotlib.pyplot as plt
from simpletransformers.language_modeling import (
    LanguageModelingModel,
    LanguageModelingArgs,
)
from simpletransformers.language_generation import LanguageGenerationModel, LanguageGenerationArgs

In [None]:
bartlett = ["One night two young men from Egulac went down to the river to hunt seals and while they were there it became foggy and calm.", 
            "Then they heard war-cries, and they thought: 'Maybe this is a war-party'.",
            "They escaped to the shore, and hid behind a log.",
            "Now canoes came up, and they heard the noise of paddles, and saw one canoe coming up to them.",
            "There were five men in the canoe, and they said:"
            "What do you think? We wish to take you along. We are going up the river to make war on the people.",
            "One of the young men said,'I have no arrows.'",
            "'Arrows are in the canoe,'' they said.",
            "'I will not go along. I might be killed. My relatives do not know where I have gone. But you,' he said, turning to the other, 'may go with them.'",
            "So one of the young men went, but the other returned home.",
            "And the warriors went on up the river to a town on the other side of Kalama.",
            "The people came down to the water and they began to fight, and many were killed.",
            "But presently the young man heard one of the warriors say, 'Quick, let us go home: that man has been hit.'", 
            "Now he thought: 'Oh, they are ghosts.' He did not feel sick, but they said he had been shot.",
            "So the canoes went back to Egulac and the young man went ashore to his house and made a fire.",
            "And he told everybody and said: 'Behold I accompanied the ghosts, and we went to fight. Many of our fellows were killed, and many of those who attacked us were killed. They said I was hit, and I did not feel sick.'",
            "He told it all, and then he became quiet. When the sun rose he fell down.", 
            "Something black came out of his mouth. His face became contorted. The people jumped up and cried. He was dead."
           ]

In [None]:
sents_list = bartlett

In [None]:
with open('train.txt', 'w') as fh:
    fh.write('\n'.join(sents_list))
    
with open('test.txt', 'w') as fh:
    fh.write('\n'.join(sents_list))

#### Train model

In [None]:
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

model_args = LanguageModelingArgs()
model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.num_train_epochs = 400
model_args.dataset_type = "simple"
model_args.mlm = False
model_args.fp16=False
model_args.save_model_every_epoch = False
model_args.save_best_model = True

train_file = "train.txt"
test_file = "test.txt"

model = LanguageModelingModel(
    "gpt2", "gpt2", args=model_args, use_cuda=True, train_files=train_file, 
)

# Train the model
model.train_model(train_file, eval_file=test_file)

# Evaluate the model
result = model.eval_model(test_file)

In [None]:
#!rm -rf outputs

#### Explore recall of story using trained model

In [None]:
def generate_with_params(temperature=1.0):
    model_args = LanguageGenerationArgs()
    model_args.temperature = temperature
    model_args.do_sampling= True
    model_args.max_length = 500
    model_args.num_beams = 1
    model_args.repetition_penalty = 1.05
    model_args.top_k = 50

    model = LanguageGenerationModel("gpt2", "./outputs", args=model_args)
    
    return model.generate(bartlett[0])

In [None]:
generated = {}
generated['set_1'] = {}

for i in [0.01,0.5,1.0,1.5,2.0]:
    gen_list = []
    for n in range(4):
        gen_list.append(generate_with_params(temperature=i))
    generated['set_1'][i] = gen_list

In [None]:
import re
bartlett_stopwords = ' '.join(bartlett).lower().replace(':', ' ').split(' ')
bartlett_stopwords = [re.sub(r"[']", '', b) for b in bartlett_stopwords]
bartlett_stopwords = [re.sub(r"[^a-zA-Z0-9-]", ' ', b).strip() for b in bartlett_stopwords]

In [None]:
import matplotlib
matplotlib.rcParams.update({'font.size': 14})

for temp in [0.01,0.5,1.0,1.5,2.0]:
    print("Temperature = {}:".format(temp))
    items = [item[0].lower() for item in generated['set_1'][temp]]
    text = ' '.join(items).replace(':', ' ')
    text = re.sub(r"[']", '', text)
    text = re.sub(r"[^a-zA-Z0-9-]", ' ', text)
    
    print(len([t for t in text.split() if t not in bartlett_stopwords]))
    
    wordcloud = WordCloud(width=600, height=400, background_color="white", max_font_size=50, stopwords=bartlett_stopwords).generate(text)
    fig = plt.figure(figsize=(10,5))
    plt.imshow(wordcloud, interpolation="bilinear")
    plt.axis("off")
    plt.title("Temperature = {}:".format(temp))
    fig.savefig('wordcloud_{}.png'.format(temp))