In [5]:
import numpy as np
import random
from tqdm import tqdm
import math
import torch

## Multiarmed bandits

In [6]:
# Function to convert episodes to trajectory sequence of codes
def episodes_to_trajectory(episodes):
    trajectory = []
    for episode in episodes:
        action = episode['action']
        reward = episode['reward']
        trajectory.extend([action, reward])
    return trajectory

def cumulative_reward_from_trajectory(trajectory):
    cumulative_reward = 0
    for i in range(0, len(trajectory), 2):
        cumulative_reward += trajectory[i + 1]
    return cumulative_reward

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

In [7]:
def select_arm(epsilon):
    if random.random() < epsilon:
        return random.randint(0, n_arms - 1), 'random'  # Exploration
    else:
        return max(range(n_arms), key=lambda x: estimated_rewards[x]), 'greedy'  # Exploitation

def update_reward(arm, reward):
    counts[arm] += 1
    estimated_rewards[arm] += (reward - estimated_rewards[arm]) / counts[arm]

def generate_reward(arm, prob_high_value, high_value=10, low_value=1):
    if random.random() < prob_high_value[arm]:
        return high_value
    else:
        return low_value

def epsilon_schedule(t):
    return max(epsilon_min, epsilon_initial / math.log(t + decay_offset))

# Function to convert episodes to trajectory sequence of codes
def episodes_to_trajectory(episodes):
    trajectory = []
    for episode in episodes:
        action = episode['action'] + 2
        reward = 0 if episode['reward'] == 1 else 1
        trajectory.extend([action, reward])
    return trajectory

def cumulative_reward_from_trajectory(trajectory):
    cumulative_reward = 0
    for i in range(0, len(trajectory), 2):
        cumulative_reward += trajectory[i + 1]
    return cumulative_reward

num_trajectories = 1000
trajectories = []

for _ in tqdm(range(num_trajectories)):
    #print('hello')
    # Initialization
    n_arms = 10
    counts = [0] * n_arms
    #rewards = [0.0] * n_arms
    estimated_rewards = [0.0] * n_arms

    epsilon_initial = 1.0
    epsilon_min = 0.001
    decay_offset = 1e-2  # Offset to avoid division by zero in log function

    # Each arm has its own fixed reward probability
    prob_high_value = np.ones(n_arms) * 0.1

    # Randomly select 2 indices
    indices = np.random.choice(n_arms, 2, replace=False)
    prob_high_value[indices] = 0.9

    high_value = 10
    low_value = 1

    episodes = []

    # Run multiple episodes
    for episode in range(1, 101):  # Number of episodes (1-indexed for log function)
        # Decay epsilon over episodes
        epsilon = epsilon_schedule(episode)
        #epsilon=0.1
        arm, arm_type = select_arm(epsilon)
        if arm in indices:
            selected_high = 1
        else:
            selected_high = 0
        reward = generate_reward(arm, prob_high_value)
        update_reward(arm, reward)
        episodes.append({'state': {'counts': list(counts), 'rewards': list(estimated_rewards)}, 'action': arm, 'reward': reward, 'arm_type': arm_type, 'selected_high': selected_high})

    # print()

    trajectory = episodes_to_trajectory(episodes)
    cumulative_reward = cumulative_reward_from_trajectory(trajectory)
    #print(f"Probabilities: {prob_high_value}")
    trajectories.append({'trajectory': trajectory, 'cumulative_reward': cumulative_reward, 'probabilities': prob_high_value,
                         'selected_highs': [episode['selected_high'] for episode in episodes], 'arm_types': [episode['arm_type'] for episode in episodes]})


100%|██████████| 1000/1000 [00:00<00:00, 3314.87it/s]


In [8]:
import plotly.express as px

# Selected highs for each trajectory
selected_highs = [trajectory['selected_highs'] for trajectory in trajectories]
# Stack the selected highs
selected_highs = np.stack(selected_highs, axis=0)
# Mean over each position
selected_highs = np.mean(selected_highs, axis=0)
# Plotly line plot
fig = px.line(y=selected_highs, labels={'x': 'Episode', 'y': 'Probability of Selecting High Reward Arm'})
# Set y-axis to be between 0 and 1
fig.update_yaxes(range=[0, 1])
fig.show()

In [9]:
# Arm types
arm_types = [trajectory['arm_types'] for trajectory in trajectories]
# Convert 'random' to 0 and 'greedy' to 1
arm_types = [[0 if arm_type == 'random' else 1 for arm_type in trajectory] for trajectory in arm_types]
# Stack the arm types
arm_types = np.stack(arm_types, axis=0)
# Mean over each position
arm_types = np.mean(arm_types, axis=0)
# Plotly line plot
fig = px.line(y=arm_types, labels={'x': 'Episode', 'y': 'Probability of Selecting Greedy Action'})
fig.show()

In [10]:
# Mean of cumulative reward in trajectories
mean_cumulative_reward = np.mean([t['cumulative_reward'] for t in trajectories])
std_cumulative_reward = np.std([t['cumulative_reward'] for t in trajectories])
print(f"Mean Cumulative Reward: {mean_cumulative_reward:.2f} ± {std_cumulative_reward:.2f}")

Mean Cumulative Reward: 66.20 ± 9.11


In [11]:
probabilities = torch.stack([torch.tensor(t['probabilities']) for t in trajectories], dim=0)
print(probabilities.shape)

# Create stacked tensor of trajectories
trajectories_tensor = torch.stack([torch.tensor(t['trajectory']) for t in trajectories], dim=0)
trajectories_tensor.shape

torch.Size([1000, 10])


torch.Size([1000, 200])

In [12]:
# Get the reward tensor from the trajectories tensor
rewards_tensor = trajectories_tensor[:, 1::2].float()

# Get the mean reward at each position
mean_rewards = rewards_tensor.mean(0)

# Smooth the mean rewards 
smoothed_mean_rewards = mean_rewards.unfold(0, 10, 1).mean(1).view(-1)
# Add zeros to the start of the smoothed mean rewards
smoothed_mean_rewards = torch.cat((torch.ones(9)*mean_rewards[0], smoothed_mean_rewards))

fig = px.line(mean_rewards, labels={'x': 'Position', 'y': 'Mean Reward'}, title='Mean Reward at Each Position')
# Add smoothed mean rewards
fig.add_scatter(y=smoothed_mean_rewards, mode='lines', name='Smoothed Mean Reward')
fig.show()

In [14]:
import tqdm
import torch
from torch.utils.data import Dataset, DataLoader, random_split

class BanditDataset(Dataset):
    def __init__(self, games_tensor, probabilities_tensor):
        self.games_tensor = games_tensor
        self.probabilities_tensor = probabilities_tensor

    def __len__(self):
        return self.games_tensor.size(0)

    def __getitem__(self, idx):
        return self.games_tensor[idx], self.probabilities_tensor[idx]

# Assuming `all_games_tensor` is defined elsewhere
games_dataset = BanditDataset(trajectories_tensor, probabilities)

# Split the dataset into training and test sets (10% for testing)
test_size = int(0.1 * len(games_dataset))
train_size = len(games_dataset) - test_size
train_dataset, test_dataset = random_split(games_dataset, [train_size, test_size])

# Create DataLoaders for both training and testing
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [15]:
from transformer_lens import HookedTransformer, HookedTransformerConfig

batch_size = 32
num_epochs = 50
lr = 1e-4
betas = (0.9, 0.95)
max_grad_norm = 1.0
wd = 0.01

cfg = HookedTransformerConfig(
    n_layers=4,
    d_model=64,
    d_head=64,
    n_heads=4,
    d_mlp=256,
    d_vocab=12,
    n_ctx=200,
    act_fn="relu",
    normalization_type="LN",
    device='cpu'
)
model = HookedTransformer(cfg)


optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))

def loss_fn(logits, tokens, per_token=False):
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]
    logits = logits[:, :-1, :]
    tokens = tokens[:, 1:].long()
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    if per_token:
        return -correct_log_probs
    else:
        return -correct_log_probs.mean()

In [16]:
# Training loop 
train_losses = []
test_losses = []

for epoch in tqdm.tqdm(range(num_epochs)):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    for tokens, _ in train_loader:
        # Uncomment the following line if you are using a GPU
        # tokens = tokens.cuda()

        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()

        # Gradient clipping
        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

        running_loss += loss.item()
        train_losses.append(loss.item())

    # Calculate average loss for the training epoch
    average_train_loss = running_loss / len(train_loader)

    # Evaluate on test set
    model.eval()  # Set the model to evaluation mode
    test_loss = 0.0
    with torch.no_grad():
        for tokens, _ in test_loader:
            # Uncomment the following line if you are using a GPU
            # tokens = tokens.cuda()

            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            test_loss += loss.item()

    # Calculate average loss for the test epoch
    average_test_loss = test_loss / len(test_loader)
    test_losses.append(average_test_loss)
    print(f"Epoch {epoch}: Train loss: {average_train_loss:.4f} Test Loss: {average_test_loss:.4f}")

  2%|▏         | 1/50 [00:07<06:00,  7.37s/it]

Epoch 0: Train loss: 2.4735 Test Loss: 2.2331


  4%|▍         | 2/50 [00:14<05:50,  7.30s/it]

Epoch 1: Train loss: 2.1341 Test Loss: 2.0298


  6%|▌         | 3/50 [00:21<05:42,  7.28s/it]

Epoch 2: Train loss: 1.8532 Test Loss: 1.6532


  8%|▊         | 4/50 [00:29<05:37,  7.33s/it]

Epoch 3: Train loss: 1.4449 Test Loss: 1.2946


 10%|█         | 5/50 [00:36<05:29,  7.32s/it]

Epoch 4: Train loss: 1.2163 Test Loss: 1.1669


 12%|█▏        | 6/50 [00:43<05:21,  7.31s/it]

Epoch 5: Train loss: 1.1299 Test Loss: 1.1136


 14%|█▍        | 7/50 [00:51<05:15,  7.34s/it]

Epoch 6: Train loss: 1.0864 Test Loss: 1.0710


 16%|█▌        | 8/50 [00:58<05:08,  7.34s/it]

Epoch 7: Train loss: 1.0599 Test Loss: 1.0452


 18%|█▊        | 9/50 [01:06<05:02,  7.38s/it]

Epoch 8: Train loss: 1.0353 Test Loss: 1.0248


 20%|██        | 10/50 [01:13<05:00,  7.51s/it]

Epoch 9: Train loss: 1.0111 Test Loss: 1.0083


 22%|██▏       | 11/50 [01:21<04:52,  7.50s/it]

Epoch 10: Train loss: 1.0021 Test Loss: 0.9928


 24%|██▍       | 12/50 [01:28<04:45,  7.52s/it]

Epoch 11: Train loss: 0.9853 Test Loss: 0.9804


 26%|██▌       | 13/50 [01:36<04:36,  7.48s/it]

Epoch 12: Train loss: 0.9738 Test Loss: 0.9684


 28%|██▊       | 14/50 [01:43<04:28,  7.46s/it]

Epoch 13: Train loss: 0.9624 Test Loss: 0.9595


 30%|███       | 15/50 [01:51<04:22,  7.49s/it]

Epoch 14: Train loss: 0.9498 Test Loss: 0.9495


 32%|███▏      | 16/50 [01:58<04:13,  7.45s/it]

Epoch 15: Train loss: 0.9443 Test Loss: 0.9433


 34%|███▍      | 17/50 [02:06<04:06,  7.47s/it]

Epoch 16: Train loss: 0.9368 Test Loss: 0.9377


 36%|███▌      | 18/50 [02:13<04:00,  7.50s/it]

Epoch 17: Train loss: 0.9306 Test Loss: 0.9312


 38%|███▊      | 19/50 [02:21<03:51,  7.46s/it]

Epoch 18: Train loss: 0.9261 Test Loss: 0.9277


 40%|████      | 20/50 [02:28<03:42,  7.43s/it]

Epoch 19: Train loss: 0.9240 Test Loss: 0.9253


 42%|████▏     | 21/50 [02:35<03:34,  7.41s/it]

Epoch 20: Train loss: 0.9169 Test Loss: 0.9214


 44%|████▍     | 22/50 [02:43<03:27,  7.40s/it]

Epoch 21: Train loss: 0.9226 Test Loss: 0.9196


 46%|████▌     | 23/50 [02:50<03:19,  7.39s/it]

Epoch 22: Train loss: 0.9166 Test Loss: 0.9170


 48%|████▊     | 24/50 [02:57<03:12,  7.39s/it]

Epoch 23: Train loss: 0.9081 Test Loss: 0.9145


 50%|█████     | 25/50 [03:05<03:04,  7.39s/it]

Epoch 24: Train loss: 0.9090 Test Loss: 0.9103


 52%|█████▏    | 26/50 [03:13<02:59,  7.49s/it]

Epoch 25: Train loss: 0.9035 Test Loss: 0.9090


 54%|█████▍    | 27/50 [03:20<02:52,  7.50s/it]

Epoch 26: Train loss: 0.9035 Test Loss: 0.9101


 56%|█████▌    | 28/50 [03:28<02:46,  7.56s/it]

Epoch 27: Train loss: 0.8983 Test Loss: 0.9063


 58%|█████▊    | 29/50 [03:36<02:39,  7.60s/it]

Epoch 28: Train loss: 0.8962 Test Loss: 0.9045


 60%|██████    | 30/50 [03:43<02:30,  7.55s/it]

Epoch 29: Train loss: 0.8923 Test Loss: 0.9015


 62%|██████▏   | 31/50 [03:50<02:22,  7.50s/it]

Epoch 30: Train loss: 0.8934 Test Loss: 0.9027


 64%|██████▍   | 32/50 [03:58<02:15,  7.51s/it]

Epoch 31: Train loss: 0.8952 Test Loss: 0.9003


 66%|██████▌   | 33/50 [04:05<02:08,  7.54s/it]

Epoch 32: Train loss: 0.8923 Test Loss: 0.8988


 68%|██████▊   | 34/50 [04:13<02:01,  7.59s/it]

Epoch 33: Train loss: 0.8906 Test Loss: 0.8994


 70%|███████   | 35/50 [04:21<01:54,  7.66s/it]

Epoch 34: Train loss: 0.8896 Test Loss: 0.9009


 72%|███████▏  | 36/50 [04:29<01:47,  7.64s/it]

Epoch 35: Train loss: 0.8883 Test Loss: 0.8958


 74%|███████▍  | 37/50 [04:36<01:39,  7.62s/it]

Epoch 36: Train loss: 0.8848 Test Loss: 0.8959


 76%|███████▌  | 38/50 [04:44<01:31,  7.63s/it]

Epoch 37: Train loss: 0.8842 Test Loss: 0.8953


 78%|███████▊  | 39/50 [04:51<01:23,  7.61s/it]

Epoch 38: Train loss: 0.8817 Test Loss: 0.8921


 80%|████████  | 40/50 [04:59<01:16,  7.63s/it]

Epoch 39: Train loss: 0.8790 Test Loss: 0.8921


 82%|████████▏ | 41/50 [05:07<01:08,  7.59s/it]

Epoch 40: Train loss: 0.8777 Test Loss: 0.8903


 84%|████████▍ | 42/50 [05:14<01:00,  7.52s/it]

Epoch 41: Train loss: 0.8761 Test Loss: 0.8908


 86%|████████▌ | 43/50 [05:21<00:52,  7.45s/it]

Epoch 42: Train loss: 0.8752 Test Loss: 0.8901


 88%|████████▊ | 44/50 [05:29<00:44,  7.41s/it]

Epoch 43: Train loss: 0.8748 Test Loss: 0.8882


 90%|█████████ | 45/50 [05:36<00:37,  7.47s/it]

Epoch 44: Train loss: 0.8756 Test Loss: 0.8867


 92%|█████████▏| 46/50 [05:44<00:29,  7.47s/it]

Epoch 45: Train loss: 0.8732 Test Loss: 0.8900


 94%|█████████▍| 47/50 [05:51<00:22,  7.48s/it]

Epoch 46: Train loss: 0.8767 Test Loss: 0.8869


 96%|█████████▌| 48/50 [05:59<00:15,  7.53s/it]

Epoch 47: Train loss: 0.8750 Test Loss: 0.8870


 98%|█████████▊| 49/50 [06:06<00:07,  7.57s/it]

Epoch 48: Train loss: 0.8708 Test Loss: 0.8869


100%|██████████| 50/50 [06:14<00:00,  7.49s/it]

Epoch 49: Train loss: 0.8708 Test Loss: 0.8854





In [17]:
# Plotly line plot of training losses
fig = px.line(train_losses, labels={'x': 'Iteration', 'y': 'Loss'}, title='Training Loss')
fig.show()

In [21]:
def evaluate_trajectory(trajectory, original_seq):

    # Calculate cumulative reward for both
    cumulative_reward = 0
    original_cumulative_reward = 0
    for i in range(0, len(trajectory), 2):
        cumulative_reward += trajectory[i + 1]
        original_cumulative_reward += original_seq[i + 1]
    
    return cumulative_reward, original_cumulative_reward


evaluate_trajectory(start.squeeze(0).tolist(), tokens.squeeze(0).tolist())

(85, 68)

In [55]:
iter_test_loader = iter(test_loader)

In [61]:
# Get a random batch from the test loader
tokens, probs = next(iter_test_loader)

def generate_bandit_predictions(model, initial_tokens, arm_probabilities):
    model.eval()
    with torch.no_grad():
        # Initialise the prediction sequence with the first token
        prediction_sequence = initial_tokens[:, :1]
        
        # Loop through each token in the sequence
        for token_index in tqdm.tqdm(range(initial_tokens.shape[1] - 1)):
            
            # Odd indices: Generate the next prediction
            if (token_index + 1) % 2 == 0:
                model_logits = model(prediction_sequence)
                next_prediction = model_logits.argmax(-1)[:, -1]
                prediction_sequence = torch.cat([prediction_sequence, next_prediction.unsqueeze(1)], dim=1)
            
            # Even indices: Generate the reward based on probabilities
            else:
                previous_action = prediction_sequence[:, -1]
                random_numbers = torch.rand(previous_action.shape[0])
                reward_probabilities = arm_probabilities[torch.arange(previous_action.shape[0]), previous_action - 2]
                reward_token = (random_numbers < reward_probabilities).long().unsqueeze(1)
                prediction_sequence = torch.cat([prediction_sequence, reward_token], dim=1)

    return prediction_sequence, initial_tokens

# Example usage
# Assuming 'model' is a trained transformer model and 'tokens' and 'probs' are from test_loader
predicted_sequence, actual_sequence = generate_bandit_predictions(model, tokens, probs)

100%|██████████| 199/199 [00:06<00:00, 29.63it/s]


In [62]:
cum_rewards, original_rewards = evaluate_trajectory(predicted_sequence, actual_sequence)
print(f"Cumulative reward: {cum_rewards.float().mean()}, Original cumulative reward: {original_rewards.float().mean()}")

Cumulative reward: 59.5, Original cumulative reward: 54.17499923706055


In [65]:
# Run for all batches in test_loader, stack predicted_sequences and actual_sequences, then evaluate
predicted_sequences = []
actual_sequences = []
i = 0
for tokens, probs in tqdm.tqdm(train_loader):
    predicted_sequence, actual_sequence = generate_bandit_predictions(model, tokens, probs)
    predicted_sequences.append(predicted_sequence)
    actual_sequences.append(actual_sequence)
    i += 1
    if i > 4: break

predicted_sequences = torch.cat(predicted_sequences, dim=0)
actual_sequences = torch.cat(actual_sequences, dim=0)

cum_rewards, original_rewards = evaluate_trajectory(predicted_sequences, actual_sequences)
print(f"Cumulative reward: {cum_rewards.float().mean():.2f}, Original cumulative reward: {original_rewards.float().mean():.2f}")

100%|██████████| 199/199 [00:06<00:00, 30.59it/s]
100%|██████████| 199/199 [00:06<00:00, 29.18it/s]
100%|██████████| 199/199 [00:06<00:00, 29.28it/s]
100%|██████████| 199/199 [00:06<00:00, 29.09it/s]
100%|██████████| 199/199 [00:06<00:00, 29.50it/s]
 14%|█▍        | 4/29 [00:33<03:30,  8.43s/it]

Cumulative reward: 248.13, Original cumulative reward: 275.25





In [104]:
tokens, probs = next(iter(test_loader))
rand_idx = random.randint(0, tokens.shape[0] - 1)
tokens = tokens[rand_idx:rand_idx+1]
probs = probs[rand_idx:rand_idx+1]

logits = model(tokens).squeeze()[:, 2:]

# We only want every second logit starting from logit zero (action logits)
action_logits = logits[1::2, :]

action_logits.shape, logits.shape

(torch.Size([100, 10]), torch.Size([200, 10]))

In [105]:
# Plotly imshow plot of logits
fig = px.imshow(logits.detach().numpy().T, labels={'x': 'Position', 'y': 'Token', 'color': 'Logit'}, title='Logits for Each Token Position')
fig.show()

In [106]:
# Plotly imshow plot of logits
fig = px.imshow(action_logits.detach().numpy().T, labels={'x': 'Position', 'y': 'Token', 'color': 'Logit'}, title='Logits for Each Token Position')
fig.show()