In [3]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install x-transformers
!pip install matplotlib
!pip install einops
!pip install wandb

from google.colab import drive
drive.mount('/content/drive')

!cd /content/drive/MyDrive/realistic-imu/src

import wandb
wandb.login()

import csv
import matplotlib.pyplot as plt
import os
import sys
import wandb
import pickle
import re
import itertools

# Add the source directory to the system path
sys.path.append('/content/drive/MyDrive/realistic-imu/src')

from trase_dataset import TraseDataset
from trase import Trase, TraseLoss


if 'ipykernel' in sys.modules:
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm


import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from torch.profiler import profile, record_function, ProfilerActivity


torch.cuda.empty_cache()

torch.set_float32_matmul_precision('high')

device = torch.device("cuda")
EPOCHS = 100
LEARNING_RATE = 1e-5
WEIGHT_DECAY = 1e-4
data_path = "/content/drive/MyDrive/beyondamass/data/realistic-imu-dataset/"
base_path = "/content/drive/MyDrive/beyondamass/data/realistic-imu-dataset/models"
D_MODEL = 2048
INPUT_EMBEDDING_DIM = 408
NUM_ENCODERS = 1
FEED_FORWARD_DIM = 2048
DROPOUT = 0.1
HEADS = 8
TOTAL_VAR_WEIGHT = 0

class CheckpointSaver:
    def __init__(self, model, initial_best_loss=float("inf")):
        self.best_dev_loss = initial_best_loss
        self.model = model

    def save_checkpoint(self, dev_loss: float):
        # dev_loss is now guaranteed to be a float
        if dev_loss <= self.best_dev_loss:
            os.makedirs(f"weights/models-{wandb.run.name}", exist_ok=True)
            torch.save(self.model.state_dict(),
                       f"weights/models-{wandb.run.name}/best.pt")
            wandb.save(f"weights/models-{wandb.run.name}/best.pt", policy="end")
            self.best_dev_loss = dev_loss
            wandb.log({"best_dev_loss": dev_loss})

        wandb.log({"dev_loss": dev_loss})


train_path = os.path.join(data_path, "train.pkl")
dev_path = os.path.join(data_path, "dev.pkl")
test_path = os.path.join(data_path, "test.pkl")

train_dataset = TraseDataset(train_path)
dev_dataset = TraseDataset(dev_path)
# test_dataset = TraseDataset(test_path)

identity_collate = lambda batch: batch


train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=identity_collate)
dev_loader = DataLoader(dev_dataset, batch_size=1, shuffle=False, collate_fn=identity_collate)
# test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

criterion = TraseLoss(total_var_weight=TOTAL_VAR_WEIGHT)


# Grid
weight_decays     = [1e-3]
learning_rates    = [1e-5]
num_encoders_list = [6]

param_grid = list(itertools.product(
    weight_decays,
    learning_rates,
    num_encoders_list,
))

results = []
for wd, lr, ne in tqdm(param_grid, desc="Hyperparameter search"):
  run = wandb.init(
      project="generating-imu-data-two",
      config={
          "learning_rate": lr,
          "epochs": EPOCHS,
          "weight_decay": wd,
          "d_model": D_MODEL,
          "input_embedding_dim": INPUT_EMBEDDING_DIM,
          "num_encoders": ne,
          "feed_forward_dim": FEED_FORWARD_DIM,
          "dropout": 0,
          "heads": HEADS,
          "total_var_weight": TOTAL_VAR_WEIGHT
      }
  )

  model = Trase(d_model=D_MODEL,
              inp_emb_dim=INPUT_EMBEDDING_DIM,
              device=device,
              num_encoders=ne,
              dim_feed_forward=FEED_FORWARD_DIM,
              heads=HEADS).to(device)

  model = torch.compile(model)

  saver = CheckpointSaver(model)

  optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
  scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)

  # 1) create a GradScaler for mixed precision
  scaler = GradScaler()

  def train_model():
      curr_loss = 0.0
      model.train()

      for data in train_loader:
          data = data[0]
          mocap_data = data["inputs"].to(device)
          real_acc = data["accelerations_output"].to(device)
          real_angular_vel = (data["angular_velocities_output"].to(device)
                              if data["angular_velocities_output"] is not None else None)
          mask = data["output_mask"].T.to(device)
          weights = data["weights"].T.repeat_interleave(3, dim=0).to(device)

          optimizer.zero_grad()

          # 2) forward + loss inside autocast for mixed precision
          with autocast("cuda"):
              kinematics, acc_output, acc_std, gyro_output, gyro_std = model(mocap_data)
              loss = criterion(
                  kinematics=kinematics * mask * weights,
                  acc_mean=acc_output * mask * weights,
                  acc_std=acc_std * weights,
                  real_acc=real_acc * mask * weights,
                  gyro_mean=gyro_output,
                  gyro_std=gyro_std,
                  real_gyro=real_angular_vel,
                  include_gyro=(real_angular_vel is not None)
              )

          # 3) scale, backward, step
          scaler.scale(loss).backward()
          scaler.step(optimizer)
          scaler.update()

          scheduler.step()

          curr_loss += loss.item()

      return curr_loss / len(train_loader)



  def evaluate_model(data_loader):
      curr_loss = 0
      model.eval()

      with torch.no_grad():
        for data in data_loader:
          data = data[0]
          mocap_data = data["inputs"]
          real_acc = data["accelerations_output"]
          real_angular_vel = data["angular_velocities_output"] if data["angular_velocities_output"] is not None else None
          mask = data["output_mask"].T
          weights = data["weights"].T.repeat_interleave(3, dim=0)

          kinematics, acc_output, acc_std, gyro_output, gyro_std = model(mocap_data)

          loss = criterion(kinematics=kinematics * mask * weights,
                          acc_mean = acc_output * mask * weights,
                          acc_std = acc_std * weights,
                          real_acc = real_acc * mask * weights,
                          gyro_mean = gyro_output,
                          gyro_std = gyro_std,
                          real_gyro = real_angular_vel,
                          include_gyro = real_angular_vel is not None)


          curr_loss += loss.item()

      return curr_loss / len(data_loader)


  dev_loss_value = evaluate_model(dev_loader)

  wandb.log({"dev_loss": dev_loss_value})
  saver.save_checkpoint(dev_loss_value)

  wandb.log({"train_loss": evaluate_model(train_loader)})

  progress_bar = tqdm(range(EPOCHS), desc="Training Progress", position=0, leave=True)

  for epoch in progress_bar:
      train_loss = train_model()
      wandb.log({"train_loss": train_loss})

      if (epoch + 1) % 1 == 0:
          dev_loss_value = evaluate_model(dev_loader)
          saver.save_checkpoint(dev_loss_value)
      else:
          dev_loss_value = None


      # Log the current learning rate
      current_lr = optimizer.param_groups[0]['lr']

      progress_desc = f"Epoch {epoch + 1}/{EPOCHS} | Train Loss: {train_loss:.4f} | LR: {current_lr:.12f}"
      if dev_loss_value is not None:
          progress_desc += f" | Dev loss: {dev_loss_value:.4f}"
      progress_bar.set_description(progress_desc)

  run.finish()



Looking in indexes: https://download.pytorch.org/whl/cu118
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjmeribe[0m ([33mstanford-curis-jmeribe[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Hyperparameter search:   0%|          | 0/1 [00:00<?, ?it/s]

W0605 18:03:50.091000 2369 torch/_dynamo/variables/builtin.py:783] [0/0] incorrect arg count <bound method BuiltinVariable.call_next of BuiltinVariable(next)> too many positional arguments and no constant handler
W0605 18:03:56.041000 2369 torch/_dynamo/variables/builtin.py:783] [1/0] incorrect arg count <bound method BuiltinVariable.call_next of BuiltinVariable(next)> too many positional arguments and no constant handler
W0605 18:03:57.329000 2369 torch/_inductor/utils.py:1137] [2/0] Not enough SMs to use max_autotune_gemm mode
W0605 18:04:00.366000 2369 torch/_dynamo/variables/builtin.py:783] [0/1] incorrect arg count <bound method BuiltinVariable.call_next of BuiltinVariable(next)> too many positional arguments and no constant handler
W0605 18:04:01.692000 2369 torch/_dynamo/variables/builtin.py:783] [1/1] incorrect arg count <bound method BuiltinVariable.call_next of BuiltinVariable(next)> too many positional arguments and no constant handler
W0605 18:04:04.605000 2369 torch/_dynam

Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]

W0605 18:04:49.143000 2369 torch/_dynamo/variables/builtin.py:783] [0/2] incorrect arg count <bound method BuiltinVariable.call_next of BuiltinVariable(next)> too many positional arguments and no constant handler
W0605 18:04:55.791000 2369 torch/_dynamo/variables/builtin.py:783] [1/2] incorrect arg count <bound method BuiltinVariable.call_next of BuiltinVariable(next)> too many positional arguments and no constant handler


KeyboardInterrupt: 