# Grasping Learning Exploratory Notebook

In [None]:
# Install necessary packages (if not already installed)
# !pip install torch torchvision torchaudio
# !pip install torch-geometric
# !pip install wandb

# Import necessary libraries
import os
import torch
import wandb
import numpy as np
import matplotlib.pyplot as plt
from datasets import GraspingDataset, create_data_loaders
from models import initialize_model
from losses import GraspingLoss
from train_grasping import train_model, validate_model  # Import training and validation functions
import torch.optim as optim
from filtering import apply_filter  # Import the new filtering module

# Set up notebook for inline plotting
%matplotlib inline

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration dictionary
cfg = {
    'seed': 42,
    'batch_size': 32,
    'validation_split': 0.2,
    'learning_rate': 0.001,
    'num_epochs': 10,
    'model_type': 'TCN',  # 'TCN', 'Transformer', 'STGCN', or 'RNNWithAttention'
    'num_channels': [64, 64, 64],
    'kernel_size': 2,
    'dropout': 0.2,
    'hidden_size': 128,
    'num_layers': 2,
    'experiment_name': 'Grasping_with_TCN',
    'filter_method': 'low_pass',  # Specify filter type ('ema', 'kalman', 'low_pass', 'savitzky_golay')
    'filter_params': {
        'cutoff_freq': 5.0,  # Parameters for the low pass filter
        'fs': 50.0,
        'order': 3
    }
    # Add more configuration parameters as needed
}

# Initialize WandB
wandb.init(project="Grasping_Project", entity="your_wandb_entity",
           name=cfg['experiment_name'], config=cfg)

## Data Loading and Preprocessing

In [None]:
# Initialize dataset with filtering
use_graph = cfg.get('model_type') == 'STGCN'
filter_method = cfg.get('filter_method')  # Retrieve the filter method from the config
filter_params = cfg.get('filter_params', {})  # Retrieve the filter parameters from the config

dataset = GraspingDataset(data_path='path_to_your_data', use_graph=use_graph, 
                          filter_method=filter_method, filter_params=filter_params)  # Replace with actual data path

input_dim = dataset.inputs.shape[2]
output_dim = dataset.labels.shape[2]

# Create data loaders using the function from datasets.py
batch_size = cfg.get('batch_size', 32)
validation_split = cfg.get('validation_split', 0.2)
train_loader, val_loader = create_data_loaders(dataset, batch_size, validation_split)


## Model Initialization

In [None]:
# Initialize model using the function from models.py
model_type = cfg.get('model_type', 'TCN')
model = initialize_model(model_type, input_dim, output_dim, cfg)
model.to(device)


## Training

In [None]:
# Train the model using train_model function from train_grasping.py
train_model(model, train_loader, val_loader, cfg, device)

## Testing

In [None]:
# Evaluate on test data (if you have a separate test set)
# For demonstration, we'll use the validation set as test set
loss_fn = GraspingLoss()
avg_test_loss = validate_model(model, val_loader, loss_fn, device, use_graph=use_graph)
print(f"Test Loss: {avg_test_loss:.4f}")

## Visualize Some Predictions

In [None]:
# Visualize some predictions
data_iter = iter(val_loader)
if use_graph:
    data_sample, target_sample, edge_index = next(data_iter)
else:
    data_sample, target_sample = next(data_iter)
data_sample = data_sample.to(device)
target_sample = target_sample.to(device)
model.eval()
with torch.no_grad():
    if use_graph:
        data_sample_prepared = data_sample.permute(0, 2, 1).unsqueeze(-1)
        output_sample = model(data_sample_prepared, edge_index[0].to(device))
    else:
        output_sample = model(data_sample)

# Convert to CPU for plotting
data_sample = data_sample.cpu()
target_sample = target_sample.cpu()
output_sample = output_sample.cpu()

# Plot the first sample's target vs. prediction for the first output dimension
plt.figure(figsize=(12, 6))
plt.plot(target_sample[0, :, 0], label='True')
plt.plot(output_sample[0, :, 0], label='Predicted')
plt.xlabel('Time Step')
plt.ylabel('Biomechanical Output (Dimension 0)')
plt.title('True vs. Predicted Output')
plt.legend()
plt.show()