In [1]:
import json
import sys
import os
# Add the parent directory to sys.path
sys.path.append(os.path.abspath('..'))
from src.tile_definitions import TILE_MAPPING

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau

# Dataset Class

In [2]:
class Level1GameDataset(Dataset):
    def __init__(self, json_file):
        with open(json_file, "r") as f:
            self.data = json.load(f)
            
        print(f"Loaded {len(self.data)} gameplay samples from {json_file}")
        
        # Action to integer mapping
        self.action_mapping = {"UP": 0, "DOWN": 1, "LEFT": 2, "RIGHT": 3}

        # Create reverse mapping (tile_type -> ID) from TILE_MAPPING
        self.tile_type_to_id = {}
        for tile_id, (tile_type, _, _, _, _) in TILE_MAPPING.items():
            self.tile_type_to_id[tile_type] = tile_id
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        state = sample["state"]
        action = self.action_mapping[sample["action"]]
        
        # Extract state components
        position = torch.tensor(state["position"], dtype=torch.float32)
        chips_collected = torch.tensor([state["player_collected_chips"]], dtype=torch.float32)
        total_chips_collected = torch.tensor([state["total_collected_chips"]], dtype=torch.float32)
        socket_unlocked = torch.tensor([int(state["socket_unlocked"])], dtype=torch.float32)
        nearest_chip = torch.tensor(state["nearest_chip"], dtype=torch.float32)
        exit_location = torch.tensor(state["exit_position"], dtype=torch.float32)
        
        # Process the full grid
        full_grid = []
        for row in state["full_grid"]:
            processed_row = []
            for tile_type in row:
                # Map to integer using the updated tile definitions
                processed_row.append(self.tile_type_to_id.get(tile_type, 1))
            full_grid.append(processed_row)
        
        full_grid_tensor = torch.tensor(full_grid, dtype=torch.float32)
        
        local_grid = []
        for row in state["local_grid"]:
            processed_row = []
            for tile_type in row:
                # Map to integer using the updated tile definitions
                processed_row.append(self.tile_type_to_id.get(tile_type, 1))
            local_grid.append(processed_row)
        
        local_grid_tensor = torch.tensor(local_grid, dtype=torch.float32)
        
        # Additional state information
        alive = torch.tensor([float(state.get("alive", True))], dtype=torch.float32)
        remaining_chips = torch.tensor([state.get("remaining_chips", 0)], dtype=torch.float32)
        other_player_pos = torch.tensor(state.get("other_player_position", [-1, -1]), dtype=torch.float32)
        
        # Concatenate all state information into a single vector
        state_vector = torch.cat([
            position,
            chips_collected, 
            total_chips_collected,
            socket_unlocked,
            nearest_chip,
            exit_location,
            full_grid_tensor.flatten(),
            local_grid_tensor.flatten(),
            alive,
            remaining_chips,
            other_player_pos
        ])
        
        return state_vector, torch.tensor(action, dtype=torch.long)

## Load Dataset and Test

In [38]:
dataset = Level1GameDataset("../data/human_play_data_level0.json")
sample_vector, sample_action = dataset[0]  
print(f"🔢 Sample Vector: {sample_vector.size()}")
print(f"📊 Loaded {len(dataset)} samples.")

Loaded 10052 gameplay samples from ../data/human_play_data_level0.json
🔢 Sample Vector: torch.Size([191])
📊 Loaded 10052 samples.


In [39]:
# Split into train and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Behavior Cloning Model

In [5]:
import torch.nn as nn
import torch.optim as optim

In [55]:
# FNC Network for behavior cloning
class BehaviorCloningModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(BehaviorCloningModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),                    
            nn.Linear(128, output_size)
        )

    def forward(self, x):
        return self.fc(x)

In [56]:
input_size = len(dataset[0][0])  # Size of State Vector
output_size = 4  # Possible Actions (UP, DOWN, LEFT, RIGHT)
model = BehaviorCloningModel(input_size, output_size)
print(f"input_size {input_size}, output_size: {output_size}")

input_size 191, output_size: 4


## Training Loop

In [57]:
training_config = {
    "epochs": 300,
    "model_name": "BehaviorCloningModel",
    "input_size": input_size,
    "output_size": output_size,
    "optimizer": "Adam",
    "learning_rate": 0.001,
    "batch_size": 32,
    "criterion": "CrossEntropyLoss",
    "activation": "ReLU",
    "scheduler": "ReduceLROnPlateau"
}

In [58]:
import wandb
wandb.login(key="294ac5de6babc54da53b9aadb344b3bb173b314d")
# change name for each run
wandb.init(project="bc_surrogate_partner_lv1", name="level1_run_3.8", config=training_config)
wandb.config.update(training_config)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/neo/.netrc


In [59]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=training_config["learning_rate"])
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, verbose=True)



In [60]:
# Training loop
for epoch in range(training_config["epochs"]):
    # Training phase
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for state_vectors, actions in train_loader:
        optimizer.zero_grad()
        outputs = model(state_vectors)
        loss = criterion(outputs, actions)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_total += actions.size(0)
        train_correct += (predicted == actions).sum().item()
    
    train_loss = train_loss / len(train_loader)
    train_accuracy = 100 * train_correct / train_total
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for state_vectors, actions in val_loader:
            outputs = model(state_vectors)
            loss = criterion(outputs, actions)
            
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += actions.size(0)
            val_correct += (predicted == actions).sum().item()
    
    val_loss = val_loss / len(val_loader)
    val_accuracy = 100 * val_correct / val_total
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Log metrics
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_accuracy": train_accuracy,
        "val_loss": val_loss,
        "val_accuracy": val_accuracy,
        "learning_rate": optimizer.param_groups[0]["lr"]
    })
    
    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{training_config['epochs']}, "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%, "
              f"LR: {optimizer.param_groups[0]['lr']:.6f}")

# Close wandb
wandb.finish()

Epoch 10/300, Train Loss: 0.7237, Train Acc: 71.16%, Val Loss: 0.7280, Val Acc: 71.76%, LR: 0.001000
Epoch 20/300, Train Loss: 0.4443, Train Acc: 82.75%, Val Loss: 0.5123, Val Acc: 82.65%, LR: 0.001000
Epoch 30/300, Train Loss: 0.3816, Train Acc: 85.19%, Val Loss: 0.4793, Val Acc: 85.83%, LR: 0.001000
Epoch 40/300, Train Loss: 0.3716, Train Acc: 85.39%, Val Loss: 0.4762, Val Acc: 86.47%, LR: 0.001000
Epoch 50/300, Train Loss: 0.3371, Train Acc: 86.85%, Val Loss: 0.7341, Val Acc: 81.10%, LR: 0.001000
Epoch 60/300, Train Loss: 0.2785, Train Acc: 88.72%, Val Loss: 0.5645, Val Acc: 87.02%, LR: 0.000500
Epoch 70/300, Train Loss: 0.2697, Train Acc: 88.63%, Val Loss: 0.5673, Val Acc: 87.92%, LR: 0.000500
Epoch 80/300, Train Loss: 0.2360, Train Acc: 89.89%, Val Loss: 0.6508, Val Acc: 87.37%, LR: 0.000250
Epoch 90/300, Train Loss: 0.2324, Train Acc: 89.80%, Val Loss: 0.6703, Val Acc: 87.77%, LR: 0.000250
Epoch 100/300, Train Loss: 0.2132, Train Acc: 90.55%, Val Loss: 0.7465, Val Acc: 87.42%, LR

0,1
epoch,▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
learning_rate,█████████▄▄▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▅▆▇▇▇▇▇▇███████████████████████████████
train_loss,█▇▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▅▇▇▇█▇█████████████████████████████████
val_loss,█▂▁▁▁▁▁▂▂▂▂▃▃▄▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇

0,1
epoch,300.0
learning_rate,0.0
train_accuracy,92.02835
train_loss,0.17822
val_accuracy,86.72302
val_loss,0.97292


In [61]:
# Model save directory
model_path = "../model/lv1_bc_model_3.8.pth"
# Save model
torch.save(model.state_dict(), model_path)
print("Model saved to " + model_path)

Model saved to ../model/lv1_bc_model_3.8.pth


In [24]:
def predict_action(model, state):
    model.eval()
    with torch.no_grad():
        state_vector = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        output = model(state_vector)
        action_idx = torch.argmax(output).item()
    
    action_mapping = {0: "UP", 1: "DOWN", 2: "LEFT", 3: "RIGHT"}
    return action_mapping[action_idx]

test_state, _ = dataset[5]  # Check the first sample in the dataset
predicted_action = predict_action(model, test_state)
print(f"AI Predicted Action: {predicted_action}")

AI Predicted Action: LEFT


  state_vector = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
