In [13]:
from melee_dataset import MeleeDataset
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
from torch.distributions import Categorical
from torch.distributions import Bernoulli,Normal
import matplotlib.pyplot as plt
from PolicyNet import PolicyNet
import torch.nn.functional as F

In [9]:
train_dataset = MeleeDataset(data_path="data/train_mini_515")
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

In [10]:
print(f"obs_dim: {train_dataset[0][0].shape}")
print(f"act_dim: {train_dataset[0][1].shape}")

obs_dim: torch.Size([54])
act_dim: torch.Size([17])


In [16]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
obs_dim = train_dataset[0][0].shape[0]
act_dim = train_dataset[0][1].shape[0]
policy  = PolicyNet(obs_dim, act_dim).to(device)
opt     = optim.Adam(policy.parameters(), lr=1e-3)

Using device: cpu


In [None]:
losses = []
for epoch in range(6):
    total_loss = 0.0
    for states, actions in train_loader:
        states = states.to(device)       # [B, obs_dim]
        actions = actions.to(device)     # [B] integers in [0…act_dim-1]

        mu, logstd = policy(states)               # mu: [B,17], logstd: [17]
        logstd = torch.clamp(logstd, min=-20, max=2)
        std    = F.softplus(logstd) + 1e-6  

        dist = Normal(loc=mu, scale=std)
        
        # log_prob has shape [B], one log‐prob per sample
        logp    = dist.log_prob(actions).sum(dim=-1)

        # 4) negative log-likelihood
        loss    = -logp.mean()

        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss += loss.item() * states.size(0)
        losses.append(loss.item())
    avg_loss = total_loss / len(train_dataset)
    print(f"Epoch {epoch+1} — Loss: {avg_loss:.4f}")

Epoch 1 — Loss: -136.6079
Epoch 2 — Loss: -136.6132
Epoch 3 — Loss: -136.5927


In [None]:
# Save the trained model
torch.save(policy.state_dict(), "trained_policy_distribution_combined.pth")
print("Model saved to trained_policy.pth")


Model saved to trained_policy.pth


: 

In [13]:


plt.yscale('log')
plt.plot(range(len(losses)), losses)

[<matplotlib.lines.Line2D at 0x29cf0c65960>]

: 