In [5]:
import time
import torch

from torch import nn, Tensor
from torch.utils.data import DataLoader

# flow_matching
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.path import AffineProbPath
from flow_matching.solver import Solver, ODESolver
from flow_matching.utils import ModelWrapper

# visualization
import matplotlib.pyplot as plt

from matplotlib import cm


# To avoide meshgrid warning
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module='torch')

In [6]:
if torch.cuda.is_available():
    device = 'cuda:0'
    print('Using gpu')
else:
    device = 'cpu'
    print('Using cpu.')
torch.manual_seed(42)

Using gpu


<torch._C.Generator at 0x7f145c448c90>

In [7]:
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        return torch.sigmoid(x) * x


# TODO: need to resolve temporal locality problem maybe with a CNN later.
class MLP(nn.Module):
    def __init__(self, input_dim: int, time_dim: int = 1, hidden_dim: int = 128):
        super().__init__()

        self.input_dim = input_dim
        self.time_dim = time_dim
        self.hidden_dim = hidden_dim

        self.main = nn.Sequential(
            nn.Linear(input_dim + time_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x: Tensor, t: Tensor) -> Tensor:
        sz = x.size()
        x = x.reshape(-1, self.input_dim)
        t = t.reshape(-1, self.time_dim).float()

        t = t.reshape(-1, 1).expand(x.shape[0], 1)
        h = torch.cat([x, t], dim=1)
        output = self.main(h)

        return output.reshape(*sz)

In [8]:
def collate_fn(batch):
    return {
        "id": torch.Tensor([x.id for x in batch]),
        "observations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.observations) for x in batch],
            batch_first=True
        ),
        "actions": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.actions) for x in batch],
            batch_first=True
        ),
        "rewards": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.rewards) for x in batch],
            batch_first=True
        ),
        "terminations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.terminations) for x in batch],
            batch_first=True
        ),
        "truncations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.truncations) for x in batch],
            batch_first=True
        )
    }

In [10]:
# load minari dataset
import minari
minari_dataset = minari.load_dataset(dataset_id="LunarLanderContinuous-v3/ppo-1000-v1")
dataloader = DataLoader(minari_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)
env = minari_dataset.recover_environment()

In [None]:
horizon = 1000 # need to adjust this
action_dim = env.action_space.shape[0]
obs_dim = env.observation_space.shape[0]
input_dim = (obs_dim + action_dim) * horizon

# Training params
lr = 0.001
num_epochs = 100
print_every = 10
hidden_dim = 1024

vf = MLP(input_dim=input_dim, time_dim=1, hidden_dim=hidden_dim).to(device)
path = AffineProbPath(scheduler=CondOTScheduler())
optim = torch.optim.Adam(vf.parameters(), lr=lr)

print("Starting training...")
for epoch in range(num_epochs):
    epoch_loss = 0.0
    start_time = time.time()

    for batch in dataloader:
        # 1. Zero gradients for this batch
        optim.zero_grad()

        # 2. Prepare data
        observations = batch["observations"][:, :-1]
        expert_actions = batch["actions"]
        x_1 = torch.cat([observations, expert_actions], dim=-1)
        x_1 = x_1.reshape(x_1.shape[0], -1).to(device)
        x_0 = torch.randn_like(x_1).to(device)
        t = torch.rand(x_1.shape[0]).to(device)

        # 3. Forward pass and Loss
        path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)
        predicted_velocity = vf(path_sample.x_t, path_sample.t)
        loss = torch.pow(predicted_velocity - path_sample.dx_t, 2).mean()

        # 4. Backward pass and Optimize
        loss.backward()
        optim.step()
        
        epoch_loss += loss.item()

    avg_epoch_loss = epoch_loss / len(dataloader)
    if (epoch + 1) % print_every == 0:
        elapsed = time.time() - start_time
        print(f"| Epoch {epoch+1:6d} | {elapsed:.2f} s/epoch | Loss {avg_epoch_loss:8.5f} ")
        start_time = time.time()

print("Training finished.")

Starting training...
| Epoch     10 | 1.12 s/epoch | Loss  1.06362 
| Epoch     20 | 1.06 s/epoch | Loss  1.05980 
| Epoch     30 | 0.76 s/epoch | Loss  1.05979 
| Epoch     40 | 0.89 s/epoch | Loss  1.06008 
| Epoch     50 | 1.09 s/epoch | Loss  1.05964 
| Epoch     60 | 1.08 s/epoch | Loss  1.05887 
| Epoch     70 | 0.77 s/epoch | Loss  1.05756 
| Epoch     80 | 1.08 s/epoch | Loss  1.05734 
| Epoch     90 | 1.06 s/epoch | Loss  1.05930 
| Epoch    100 | 1.06 s/epoch | Loss  1.05805 
Training finished.


In [16]:
# try sampling from trained model...

class WrappedModel(ModelWrapper):
    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
        return self.model(x, t)

wrapped_vf = WrappedModel(vf)

In [23]:
# step size for ode solver
step_size = 0.05

batch_size = 1  # batch size
T = torch.linspace(0,1,10)  # sample times
T = T.to(device=device)

x_init = torch.randn((batch_size, input_dim), dtype=torch.float32, device=device)
solver = ODESolver(velocity_model=wrapped_vf)  # create an ODESolver class
sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=True)  # sample from the model

In [37]:
final_trajectory = sol.reshape(horizon, -1)
final_trajectory.shape
observations = final_trajectory[:, :obs_dim]
actions = final_trajectory[:, obs_dim:obs_dim + action_dim]
# print(observations.shape, actions.shape)

In [38]:
env = minari_dataset.recover_environment(eval_env = True, render_mode="human")
obs, _ = env.reset(seed=42)
total_rew = 0
for i in range(horizon):
    action = actions[i].cpu().numpy()
    obs, rew, terminated, truncated, info = env.step(action)
    total_rew += rew
    print(f"Action: {action}, Reward: {rew}")
    env.render()
    if terminated or truncated:
        break

Action: [-0.59474504 -2.3656845 ], Reward: 2.0804712989958616
Action: [-1.6046611  1.4039768], Reward: 0.529247597147845
Action: [-0.30896065  0.40414703], Reward: 1.2011792847636968
Action: [-1.4175168  0.1911168], Reward: 1.1590071000121895
Action: [-0.3444995  1.2422355], Reward: 0.23162445040557714
Action: [ 0.8719092  -0.16229032], Reward: -1.1592154582360479
Action: [0.90314084 0.65312976], Reward: -2.524799849197235
Action: [ 0.5631698 -0.9630722], Reward: -2.927212301543244
Action: [ 1.014588   -0.01312742], Reward: -1.8393398567428279
Action: [0.93166876 1.2738359 ], Reward: -2.716529957730306
Action: [2.109867   0.09654108], Reward: -1.7539641506018484
Action: [-2.3889139  0.9981402], Reward: -0.24904855373091095
Action: [-0.4826049   0.11079448], Reward: 0.4263253605946318
Action: [-0.30478856  1.472222  ], Reward: -0.5918381787485874
Action: [-0.44305983 -1.0094665 ], Reward: 0.861314760187754
Action: [0.5981448 0.865999 ], Reward: -2.45467567226086
Action: [-0.5309649 -1.7