In [1]:
import time
import torch

from torch import nn, Tensor
from torch.utils.data import DataLoader

# flow_matching
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.path import AffineProbPath
from flow_matching.solver import Solver, ODESolver
from flow_matching.utils import ModelWrapper

# visualization
import matplotlib.pyplot as plt

from matplotlib import cm


# To avoide meshgrid warning
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module='torch')


In [2]:
if torch.cuda.is_available():
    device = 'cuda:0'
    print('Using gpu')
else:
    device = 'cpu'
    print('Using cpu.')
torch.manual_seed(42)

Using cpu.


<torch._C.Generator at 0x10eae7590>

In [3]:
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        return torch.sigmoid(x) * x


# TODO: need to resolve temporal locality problem maybe with a CNN later.
class MLP(nn.Module):
    def __init__(self, input_dim: int, time_dim: int = 1, hidden_dim: int = 128):
        super().__init__()

        self.input_dim = input_dim
        self.time_dim = time_dim
        self.hidden_dim = hidden_dim

        self.main = nn.Sequential(
            nn.Linear(input_dim + time_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x: Tensor, t: Tensor) -> Tensor:
        sz = x.size()
        x = x.reshape(-1, self.input_dim)
        t = t.reshape(-1, self.time_dim).float()

        t = t.reshape(-1, 1).expand(x.shape[0], 1)
        h = torch.cat([x, t], dim=1)
        output = self.main(h)

        return output.reshape(*sz)

In [60]:
def collate_fn(batch):
    observations = [torch.as_tensor(x.observations) for x in batch]
    actions = [torch.as_tensor(x.actions) for x in batch]
    rewards = [torch.as_tensor(x.rewards) for x in batch]
    terminations = [torch.as_tensor(x.terminations) for x in batch]
    truncations = [torch.as_tensor(x.truncations) for x in batch]
    episode_lengths = torch.tensor([len(x.actions) for x in batch], dtype=torch.long)

    return {
        "id": torch.Tensor([x.id for x in batch]),
        "observations": torch.nn.utils.rnn.pad_sequence(
            observations,
            batch_first=True
        ),
        "actions": torch.nn.utils.rnn.pad_sequence(
            actions,
            batch_first=True
        ),
        "rewards": torch.nn.utils.rnn.pad_sequence(
            rewards,
            batch_first=True
        ),
        "terminations": torch.nn.utils.rnn.pad_sequence(
            terminations,
            batch_first=True
        ),
        "truncations": torch.nn.utils.rnn.pad_sequence(
            truncations,
            batch_first=True
        ),
        "episode_lengths": episode_lengths
    }

In [69]:
def create_trajectory_chunks(batch, horizon):
    """
    Processes a padded batch to create fixed-size trajectory chunks.
    """
    batch_size = batch['observations'].shape[0]
    all_chunks = []

    for i in range(batch_size):
        # Get the data for one episode and its true length
        obs = batch['observations'][i]      # Shape: (max_len, 8)
        act = batch['actions'][i]          # Shape: (max_len-1, 2)
        length = batch['episode_lengths'][i]       # Scalar, e.g., 495

        # A single episode can produce multiple chunks
        # We slide a window of size 'horizon' over the valid part of the episode
        for start_idx in range(length - horizon + 1):
            end_idx = start_idx + horizon

            # Slice the observation and action sequences to get a chunk
            obs_chunk = obs[start_idx:end_idx] # Shape: (horizon, 8)
            act_chunk = act[start_idx:end_idx] # Shape: (horizon, 2)
            
            # Combine them into a single (horizon, 10) tensor
            chunk = torch.cat([obs_chunk, act_chunk], dim=-1)

            # Flatten the chunk to the final 1000-D vector and add to our list
            all_chunks.append(chunk.flatten())

    if not all_chunks:
        return None

    return torch.stack(all_chunks)

In [72]:
# load minari dataset
import minari
minari_dataset = minari.load_dataset(dataset_id="LunarLanderContinuous-v3/ppo-1000-deterministic-v1")
dataloader = DataLoader(minari_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)
for batch in dataloader:
    processed_chunks = create_trajectory_chunks(batch, 100)
    print(processed_chunks.shape)


torch.Size([48515, 1000])
torch.Size([47538, 1000])
torch.Size([47346, 1000])
torch.Size([42836, 1000])


In [None]:
env = minari_dataset.recover_environment()
horizon = 100
action_dim = env.action_space.shape[0]
obs_dim = env.observation_space.shape[0]
input_dim = (obs_dim + action_dim) * horizon

# Training params
lr = 0.001
num_epochs = 500
print_every = 10
hidden_dim = 256

vf = MLP(input_dim=input_dim, time_dim=1, hidden_dim=hidden_dim).to(device)
path = AffineProbPath(scheduler=CondOTScheduler())
optim = torch.optim.Adam(vf.parameters(), lr=lr)

print("Starting training...")
for epoch in range(num_epochs):
    epoch_loss = 0.0
    start_time = time.time()

    for batch in dataloader:
        optim.zero_grad()

        x_1 = create_trajectory_chunks(batch, horizon)
        if x_1 is None:
            continue
        x_1 = x_1.to(device)
        x_0 = torch.randn_like(x_1).to(device)
        t = torch.rand(x_1.shape[0]).to(device)

        # 3. Forward pass and Loss
        path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)
        predicted_velocity = vf(path_sample.x_t, path_sample.t)
        loss = torch.pow(predicted_velocity - path_sample.dx_t, 2).mean()

        # 4. Backward pass and Optimize
        loss.backward()
        optim.step()
        
        epoch_loss += loss.item()

    avg_epoch_loss = epoch_loss / len(dataloader)
    if (epoch + 1) % print_every == 0:
        elapsed = time.time() - start_time
        print(f"| Epoch {epoch+1:6d} | {elapsed:.2f} s/epoch | Loss {avg_epoch_loss:8.5f} ")
        start_time = time.time()

print("Training finished.")

Starting training...
| Epoch     10 | 4.65 s/epoch | Loss  1.03000 
| Epoch     20 | 4.65 s/epoch | Loss  1.00726 
| Epoch     30 | 4.52 s/epoch | Loss  0.98533 
| Epoch     40 | 4.82 s/epoch | Loss  0.97128 
| Epoch     50 | 4.70 s/epoch | Loss  0.96085 
| Epoch     60 | 4.64 s/epoch | Loss  0.95351 
| Epoch     70 | 4.59 s/epoch | Loss  0.94628 
| Epoch     80 | 4.62 s/epoch | Loss  0.94132 
| Epoch     90 | 4.88 s/epoch | Loss  0.93563 
| Epoch    100 | 5.05 s/epoch | Loss  0.93182 
| Epoch    110 | 4.83 s/epoch | Loss  0.92846 
| Epoch    120 | 4.81 s/epoch | Loss  0.92483 
| Epoch    130 | 4.75 s/epoch | Loss  0.92265 
| Epoch    140 | 4.81 s/epoch | Loss  0.92069 
| Epoch    150 | 4.80 s/epoch | Loss  0.91799 
| Epoch    160 | 4.95 s/epoch | Loss  0.91605 
| Epoch    170 | 4.70 s/epoch | Loss  0.91469 
| Epoch    180 | 4.62 s/epoch | Loss  0.91167 
| Epoch    190 | 4.66 s/epoch | Loss  0.90975 
| Epoch    200 | 4.61 s/epoch | Loss  0.90810 
| Epoch    210 | 4.61 s/epoch | Loss  0

In [74]:
# try sampling from trained model...

class WrappedModel(ModelWrapper):
    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
        return self.model(x, t)

wrapped_vf = WrappedModel(vf)

In [113]:
# step size for ode solver
step_size = 0.05

batch_size = 1  # batch size
T = torch.linspace(0,1,10)  # sample times
T = T.to(device=device)

x_init = torch.randn((batch_size, input_dim), dtype=torch.float32, device=device)
solver = ODESolver(velocity_model=wrapped_vf)  # create an ODESolver class
sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=True)
sol.shape

torch.Size([10, 1, 1000])

In [146]:
def generate_trajectory():
    sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=True)
    final_trajectory = sol[-1]
    reshaped_trajectory = final_trajectory.reshape(horizon, obs_dim + action_dim)
    print(reshaped_trajectory.shape)
    observations = reshaped_trajectory[:, :obs_dim]
    actions = reshaped_trajectory[:, obs_dim:obs_dim + action_dim]
    print(observations.shape, actions.shape)
    return observations, actions

In [144]:
observations, actions = generate_trajectory()

torch.Size([100, 10])
torch.Size([100, 8]) torch.Size([100, 2])


In [145]:
env = minari_dataset.recover_environment(eval_env = True, enable_wind=False, render_mode="human")
obs, _ = env.reset()
total_rew = 0
print(len(actions))
for i in range(len(actions)):
    action = actions[i].cpu().numpy()
    obs, rew, terminated, truncated, info = env.step(action)
    total_rew += rew
    print(f"Action: {action}, Reward: {rew}")
    env.render()
    if terminated or truncated:
        break


100
Action: [-1.0113046   0.09929121], Reward: 0.22038934106339525
Action: [-0.30197778 -1.6979274 ], Reward: 0.9631495138452533
Action: [ 1.6620966 -0.7585534], Reward: -1.4074033861991337
Action: [0.43808994 1.5055774 ], Reward: -2.2912380142502786
Action: [-1.7186     0.5056048], Reward: -0.3415482071275153
Action: [ 2.6137555  -0.49187252], Reward: -4.167473103899954
Action: [ 1.2885908  -0.13608862], Reward: -2.4754375993725146
Action: [-0.8797167  1.875761 ], Reward: -0.9152571326638952
Action: [-1.955152  -1.1314248], Reward: 1.0919989343134955
Action: [-0.7718348 -0.4404358], Reward: 0.160859930516267
Action: [1.1052104  0.37521407], Reward: -2.944002022035863
Action: [0.31785864 1.2918353 ], Reward: -2.3368015759477077
Action: [ 0.07928108 -1.3897114 ], Reward: -1.1875503234348161
Action: [-0.820808 -2.848034], Reward: 1.2992528323366355
Action: [-1.6282293   0.76094323], Reward: -0.5903354172017725
Action: [-0.6522331 -0.7577608], Reward: 0.8966462321723009
Action: [-0.213564

In [None]:

horizon = 100
minari_dataset = minari.load_dataset(dataset_id="LunarLanderContinuous-v3/ppo-1000-deterministic-v1")
env = minari_dataset.recover_environment(eval_env=True, render_mode="human")

# --- Random Agent Evaluation ---
print("\n--- Running Random Agent ---")
obs, _ = env.reset()
total_rew_random = 0
for _ in range(horizon):
    action = env.action_space.sample()
    
    obs, rew, terminated, truncated, info = env.step(action)
    total_rew_random += rew
    env.render()
        
    if terminated or truncated:
        print("Episode finished early.")
        break

print(f"Total reward from random agent: {total_rew_random:.2f}")
env.close()