### Modelling statistical learning

To explore the effect of sleep on statistical learning, Durrant et al. (2011) constructed two types of sequence, both made up of regular tones at differing frequencies. One type had a structure in which the preceding two tones determined the next, except for a few transitions which were random to avoid repetition. The other type was the reverse – most transitions were random. After listening to a structured sequence, participants were tested on their ability distinguish short structured and unstructured sequences. Delayed recall was then tested, after a night’s sleep for one group, and after a waking rest for the other. Durrant et al. (2011) found that sleep improved performance more than waking rest, suggesting systems consolidation promotes statistical learning.

Here, we generate a set of sequences based on the transition structure in Durrant et al. (2011). A model with the GPT-2 architecture is trained from scratch on the structured sequences only. At the end of each epoch of the training, the perplexity is calculated for a two test sets of structured and unstructured sequences. We find that the difference in perplexity of these two sets increases over time, corresponding to improved ability to distinguish them. In addition, outputs from the trained model are structured in the same way as the training data.

Tested with conda_pytorch_latest_p36 kernel in AWS SageMaker.

#### Installation:

In [None]:
!pip install simpletransformers

#### Imports:

In [None]:
import pandas as pd
import random
import logging
from random import shuffle
from matplotlib import pyplot as plt
from statistical_learning_utils import *
from gpt import GPT

In [None]:
def get_random_sequence():
    start = [random.randint(1,5),random.randint(1,5)]
    for i in range(50):
        next_val = random.randint(1,5)
        start.append(next_val)
    return ','.join([str(i) for i in start])

!rm -rf durrant/
!mkdir durrant/

text_file = open("durrant/train.txt", "w")
walks = [get_sequence() for i in range(2000)]
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

text_file = open("durrant/val.txt", "w")
walks = [get_sequence() for i in range(100)]
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

text_file = open("durrant/structured_test.txt", "w")
walks = [get_sequence() for i in range(100)]
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

text_file = open("durrant/unstructured_test.txt", "w")
walks = [get_random_sequence() for i in range(100)]
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

#### Train generative model

Train GPT-2 from scratch on dataset created above.

In [None]:
structured_test_file = "durrant/structured_test.txt"
unstructured_test_file = "durrant/unstructured_test.txt"

all_unstructured = []
all_structured = []

for trial in range(3):

    perplexity_structured = []
    perplexity_unstructured = []
    
    for num in [1, 2, 3]:
    
        # Train the model
        gpt = GPT(vocab_size=10)
        model = gpt.train(segmented_sequence_list=[], 
                          best_model_dir='durrant', 
                          train_file="durrant/train.txt", 
                          test_file="durrant/val.txt", 
                          eps=num,
                          seed=trial)
        
        p = model.eval_model(structured_test_file)
        perplexity_structured.append(p)
        p = model.eval_model(unstructured_test_file)
        perplexity_unstructured.append(p)
    all_unstructured.append(perplexity_unstructured)
    all_structured.append(perplexity_structured)

In [None]:
# all_structured = [[{'eval_loss': 0.7781103165061386, 'perplexity': tensor(2.1774)},
#   {'eval_loss': 0.5021468801998797, 'perplexity': tensor(1.6523)},
#   {'eval_loss': 0.34246792432702616, 'perplexity': tensor(1.4084)}],
#  [{'eval_loss': 0.7827572005766409, 'perplexity': tensor(2.1875)},
#   {'eval_loss': 0.4712955543288478, 'perplexity': tensor(1.6021)},
#   {'eval_loss': 0.3468914611472024, 'perplexity': tensor(1.4147)}],
#  [{'eval_loss': 0.7814871138996549, 'perplexity': tensor(2.1847)},
#   {'eval_loss': 0.5483069876093923, 'perplexity': tensor(1.7303)},
#   {'eval_loss': 0.35185586081610787, 'perplexity': tensor(1.4217)}]]

# all_unstructured = [[{'eval_loss': 0.9646942218144735, 'perplexity': tensor(2.6240)},
#   {'eval_loss': 1.3187024195988972, 'perplexity': tensor(3.7386)},
#   {'eval_loss': 1.5715690041765755, 'perplexity': tensor(4.8142)}],
#  [{'eval_loss': 0.9497873113479143, 'perplexity': tensor(2.5852)},
#   {'eval_loss': 1.2719914162600483, 'perplexity': tensor(3.5680)},
#   {'eval_loss': 1.5441029322000197, 'perplexity': tensor(4.6838)}],
#  [{'eval_loss': 0.9562439778704702, 'perplexity': tensor(2.6019)},
#   {'eval_loss': 1.2173525448198672, 'perplexity': tensor(3.3782)},
#   {'eval_loss': 1.5220492607281533, 'perplexity': tensor(4.5816)}]]

Let's sample some outputs from the trained model:

In [None]:
# gpt = GPT(base_model='durrant', base_model_name='gpt2')

In [None]:
# gpt.continue_input('1,5,', do_sample=False)

#### Plot perplexity against time

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Function to calculate mean and SEM
def calc_metrics(data):
    eval_loss = [[trial['eval_loss'] for trial in epoch] for epoch in zip(*data)]
    perplexity = [[trial['perplexity'] for trial in epoch] for epoch in zip(*data)]
    eval_loss_mean = np.mean(eval_loss, axis=1)
    perplexity_mean = np.mean(perplexity, axis=1)
    eval_loss_sem = np.std(eval_loss, axis=1, ddof=1) #/ np.sqrt(len(eval_loss[0]))
    perplexity_sem = np.std(perplexity, axis=1, ddof=1) #/ np.sqrt(len(perplexity[0]))
    return eval_loss_mean, eval_loss_sem, perplexity_mean, perplexity_sem

structured_eval_loss_mean, structured_eval_loss_sem, structured_perplexity_mean, structured_perplexity_sem = calc_metrics(all_structured)
unstructured_eval_loss_mean, unstructured_eval_loss_sem, unstructured_perplexity_mean, unstructured_perplexity_sem = calc_metrics(all_unstructured)

# Plotting with specified figure size for perplexity
plt.figure(figsize=(4, 2.5))  # Set figure size here: (width, height) in inches
epochs = range(1, 4)
plt.errorbar(epochs, structured_perplexity_mean, yerr=structured_perplexity_sem, label='Structured', fmt='-o', color='red', capsize=5)
plt.errorbar(epochs, unstructured_perplexity_mean, yerr=unstructured_perplexity_sem, label='Unstructured', fmt='-o', color='blue', capsize=5)
plt.xlabel('Epoch')
plt.ylabel('Perplexity')
plt.legend()
plt.savefig('perplexities.png', dpi=500, bbox_inches='tight')
plt.show()


In [None]:
# Settings for the grouped bar chart
barWidth = 0.3
epochs = np.arange(1, 4)
r1 = np.arange(len(epochs))
r2 = [x + barWidth for x in r1]

# Create grouped bar chart
plt.figure(figsize=(5, 2.5))
plt.bar(r1, structured_perplexity_mean, color='blue', alpha=0.5, width=barWidth, label='Structured', yerr=structured_perplexity_sem, capsize=5)
plt.bar(r2, unstructured_perplexity_mean, color='red', alpha=0.5, width=barWidth, label='Unstructured', yerr=unstructured_perplexity_sem, capsize=5)

# Add labels, title, and legend
plt.xlabel('Epoch')
plt.ylabel('Perplexity')
plt.xticks([r + barWidth/2 for r in range(len(epochs))], ['1', '2', '3'])
plt.legend()
plt.savefig('perplexities.png', dpi=500, bbox_inches='tight')

plt.show()

In [None]:
gpt = GPT(base_model='durrant', base_model_name='gpt2')
data = ""
for num in range(20):
    for i in range(1, 6):
        out = gpt.continue_input(str(i), do_sample=True, temperature=0.1)
        data += out

data

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

data_list = [int(x) for x in data.split(',') if x]

# Calculate transition probabilities
# Initialize a dictionary to hold the transition counts
transition_counts = {((i, j), k): 0 for i in range(1, 6) for j in range(1, 6) for k in range(1, 6)}

# Populate the transition counts
for i in range(len(data_list) - 2):
    prev_pair = (data_list[i], data_list[i+1])
    next_num = data_list[i+2]
    transition_counts[(prev_pair, next_num)] += 1

# Calculate probabilities from counts
transition_probabilities = {}
for key, value in transition_counts.items():
    prev_pair = key[0]
    total_transitions = sum([transition_counts[(prev_pair, k)] for k in range(1, 6)])
    if total_transitions > 0:
        transition_probabilities[key] = value / total_transitions
    else:
        transition_probabilities[key] = 0

# Prepare data for plotting
plot_data = np.zeros((25, 5))  # 25 possible pairs and 5 possible next numbers
for i, pair in enumerate(transition_counts.keys()):
    y_index = (pair[0][0] - 1) * 5 + (pair[0][1] - 1)
    x_index = pair[1] - 1
    plot_data[y_index, x_index] = 1 - transition_probabilities[pair]

# Plot
fig, ax = plt.subplots(figsize=(5, 5))
cax = ax.matshow(plot_data, cmap='Greys')

# Set ticks
ax.set_xticks(range(5))
ax.set_xticklabels(range(1, 6))
ax.set_yticks(range(25))
ax.set_yticklabels([f'{i//5+1},{i%5+1}' for i in range(25)])

ax.set_xlabel('Next Number')
ax.set_ylabel('Previous Pair')
ax.set_title('Transition Probabilities')

plt.colorbar(cax)
plt.savefig('trps.png', dpi=500)
plt.show()


#### Visualise attention

In [None]:
!pip install bertviz

In [None]:
from bertviz import head_view, model_view
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained('durrant', output_attentions=True)
tokenizer = GPT2Tokenizer.from_pretrained('durrant')
input_text='2,2,4,1,1,4,1,1,4,1,1,4,1,5,5,3,5,4,3'
inputs = tokenizer.encode(input_text, return_tensors='pt')  # Tokenize input text
outputs = model(inputs)  # Run model
attention = outputs[-1]  # Retrieve attention from model outputs
tokens = tokenizer.convert_ids_to_tokens(inputs[0])  # Convert input ids to token strings
model_view(attention, tokens)  # Display model view

In [None]:
head_view(attention, tokens)