In [1]:
import time
import torch

from torch import nn, Tensor
from torch.utils.data import DataLoader

# flow_matching
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.path import AffineProbPath
from flow_matching.solver import Solver, ODESolver
from flow_matching.utils import ModelWrapper

# visualization
import matplotlib.pyplot as plt

from matplotlib import cm


# To avoide meshgrid warning
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module='torch')


In [2]:
if torch.cuda.is_available():
    device = 'cuda:0'
    print('Using gpu')
else:
    device = 'cpu'
    print('Using cpu.')
torch.manual_seed(42)

Using cpu.


<torch._C.Generator at 0x10eae7590>

In [3]:
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        return torch.sigmoid(x) * x


# TODO: need to resolve temporal locality problem maybe with a CNN later.
class MLP(nn.Module):
    def __init__(self, input_dim: int, time_dim: int = 1, hidden_dim: int = 128):
        super().__init__()

        self.input_dim = input_dim
        self.time_dim = time_dim
        self.hidden_dim = hidden_dim

        self.main = nn.Sequential(
            nn.Linear(input_dim + time_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x: Tensor, t: Tensor) -> Tensor:
        sz = x.size()
        x = x.reshape(-1, self.input_dim)
        t = t.reshape(-1, self.time_dim).float()

        t = t.reshape(-1, 1).expand(x.shape[0], 1)
        h = torch.cat([x, t], dim=1)
        output = self.main(h)

        return output.reshape(*sz)

In [4]:
def collate_fn(batch):
    observations = [torch.as_tensor(x.observations) for x in batch]
    actions = [torch.as_tensor(x.actions) for x in batch]
    rewards = [torch.as_tensor(x.rewards) for x in batch]
    terminations = [torch.as_tensor(x.terminations) for x in batch]
    truncations = [torch.as_tensor(x.truncations) for x in batch]
    return {
        "id": torch.Tensor([x.id for x in batch]),
        "observations": torch.nn.utils.rnn.pad_sequence(
            observations,
            batch_first=True
        ),
        "actions": torch.nn.utils.rnn.pad_sequence(
            actions,
            batch_first=True
        ),
        "rewards": torch.nn.utils.rnn.pad_sequence(
            rewards,
            batch_first=True
        ),
        "terminations": torch.nn.utils.rnn.pad_sequence(
            terminations,
            batch_first=True
        ),
        "truncations": torch.nn.utils.rnn.pad_sequence(
            truncations,
            batch_first=True
        )
    }   

In [6]:
# load minari dataset
import minari
minari_dataset = minari.load_dataset(dataset_id="LunarLanderContinuous-v3/ppo-1000-deterministic-v1")
dataloader = DataLoader(minari_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)
env = minari_dataset.recover_environment()

In [30]:
horizon = 250
action_dim = env.action_space.shape[0]
obs_dim = env.observation_space.shape[0]
input_dim = (obs_dim + action_dim) * horizon

# Training params
lr = 0.001
num_epochs = 500
print_every = 10
hidden_dim = 256

vf = MLP(input_dim=input_dim, time_dim=1, hidden_dim=hidden_dim).to(device)
path = AffineProbPath(scheduler=CondOTScheduler())
optim = torch.optim.Adam(vf.parameters(), lr=lr)

print("Starting training...")
for epoch in range(num_epochs):
    epoch_loss = 0.0
    start_time = time.time()

    for batch in dataloader:
        # 1. Zero gradients for this batch
        optim.zero_grad()

        # 2. Prepare data
        observations = batch["observations"][:, :-1][:, :horizon]
        expert_actions = batch["actions"][:, :horizon]
        expert_actions = batch["actions"][:, :horizon]
        x_1 = torch.cat([observations, expert_actions], dim=-1)
        x_1 = x_1.reshape(x_1.shape[0], -1).to(device)
        x_0 = torch.randn_like(x_1).to(device)
        t = torch.rand(x_1.shape[0]).to(device)

        # 3. Forward pass and Loss
        path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)
        predicted_velocity = vf(path_sample.x_t, path_sample.t)
        loss = torch.pow(predicted_velocity - path_sample.dx_t, 2).mean()

        # 4. Backward pass and Optimize
        loss.backward()
        optim.step()
        
        epoch_loss += loss.item()

    avg_epoch_loss = epoch_loss / len(dataloader)
    if (epoch + 1) % print_every == 0:
        elapsed = time.time() - start_time
        print(f"| Epoch {epoch+1:6d} | {elapsed:.2f} s/epoch | Loss {avg_epoch_loss:8.5f} ")
        start_time = time.time()

print("Training finished.")

Starting training...
| Epoch     10 | 0.43 s/epoch | Loss  1.03202 
| Epoch     20 | 0.43 s/epoch | Loss  1.02859 
| Epoch     30 | 0.39 s/epoch | Loss  1.02167 
| Epoch     40 | 0.39 s/epoch | Loss  1.01973 
| Epoch     50 | 0.39 s/epoch | Loss  1.01940 
| Epoch     60 | 0.39 s/epoch | Loss  1.01963 
| Epoch     70 | 0.42 s/epoch | Loss  1.01684 
| Epoch     80 | 0.41 s/epoch | Loss  1.01512 
| Epoch     90 | 0.40 s/epoch | Loss  1.01169 
| Epoch    100 | 0.40 s/epoch | Loss  1.01212 
| Epoch    110 | 0.41 s/epoch | Loss  1.00878 
| Epoch    120 | 0.42 s/epoch | Loss  1.00858 
| Epoch    130 | 0.39 s/epoch | Loss  1.00506 
| Epoch    140 | 0.40 s/epoch | Loss  1.00555 
| Epoch    150 | 0.39 s/epoch | Loss  1.00294 
| Epoch    160 | 0.39 s/epoch | Loss  1.00307 
| Epoch    170 | 0.40 s/epoch | Loss  1.00188 
| Epoch    180 | 0.39 s/epoch | Loss  1.00036 
| Epoch    190 | 0.40 s/epoch | Loss  1.00065 
| Epoch    200 | 0.39 s/epoch | Loss  0.99817 
| Epoch    210 | 0.40 s/epoch | Loss  0

In [36]:
# try sampling from trained model...

class WrappedModel(ModelWrapper):
    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
        return self.model(x, t)

wrapped_vf = WrappedModel(vf)

In [37]:
# step size for ode solver
step_size = 0.05

batch_size = 1  # batch size
T = torch.linspace(0,1,10)  # sample times
T = T.to(device=device)

x_init = torch.randn((batch_size, input_dim), dtype=torch.float32, device=device)
solver = ODESolver(velocity_model=wrapped_vf)  # create an ODESolver class
sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=True)  # sample from the model

In [38]:
final_trajectory = sol.reshape(horizon, -1)
final_trajectory.shape
observations = final_trajectory[:, :obs_dim]
actions = final_trajectory[:, obs_dim:obs_dim + action_dim]
print(observations.shape, actions.shape)

torch.Size([250, 8]) torch.Size([250, 2])


In [41]:
env = minari_dataset.recover_environment(eval_env = True, render_mode="human")
obs, _ = env.reset()
total_rew = 0
for i in range(horizon):
    action = actions[i].cpu().numpy()
    obs, rew, terminated, truncated, info = env.step(action)
    total_rew += rew
    print(f"Action: {action}, Reward: {rew}")
    env.render()
    if terminated or truncated:
        break

import pygame
pygame.display.quit()
pygame.quit()

Action: [0.2158604  0.60467863], Reward: -1.4394063672600463
Action: [ 0.551068  -1.0015647], Reward: -2.0564348103179486
Action: [-1.0130321  0.3193489], Reward: -0.02581130638174045
Action: [1.1032864 1.2012154], Reward: -4.341134774900683
Action: [-1.760066  -1.6100212], Reward: 0.6875222002680641
Action: [-1.0467013  -0.92995286], Reward: 0.814113899677318
Action: [ 0.5016967 -1.0278525], Reward: -1.6822322741013511
Action: [0.95547706 0.44048885], Reward: -1.2967076224971605
Action: [ 0.17151117 -0.7747519 ], Reward: -0.15411553887180957
Action: [-0.25561076 -0.12994733], Reward: 0.4426983293079729
Action: [0.30285662 0.21694663], Reward: -0.9423567858843171
Action: [-0.33751944 -1.7288631 ], Reward: 1.247951803727717
Action: [0.29053512 0.15206575], Reward: -1.1299692777025485
Action: [0.40577304 0.86196953], Reward: -2.2788200959333222
Action: [ 0.39176193 -0.48698473], Reward: -1.1072202172536663
Action: [0.57743746 0.15333663], Reward: -0.820941460335041
Action: [-0.18786347  