In [None]:
import torch
from flow_matching.supervised.alphas_betas import LinearAlpha, LinearBeta
from flow_matching.supervised.prob_paths import GaussianConditionalProbabilityPath
from flow_matching.supervised.training import FlowTrainer

from flow_matching.whar.sampler import WHARSampler
from flow_matching.whar.vector_field import WHARUnet

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
sampler = WHARSampler()

Creating config hash...
Checking download...
Download exists.
Checking sessions...
Sessions exist.
Validating common format...
[########################################] | 100% Completed | 746.40 ms
Common format validated.
Checking windowing...
Loading config hash...
Windowing exists.
subject_ids: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29]
activity_ids: [0 1 2 3 4 5]
train: 7046 | val: 782 | test: 1671
Postprocessing...


Loading windows: 100%|██████████| 9499/9499 [00:07<00:00, 1296.00it/s]


Getting normalization parameters...
Loading config hash...
Creating normalization parameters hash...
Loading normalization parameters hash...


Loading samples: 100%|██████████| 9499/9499 [00:01<00:00, 6598.88it/s]


In [None]:
shape = sampler.get_shape()
print(shape)
num_classes = len(sampler.dataset.get_class_weights(sampler.train_loader).keys())
print(num_classes)

path = GaussianConditionalProbabilityPath(
    p_data=sampler,
    p_simple_shape=shape,
    alpha=LinearAlpha(),
    beta=LinearBeta(),
).to(device)

unet = WHARUnet(
    in_channels=shape[0],
    channels=[16, 32, 64],
    num_blocks=4,
    emb_dim=16,
    num_classes=num_classes,
).to(device)

trainer = FlowTrainer(
    path=path,
    model=unet,
    null_class=num_classes,
    eta=0.001,  # (1 / num_classes)
)

[9, 5, 64]
6


In [None]:
trainer.train(num_epochs=1000, device=device, lr=1e-3, batch_size=32)

In [8]:
torch.save(unet.state_dict(), "unet.pt")