In [7]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json

# Enable MPS fallback for unsupported operations
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

class SequenceDataset(Dataset):
    def __init__(self, data, context_length=5):
        self.observations = np.array(data["observations"], dtype=np.float32)
        self.actions = np.array(data["actions"], dtype=np.float32)
        self.dones = np.array(data["dones"])
        self.context_length = context_length
        
        # Create valid indices list
        self.valid_indices = []
        start_idx = 0
        for i, done in enumerate(self.dones):
            if i >= start_idx + context_length - 1:
                self.valid_indices.append(i)
            if done:
                start_idx = i + 1

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        if idx >= len(self.valid_indices):
            raise IndexError(f"Index {idx} out of bounds for dataset of size {len(self)}")
            
        end_idx = self.valid_indices[idx]
        start_idx = end_idx - self.context_length + 1
        
        obs_seq = self.observations[start_idx:end_idx+1]
        action_target = self.actions[end_idx]
        
        return torch.tensor(obs_seq), torch.tensor(action_target)

class TransformerPolicy(nn.Module):
    def __init__(self, input_dim=6, action_dim=3, hidden_dim=128, num_layers=3):
        super().__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=4,
            dim_feedforward=hidden_dim*4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.action_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, action_dim),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer(x)
        x = x[:, -1, :]
        return 5.0 * self.action_head(x)

def train():
    # Check MPS availability
    if not torch.backends.mps.is_available():
        if not torch.backends.mps.is_built():
            print("MPS not available because the current PyTorch install was not built with MPS enabled.")
        else:
            print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device.")
        return
    
    device = torch.device("mps")
    
    with open("expert_demonstrations.json", "r") as f:
        data = json.load(f)
    
    dataset = SequenceDataset(data)
    dataloader = DataLoader(
        dataset,
        batch_size=128,
        shuffle=True,
        num_workers=0  # Required for MPS compatibility
    )
    
    model = TransformerPolicy().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    criterion = nn.MSELoss()

    for epoch in range(100):
        model.train()
        total_loss = 0
        for obs_seq, actions in dataloader:
            obs_seq = obs_seq.to(device)
            actions = actions.to(device)
            
            optimizer.zero_grad()
            pred_actions = model(obs_seq)
            loss = criterion(pred_actions, actions)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")
    
    torch.save(model.state_dict(), "transformer_policy_mps.pth")

In [8]:
if __name__ == "__main__":
    train()

Epoch 1, Loss: 0.1965
Epoch 2, Loss: 0.0206
Epoch 3, Loss: 0.0136
Epoch 4, Loss: 0.0109
Epoch 5, Loss: 0.0089
Epoch 6, Loss: 0.0081
Epoch 7, Loss: 0.0073
Epoch 8, Loss: 0.0065
Epoch 9, Loss: 0.0060
Epoch 10, Loss: 0.0057
Epoch 11, Loss: 0.0051
Epoch 12, Loss: 0.0047
Epoch 13, Loss: 0.0047
Epoch 14, Loss: 0.0041
Epoch 15, Loss: 0.0040
Epoch 16, Loss: 0.0039
Epoch 17, Loss: 0.0034
Epoch 18, Loss: 0.0033
Epoch 19, Loss: 0.0032
Epoch 20, Loss: 0.0030
Epoch 21, Loss: 0.0028
Epoch 22, Loss: 0.0025
Epoch 23, Loss: 0.0025
Epoch 24, Loss: 0.0025
Epoch 25, Loss: 0.0023
Epoch 26, Loss: 0.0022
Epoch 27, Loss: 0.0021
Epoch 28, Loss: 0.0019
Epoch 29, Loss: 0.0019
Epoch 30, Loss: 0.0018
Epoch 31, Loss: 0.0017
Epoch 32, Loss: 0.0016
Epoch 33, Loss: 0.0016
Epoch 34, Loss: 0.0016
Epoch 35, Loss: 0.0015
Epoch 36, Loss: 0.0015
Epoch 37, Loss: 0.0014
Epoch 38, Loss: 0.0013
Epoch 39, Loss: 0.0012
Epoch 40, Loss: 0.0012
Epoch 41, Loss: 0.0012
Epoch 42, Loss: 0.0012
Epoch 43, Loss: 0.0012
Epoch 44, Loss: 0.00