In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal

class BayesianNeuralNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128):
        super(BayesianNeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3_mu = nn.Linear(hidden_dim, output_dim)
        self.fc3_log_sigma = nn.Linear(hidden_dim, output_dim)

        # Prior for weights and biases (e.g., Gaussian prior)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mu = self.fc3_mu(x)
        log_sigma = self.fc3_log_sigma(x)
        sigma = torch.exp(log_sigma)
        return mu, sigma

class BayesianDQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=128, lr=1e-3, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

        self.model = BayesianNeuralNetwork(state_dim, action_dim, hidden_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return np.random.randint(self.action_dim)
        else:
            state = torch.FloatTensor(state)
            mu, sigma = self.model(state)
            # Sample Q-values from the posterior distribution
            q_values = Normal(mu, sigma).sample()
            return torch.argmax(q_values).item()

    def learn(self, state, action, reward, next_state, done):
        state = torch.FloatTensor(state)
        next_state = torch.FloatTensor(next_state)
        action = torch.LongTensor([action])
        reward = torch.FloatTensor([reward])
        done = torch.FloatTensor([done])

        # Calculate target Q-value
        mu, sigma = self.model(next_state)
        # Sample next Q-value from the posterior distribution
        next_q_value = Normal(mu, sigma).sample()
        target_q_value = reward + (1 - done) * self.gamma * torch.max(next_q_value)

        # Calculate expected Q-value
        mu, sigma = self.model(state)
        q_value = Normal(mu, sigma).sample()
        expected_q_value = q_value.gather(1, action)

        # Calculate loss (e.g., using expected loss)
        loss = F.mse_loss(expected_q_value, target_q_value.unsqueeze(1))

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

# Example usage with a simple environment
state_dim = 10
action_dim = 4
agent = BayesianDQNAgent(state_dim, action_dim)

# Replace this with your actual environment interaction loop
for episode in range(1000):
    state = np.zeros(state_dim)  # Initial state
    done = False
    total_reward = 0

    while not done:
        action = agent.act(state)
        # next_state, reward, done = env.step(action)  # Interact with your environment
        # Replace the line above with your environment's step function

        # Example dummy environment
        next_state = np.zeros(state_dim)
        next_state[(state == 1).argmax() + 1] = 1
        reward = 1 if next_state[-1] == 1 else 0
        done = next_state[-1] == 1

        agent.learn(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward

    print(f"Episode: {episode + 1}, Total Reward: {total_reward}")


ModuleNotFoundError: No module named 'torch'