<a href="https://colab.research.google.com/github/hanyu-xiao/STA410xiaoCourseProject/blob/main/STA410xiaoProject.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
pip install torch torchbnn numpy matplotlib

Collecting torchbnn
  Downloading torchbnn-1.2-py3-none-any.whl.metadata (7.1 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting n

In [8]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchbnn as bnn
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

class Maze:
    def __init__(self, size=10):
        self.size = size
        self.grid = np.zeros((size, size))
        self.goal = (size-1, size-1)
        self.reset()

    def reset(self):
        self.agent_pos = (0, 0)
        return self.agent_pos

    def step(self, action):
        x, y = self.agent_pos
        if action == 0 and x > 0:           x -= 1
        elif action == 1 and y < self.size-1: y += 1
        elif action == 2 and x < self.size-1: x += 1
        elif action == 3 and y > 0:           y -= 1

        self.agent_pos = (x, y)
        done = (self.agent_pos == self.goal)
        reward = 1 if done else -0.01
        return self.agent_pos, reward, done

class BayesianDQN(torch.nn.Module):
    def __init__(self, state_dim=2, action_dim=4):
        super().__init__()
        self.fc1 = bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=state_dim, out_features=32)
        self.fc2 = bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=32, out_features=action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def train_and_animate(maze, episodes=500, early_stop=True):
    model = BayesianDQN()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = bnn.BKLLoss(reduction='mean', last_layer_only=False)

    # Store uncertainty maps and metrics
    uncertainty_history = []
    episode_metrics = []
    best_steps = float('inf')
    patience = 20
    no_improve = 0

    for ep in range(episodes):
        state = maze.reset()
        done = False
        steps = 0

        while not done:
            state_tensor = torch.FloatTensor(state)
            q_values = model(state_tensor)
            action = torch.argmax(q_values).item()
            next_state, reward, done = maze.step(action)

            target = reward + 0.99 * torch.max(model(torch.FloatTensor(next_state)))
            loss = torch.nn.MSELoss()(q_values[action], target) + criterion(model)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            steps += 1
            state = next_state

        # Early stopping logic
        if early_stop:
            if steps < best_steps:
                best_steps = steps
                no_improve = 0
            else:
                no_improve += 1

            if no_improve >= patience:
                print(f"Early stopping at episode {ep} (no improvement for {patience} episodes)")
                break

        # Store uncertainty map every N episodes
        if ep % 5 == 0:
            uncertainty_map = generate_uncertainty_map(model, maze)
            uncertainty_history.append(uncertainty_map)
            episode_metrics.append((ep, steps, best_steps))
            print(f"Episode {ep}: Steps={steps}, Best={best_steps}")

    # Create and return animation
    return create_animation(uncertainty_history, episode_metrics, maze.size)

def generate_uncertainty_map(model, maze):
    uncertainty_map = np.zeros((maze.size, maze.size))
    for x in range(maze.size):
        for y in range(maze.size):
            state = torch.FloatTensor([x, y])
            q_samples = [model(state).detach().numpy() for _ in range(10)]
            uncertainty_map[x, y] = np.std(q_samples, axis=0).mean()
    return uncertainty_map

def create_animation(uncertainty_history, episode_metrics, maze_size):
    fig, ax = plt.subplots(figsize=(8, 6))
    plt.close()  # Prevents duplicate display

    # Create initial plot
    im = ax.imshow(uncertainty_history[0], cmap='hot', vmin=0, vmax=1)
    cbar = fig.colorbar(im, ax=ax, label='Uncertainty (Std Dev)')
    title = ax.set_title(f'Episode {episode_metrics[0][0]}\nSteps: {episode_metrics[0][1]}, Best: {episode_metrics[0][2]}')

    def update(frame):
        im.set_array(uncertainty_history[frame])
        title.set_text(f'Episode {episode_metrics[frame][0]}\nSteps: {episode_metrics[frame][1]}, Best: {episode_metrics[frame][2]}')
        return im, title

    ani = FuncAnimation(
        fig,
        update,
        frames=len(uncertainty_history),
        interval=300,
        blit=True
    )

    return HTML(ani.to_jshtml())

if __name__ == "__main__":
    maze = Maze(size=10)
    animation = train_and_animate(maze)
    display(animation)  # For Jupyter notebook display

Episode 0: Steps=173, Best=173
Episode 5: Steps=101, Best=101
Episode 10: Steps=840, Best=28
Episode 15: Steps=25, Best=25
Episode 20: Steps=259, Best=25
Episode 25: Steps=242, Best=25
Episode 30: Steps=130, Best=25
Early stopping at episode 35 (no improvement for 20 episodes)


In [6]:
ani

NameError: name 'ani' is not defined

In [None]:
# In scripts: Use ani.save('uncertainty_evolution.mp4')