In [8]:
import time
import torch

from torch import nn, Tensor
from torch.utils.data import DataLoader

# flow_matching
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.path import AffineProbPath
from flow_matching.solver import Solver, ODESolver
from flow_matching.utils import ModelWrapper

# visualization
import matplotlib.pyplot as plt

from matplotlib import cm


# To avoide meshgrid warning
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module='torch')

In [6]:
if torch.cuda.is_available():
    device = 'cuda:0'
    print('Using gpu')
else:
    device = 'cpu'
    print('Using cpu.')
torch.manual_seed(42)

Using cpu.


<torch._C.Generator at 0x120c70550>

In [None]:
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        return torch.sigmoid(x) * x


# TODO: need to resolve temporal locality problem maybe with a CNN later.
class MLP(nn.Module):
    def __init__(self, input_dim: int, time_dim: int = 1, hidden_dim: int = 128):
        super().__init__()

        self.input_dim = input_dim
        self.time_dim = time_dim
        self.hidden_dim = hidden_dim

        self.main = nn.Sequential(
            nn.Linear(input_dim + time_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x: Tensor, t: Tensor) -> Tensor:
        sz = x.size()
        x = x.reshape(-1, self.input_dim)
        t = t.reshape(-1, self.time_dim).float()

        t = t.reshape(-1, 1).expand(x.shape[0], 1)
        h = torch.cat([x, t], dim=1)
        output = self.main(h)

        return output.reshape(*sz)

In [9]:
def collate_fn(batch):
    return {
        "id": torch.Tensor([x.id for x in batch]),
        "observations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.observations) for x in batch],
            batch_first=True
        ),
        "actions": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.actions) for x in batch],
            batch_first=True
        ),
        "rewards": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.rewards) for x in batch],
            batch_first=True
        ),
        "terminations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.terminations) for x in batch],
            batch_first=True
        ),
        "truncations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.truncations) for x in batch],
            batch_first=True
        )
    }

In [None]:
# load minari dataset
import minari
minari_dataset = minari.load_dataset(dataset_id="LunarLanderContinuous-v3/ppo-1000-v1")
dataloader = DataLoader(minari_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)
env = minari_dataset.recover_environment()




In [20]:
horizon = 1000
action_dim = env.action_space.shape[0]
obs_dim = env.observation_space.shape[0]
input_dim = (obs_dim + action_dim) * horizon

# training parameters
lr = 0.001
batch_size = 256
iterations = 20001
hidden_dim = 512

# init the velocity field model
vf = MLP(input_dim=input_dim, time_dim=1, hidden_dim=hidden_dim).to(device)

# instantiate an affine path object
path = AffineProbPath(scheduler=CondOTScheduler())

# init optimizer
optim = torch.optim.Adam(vf.parameters(), lr=lr)

# train
start_time = time.time()
for i in range(iterations):
    for batch in dataloader:
        observations = batch["observations"][:, :-1]
        expert_actions = batch["actions"]
        x_1 = torch.cat([observations, expert_actions], dim=-1)
        x_1 = x_1.reshape(x_1.shape[0], -1).to(device)
        print(x_1.shape)
        x_0 = torch.rand_like(x_1).to(device)

        t = torch.rand(x_1.shape[0]).to(device)

        # sample probability path
        path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)

        # flow matching l2 loss
        loss = torch.pow( vf(path_sample.x_t,path_sample.t) - path_sample.dx_t, 2).mean() 

torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([232, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([232, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([232, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([232, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([232, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([232, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([232, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([232, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([232, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([256, 10000])
torch.Size([232, 10000])


KeyboardInterrupt: 