In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install x-transformers
!pip install matplotlib
!pip install einops
!pip install wandb

from google.colab import drive
drive.mount('/content/drive')

!cd /content/drive/MyDrive/realistic-imu/src

import wandb
wandb.login()

import csv
import matplotlib.pyplot as plt
import os
import sys
import wandb
import pickle
import re

# Add the source directory to the system path
sys.path.append('/content/drive/MyDrive/realistic-imu/src')

from trase_dataset import TraseDataset
from trase import Trase, TraseLoss


if 'ipykernel' in sys.modules:
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm


import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from torch.profiler import profile, record_function, ProfilerActivity


torch.cuda.empty_cache()

torch.set_float32_matmul_precision('high')

device = torch.device("cuda")
EPOCHS = 500
LEARNING_RATE = 1e-5
WEIGHT_DECAY = 1e-3
data_path = "/content/drive/MyDrive/realistic-imu/data/realistic-imu-dataset/"
base_path = "/content/drive/MyDrive/realistic-imu/data/realistic-imu-dataset/models"
D_MODEL = 1024
INPUT_EMBEDDING_DIM = 408
NUM_ENCODERS = 1
FEED_FORWARD_DIM = 2048
DROPOUT = 0.1
HEADS = 8
TOTAL_VAR_WEIGHT = 1e-2


run = wandb.init(
    project="generating-imu-data-two",
    config={
        "learning_rate": LEARNING_RATE,
        "epochs": EPOCHS,
        "weight_decay": WEIGHT_DECAY,
        "d_model": D_MODEL,
        "input_embedding_dim": INPUT_EMBEDDING_DIM,
        "num_encoders": NUM_ENCODERS,
        "feed_forward_dim": FEED_FORWARD_DIM,
        "dropout": DROPOUT,
        "heads": HEADS,
        "total_var_weight": TOTAL_VAR_WEIGHT
    }
)

In [None]:
model = Trase(d_model=D_MODEL,
              inp_emb_dim=INPUT_EMBEDDING_DIM,
              device=device,
              num_encoders=NUM_ENCODERS,
              dim_feed_forward=FEED_FORWARD_DIM,
              dropout=DROPOUT,
              heads=HEADS).to(device)

In [None]:
# model = torch.compile(model)


train_path = os.path.join(data_path, "train.pkl")
dev_path = os.path.join(data_path, "dev.pkl")
test_path = os.path.join(data_path, "test.pkl")

train_dataset = TraseDataset(train_path)
dev_dataset = TraseDataset(dev_path)
# test_dataset = TraseDataset(test_path)

identity_collate = lambda batch: batch


train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=identity_collate)
dev_loader = DataLoader(dev_dataset, batch_size=1, shuffle=False, collate_fn=identity_collate)
# test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

criterion = TraseLoss(total_var_weight=TOTAL_VAR_WEIGHT)


model.to(device)

In [None]:
class CheckpointSaver:
    def __init__(self, model, initial_best_loss=float("inf")):
        self.best_dev_loss = initial_best_loss
        self.model = model

    def save_checkpoint(self, dev_loss: float):
        # dev_loss is now guaranteed to be a float
        if dev_loss <= self.best_dev_loss:
            os.makedirs(f"weights/models-{wandb.run.name}", exist_ok=True)
            torch.save(self.model.state_dict(),
                       f"weights/models-{wandb.run.name}/best.pt")
            wandb.save(f"weights/models-{wandb.run.name}/best.pt")
            self.best_dev_loss = dev_loss
            wandb.log({"best_dev_loss": dev_loss})

        wandb.log({"dev_loss": dev_loss})

saver = CheckpointSaver(model)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)

# 1) create a GradScaler for mixed precision
scaler = GradScaler()

def train_model():
    curr_loss = 0.0
    model.train()

    for data in train_loader:
        data = data[0]
        mocap_data = data["inputs"].to(device)
        real_acc = data["accelerations_output"].to(device)
        real_angular_vel = (data["angular_velocities_output"].to(device)
                            if data["angular_velocities_output"] is not None else None)
        mask = data["output_mask"].T.to(device)
        weights = data["weights"].T.repeat_interleave(3, dim=0).to(device)

        optimizer.zero_grad()

        # 2) forward + loss inside autocast for mixed precision
        with autocast("cuda"):
            kinematics, acc_output, acc_std, gyro_output, gyro_std = model(mocap_data)
            loss = criterion(
                kinematics=kinematics * mask * weights,
                acc_mean=acc_output * mask * weights,
                acc_std=acc_std * weights,
                real_acc=real_acc * mask * weights,
                gyro_mean=gyro_output,
                gyro_std=gyro_std,
                real_gyro=real_angular_vel,
                include_gyro=(real_angular_vel is not None)
            )

        # 3) scale, backward, step
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()

        curr_loss += loss.item()

    return curr_loss / len(train_loader)



def evaluate_model(data_loader):
    curr_loss = 0
    model.eval()

    with torch.no_grad():
      for data in data_loader:
        data = data[0]
        mocap_data = data["inputs"]
        real_acc = data["accelerations_output"]
        real_angular_vel = data["angular_velocities_output"] if data["angular_velocities_output"] is not None else None
        mask = data["output_mask"].T
        weights = data["weights"].T.repeat_interleave(3, dim=0)

        kinematics, acc_output, acc_std, gyro_output, gyro_std = model(mocap_data)

        loss = criterion(kinematics=kinematics * mask * weights,
                        acc_mean = acc_output * mask * weights,
                        acc_std = acc_std * weights,
                        real_acc = real_acc * mask * weights,
                        gyro_mean = gyro_output,
                        gyro_std = gyro_std,
                        real_gyro = real_angular_vel,
                        include_gyro = real_angular_vel is not None)


        curr_loss += loss.item()

    return curr_loss / len(data_loader)

In [None]:
dev_loss_value = evaluate_model(dev_loader)

wandb.log({"dev_loss": dev_loss_value})
saver.save_checkpoint(dev_loss_value)

wandb.log({"train_loss": evaluate_model(train_loader)})

progress_bar = tqdm(range(EPOCHS), desc="Training Progress", position=0, leave=True)

for epoch in progress_bar:
    train_loss = train_model()
    wandb.log({"train_loss": train_loss})

    if (epoch + 1) % 1 == 0:
        dev_loss_value = evaluate_model(dev_loader)
        saver.save_checkpoint(dev_loss_value)
    else:
        dev_loss_value = None


    # Log the current learning rate
    current_lr = optimizer.param_groups[0]['lr']

    progress_desc = f"Epoch {epoch + 1}/{EPOCHS} | Train Loss: {train_loss:.4f} | LR: {current_lr:.12f}"
    if dev_loss_value is not None:
        progress_desc += f" | Dev loss: {dev_loss_value:.4f}"
    progress_bar.set_description(progress_desc)

In [None]:
os.makedirs(base_path, exist_ok=True)
model_files = [f for f in os.listdir(base_path) if re.match(r"model_\d+\.pkl", f)]

if model is not None:
    if model_files:
        max_num = max(int(re.search(r"model_(\d+)\.pkl", f).group(1)) for f in model_files)
    else:
        max_num = 0
    new_model_name = f"model_{max_num + 1}.pkl"
    save_path = os.path.join(base_path, new_model_name)
    with open(save_path, 'wb') as file:
        pickle.dump(model, file)
    print(f"Model saved to: {save_path}")
else:
    if not model_files:
        raise FileNotFoundError("No model files found in the directory.")
    latest_model_file = max(model_files, key=lambda f: int(re.search(r"model_(\d+)\.pkl", f).group(1)))
    load_path = os.path.join(base_path, latest_model_file)
    with open(load_path, 'rb') as file:
        model = pickle.load(file)
    print(f"Loaded model from: {load_path}")