In [1]:
import time
import torch

from torch import nn, Tensor
from torch.utils.data import DataLoader

# flow_matching
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.path import AffineProbPath
from flow_matching.solver import Solver, ODESolver
from flow_matching.utils import ModelWrapper

# visualization
import matplotlib.pyplot as plt

from matplotlib import cm


# To avoide meshgrid warning
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module='torch')


In [2]:
if torch.cuda.is_available():
    device = 'cuda:0'
    print('Using gpu')
else:
    device = 'cpu'
    print('Using cpu.')
torch.manual_seed(42)

Using gpu


<torch._C.Generator at 0x7f93eed6d670>

In [3]:
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        return torch.sigmoid(x) * x


# TODO: need to resolve temporal locality problem maybe with a CNN later.
class MLP(nn.Module):
    def __init__(self, input_dim: int, time_dim: int = 1, hidden_dim: int = 128):
        super().__init__()

        self.input_dim = input_dim
        self.time_dim = time_dim
        self.hidden_dim = hidden_dim

        self.main = nn.Sequential(
            nn.Linear(input_dim + time_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x: Tensor, t: Tensor) -> Tensor:
        sz = x.size()
        x = x.reshape(-1, self.input_dim)
        t = t.reshape(-1, self.time_dim).float()

        t = t.reshape(-1, 1).expand(x.shape[0], 1)
        h = torch.cat([x, t], dim=1)
        output = self.main(h)

        return output.reshape(*sz)

In [4]:
# def collate_fn(batch):
#     batch_observations = []
#     batch_actions = []

#     for episode_data in batch:


def collate_fn(batch):
    return {
        "id": torch.Tensor([x.id for x in batch]),
        "observations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.observations) for x in batch],
            batch_first=True
        ),
        "actions": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.actions) for x in batch],
            batch_first=True
        ),
        "rewards": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.rewards) for x in batch],
            batch_first=True
        ),
        "terminations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.terminations) for x in batch],
            batch_first=True
        ),
        "truncations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.truncations) for x in batch],
            batch_first=True
        )
    }

In [5]:
# load minari dataset
import minari
minari_dataset = minari.load_dataset(dataset_id="LunarLanderContinuous-v3/ppo-1000-v1")
dataloader = DataLoader(minari_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)
env = minari_dataset.recover_environment()
episode = minari_dataset[0]
episode

EpisodeData(id=0, total_steps=1000, observations=ndarray of shape (1001, 8) and dtype float32, actions=ndarray of shape (1000, 2) and dtype float32, rewards=ndarray of 1000 floats, terminations=ndarray of 1000 bools, truncations=ndarray of 1000 bools, infos=dict with the following keys: [])

In [11]:
# preprocess dataset to the right horizon


In [None]:
horizon = 100 # need to adjust this
action_dim = env.action_space.shape[0]
obs_dim = env.observation_space.shape[0]
input_dim = (obs_dim + action_dim) * horizon

# Training params
lr = 0.001
num_epochs = 1500
print_every = 100
hidden_dim = 1024

vf = MLP(input_dim=input_dim, time_dim=1, hidden_dim=hidden_dim).to(device)
path = AffineProbPath(scheduler=CondOTScheduler())
optim = torch.optim.Adam(vf.parameters(), lr=lr)

print("Starting training...")
for epoch in range(num_epochs):
    epoch_loss = 0.0
    start_time = time.time()

    for batch in dataloader:
        # 1. Zero gradients for this batch
        optim.zero_grad()

        # 2. Prepare data
        observations = batch["observations"][:, :-1][:, :horizon]
        expert_actions = batch["actions"][:, :horizon]
        x_1 = torch.cat([observations, expert_actions], dim=-1)
        x_1 = x_1.reshape(x_1.shape[0], -1).to(device)
        x_0 = torch.randn_like(x_1).to(device)
        t = torch.rand(x_1.shape[0]).to(device)

        # 3. Forward pass and Loss
        path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)
        predicted_velocity = vf(path_sample.x_t, path_sample.t)
        loss = torch.pow(predicted_velocity - path_sample.dx_t, 2).mean()

        # 4. Backward pass and Optimize
        loss.backward()
        optim.step()
        
        epoch_loss += loss.item()

    avg_epoch_loss = epoch_loss / len(dataloader)
    if (epoch + 1) % print_every == 0:
        elapsed = time.time() - start_time
        print(f"| Epoch {epoch+1:6d} | {elapsed:.2f} s/epoch | Loss {avg_epoch_loss:8.5f} ")
        start_time = time.time()

print("Training finished.")

Starting training...
| Epoch    100 | 0.76 s/epoch | Loss  0.95060 
| Epoch    200 | 0.84 s/epoch | Loss  0.86491 
| Epoch    300 | 0.83 s/epoch | Loss  0.82484 
| Epoch    400 | 0.84 s/epoch | Loss  0.80036 
| Epoch    500 | 0.85 s/epoch | Loss  0.79039 
| Epoch    600 | 0.83 s/epoch | Loss  0.77608 
| Epoch    700 | 0.84 s/epoch | Loss  0.76352 
| Epoch    800 | 0.84 s/epoch | Loss  0.75767 
| Epoch    900 | 0.82 s/epoch | Loss  0.75547 
| Epoch   1000 | 0.84 s/epoch | Loss  0.74468 
| Epoch   1100 | 0.84 s/epoch | Loss  0.73785 
| Epoch   1200 | 0.84 s/epoch | Loss  0.73716 
| Epoch   1300 | 0.84 s/epoch | Loss  0.72871 
| Epoch   1400 | 0.84 s/epoch | Loss  0.73089 
| Epoch   1500 | 0.84 s/epoch | Loss  0.72767 
| Epoch   1600 | 0.84 s/epoch | Loss  0.72332 
| Epoch   1700 | 0.84 s/epoch | Loss  0.72246 
| Epoch   1800 | 0.84 s/epoch | Loss  1.22273 
| Epoch   1900 | 0.85 s/epoch | Loss  1.23737 
| Epoch   2000 | 0.83 s/epoch | Loss  1.39969 
| Epoch   2100 | 0.84 s/epoch | Loss  1

KeyboardInterrupt: 

In [36]:
# try sampling from trained model...

class WrappedModel(ModelWrapper):
    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
        return self.model(x, t)

wrapped_vf = WrappedModel(vf)

In [37]:
# step size for ode solver
step_size = 0.05

batch_size = 1  # batch size
T = torch.linspace(0,1,10)  # sample times
T = T.to(device=device)

x_init = torch.randn((batch_size, input_dim), dtype=torch.float32, device=device)
solver = ODESolver(velocity_model=wrapped_vf)  # create an ODESolver class
sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=True)  # sample from the model

In [38]:
final_trajectory = sol.reshape(horizon, -1)
final_trajectory.shape
observations = final_trajectory[:, :obs_dim]
actions = final_trajectory[:, obs_dim:obs_dim + action_dim]
print(observations.shape, actions.shape)

torch.Size([100, 8]) torch.Size([100, 2])


In [39]:
env = minari_dataset.recover_environment(eval_env = True, render_mode="human")
obs, _ = env.reset(seed=42)
total_rew = 0
for i in range(horizon):
    action = actions[i].cpu().numpy()
    obs, rew, terminated, truncated, info = env.step(action)
    total_rew += rew
    print(f"Action: {action}, Reward: {rew}")
    env.render()
    if terminated or truncated:
        break

import pygame
pygame.display.quit()
pygame.quit()

Action: [-1.1722629  0.7583265], Reward: 0.6375981558857995
Action: [1.9599218  0.45798123], Reward: -4.504294383485461
Action: [-1.0760934  2.048606 ], Reward: 0.2667735376406608
Action: [0.59545684 0.6822942 ], Reward: -5.112996145088616
Action: [-1.4163488  2.2833805], Reward: -0.08246028055910301
Action: [ 1.9525062 -1.7280059], Reward: -1.4656804134121604
Action: [ 0.63257545 -0.09153331], Reward: -1.854936977588261
Action: [ 0.3800331  -0.53515935], Reward: -3.110143281546466
Action: [-0.21625257  2.0234437 ], Reward: 0.060902082400816654
Action: [ 0.33920717 -2.706087  ], Reward: -0.6641208879853298
Action: [-1.2724048  0.7058647], Reward: 0.28389711074927393
Action: [1.8796579 0.4381677], Reward: -2.6834044219850055
Action: [-1.132049   2.0635967], Reward: -0.3745501055491036
Action: [0.6013028 0.717165 ], Reward: -3.587055311932185
Action: [-1.3698475  2.3062391], Reward: -0.8832771738943268
Action: [ 2.0063336 -1.759538 ], Reward: -2.083958208590944
Action: [ 0.6986403  -0.10