In [None]:
import torch
import torch.nn as nn 
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
import numpy as np
import gymnasium as gym
import gymnasium_robotics
import multiprocessing
import pickle

NUM_EPISODES = 5000
MAX_STEPS = 300 

gym.register_envs(gymnasium_robotics)

env = gym.make(id='FrankaKitchen-v1', 
               max_episode_steps=MAX_STEPS,
               tasks_to_complete=['kettle'])

In [45]:
total_capacity = NUM_EPISODES * MAX_STEPS
X_data = np.zeros((total_capacity, 68), dtype=np.float32)
y_data = np.zeros((total_capacity, 59), dtype=np.float32)

In [None]:
def collect_data_worker(num_episodes_for_this_worker):
    env = gym.make('FrankaKitchen-v1', tasks_to_complete=["kettle"], max_episode_steps=MAX_STEPS)
    
    local_X = []
    local_y = []
    
    for episode in range(num_episodes_for_this_worker):
        obs_dict, _ = env.reset()
        state = obs_dict['observation']
        
        for step in range(MAX_STEPS):
            action = env.action_space.sample()
            
            next_obs_dict, _, term, trunc, _ = env.step(action)
            next_state = next_obs_dict['observation']
            
            # X: State + Action
            # y: Delta (Next - Current)
            local_X.append(np.concatenate([state, action]))
            local_y.append(next_state - state)
            
            state = next_state
            
            if term or trunc:
                print(f"Episode {episode} ended at step {step}. Terminated={term}, Truncated={trunc}")
                break
                
    env.close()

    return np.array(local_X, dtype=np.float32), np.array(local_y, dtype=np.float32)

In [50]:
num_workers = 16

print(f"Starting collection on {num_workers} CPU cores...")

chunk_size = NUM_EPISODES // num_workers
tasks = [chunk_size] * num_workers

tasks[-1] += NUM_EPISODES % num_workers

print(f"Work distribution: {tasks}")

with multiprocessing.Pool(processes=num_workers) as pool:
    results = pool.map(collect_data_worker, tasks)

Starting collection on 16 CPU cores...
Work distribution: [312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 320]
Episode 0 ended at step 114. Terminated=True, Truncated=False
Episode 1 ended at step 112. Terminated=True, Truncated=False
Episode 0 ended at step 299. Terminated=False, Truncated=True
Episode 0 ended at step 299. Terminated=False, Truncated=True
Episode 0 ended at step 299. Terminated=False, Truncated=True
Episode 0 ended at step 299. Terminated=False, Truncated=True
Episode 0 ended at step 299. Terminated=False, Truncated=True
Episode 0 ended at step 299. Terminated=False, Truncated=True
Episode 0 ended at step 299. Terminated=False, Truncated=True
Episode 0 ended at step 299. Terminated=False, Truncated=True
Episode 0 ended at step 299. Terminated=False, Truncated=TrueEpisode 0 ended at step 299. Terminated=False, Truncated=True

Episode 0 ended at step 299. Terminated=False, Truncated=True
Episode 0 ended at step 299. Terminated=False, Truncate

In [None]:
all_X = np.concatenate([r[0] for r in results])
all_y = np.concatenate([r[1] for r in results])

print(f"Total Data Collected: X shape {all_X.shape}, y shape {all_y.shape}")

data_to_save = {
    "X": all_X,
    "y": all_y
}

with open("data.pkl", "wb") as f:
    pickle.dump(data_to_save, f)

Total Data Collected: X shape (1464561, 68), y shape (1464561, 59)


In [None]:
class DynamicsModel(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        # Input: Current State + Action
        # Output: Predicted change in state (Delta)
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, state_dim)
        )

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        delta = self.net(x)
        return state + delta

In [None]:
DEVICE = "cuda"
LEARNING_RATE = 1e-4
EPOCHS = 30

x_tensor = torch.from_numpy(all_X)
y_tensor = torch.from_numpy(all_y)

full_dataset = TensorDataset(x_tensor, y_tensor)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

dataset = TensorDataset(x_tensor, y_tensor)

train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,     
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,     
    num_workers=4,
    pin_memory=True
)

model = DynamicsModel(59, 9).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()

In [64]:
for epoch in range(EPOCHS):
    model.train() 
    running_train_loss = 0.0
    
    for x_batch, y_batch in train_loader:
        x_batch = x_batch.to(DEVICE, non_blocking=True)
        y_batch = y_batch.to(DEVICE, non_blocking=True)

        state_batch = x_batch[:, :59] 
        action_batch = x_batch[:, 59:]

        optimizer.zero_grad()

        preds = model(state_batch, action_batch) 
        
        loss = loss_fn(preds, y_batch)
    
        loss.backward()
    
        optimizer.step()
        
        running_train_loss += loss.item()
        
    avg_train_loss = running_train_loss / len(train_loader)

    model.eval()
    running_val_loss = 0.0
    
    with torch.no_grad():
        for x_val, y_val in val_loader:
            x_val = x_val.to(DEVICE, non_blocking=True)
            y_val = y_val.to(DEVICE, non_blocking=True)
            
            state_val = x_val[:, :59]
            action_val = x_val[:, 59:]
            
            val_preds = model(state_val, action_val)
            val_loss = loss_fn(val_preds, y_val)
            
            running_val_loss += val_loss.item()
            
    avg_val_loss = running_val_loss / len(val_loader)
    
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}")

Epoch 1/30 | Train Loss: 0.002266 | Val Loss: 0.002232
Epoch 2/30 | Train Loss: 0.002210 | Val Loss: 0.002189
Epoch 3/30 | Train Loss: 0.002163 | Val Loss: 0.002133
Epoch 4/30 | Train Loss: 0.002125 | Val Loss: 0.002110
Epoch 5/30 | Train Loss: 0.002096 | Val Loss: 0.002082
Epoch 6/30 | Train Loss: 0.002076 | Val Loss: 0.002082
Epoch 7/30 | Train Loss: 0.002063 | Val Loss: 0.002044
Epoch 8/30 | Train Loss: 0.002053 | Val Loss: 0.002057
Epoch 9/30 | Train Loss: 0.002043 | Val Loss: 0.002048
Epoch 10/30 | Train Loss: 0.002035 | Val Loss: 0.002029
Epoch 11/30 | Train Loss: 0.002026 | Val Loss: 0.002027
Epoch 12/30 | Train Loss: 0.002017 | Val Loss: 0.001995
Epoch 13/30 | Train Loss: 0.002009 | Val Loss: 0.002012
Epoch 14/30 | Train Loss: 0.001998 | Val Loss: 0.002003
Epoch 15/30 | Train Loss: 0.001988 | Val Loss: 0.001984
Epoch 16/30 | Train Loss: 0.001976 | Val Loss: 0.001963
Epoch 17/30 | Train Loss: 0.001961 | Val Loss: 0.001982
Epoch 18/30 | Train Loss: 0.001946 | Val Loss: 0.001934
E