In [5]:
import numpy as np
import random
from tqdm import tqdm
import math
import torch

## Multiarmed bandits

In [6]:
# Function to convert episodes to trajectory sequence of codes
def episodes_to_trajectory(episodes):
    trajectory = []
    for episode in episodes:
        action = episode['action']
        reward = episode['reward']
        trajectory.extend([action, reward])
    return trajectory

def cumulative_reward_from_trajectory(trajectory):
    cumulative_reward = 0
    for i in range(0, len(trajectory), 2):
        cumulative_reward += trajectory[i + 1]
    return cumulative_reward

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

In [7]:
def select_arm(epsilon):
    if random.random() < epsilon:
        return random.randint(0, n_arms - 1), 'random'  # Exploration
    else:
        return max(range(n_arms), key=lambda x: estimated_rewards[x]), 'greedy'  # Exploitation

def update_reward(arm, reward):
    counts[arm] += 1
    estimated_rewards[arm] += (reward - estimated_rewards[arm]) / counts[arm]

def generate_reward(arm, prob_high_value, high_value=10, low_value=1):
    if random.random() < prob_high_value[arm]:
        return high_value
    else:
        return low_value

def epsilon_schedule(t):
    return max(epsilon_min, epsilon_initial / math.log(t + decay_offset))

# Function to convert episodes to trajectory sequence of codes
def episodes_to_trajectory(episodes):
    trajectory = []
    for episode in episodes:
        action = episode['action'] + 2
        reward = 0 if episode['reward'] == 1 else 1
        trajectory.extend([action, reward])
    return trajectory

def cumulative_reward_from_trajectory(trajectory):
    cumulative_reward = 0
    for i in range(0, len(trajectory), 2):
        cumulative_reward += trajectory[i + 1]
    return cumulative_reward

num_trajectories = 1000
trajectories = []

for _ in tqdm(range(num_trajectories)):
    #print('hello')
    # Initialization
    n_arms = 10
    counts = [0] * n_arms
    #rewards = [0.0] * n_arms
    estimated_rewards = [0.0] * n_arms

    epsilon_initial = 1.0
    epsilon_min = 0.001
    decay_offset = 1e-2  # Offset to avoid division by zero in log function

    # Each arm has its own fixed reward probability
    prob_high_value = np.ones(n_arms) * 0.1

    # Randomly select 2 indices
    indices = np.random.choice(n_arms, 2, replace=False)
    prob_high_value[indices] = 0.9

    high_value = 10
    low_value = 1

    episodes = []

    # Run multiple episodes
    for episode in range(1, 101):  # Number of episodes (1-indexed for log function)
        # Decay epsilon over episodes
        epsilon = epsilon_schedule(episode)
        #epsilon=0.1
        arm, arm_type = select_arm(epsilon)
        if arm in indices:
            selected_high = 1
        else:
            selected_high = 0
        reward = generate_reward(arm, prob_high_value)
        update_reward(arm, reward)
        episodes.append({'state': {'counts': list(counts), 'rewards': list(estimated_rewards)}, 'action': arm, 'reward': reward, 'arm_type': arm_type, 'selected_high': selected_high})

    # print()

    trajectory = episodes_to_trajectory(episodes)
    cumulative_reward = cumulative_reward_from_trajectory(trajectory)
    #print(f"Probabilities: {prob_high_value}")
    trajectories.append({'trajectory': trajectory, 'cumulative_reward': cumulative_reward, 'probabilities': prob_high_value,
                         'selected_highs': [episode['selected_high'] for episode in episodes], 'arm_types': [episode['arm_type'] for episode in episodes]})


100%|██████████| 1000/1000 [00:00<00:00, 3314.87it/s]


In [8]:
import plotly.express as px

# Selected highs for each trajectory
selected_highs = [trajectory['selected_highs'] for trajectory in trajectories]
# Stack the selected highs
selected_highs = np.stack(selected_highs, axis=0)
# Mean over each position
selected_highs = np.mean(selected_highs, axis=0)
# Plotly line plot
fig = px.line(y=selected_highs, labels={'x': 'Episode', 'y': 'Probability of Selecting High Reward Arm'})
# Set y-axis to be between 0 and 1
fig.update_yaxes(range=[0, 1])
fig.show()

In [9]:
# Arm types
arm_types = [trajectory['arm_types'] for trajectory in trajectories]
# Convert 'random' to 0 and 'greedy' to 1
arm_types = [[0 if arm_type == 'random' else 1 for arm_type in trajectory] for trajectory in arm_types]
# Stack the arm types
arm_types = np.stack(arm_types, axis=0)
# Mean over each position
arm_types = np.mean(arm_types, axis=0)
# Plotly line plot
fig = px.line(y=arm_types, labels={'x': 'Episode', 'y': 'Probability of Selecting Greedy Action'})
fig.show()

In [10]:
# Mean of cumulative reward in trajectories
mean_cumulative_reward = np.mean([t['cumulative_reward'] for t in trajectories])
std_cumulative_reward = np.std([t['cumulative_reward'] for t in trajectories])
print(f"Mean Cumulative Reward: {mean_cumulative_reward:.2f} ± {std_cumulative_reward:.2f}")

Mean Cumulative Reward: 66.20 ± 9.11


In [11]:
probabilities = torch.stack([torch.tensor(t['probabilities']) for t in trajectories], dim=0)
print(probabilities.shape)

# Create stacked tensor of trajectories
trajectories_tensor = torch.stack([torch.tensor(t['trajectory']) for t in trajectories], dim=0)
trajectories_tensor.shape

torch.Size([1000, 10])


torch.Size([1000, 200])

In [12]:
# Get the reward tensor from the trajectories tensor
rewards_tensor = trajectories_tensor[:, 1::2].float()

# Get the mean reward at each position
mean_rewards = rewards_tensor.mean(0)

# Smooth the mean rewards 
smoothed_mean_rewards = mean_rewards.unfold(0, 10, 1).mean(1).view(-1)
# Add zeros to the start of the smoothed mean rewards
smoothed_mean_rewards = torch.cat((torch.ones(9)*mean_rewards[0], smoothed_mean_rewards))

fig = px.line(mean_rewards, labels={'x': 'Position', 'y': 'Mean Reward'}, title='Mean Reward at Each Position')
# Add smoothed mean rewards
fig.add_scatter(y=smoothed_mean_rewards, mode='lines', name='Smoothed Mean Reward')
fig.show()

In [13]:
import tqdm
import torch
from torch.utils.data import Dataset, DataLoader, random_split

class BanditDataset(Dataset):
    def __init__(self, games_tensor, probabilities_tensor):
        self.games_tensor = games_tensor
        self.probabilities_tensor = probabilities_tensor

    def __len__(self):
        return self.games_tensor.size(0)

    def __getitem__(self, idx):
        return self.games_tensor[idx], self.probabilities_tensor[idx]

# Assuming `all_games_tensor` is defined elsewhere
games_dataset = BanditDataset(trajectories_tensor, probabilities)

# Split the dataset into training and test sets (10% for testing)
test_size = int(0.1 * len(games_dataset))
train_size = len(games_dataset) - test_size
train_dataset, test_dataset = random_split(games_dataset, [train_size, test_size])

# Create DataLoaders for both training and testing
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [78]:
from transformer_lens import HookedTransformer, HookedTransformerConfig

batch_size = 32
num_epochs = 50
lr = 1e-4
betas = (0.9, 0.95)
max_grad_norm = 1.0
wd = 0.01

cfg = HookedTransformerConfig(
    n_layers=4,
    d_model=64,
    d_head=64,
    n_heads=4,
    d_mlp=256,
    d_vocab=12,
    n_ctx=200,
    act_fn="relu",
    normalization_type="LN",
    device='cpu'
)
model = HookedTransformer(cfg)


optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))

def loss_fn(logits, tokens, per_token=False):
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]


    # logits = logits[:, 1:]
    # tokens = tokens[:, :-1].long()
    logits = logits[:, :-1, :]
    tokens = tokens[:, 1:].long()
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    if per_token:
        return -correct_log_probs
    else:
        return -correct_log_probs.mean()

In [79]:
# Training loop 
train_losses = []
test_losses = []

for epoch in tqdm.tqdm(range(num_epochs)):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    for tokens, _ in train_loader:
        # Uncomment the following line if you are using a GPU
        # tokens = tokens.cuda()

        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()

        # Gradient clipping
        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

        running_loss += loss.item()
        train_losses.append(loss.item())

    # Calculate average loss for the training epoch
    average_train_loss = running_loss / len(train_loader)

    # Evaluate on test set
    model.eval()  # Set the model to evaluation mode
    test_loss = 0.0
    with torch.no_grad():
        for tokens, _ in test_loader:
            # Uncomment the following line if you are using a GPU
            # tokens = tokens.cuda()

            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            test_loss += loss.item()

    # Calculate average loss for the test epoch
    average_test_loss = test_loss / len(test_loader)
    test_losses.append(average_test_loss)
    print(f"Epoch {epoch}: Train loss: {average_train_loss:.4f} Test Loss: {average_test_loss:.4f}")

  2%|▏         | 1/50 [00:07<05:56,  7.27s/it]

Epoch 0: Train loss: 2.5082 Test Loss: 2.2936


  4%|▍         | 2/50 [00:14<05:45,  7.19s/it]

Epoch 1: Train loss: 2.2105 Test Loss: 2.0768


  6%|▌         | 3/50 [00:21<05:36,  7.15s/it]

Epoch 2: Train loss: 2.0110 Test Loss: 1.8507


  8%|▊         | 4/50 [00:28<05:29,  7.16s/it]

Epoch 3: Train loss: 1.6690 Test Loss: 1.3881


 10%|█         | 5/50 [00:35<05:17,  7.06s/it]

Epoch 4: Train loss: 1.3226 Test Loss: 1.1873


 12%|█▏        | 6/50 [00:42<05:04,  6.91s/it]

Epoch 5: Train loss: 1.1892 Test Loss: 1.1123


 14%|█▍        | 7/50 [00:48<04:53,  6.82s/it]

Epoch 6: Train loss: 1.1262 Test Loss: 1.0768


 16%|█▌        | 8/50 [00:55<04:48,  6.86s/it]

Epoch 7: Train loss: 1.0912 Test Loss: 1.0468


 18%|█▊        | 9/50 [01:02<04:40,  6.85s/it]

Epoch 8: Train loss: 1.0647 Test Loss: 1.0335


 20%|██        | 10/50 [01:09<04:33,  6.84s/it]

Epoch 9: Train loss: 1.0468 Test Loss: 1.0192


 22%|██▏       | 11/50 [01:16<04:26,  6.82s/it]

Epoch 10: Train loss: 1.0256 Test Loss: 1.0143


 24%|██▍       | 12/50 [01:22<04:18,  6.80s/it]

Epoch 11: Train loss: 1.0204 Test Loss: 0.9929


 26%|██▌       | 13/50 [01:29<04:12,  6.82s/it]

Epoch 12: Train loss: 1.0030 Test Loss: 0.9851


 28%|██▊       | 14/50 [01:36<04:07,  6.86s/it]

Epoch 13: Train loss: 0.9895 Test Loss: 0.9760


 30%|███       | 15/50 [01:43<04:01,  6.90s/it]

Epoch 14: Train loss: 0.9771 Test Loss: 0.9658


 32%|███▏      | 16/50 [01:50<03:56,  6.97s/it]

Epoch 15: Train loss: 0.9722 Test Loss: 0.9629


 34%|███▍      | 17/50 [01:57<03:48,  6.92s/it]

Epoch 16: Train loss: 0.9639 Test Loss: 0.9530


 36%|███▌      | 18/50 [02:04<03:41,  6.93s/it]

Epoch 17: Train loss: 0.9597 Test Loss: 0.9514


 38%|███▊      | 19/50 [02:11<03:35,  6.94s/it]

Epoch 18: Train loss: 0.9488 Test Loss: 0.9439


 40%|████      | 20/50 [02:18<03:27,  6.93s/it]

Epoch 19: Train loss: 0.9424 Test Loss: 0.9412


 42%|████▏     | 21/50 [02:25<03:22,  6.97s/it]

Epoch 20: Train loss: 0.9387 Test Loss: 0.9394


 44%|████▍     | 22/50 [02:32<03:13,  6.92s/it]

Epoch 21: Train loss: 0.9330 Test Loss: 0.9369


 46%|████▌     | 23/50 [02:39<03:07,  6.93s/it]

Epoch 22: Train loss: 0.9271 Test Loss: 0.9313


 48%|████▊     | 24/50 [02:46<02:59,  6.91s/it]

Epoch 23: Train loss: 0.9232 Test Loss: 0.9287


 50%|█████     | 25/50 [02:53<02:53,  6.95s/it]

Epoch 24: Train loss: 0.9227 Test Loss: 0.9254


 52%|█████▏    | 26/50 [03:00<02:48,  7.01s/it]

Epoch 25: Train loss: 0.9214 Test Loss: 0.9226


 54%|█████▍    | 27/50 [03:07<02:40,  6.98s/it]

Epoch 26: Train loss: 0.9146 Test Loss: 0.9230


 56%|█████▌    | 28/50 [03:14<02:32,  6.94s/it]

Epoch 27: Train loss: 0.9109 Test Loss: 0.9190


 58%|█████▊    | 29/50 [03:21<02:26,  7.00s/it]

Epoch 28: Train loss: 0.9063 Test Loss: 0.9178


 60%|██████    | 30/50 [03:28<02:18,  6.92s/it]

Epoch 29: Train loss: 0.9072 Test Loss: 0.9153


 62%|██████▏   | 31/50 [03:34<02:11,  6.90s/it]

Epoch 30: Train loss: 0.9038 Test Loss: 0.9160


 64%|██████▍   | 32/50 [03:41<02:03,  6.88s/it]

Epoch 31: Train loss: 0.9007 Test Loss: 0.9104


 66%|██████▌   | 33/50 [03:48<01:57,  6.94s/it]

Epoch 32: Train loss: 0.8989 Test Loss: 0.9157


 68%|██████▊   | 34/50 [03:55<01:51,  6.94s/it]

Epoch 33: Train loss: 0.8985 Test Loss: 0.9095


 70%|███████   | 35/50 [04:02<01:44,  6.94s/it]

Epoch 34: Train loss: 0.8973 Test Loss: 0.9070


 72%|███████▏  | 36/50 [04:09<01:36,  6.90s/it]

Epoch 35: Train loss: 0.8954 Test Loss: 0.9097


 74%|███████▍  | 37/50 [04:16<01:29,  6.90s/it]

Epoch 36: Train loss: 0.8980 Test Loss: 0.9086


 76%|███████▌  | 38/50 [04:23<01:21,  6.83s/it]

Epoch 37: Train loss: 0.8944 Test Loss: 0.9065


 78%|███████▊  | 39/50 [04:29<01:14,  6.79s/it]

Epoch 38: Train loss: 0.8879 Test Loss: 0.9052


 80%|████████  | 40/50 [04:36<01:08,  6.81s/it]

Epoch 39: Train loss: 0.8888 Test Loss: 0.9021


 82%|████████▏ | 41/50 [04:43<01:01,  6.83s/it]

Epoch 40: Train loss: 0.8859 Test Loss: 0.9028


 84%|████████▍ | 42/50 [04:50<00:54,  6.86s/it]

Epoch 41: Train loss: 0.8866 Test Loss: 0.9005


 86%|████████▌ | 43/50 [04:57<00:48,  6.87s/it]

Epoch 42: Train loss: 0.8844 Test Loss: 0.9017


 88%|████████▊ | 44/50 [05:04<00:41,  6.90s/it]

Epoch 43: Train loss: 0.8849 Test Loss: 0.9028


 90%|█████████ | 45/50 [05:11<00:34,  6.95s/it]

Epoch 44: Train loss: 0.8822 Test Loss: 0.8983


 92%|█████████▏| 46/50 [05:18<00:27,  6.97s/it]

Epoch 45: Train loss: 0.8806 Test Loss: 0.9026


 94%|█████████▍| 47/50 [05:25<00:20,  6.94s/it]

Epoch 46: Train loss: 0.8806 Test Loss: 0.9030


 96%|█████████▌| 48/50 [05:32<00:13,  6.94s/it]

Epoch 47: Train loss: 0.8819 Test Loss: 0.9014


 98%|█████████▊| 49/50 [05:39<00:06,  6.94s/it]

Epoch 48: Train loss: 0.8790 Test Loss: 0.8982


100%|██████████| 50/50 [05:46<00:00,  6.92s/it]

Epoch 49: Train loss: 0.8782 Test Loss: 0.8965





In [80]:
# Plotly line plot of training losses
fig = px.line(train_losses, labels={'x': 'Iteration', 'y': 'Loss'}, title='Training Loss')
fig.show()

In [81]:
trajectories_tensor.shape

torch.Size([1000, 200])

In [107]:
# Get a random batch from the test loader
tokens, probs = next(iter(test_loader))
rand_idx = random.randint(0, len(tokens) - 1)
tokens = tokens[rand_idx, :].unsqueeze(0)
probs = probs[rand_idx, :]

In [108]:
tokens, probs

(tensor([[10,  0,  2,  0,  2,  1, 10,  0,  2,  0,  7,  0, 11,  1, 11,  0,  8,  1,
           8,  1,  8,  1,  8,  1,  8,  1,  8,  1,  8,  1,  3,  0,  8,  1,  3,  0,
           8,  0,  8,  1,  8,  1,  8,  0,  8,  1,  8,  1,  8,  1,  8,  1,  8,  1,
           8,  1,  8,  1,  2,  0,  8,  1,  8,  1,  8,  1,  8,  1,  4,  1,  4,  0,
           8,  0,  8,  1,  8,  1,  8,  1,  8,  1,  8,  1,  8,  1, 10,  0,  8,  1,
           8,  1,  8,  1,  8,  1,  8,  1, 10,  0,  8,  1,  8,  1,  4,  0,  8,  1,
           8,  0, 10,  0,  8,  1,  8,  1,  8,  1, 11,  0, 11,  0,  8,  0,  8,  1,
           8,  1,  8,  1,  8,  1,  8,  1,  8,  1,  8,  1,  7,  0,  8,  1,  8,  1,
           8,  1,  8,  0,  8,  1,  8,  0,  8,  1,  8,  1,  8,  0,  8,  1,  8,  1,
           6,  1,  6,  0,  8,  1,  8,  1,  8,  1,  8,  1, 11,  1,  8,  1,  8,  1,
           8,  1, 11,  0,  8,  1,  8,  1,  8,  0,  8,  1,  8,  0,  2,  0,  8,  1,
           4,  1]]),
 tensor([0.1000, 0.1000, 0.9000, 0.1000, 0.1000, 0.1000, 0.9000, 0.1000, 0.10

In [113]:
# Get one batch of tokens from the test loader
#tokens = next(iter(test_loader))
# tokens = trajectories_tensor[0, :].unsqueeze(0)
# print(tokens.shape)

# TO-DO: bug here with probs not lining up with tokens

# idx = 103
# tokens = trajectories_tensor[idx, :].unsqueeze(0)
# probs = probabilities[idx]

model.eval()
with torch.no_grad():
    # Go through tokens from just one to get the predictions
    start = tokens[:, :1]
    for i in range(tokens.shape[1]-1):
        logits = model(tokens[:, :i+1])
        logits_with_previous = model(start)
        
        # Get the next prediction
        next_token = logits.argmax(-1)[:, -1]
        next_token_with_previous = logits_with_previous.argmax(-1)[:, -1]   
        # Get the true next token
        true_next_token = tokens[:, i+1]
        print(f"Tokens = {tokens[:, :i+1]}, Start = {start}")
        print(f"Predicted: {next_token.item()} True: {true_next_token.item()}, Predicted with start: {next_token_with_previous.item()}\n")

        # Add to start
        #if true_next_token > 1:
        if (i+1) % 2 == 0:
            start = torch.cat([start, next_token.unsqueeze(1)], dim=1)
        else:
            # Get reward from probs
            rand_num = np.random.rand()
            # Previous action
            prev_action = start[:, -1]
            #print(f"Previous action: {prev_action.item()}")
            reward = 1 if rand_num < probs[prev_action-2] else 0
            start = torch.cat([start, torch.tensor([reward]).unsqueeze(0)], dim=1)
            #start = torch.cat([start, true_next_token.unsqueeze(1)], dim=1)

print()
print(f"Final tokens: {tokens}")
print(f"Final start: {start}")

Tokens = tensor([[10]]), Start = tensor([[10]])
Predicted: 0 True: 0, Predicted with start: 0

Tokens = tensor([[10,  0]]), Start = tensor([[10,  0]])
Predicted: 10 True: 2, Predicted with start: 10

Tokens = tensor([[10,  0,  2]]), Start = tensor([[10,  0, 10]])
Predicted: 0 True: 0, Predicted with start: 0

Tokens = tensor([[10,  0,  2,  0]]), Start = tensor([[10,  0, 10,  0]])
Predicted: 2 True: 2, Predicted with start: 10

Tokens = tensor([[10,  0,  2,  0,  2]]), Start = tensor([[10,  0, 10,  0,  2]])
Predicted: 0 True: 1, Predicted with start: 0

Tokens = tensor([[10,  0,  2,  0,  2,  1]]), Start = tensor([[10,  0, 10,  0,  2,  0]])
Predicted: 2 True: 10, Predicted with start: 2

Tokens = tensor([[10,  0,  2,  0,  2,  1, 10]]), Start = tensor([[10,  0, 10,  0,  2,  0,  2]])
Predicted: 0 True: 0, Predicted with start: 0

Tokens = tensor([[10,  0,  2,  0,  2,  1, 10,  0]]), Start = tensor([[10,  0, 10,  0,  2,  0,  2,  0]])
Predicted: 2 True: 2, Predicted with start: 2

Tokens = ten

In [114]:
def evaluate_trajectory(trajectory, original_seq):

    # Calculate cumulative reward for both
    cumulative_reward = 0
    original_cumulative_reward = 0
    for i in range(0, len(trajectory), 2):
        cumulative_reward += trajectory[i + 1]
        original_cumulative_reward += original_seq[i + 1]
    
    return cumulative_reward, original_cumulative_reward


evaluate_trajectory(start.squeeze(0).tolist(), tokens.squeeze(0).tolist())

(85, 70)

In [117]:
# Get a random batch from the test loader
tokens, probs = next(iter(test_loader))


model.eval()
with torch.no_grad():
    # Go through tokens from just one to get the predictions
    start = tokens[:, :1]
    for i in range(tokens.shape[1]-1):
        logits = model(tokens[:, :i+1])
        logits_with_previous = model(start)
        
        # Get the next prediction
        next_token = logits.argmax(-1)[:, -1]
        next_token_with_previous = logits_with_previous.argmax(-1)[:, -1]   
        # Get the true next token
        true_next_token = tokens[:, i+1]
        #print(f"Tokens = {tokens[:, :i+1]}, Start = {start}")
        #print(f"Predicted: {next_token.item()} True: {true_next_token.item()}, Predicted with start: {next_token_with_previous.item()}\n")

        # Add to start
        if (i+1) % 2 == 0:
            start = torch.cat([start, next_token], dim=1)
        else:
            # Get reward from probs
            rand_num = np.random.rand()
            # Previous action
            prev_action = start[:, -1]
            print(f"Previous action: {prev_action}") 
            reward = 1 if rand_num < probs[prev_action-2] else 0
            start = torch.cat([start, torch.tensor([reward])], dim=1)
            #start = torch.cat([start, true_next_token.unsqueeze(1)], dim=1)

# print()
# print(f"Final tokens: {tokens}")
# print(f"Final start: {start}")

Previous action: tensor([ 3,  3,  7, 10,  9,  7,  9,  5, 10,  5,  5,  6, 11,  7, 10,  5,  9, 10,
         8,  4,  2,  5,  3,  2,  6,  3,  9,  5,  7,  2, 10,  6])


RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [118]:
prev_action = torch.tensor([3,  3,  7, 10,  9,  7,  9,  5, 10,  5,  5,  6, 11,  7, 10,  5,  9, 10, 
                            8,  4,  2,  5,  3,  2,  6,  3,  9,  5,  7,  2, 10,  6])

In [624]:
import numpy as np
import torch

def create_bandit_environment(n_arms=10, high_prob_indices=2, high_prob=0.9, low_prob=0.1):
    """Creates a multi-armed bandit environment with specified probabilities."""
    prob_high_value = np.ones(n_arms) * low_prob
    indices = np.random.choice(n_arms, high_prob_indices, replace=False)
    prob_high_value[indices] = high_prob
    return prob_high_value

def generate_reward(prob_high_value, arm, high_value=10, low_value=1):
    """Generates a reward based on the probability distribution of the bandit."""
    if np.random.random() < prob_high_value[arm]:
        return high_value
    else:
        return low_value

def predict_next_action(model, start, num_steps, prob_high_value, high_value=10, low_value=1):
    """Predicts the next action using the model and updates the trajectory accordingly."""
    trajectory = start

    with torch.no_grad():
        for _ in range(num_steps):
            # Predict next action using the model
            model.eval()
            logits = model(trajectory)
            print(logits.shape)
            print(logits[:, -1, :])
            next_action = logits[:, -1, :].argmax(-1)
            # Multinomial sample from logits
            #next_action = torch.multinomial(logits[:, -1, :].softmax(-1), num_samples=1)

            # Get the reward based on the predicted action
            reward = generate_reward(prob_high_value, next_action.item()-2, high_value, low_value)
            reward_token = 0 if reward == low_value else 1

            # Update trajectory with action and reward
            trajectory = torch.cat([trajectory, torch.tensor([[next_action, reward_token]])], dim=1)
    
    return trajectory

# Example usage
n_arms = 10
prob_high_value = create_bandit_environment(n_arms)
print(prob_high_value)
rand_start_choice = np.random.randint(0, n_arms) + 2
print(rand_start_choice)
rand_start_reward = generate_reward(prob_high_value, rand_start_choice-2)
rand_start_reward = 0 if rand_start_reward == 1 else 1
print(rand_start_reward)
start = test_trajectory[:, :2] #torch.tensor([[rand_start_choice, rand_start_reward]])  # Example starting guess and reward
num_steps = 99

# Assuming 'model' is a trained transformer model
generated_trajectory = predict_next_action(model, start, num_steps, prob_high_value)

# The final generated trajectory will be in 'generated_trajectory'
generated_trajectory

[0.1 0.1 0.1 0.9 0.1 0.1 0.9 0.1 0.1 0.1]
3
0
torch.Size([1, 2, 12])
tensor([[-2.0954, -2.1159,  0.5268,  0.6583,  0.6008,  0.4590,  1.0056,  0.2833,
          0.4402,  0.5252,  1.1907,  0.3881]])
torch.Size([1, 4, 12])
tensor([[-1.9039, -2.6054,  0.7114,  0.8786,  0.6053,  0.9363,  0.8407,  0.3091,
          0.7437,  0.5741,  1.3829,  0.6372]])
torch.Size([1, 6, 12])
tensor([[-2.5277, -3.2695,  0.7161,  1.3218,  0.9452,  1.0653,  0.9351,  0.3008,
          0.9261,  0.6255,  1.9129,  0.7582]])
torch.Size([1, 8, 12])
tensor([[-2.3159, -3.6037,  0.9554,  1.4340,  0.9950,  1.3485,  0.9444,  0.2678,
          0.7177,  0.4380,  2.2705,  0.5800]])
torch.Size([1, 10, 12])
tensor([[-1.9962, -3.2587,  0.5809,  1.2515,  0.8739,  1.0422,  0.6572,  0.2181,
          0.6376,  0.9095,  2.4841,  0.6176]])
torch.Size([1, 12, 12])
tensor([[-1.9197, -3.2313,  0.7977,  1.3682,  0.8981,  0.7755,  0.4315, -0.1561,
          0.7240,  0.5988,  2.5332,  0.5654]])
torch.Size([1, 14, 12])
tensor([[-2.3106, -3.3

tensor([[10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0,
         10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0,
         10,  0, 10,  0, 10,  0, 10,  0, 10,  1, 10,  0, 10,  0, 10,  0, 10,  0,
         10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0,
         10,  0, 10,  0, 10,  1, 10,  1, 10,  0, 10,  0, 10,  0, 10,  0, 10,  1,
         10,  0, 10,  1, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0,
         10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0,
         10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0,
         10,  0, 10,  0, 10,  1,  0,  1,  0,  0, 10,  0, 10,  0, 10,  0, 10,  0,
         10,  0, 10,  1,  0,  0, 10,  1, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0,
         10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  0, 10,  1,
          0,  0]])

In [625]:
def loss_fn(logits, tokens, per_token=False):
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]


    # logits = logits[:, 1:]
    # tokens = tokens[:, :-1].long()
    logits = logits[:, :-1, :]
    tokens = tokens[:, 1:].long()
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    if per_token:
        return -correct_log_probs
    else:
        return -correct_log_probs.mean()

In [97]:
test_trajectory, _ = next(iter(test_loader))
test_trajectory = test_trajectory[0, :].unsqueeze(0)

In [102]:
logits = model(test_trajectory)
# Calculate loss
loss = loss_fn(logits, test_trajectory, per_token=True)
# Plotly lineplot
fig = px.line(loss.detach().cpu().squeeze().numpy(), labels={'x': 'Position', 'y': 'Loss'}, title='Loss at Each Position')
# Label each point on the line with the correct token from test_trajectory
fig.update_traces(text=test_trajectory.squeeze().numpy(), textposition='top right')
# Do the rolling mean
smoothed_loss = loss.detach().cpu().squeeze().numpy().reshape(-1, 199).mean(1)
smoothed_loss = np.concatenate([np.ones(9)*smoothed_loss[0], smoothed_loss])
# Add smoothed loss to the plot
fig.add_scatter(y=smoothed_loss, mode='lines', name='Smoothed Loss')
fig.show()

In [627]:
# Get rewards from the generated trajectory
rewards = generated_trajectory[:, 1::2].float()
# Moving average of rewards
moving_avg_rewards = [rewards.squeeze()[i-10:i].mean() for i in range(11, len(rewards.squeeze())+1)]
# Plotly line plot
fig = px.line(y=moving_avg_rewards, labels={'x': 'Position', 'y': 'Reward'}, title='Rewards in Generated Trajectory')
fig.show()