In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
import copy
import sys
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.model_selection import train_test_split
from embedding_layers import SkeletalInputEmbedding
from encoder_layers import TransformerEncoder
from decoder_layers import TransformerDecoder
import TF_helper_functions as hf
import pickle
import ipyvolume as ipv
import numpy as np
import ipywidgets as widgets
from IPython.display import display


In [7]:
datapath='/home/maleen/research_data/Transformers/datasets/training/'
# Base path and file information
base_name='24_07_25_training_norm'

weights_path='/home/maleen/research_data/Transformers/models/TF_tokenised/24_07_31_v1_best_model.pth'

filename=datapath+base_name+'.pkl'

def load_results_from_pickle(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)

def process_all_datasets(results, input_length=60, predict_length=60):
    all_X_pos, all_X_vel, all_X_acc = [], [], []
    all_Y_pos, all_Y_vel, all_Y_acc = [], [], []
    discarded_frames = {}

    for i in range(1, 7):  # Assuming you have 6 datasets
        dataset_key = f'dataset{i}'
        norm_pos = results[f'{dataset_key}_normpos']
        norm_vel = results[f'{dataset_key}_normvel']
        norm_acc = results[f'{dataset_key}_normacc']

        # Generate sequences for this dataset
        X_pos, X_vel, X_acc, Y_pos, Y_vel, Y_acc = hf.generate_sequences(norm_pos, norm_vel, norm_acc, input_length, predict_length)
        
        all_X_pos.append(X_pos)
        all_X_vel.append(X_vel)
        all_X_acc.append(X_acc)
        all_Y_pos.append(Y_pos)
        all_Y_vel.append(Y_vel)
        all_Y_acc.append(Y_acc)

        # Calculate discarded frames
        total_frames = norm_pos.shape[0]
        used_frames = X_pos.shape[0] + input_length + predict_length - 1
        discarded = total_frames - used_frames
        discarded_frames[dataset_key] = discarded

    # Combine sequences from all datasets
    combined_X_pos = np.concatenate(all_X_pos)
    combined_X_vel = np.concatenate(all_X_vel)
    combined_X_acc = np.concatenate(all_X_acc)
    combined_Y_pos = np.concatenate(all_Y_pos)
    combined_Y_vel = np.concatenate(all_Y_vel)
    combined_Y_acc = np.concatenate(all_Y_acc)

    return (combined_X_pos, combined_X_vel, combined_X_acc, 
            combined_Y_pos, combined_Y_vel, combined_Y_acc, 
            discarded_frames)




In [8]:
input_length = 30
predict_length = 20
datasetnum=6

# Load the results
results = load_results_from_pickle(filename)

medians_pos = results['combined_medians_pos']
iqrs_pos = results['combined_iqrs_pos']


# Process all datasets and get combined sequences
(combined_X_pos, combined_X_vel, combined_X_acc, 
 combined_Y_pos, combined_Y_vel, combined_Y_acc, 
 discarded_frames) = process_all_datasets(results, input_length, predict_length)

print("Combined sequences shapes:")
print(f"X_pos shape: {combined_X_pos.shape}")
print(f"X_vel shape: {combined_X_vel.shape}")
print(f"X_acc shape: {combined_X_acc.shape}")
print(f"Y_pos shape: {combined_Y_pos.shape}")
print(f"Y_vel shape: {combined_Y_vel.shape}")
print(f"Y_acc shape: {combined_Y_acc.shape}")

print("\nDiscarded frames per dataset:")
for dataset, frames in discarded_frames.items():
    print(f"{dataset}: {frames} frames")


Combined sequences shapes:
X_pos shape: (7559, 30, 6, 3)
X_vel shape: (7559, 30, 6, 3)
X_acc shape: (7559, 30, 6, 3)
Y_pos shape: (7559, 20, 6, 3)
Y_vel shape: (7559, 20, 6, 3)
Y_acc shape: (7559, 20, 6, 3)

Discarded frames per dataset:
dataset1: 0 frames
dataset2: 0 frames
dataset3: 0 frames
dataset4: 0 frames
dataset5: 0 frames
dataset6: 0 frames


In [9]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the saved model weights
checkpoint = torch.load(weights_path, map_location=device)
embed_dim = 128
num_heads = 8
num_layers = 6
num_joints = 6
dropout_rate = 0.1
autoregressiveloops=20
batch_size = 1
dof=3
input_dim = num_joints * dof

# Initialize the models with the same configuration as during training
embedding = SkeletalInputEmbedding(input_dim).to(device)
#t_embedding = TargetEmbedding(num_joints=num_joints, dof=3, embed_dim=embed_dim,device=device).to(device)
encoder = TransformerEncoder(embed_dim, num_heads, num_layers, dropout_rate).to(device)
decoder = TransformerDecoder(embed_dim, num_heads, num_layers, num_joints, dropout_rate).to(device)

# Load state dicts
embedding.load_state_dict(checkpoint['embedding_state_dict'])
#t_embedding.load_state_dict(checkpoint['t_embedding_state_dict'])
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])

embedding.eval()
#t_embedding.eval()
encoder.eval()
decoder.eval()

# X_pos, X_vel, X_acc, Y_pos, Y_vel, Y_acc = generate_sequences(norm_pos, norm_vel, norm_acc, input_length, predict_length)

# Convert to PyTorch tensors
X_pos_tensor = torch.tensor(combined_X_pos, dtype=torch.float32)
X_vel_tensor = torch.tensor(combined_X_vel, dtype=torch.float32)
X_acc_tensor = torch.tensor(combined_X_acc, dtype=torch.float32)


Y_pos_tensor = torch.tensor(combined_Y_pos, dtype=torch.float32)
Y_vel_tensor = torch.tensor(combined_Y_vel, dtype=torch.float32)
Y_acc_tensor = torch.tensor(combined_Y_acc, dtype=torch.float32)

# Create the DataLoader for inference data
dataset = TensorDataset(X_pos_tensor, X_vel_tensor, X_acc_tensor, Y_pos_tensor, Y_vel_tensor)
inference_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Prepare for autoregressive decoding
predicted_positions = []

criterion = hf.MaskedMSELoss()

# Perform inference across all batches
for batch in inference_loader:
    X_pos_batch, X_vel_batch, X_acc_batch, Y_pos_batch, Y_vel_tensor = [b.to(device) for b in batch]

    # Encoder pass
   
    inputembeddings = embedding(X_pos_batch, X_vel_batch)
    memory = encoder(inputembeddings, src_key_padding_mask=None)

    
    # Initialize the start token for decoding
    current_pos = X_pos_batch[:, -1:, :, :]
    current_vel= X_vel_batch[:, -1:, :, :]


    for i in range(autoregressiveloops):
        # Embed the current position
        Y_expected= Y_pos_batch[:,i:i+1,:,:]
    

        # # #Running whole model
        # if i > 0:
        #     X_mask_batch_ar = torch.cat([X_mask_batch[:, 1:, :], current_mask], dim=1)
        #     X_pos_batch_ar = torch.cat([X_pos_batch[:, 1:, :, :], current_pos], dim=1)
    
        #     src_key_padding_mask = ~X_mask_batch_ar.view(batch_size, input_length * num_joints)
        #     input_embeddings = embedding(X_pos_batch_ar, X_mask_batch_ar)
        #     memory = encoder(input_embeddings, src_key_padding_mask=src_key_padding_mask)
        # ##
        
        current_embeddings = embedding(current_pos, current_vel)
        
        # Decoder pass
        output = decoder(current_embeddings, memory, tgt_key_padding_mask=None, memory_key_padding_mask=None)
        print(output.shape)
    
        # Update current_pos for the next prediction

        old_pos= current_pos
        current_pos = output[:, :, :, :].detach()  # only take the last timestep

        #current_vel = (current_pos-old_pos)/0.1

        # velocity_error=Y_vel_tensor[:, i, :, :]-current_vel
        current_vel = Y_vel_tensor[:, i, :, :]
        #
        # print('velocity_error: ', velocity_error)
    
        predicted_positions.append(current_pos.squeeze().cpu().numpy())

        
    
        output = output.where(~torch.isnan(output), torch.zeros_like(output))
        # masked_output = output * Xmask
    
        Y_expected = Y_expected.where(~torch.isnan(Y_expected), torch.zeros_like(Y_expected))
        # masked_y_pos = Y_pos_batch * Ymask

        #print(Y_expected)
    
        # Compute loss
        
        loss = criterion(output, Y_expected)

        print(loss)

    break
# Convert the list of predicted positions to a more manageable form, e.g., a NumPy array
predicted_positions = np.array(predicted_positions)

print("Predicted Positions:", predicted_positions)


torch.Size([1, 1, 6, 3])
tensor(0.0339, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1, 6, 3])
tensor(0.0396, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1, 6, 3])
tensor(0.0135, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1, 6, 3])
tensor(0.0467, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1, 6, 3])
tensor(0.0280, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1, 6, 3])
tensor(0.0144, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1, 6, 3])
tensor(0.0545, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1, 6, 3])
tensor(0.0617, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1, 6, 3])
tensor(0.0530, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1, 6, 3])
tensor(0.0487, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1, 6, 3])
tensor(0.0139, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1, 6, 3])
tensor(0.0208, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([1, 1

In [10]:
updated_connections = [
        (0, 1), (1, 2),  # Right arm
        (3, 4), (4, 5),  # Left arm
        (2, 3),  # Connection between arms
    ]

# Animation function
def update_plot(pred):
    # First dataset processing
    data = hf.reverse_normalization(predicted_positions[pred], medians_pos, iqrs_pos)

    if len(data.shape) == 1:
        print("Data is a scalar or has unexpected shape.")
        return np.zeros((1, 3)), [], np.zeros((1, 3)), []
    else:
        valid_keypoints = ~np.isnan(data[:, :3]).any(axis=1)
        filtered_data = data[valid_keypoints]

        # Create mapping from old indices to new indices after NaN removal
        index_mapping = {old_index: new_index for new_index, old_index in enumerate(np.where(valid_keypoints)[0])}
        # Create new connections for the first dataset
        new_connections = [(index_mapping[start], index_mapping[end])
                           for start, end in updated_connections
                           if start in index_mapping and end in index_mapping]

    # Second dataset processing
    data_y = hf.reverse_normalization(Y_pos[0][pred], medians_pos, iqrs_pos)

    if len(data_y.shape) == 1:
        print("Data_y is a scalar or has unexpected shape.")
        return np.zeros((1, 3)), [], np.zeros((1, 3)), []
    else:
        valid_keypoints_y = ~np.isnan(data_y[:, :3]).any(axis=1)
        filtered_data_y = data_y[valid_keypoints_y]

        # Create mapping from old indices to new indices for the second dataset
        index_mapping_y = {old_index: new_index for new_index, old_index in enumerate(np.where(valid_keypoints_y)[0])}
        # Create new connections for the second dataset
        new_connections_y = [(index_mapping_y[start], index_mapping_y[end])
                             for start, end in updated_connections
                             if start in index_mapping_y and end in index_mapping_y]

    return filtered_data, new_connections, filtered_data_y, new_connections_y

# Plot configuration
ipv.figure()

# Initialize scatter and plot objects with some initial data
scatter = ipv.scatter([0], [0], [0], color='blue', marker='sphere', size=2)
scatter_y = ipv.scatter([0], [0], [0], color='green', marker='sphere', size=2)
lines = [ipv.plot([0, 0], [0, 0], [0, 0], color='red') for _ in range(len(updated_connections))]
lines_y = [ipv.plot([0, 0], [0, 0], [0, 0], color='lime') for _ in range(len(updated_connections))]

# Set axis limits
ipv.xlim(-1, 1)
ipv.ylim(-1, 1)
ipv.zlim(-1, 1)

def animate(pred):
    filtered_data, new_connections, filtered_data_y, new_connections_y = update_plot(pred)
    
    if len(filtered_data) == 0 or len(filtered_data_y) == 0:
        return
    
    # Update scatter data
    scatter.x = filtered_data[:, 0]
    scatter.y = filtered_data[:, 1]
    scatter.z = filtered_data[:, 2]
    
    scatter_y.x = filtered_data_y[:, 0]
    scatter_y.y = filtered_data_y[:, 1]
    scatter_y.z = filtered_data_y[:, 2]
    
    # Update lines data
    for line, (start, end) in zip(lines, new_connections):
        line.x = filtered_data[[start, end], 0]
        line.y = filtered_data[[start, end], 1]
        line.z = filtered_data[[start, end], 2]
    
    for line_y, (start, end) in zip(lines_y, new_connections_y):
        line_y.x = filtered_data_y[[start, end], 0]
        line_y.y = filtered_data_y[[start, end], 1]
        line_y.z = filtered_data_y[[start, end], 2]

# Define the range for pred
pred_range = range(len(predicted_positions))  # Adjust this range according to your data

# Create slider
slider = widgets.IntSlider(min=0, max=len(pred_range) - 1, step=1, description='Frame')

# Update plot when slider value changes
widgets.interactive(animate, pred=slider)

# Display the slider and plot
display(slider)
ipv.view(azimuth=0, elevation=-90)
ipv.show()


IntSlider(value=0, description='Frame', max=19)

Container(figure=Figure(box_center=[0.5, 0.5, 0.5], box_size=[1.0, 1.0, 1.0], camera=PerspectiveCamera(fov=45.â€¦