In [1]:
import sys
import os

# Define the path to the modules directory
root_path = os.path.abspath("..")

# Add the path to the sys.path if it's not already there
if root_path not in sys.path:
    sys.path.append(root_path)

# Import modules from the 'modules' package
from modules.mamba import MambaModule
from modules.rnn import RNNModule
from modules.transformer import TransformerModule
from dataset import FrankaDataset

# Other imports
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import glob
from const import *
import yaml
import shutil
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def copy_and_modify_yaml(input_path, output_dir, work_dir):
    """
    Copies a YAML file, modifies the resnet_checkpoint path, and returns the new path.

    Args:
    - input_path (str): The path to the input YAML file.
    - output_dir (str): The directory where the new YAML file will be saved.
    - work_dir (str): The new directory path to replace the resnet_checkpoint path.

    Returns:
    - str: The path to the new YAML file.
    """
    # Ensure output directory exists
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Read the input YAML file
    with open(input_path, "r") as file:
        yaml_data = yaml.safe_load(file)

    # Modify the resnet_checkpoint path
    if "resnet_checkpoint" in yaml_data:
        old_checkpoint_path = yaml_data["resnet_checkpoint"]
        filename = Path(old_checkpoint_path).name
        new_checkpoint_path = Path(work_dir) / filename
        yaml_data["resnet_checkpoint"] = str(new_checkpoint_path)

    # Create the new file path
    input_filename = Path(input_path).name
    new_filename = f"temp_{input_filename}"
    new_file_path = Path(output_dir) / new_filename

    # Write the modified data to the new YAML file
    with open(new_file_path, "w") as file:
        yaml.safe_dump(yaml_data, file)

    return str(new_file_path)

In [3]:
ds = FrankaDataset(data_dir=os.path.join(WORK_DIR, "data-w-camera"),
                      episode_length=200,
                      limited_gpu_memory=True,
                      stride=1,
                      prediction_distance=1,
                      window_size=10
                      )
dl = DataLoader(ds, batch_size=1, shuffle=False)

In [4]:
item = next(iter(dl))

In [5]:
item['sensor_data'].shape

torch.Size([1, 11, 35])

In [9]:
model = RNNModule.load_from_checkpoint(
    checkpoint_path=RNN_LARGE_CHECKPOINT,
    hparams_file=copy_and_modify_yaml(RNN_LARGE_PARAMS, WORK_DIR, WORK_DIR),
)

/mnt/BigHD_1/loucas/miniconda3/envs/diss/lib/python3.11/site-packages/lightning/pytorch/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.3.3, which is newer than your current Lightning version: v2.3.0


In [7]:
import matplotlib.pyplot as plt

# Initialize lists to store target and output values
targets_sensor_6 = []
targets_sensor_7 = []
outputs_sensor_6 = []
outputs_sensor_7 = []

# Assuming dl is your dataloader and model is your trained model
for i, batch in enumerate(dl):
    if i >= 200:
        break
    target = batch["sensor_data"][0, -1:, [6, 7]]
    output = model.predict(batch)
    print(output, target)

    # Append the target values to the lists
    targets_sensor_6.append(target[0, 0].item())
    targets_sensor_7.append(target[0, 1].item())

    # Extracting values from the output tensor
    outputs_sensor_6.append(output[0, 0, 0].item())
    outputs_sensor_7.append(output[0, 0, 1].item())

# Plotting
plt.figure(figsize=(14, 7))

# Plot for Sensor 6
plt.subplot(2, 1, 1)
plt.plot(targets_sensor_6, label="Target Sensor 6")
plt.plot(outputs_sensor_6, label="Output Sensor 6")
plt.title("Sensor 6 - Target vs Output")
plt.xlabel("Batch")
plt.ylabel("Value")
plt.legend()

# Plot for Sensor 7
plt.subplot(2, 1, 2)
plt.plot(targets_sensor_7, label="Target Sensor 7")
plt.plot(outputs_sensor_7, label="Output Sensor 7")
plt.title("Sensor 7 - Target vs Output")
plt.xlabel("Batch")
plt.ylabel("Value")
plt.legend()

plt.tight_layout()
plt.show()

tensor([[[ 0.0113, -0.0148]]], device='cuda:0') tensor([[ 0.0018, -0.0248]], device='cuda:0')
tensor([[[ 0.0113, -0.0148]]], device='cuda:0') tensor([[-0.0010, -0.0134]], device='cuda:0')
tensor([[[ 0.0113, -0.0148]]], device='cuda:0') tensor([[-0.0095, -0.0090]], device='cuda:0')
tensor([[[ 0.0113, -0.0148]]], device='cuda:0') tensor([[-0.0096, -0.0233]], device='cuda:0')
tensor([[[ 0.0113, -0.0148]]], device='cuda:0') tensor([[-0.0100, -0.0373]], device='cuda:0')
tensor([[[ 0.0113, -0.0148]]], device='cuda:0') tensor([[-0.0098, -0.0489]], device='cuda:0')
tensor([[[ 0.0113, -0.0148]]], device='cuda:0') tensor([[-0.0037, -0.1625]], device='cuda:0')
tensor([[[ 0.0113, -0.0148]]], device='cuda:0') tensor([[-0.0006,  0.0122]], device='cuda:0')
tensor([[[ 0.0113, -0.0148]]], device='cuda:0') tensor([[0.4520, 0.0578]], device='cuda:0')
tensor([[[ 0.0113, -0.0148]]], device='cuda:0') tensor([[ 0.1309, -0.0677]], device='cuda:0')
tensor([[[ 0.0113, -0.0148]]], device='cuda:0') tensor([[ 0.10

KeyboardInterrupt: 