In [2]:
from melee_dataset import MeleeDataset
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch

In [3]:
train_dataset = MeleeDataset(data_path="data/train_mini_515")
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

In [4]:
print(f"obs_dim: {train_dataset[0][0].shape}")
print(f"act_dim: {train_dataset[0][1].shape}")

obs_dim: torch.Size([54])
act_dim: torch.Size([17])


In [5]:
class PolicyNet(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 64), nn.ReLU(),
            nn.Linear(64, 64),      nn.ReLU(),
            nn.Linear(64, act_dim)
        )

    def forward(self, x):
        return self.net(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
obs_dim = train_dataset[0][0].shape[0]
act_dim = train_dataset[0][1].shape[0]
policy  = PolicyNet(obs_dim, act_dim).to(device)
opt     = optim.Adam(policy.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()  

Using device: cpu


In [6]:
losses = []
for epoch in range(5):
    total_loss = 0.0
    for batch in tqdm(train_loader):
        batch_states = batch[0].to(device)
        batch_actions = batch[1].to(device)
        pred = policy(batch_states)
        loss = loss_fn(pred, batch_actions)
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss += loss.item() * batch_states.size(0)
        losses.append(loss.item())
    avg_loss = total_loss / len(train_dataset)
    print(f"Epoch {epoch+1} — Loss: {avg_loss:.4f}")

100%|██████████| 47239/47239 [25:01<00:00, 31.45it/s]  


Epoch 1 — Loss: 0.0198


100%|██████████| 47239/47239 [04:23<00:00, 179.37it/s]


Epoch 2 — Loss: 0.0053


100%|██████████| 47239/47239 [05:01<00:00, 156.52it/s]


Epoch 3 — Loss: 0.0051


100%|██████████| 47239/47239 [05:33<00:00, 141.60it/s]


Epoch 4 — Loss: 0.0051


100%|██████████| 47239/47239 [05:39<00:00, 139.26it/s]

Epoch 5 — Loss: 0.0050





In [8]:
import matplotlib.pyplot as plt

plt.yscale('log')
plt.plot(range(len(losses)), losses)

[<matplotlib.lines.Line2D at 0x223952a3df0>]

: 

In [7]:
# Save the trained model
torch.save(policy.state_dict(), "trained_policy_pog.pth")
print("Model saved to trained_policy.pth")


Model saved to trained_policy.pth
