In [1]:
import sys
from pathlib import Path
from os.path  import join
from src.config import DATA_DIR, CONFIG_DIR, MODELS_DIR
from omegaconf import OmegaConf
from src.external.hptr.src.data_modules.agent_centric import AgentCentricPreProcessing
from src.external.hptr.src.data_modules.ac_global import AgentCentricGlobal
from src.mimolm import InputProjections, EarlyFusionEncoder, MotionDecoder
import torch
import torch.nn.functional as F
import lightning as pl

# Add the project root to sys.path
project_root = Path().resolve().parent  # Adjust as needed to point to the root folder
sys.path.append(str(project_root))

print(Path.cwd())  # Check if the path is added

[32m2025-02-20 09:55:41.798[0m | [1mINFO    [0m | [36msrc.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /home/harshavardhan-patil/Work/Projects/mimolm[0m
  from .autonotebook import tqdm as notebook_tqdm


/home/harshavardhan-patil/Work/Projects/mimolm/notebooks


In [4]:
from src.external.hptr.src.data_modules.data_h5_av2 import DataH5av2
from src.mimolm import MimoLM
from src.modeling.modules.lm_utils import interpolate_trajectory, cluster_rollouts, non_maximum_suppression

torch.set_printoptions(sci_mode=False)
data_module = DataH5av2(DATA_DIR
                        , batch_size=128)
data_module.setup(stage="validate")
val_loader = data_module.val_dataloader()

model = MimoLM(data_size=data_module.tensor_size_val
                                    , n_rollouts = 1
                                    , learning_rate = 8.e-8,
                                    sampling_rate=2)
trainer = pl.Trainer(fast_dev_run=10)
output = trainer.test(model=model, dataloaders=val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 10 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


In [None]:
break

In [3]:
import torch

token = torch.randint(100, (2, 3, 5, 1)) / 100.
token

tensor([[[[0.0100],
          [0.3600],
          [0.1000],
          [0.6700],
          [0.1900]],

         [[0.0700],
          [0.4900],
          [0.5600],
          [0.7900],
          [0.5100]],

         [[0.0700],
          [0.9600],
          [0.8500],
          [0.8000],
          [0.8800]]],


        [[[0.8700],
          [0.3300],
          [0.6800],
          [0.7200],
          [0.4500]],

         [[0.8100],
          [0.9300],
          [0.5300],
          [0.9100],
          [0.6600]],

         [[0.8000],
          [0.9000],
          [0.9000],
          [0.1000],
          [0.9500]]]])

In [4]:
motion_embeddings = token.flatten(1,2)
print(motion_embeddings.shape)
print(motion_embeddings)
query = motion_embeddings

torch.Size([2, 15, 1])
tensor([[[0.0100],
         [0.3600],
         [0.1000],
         [0.6700],
         [0.1900],
         [0.0700],
         [0.4900],
         [0.5600],
         [0.7900],
         [0.5100],
         [0.0700],
         [0.9600],
         [0.8500],
         [0.8000],
         [0.8800]],

        [[0.8700],
         [0.3300],
         [0.6800],
         [0.7200],
         [0.4500],
         [0.8100],
         [0.9300],
         [0.5300],
         [0.9100],
         [0.6600],
         [0.8000],
         [0.9000],
         [0.9000],
         [0.1000],
         [0.9500]]])


In [5]:
from src.modeling.modules.lm_utils import get_attention_mask

attn_mask = get_attention_mask(5, query.shape[1])
print(attn_mask.shape)
print(attn_mask)

torch.Size([15, 15])
tensor([[False,  True,  True,  True,  True, False,  True,  True,  True,  True,
         False,  True,  True,  True,  True],
        [False, False,  True,  True,  True, False, False,  True,  True,  True,
         False, False,  True,  True,  True],
        [False, False, False,  True,  True, False, False, False,  True,  True,
         False, False, False,  True,  True],
        [False, False, False, False,  True, False, False, False, False,  True,
         False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False],
        [False,  True,  True,  True,  True, False,  True,  True,  True,  True,
         False,  True,  True,  True,  True],
        [False, False,  True,  True,  True, False, False,  True,  True,  True,
         False, False,  True,  True,  True],
        [False, False, False,  True,  True, False, False, False,  True,  True,
         False, False, False,  T

In [6]:
padding_mask = torch.randint(0, 2, [2, 3, 5]).flatten(1, -1)
print(padding_mask.shape)
print(padding_mask)

torch.Size([2, 15])
tensor([[0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0],
        [1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0]])


In [7]:
import torch.nn as nn

self_attn = nn.MultiheadAttention(embed_dim = 1, num_heads = 1, batch_first=True)

In [8]:
attn_out_1, _ = self_attn(query, query, query, attn_mask=attn_mask, key_padding_mask = padding_mask.bool())
print(attn_out_1.shape)
print(attn_out_1)

torch.Size([2, 15, 1])
tensor([[[0.0063],
         [0.3032],
         [0.2576],
         [0.3483],
         [0.3709],
         [0.0063],
         [0.3017],
         [0.2541],
         [0.3476],
         [0.3693],
         [0.0063],
         [0.2962],
         [0.2519],
         [0.3475],
         [0.3673]],

        [[   nan],
         [0.2092],
         [0.4017],
         [0.3813],
         [0.4136],
         [   nan],
         [0.2092],
         [0.4021],
         [0.3804],
         [0.4126],
         [   nan],
         [0.2092],
         [0.4010],
         [0.3841],
         [0.4112]]], grad_fn=<TransposeBackward0>)


In [9]:
attn_out_2 = attn_out_1.unflatten(dim=1, sizes=(3, 5)).flatten(0, 1)
print(attn_out_2.shape)
print(attn_out_2)

torch.Size([6, 5, 1])
tensor([[[0.0063],
         [0.3032],
         [0.2576],
         [0.3483],
         [0.3709]],

        [[0.0063],
         [0.3017],
         [0.2541],
         [0.3476],
         [0.3693]],

        [[0.0063],
         [0.2962],
         [0.2519],
         [0.3475],
         [0.3673]],

        [[   nan],
         [0.2092],
         [0.4017],
         [0.3813],
         [0.4136]],

        [[   nan],
         [0.2092],
         [0.4021],
         [0.3804],
         [0.4126]],

        [[   nan],
         [0.2092],
         [0.4010],
         [0.3841],
         [0.4112]]], grad_fn=<UnsafeViewBackward0>)


In [10]:
fused_emb = torch.randint(100, (6, 10, 1)) / 100.
fused_emb

tensor([[[0.7100],
         [0.2300],
         [0.3800],
         [0.4900],
         [0.5300],
         [0.9300],
         [0.0900],
         [0.3300],
         [0.6300],
         [0.6800]],

        [[0.5500],
         [0.7000],
         [0.2400],
         [0.9500],
         [0.3100],
         [0.1500],
         [0.8400],
         [0.3400],
         [0.8100],
         [0.1000]],

        [[0.7500],
         [0.5600],
         [0.1000],
         [0.6200],
         [0.4000],
         [0.1800],
         [0.6400],
         [0.3500],
         [0.2100],
         [0.0500]],

        [[0.2600],
         [0.2200],
         [0.0600],
         [0.0100],
         [0.8800],
         [0.5100],
         [0.4800],
         [0.9100],
         [0.5800],
         [0.2500]],

        [[0.7500],
         [0.2300],
         [0.4900],
         [0.9500],
         [0.4800],
         [0.4800],
         [0.6900],
         [0.4100],
         [0.2800],
         [0.1400]],

        [[0.9000],
         [0.2100],
  

In [11]:
cross_attn = nn.MultiheadAttention(embed_dim = 1, num_heads = 1, batch_first=True)

In [12]:
query, _ = cross_attn(attn_out_2, fused_emb, fused_emb)
print(query.shape)
print(query)

torch.Size([6, 5, 1])
tensor([[[0.1355],
         [0.1363],
         [0.1362],
         [0.1364],
         [0.1365]],

        [[0.1352],
         [0.1364],
         [0.1362],
         [0.1366],
         [0.1367]],

        [[0.1046],
         [0.1054],
         [0.1052],
         [0.1055],
         [0.1055]],

        [[   nan],
         [0.1136],
         [0.1144],
         [0.1143],
         [0.1144]],

        [[   nan],
         [0.1333],
         [0.1338],
         [0.1338],
         [0.1338]],

        [[   nan],
         [0.1272],
         [0.1281],
         [0.1280],
         [0.1281]]], grad_fn=<TransposeBackward0>)


In [13]:
query = query.unflatten(dim=0, sizes=(2, 3)).flatten(1, 2)
print(query.shape)
print(query)

torch.Size([2, 15, 1])
tensor([[[0.1355],
         [0.1363],
         [0.1362],
         [0.1364],
         [0.1365],
         [0.1352],
         [0.1364],
         [0.1362],
         [0.1366],
         [0.1367],
         [0.1046],
         [0.1054],
         [0.1052],
         [0.1055],
         [0.1055]],

        [[   nan],
         [0.1136],
         [0.1144],
         [0.1143],
         [0.1144],
         [   nan],
         [0.1333],
         [0.1338],
         [0.1338],
         [0.1338],
         [   nan],
         [0.1272],
         [0.1281],
         [0.1280],
         [0.1281]]], grad_fn=<UnsafeViewBackward0>)


In [14]:
target_valid = torch.randint(0, 2, [2, 8, 25])
if target_valid.shape[-1] != 55:
    target_valid = torch.cat([target_valid, torch.ones([2, 8, 55 - target_valid.shape[-1]])], dim=-1)
torch.randn([2, 8, 55, 1])[~target_valid.bool()].shape

torch.Size([201, 1])