# World Model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
import numpy as np
import pandas as pd
import flappy_bird_gymnasium
import matplotlib.pyplot as plt
import random
import gymnasium as gym

## Define Model

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import os

# ----- World Model -----
class WorldModel(nn.Module):
    def __init__(self, state_dim=12, action_dim=1, hidden_dim=128):
        super(WorldModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim + 1)  # next_state + reward
        )

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        return self.model(x)

# ----- Train World Model -----
def train_world_model(model, data, epochs=100, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    for epoch in range(epochs):
        print(epoch)
        total_loss = 0
        for (s, a, next_s, r) in data:
            s_tensor = torch.tensor(s, dtype=torch.float32).unsqueeze(0)
            a_tensor = torch.tensor([a], dtype=torch.float32).unsqueeze(0)
            next_s = np.append(next_s, r)
            target = torch.tensor(next_s, dtype=torch.float32).unsqueeze(0)
            output = model(s_tensor, a_tensor)

            loss = loss_fn(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {total_loss:.4f}")

# ----- Save & Load -----
def save_model(model, path="saved_policies/world_model.pth"):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

def load_model(model, path="saved_policies/world_model.pth"):
    model.load_state_dict(torch.load(path))
    model.eval()
    print(f"Model loaded from {path}")
    return model

# ----- Simulated Controller Example -----
def simulate(model, start_state, steps=10):
    state = torch.tensor(start_state, dtype=torch.float32).unsqueeze(0)

    done = False
    for _ in range(steps):
        action = torch.tensor([[0.0]])  # Try both 0.0 and 1.0 for exploration
        prediction = model(state, action)
        next_state = prediction[0, :-1].detach().numpy()
        reward = prediction[0, -1].item()
        print(f"Simulated Reward: {reward:.3f}")

        if reward < -1: 
            done = True
            break

        state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)

    return next_state, reward, done, False, {}


def select_best_action(model, current_state):
    model.eval()
    state_tensor = torch.tensor(current_state, dtype=torch.float32).unsqueeze(0)

    best_action = None
    best_score = -float('inf')

    for action_val in [0.0, 1.0]:  # Explore both actions
        action_tensor = torch.tensor([[action_val]], dtype=torch.float32)
        prediction = model(state_tensor, action_tensor)
        next_state_pred = prediction[0, :-1].detach().numpy()
        reward_pred = prediction[0, -1].item()

        # Simple scoring: reward only (can be replaced with a value function)
        score = reward_pred  

        if score > best_score:
            best_score = score
            best_action = int(action_val)

    return best_action, best_score

## Train Virtual World

In [None]:
env = gym.make("FlappyBird-v0", render_mode=None, use_lidar=False)

sample_size = 100_0

dummy_data = []


total_reward = 0
for i in range(sample_size):
    done = False
    
    obs, _ = env.reset()

    while not done:

        action = env.action_space.sample()

        next_obs, reward, done, truncated, info = env.step(action)
        dummy_data.append((obs, action, next_obs, reward))

        obs = next_obs

    if i % 100 == 0:
        print(f"Sample {i}, Avg reward (past 100): {total_reward/100:.2f}")
        total_reward = 0


env.close()

print(len(dummy_data))

model = WorldModel()
train_world_model(model, dummy_data, epochs=50)
save_model(model)

1829
0
Epoch 0, Loss: 33.3325
1
2
3
4
5
6
7
8
9
10
Epoch 10, Loss: 13.4931
11
12
13
14
15
16
17
18
19
20
Epoch 20, Loss: 11.3329
21
22
23
24
25
26
27
28
29
30
Epoch 30, Loss: 9.8330
31
32
33
34
35
36
37
38
39
40
Epoch 40, Loss: 9.1914
41
42
43
44
45
46
47
48
49
Model saved to saved_policies/world_model.pth


## Let the bot play the game

In [None]:
env = gym.make("FlappyBird-v0", render_mode="human", use_lidar=False)

# Load and simulate
loaded_model = WorldModel()
model = load_model(loaded_model)
model.eval()


# Use the trained model
obs, _ = env.reset()

done = False
while not done:

    action, score = select_best_action(model, obs)
    # print(action)

    obs, reward, done, truncated, info = env.step(action)
    # print(f"Action {action} gave {reward} reward...")
    # print(obs, reward, done, truncated, info)

env.close()


Model loaded from saved_policies/world_model.pth
0
1
0
0
1
0
0
1
0
0
1
0
0
0
1
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
