In [2]:
import torch
import torch.nn as nn


class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # LSTM layer
        # batch_first=True means the input and output tensors are provided as (batch, seq, feature)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

        # Fully connected layer to get the final output
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Initialize hidden state and cell state with zeros
        # h0: (num_layers * num_directions, batch_size, hidden_size)
        # c0: (num_layers * num_directions, batch_size, hidden_size)
        # These tensors will be on CPU by default, matching the model's device if not moved.
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)

        # Pass the input and hidden/cell states to the LSTM
        # out: (batch_size, seq_length, hidden_size) - output features (h_t) from the last layer of LSTM, for each t
        # hn: (num_layers * num_directions, batch_size, hidden_size) - final hidden state
        # cn: (num_layers * num_directions, batch_size, hidden_size) - final cell state
        out, (hn, cn) = self.lstm(x, (h0, c0))

        # We take the output of the last time step to make a prediction for the next step.
        # out[:, -1, :] accesses the output of all batches at the last time step.
        # Shape: (batch_size, hidden_size)
        out = self.fc(out[:, -1, :])
        return out


# Example of how to instantiate and use the model for a two-armed bandit task:

# Define model parameters for the two-armed bandit task
# Input: sequence of (previous_choice, previous_reward)
#   - Previous choice (1 or 2): one-hot encoded (e.g., choice 1 -> [1,0], choice 2 -> [0,1]). Size 2.
#   - Previous reward (0 or 1): single numerical value. Size 1.
#   - Total input_size = 2 (for choice) + 1 (for reward) = 3.
input_dim = 3
hidden_dim = 20  # Number of features in the hidden state (hyperparameter)
layer_dim = 1  # Number of LSTM layers (hyperparameter)
# Output: prediction for the next choice (logits for choice 1, choice 2). Size 2.
output_dim = 2

# Instantiate the model
model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)
print("Model Architecture:\n", model)

# Create

# Perform a forward pass
# The model is on CPU by default, and dummy_input is also on CPU
dummy_input = torch.randn(5, 10, input_dim)  # (batch_size, seq_length, input_dim)
predictions = model(dummy_input)

print("\nShape of dummy input:", dummy_input.shape)
print("Shape of predictions:", predictions.shape)

# Example for a sequence-to-one prediction (like sentiment analysis or time series forecasting one step ahead)
# If you need sequence-to-sequence (e.g., machine translation, or predicting multiple steps ahead),
# you might want to pass the entire 'out' tensor to the fully connected layer,
# possibly reshaping it first or applying the fc layer time-step-wise.
# For example: self.fc(out) if self.fc is adapted for sequence input.

Model Architecture:
 LSTMModel(
  (lstm): LSTM(3, 20, batch_first=True)
  (fc): Linear(in_features=20, out_features=2, bias=True)
)

Shape of dummy input: torch.Size([5, 10, 3])
Shape of predictions: torch.Size([5, 2])


In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np


# -------------------------------
# Simulate bandit task
# -------------------------------
def generate_bandit_data(n_trials=200, p_arm1=0.7, p_arm2=0.3):
    choices = []
    rewards = []
    for _ in range(n_trials):
        action = np.random.choice([1, 2])
        reward = np.random.rand() < (p_arm1 if action == 1 else p_arm2)
        choices.append(action)
        rewards.append(int(reward))
    return np.array(choices), np.array(rewards)


# -------------------------------
# LSTM-based Agent
# -------------------------------
class LSTMPolicy(nn.Module):
    def __init__(self, input_size=4, hidden_size=32, output_size=2):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, hidden=None):
        out, hidden = self.lstm(x, hidden)
        logits = self.fc(out)
        probs = self.softmax(logits)
        return probs, hidden


# -------------------------------
# Helper: Convert to one-hot
# -------------------------------
def one_hot(index, num_classes):
    return np.eye(num_classes)[index]


# -------------------------------
# Prepare sequence input
# -------------------------------
def prepare_input(choices, rewards):
    # Map choices 1/2 -> 0/1
    choices = choices - 1
    x = []
    for c, r in zip(choices, rewards):
        choice_oh = one_hot(c, 2)
        reward_oh = one_hot(r, 2)
        x.append(np.concatenate([choice_oh, reward_oh]))
    return torch.tensor(x, dtype=torch.float32).unsqueeze(0)  # shape: (1, seq_len, 4)


# -------------------------------
# Train LSTM Agent using REINFORCE
# -------------------------------
def train_lstm_agent(n_epochs=1000, n_trials=200, lr=1e-2):
    model = LSTMPolicy()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    all_losses = []

    for epoch in range(n_epochs):
        choices, rewards = generate_bandit_data(n_trials=n_trials)
        input_seq = prepare_input(choices, rewards)

        model.zero_grad()
        probs, _ = model(input_seq)
        log_probs = torch.log(probs.squeeze(0) + 1e-8)

        # True choices (mapped to 0/1)
        actions = torch.tensor(choices - 1, dtype=torch.long)

        # Negative log-likelihood weighted by rewards (REINFORCE-style)
        selected_log_probs = log_probs[range(n_trials), actions]
        loss = -torch.sum(
            selected_log_probs * torch.tensor(rewards, dtype=torch.float32)
        )

        loss.backward()
        optimizer.step()
        all_losses.append(loss.item())

        if (epoch + 1) % 100 == 0:
            avg_reward = np.mean(rewards)
            print(
                f"Epoch {epoch+1}, Loss: {loss.item():.2f}, Avg Reward: {avg_reward:.2f}"
            )

    return model, all_losses


# -------------------------------
# Run training
# -------------------------------
if __name__ == "__main__":
    model, losses = train_lstm_agent()

  return torch.tensor(x, dtype=torch.float32).unsqueeze(0)  # shape: (1, seq_len, 4)


Epoch 100, Loss: 0.08, Avg Reward: 0.51
Epoch 200, Loss: 0.06, Avg Reward: 0.58
Epoch 300, Loss: 0.02, Avg Reward: 0.59
Epoch 400, Loss: 0.01, Avg Reward: 0.51
Epoch 500, Loss: 0.01, Avg Reward: 0.52
Epoch 600, Loss: 0.01, Avg Reward: 0.51
Epoch 700, Loss: 0.01, Avg Reward: 0.47
Epoch 800, Loss: 0.00, Avg Reward: 0.54
Epoch 900, Loss: 0.00, Avg Reward: 0.49
Epoch 1000, Loss: 0.00, Avg Reward: 0.49


In [20]:
model.lstm.all_weights[0][1].shape

torch.Size([128, 32])

In [24]:
model.fc.weight[1].shape

torch.Size([32])

In [8]:
generate_bandit_data(200, 0.7, 0.3)

(array([1, 2, 2, 1, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 2, 1, 1,
        1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 1, 1, 1, 2,
        1, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 1, 2, 1, 2,
        1, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1,
        2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 1, 2, 1, 1, 2, 2,
        2, 2, 2, 1, 1, 2, 1, 1, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1,
        2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1,
        1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 1, 2, 1, 1,
        1, 1]),
 array([1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0,
        1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1,
        0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1,
        1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1,
        0, 1, 1, 1, 0,