In [1]:
# Imports
import scipy.io
import numpy as np
import torch
import torch.nn as nn

In [2]:
# Load the .mat file
data = scipy.io.loadmat('data/condsForSimJ2moMuscles.mat')

# Extract condsForSim struct
conds_for_sim = data['condsForSim']

In [3]:
# Initialize lists to store data for all conditions
go_envelope_all = []
plan_all = []
muscle_all = []

# Get the number of conditions (rows) and delay durations (columns)
num_conditions, num_delays = conds_for_sim.shape

# Loop through each condition and extract data
for i in range(num_conditions):  # 27 conditions
    go_envelope_condition = []
    plan_condition = []
    muscle_condition = []

    for j in range(num_delays):  # 8 delay durations
        condition = conds_for_sim[i, j]

        go_envelope = condition['goEnvelope']
        plan = condition['plan']
        muscle = condition['muscle']

        go_envelope_condition.append(go_envelope)
        plan_condition.append(plan)
        muscle_condition.append(muscle)

    # Stack data for each condition
    go_envelope_all.append(torch.tensor(go_envelope_condition, dtype=torch.float32))
    plan_all.append(torch.tensor(plan_condition, dtype=torch.float32))
    muscle_all.append(torch.tensor(muscle_condition, dtype=torch.float32))

# Stack data for all conditions
go_envelope_tensor = torch.stack(go_envelope_all)
plan_tensor = torch.stack(plan_all)
muscle_tensor = torch.stack(muscle_all)

# Reshape to merge the first two dimensions
go_envelope_tensor = go_envelope_tensor.reshape(-1, *go_envelope_tensor.shape[2:])
plan_tensor = plan_tensor.reshape(-1, *plan_tensor.shape[2:])
muscle_tensor = muscle_tensor.reshape(-1, *muscle_tensor.shape[2:])

# Print shapes
print(f"Go Envelope Tensor Shape: {go_envelope_tensor.shape}")
print(f"Plan Tensor Shape: {plan_tensor.shape}")
print(f"Muscle Tensor Shape: {muscle_tensor.shape}")

  go_envelope_all.append(torch.tensor(go_envelope_condition, dtype=torch.float32))


Go Envelope Tensor Shape: torch.Size([216, 296, 1])
Plan Tensor Shape: torch.Size([216, 296, 15])
Muscle Tensor Shape: torch.Size([216, 296, 8])


In [4]:
# Normalization and standardization Functions
def normalize(tensor):
    # Scale data to the range [0, 1]
    min_val = tensor.min(dim=0, keepdim=True)[0]
    max_val = tensor.max(dim=0, keepdim=True)[0]
    normalized_tensor = (tensor - min_val) / (max_val - min_val)
    return normalized_tensor

def standardize(tensor):
    # Standardize data to have mean=0 and std=1
    mean = tensor.mean(dim=0, keepdim=True)
    std = tensor.std(dim=0, keepdim=True)
    standardized_tensor = (tensor - mean) / std
    return standardized_tensor

# Apply normalization and standardization
normalized_go_envelope = normalize(go_envelope_tensor)
standardized_go_envelope = standardize(go_envelope_tensor)

normalized_plan = normalize(plan_tensor)
standardized_plan = standardize(plan_tensor)

normalized_muscle = normalize(muscle_tensor)
standardized_muscle = standardize(muscle_tensor)

# Print shapes and some sample data for verification
print(f"Normalized Go Envelope Tensor Shape: {normalized_go_envelope.shape}")
print(f"Standardized Go Envelope Tensor Shape: {standardized_go_envelope.shape}")

print(f"Normalized Plan Tensor Shape: {normalized_plan.shape}")
print(f"Standardized Plan Tensor Shape: {standardized_plan.shape}")

print(f"Normalized Muscle Tensor Shape: {normalized_muscle.shape}")
print(f"Standardized Muscle Tensor Shape: {standardized_muscle.shape}")

# Example to print some values
print("Sample normalized go envelope values:", normalized_go_envelope[0])
print("Sample standardized go envelope values:", standardized_go_envelope[0])


Normalized Go Envelope Tensor Shape: torch.Size([216, 296, 1])
Standardized Go Envelope Tensor Shape: torch.Size([216, 296, 1])
Normalized Plan Tensor Shape: torch.Size([216, 296, 15])
Standardized Plan Tensor Shape: torch.Size([216, 296, 15])
Normalized Muscle Tensor Shape: torch.Size([216, 296, 8])
Standardized Muscle Tensor Shape: torch.Size([216, 296, 8])
Sample normalized go envelope values: tensor([[nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],
        [nan],


In [5]:
# Define a RNN model

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out)
        return out

# Assuming the sizes from data
input_size = 16 # Calculated based on input shapes
hidden_size = 64 
output_size = 8 # Based on the output shape

model = SimpleRNN(input_size, hidden_size, output_size)

In [6]:
# Adjust the shape of go_envelope_tensor
go_envelope_tensor_adjusted = go_envelope_tensor.squeeze(-1)  # Removes the last dimension

# Check dimensions after squeezing
if go_envelope_tensor_adjusted.dim() == plan_tensor.dim() - 1:
    # Add an extra dimension to go_envelope_tensor_adjusted to match plan_tensor
    go_envelope_tensor_adjusted = go_envelope_tensor_adjusted.unsqueeze(-1)

    # Now concatenate along the last dimension
    input_tensor = torch.cat((go_envelope_tensor_adjusted, plan_tensor), dim=-1)
else:
    raise RuntimeError("Dimension mismatch after adjustment")

In [7]:
# Split data into training and testing sets
train_size = int(0.8 * input_tensor.size(0))
train_input = input_tensor[:train_size]
test_input = input_tensor[train_size:]
train_target = muscle_tensor[:train_size]
test_target = muscle_tensor[train_size:]

# Verify the sizes
print(f"Train Input Size: {train_input.size()}")
print(f"Test Input Size: {test_input.size()}")
print(f"Train Target Size: {train_target.size()}")
print(f"Test Target Size: {test_target.size()}")

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Train Input Size: torch.Size([172, 296, 16])
Test Input Size: torch.Size([44, 296, 16])
Train Target Size: torch.Size([172, 296, 8])
Test Target Size: torch.Size([44, 296, 8])


In [8]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 500

# Training loop
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    output = model(train_input)
    loss = criterion(output, train_target)
    loss.backward()
    optimizer.step()

    # Evaluation on test data
    model.eval()
    with torch.no_grad():
        test_output = model(test_input)
        test_loss = criterion(test_output, test_target)

    # Print loss every few epochs for monitoring
    if epoch % 10 == 0:
        print(f'Epoch [{epoch}/{num_epochs}], Loss: {loss.item()}, Test Loss: {test_loss.item()}')

# Save the model
torch.save(model.state_dict(), 'simple_model_checkpoint.pth')


Epoch [0/500], Loss: 0.07742229849100113, Test Loss: 0.0727684423327446
Epoch [10/500], Loss: 0.021188165992498398, Test Loss: 0.019497007131576538
Epoch [20/500], Loss: 0.019121600314974785, Test Loss: 0.018268633633852005
Epoch [30/500], Loss: 0.017821267247200012, Test Loss: 0.017074184492230415
Epoch [40/500], Loss: 0.017489111050963402, Test Loss: 0.01645728014409542
Epoch [50/500], Loss: 0.017208021134138107, Test Loss: 0.016325168311595917
Epoch [60/500], Loss: 0.01704731211066246, Test Loss: 0.015920618548989296
Epoch [70/500], Loss: 0.016925005242228508, Test Loss: 0.01593983732163906
Epoch [80/500], Loss: 0.016836464405059814, Test Loss: 0.015746857970952988
Epoch [90/500], Loss: 0.016749737784266472, Test Loss: 0.01566593162715435
Epoch [100/500], Loss: 0.016630981117486954, Test Loss: 0.01548491045832634
Epoch [110/500], Loss: 0.016384156420826912, Test Loss: 0.015110716223716736
Epoch [120/500], Loss: 0.014861536212265491, Test Loss: 0.012818137183785439
Epoch [130/500], L