In [20]:
import torch
import torch.nn as nn


class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(hidden_sizes)):
            if i == 0:
                layers.append(nn.Linear(input_size, hidden_sizes[i]))
            else:
                layers.append(nn.Linear(hidden_sizes[i-1], hidden_sizes[i]))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_sizes[-1], output_size))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)



class IK_Network(nn.Module):
    def __init__(self, input_size, s1_hudden_list, s2_hidden_list, middle_state_size, output_size,model_choice,second_network_path=None):
        super(IK_Network, self).__init__()
        # The first MLP maps inputs to middle states
        self.mlp1 = MLP(input_size, s1_hudden_list, middle_state_size)  # Example hidden sizes

        # The second MLP maps middle states to outputs
        # self.mlp2 = MLP(middle_state_size, s2_hidden_list, output_size)  # Example hidden sizes

        # self.resmlp = ResMLP(middle_state_size, s2_hidden_list, output_size)

        # Load the second network weights if provided
        if model_choice == "MLP":
            self.second_net = MLP(middle_state_size, s2_hidden_list, output_size)
            
        if model_choice == "PreMLP":
            self.second_net = MLP(middle_state_size, s2_hidden_list, output_size)
            self.second_net.load_state_dict(torch.load(second_network_path))
            # Freeze the second network
            for param in self.second_net.parameters():
                param.requires_grad = False
        
        if model_choice == "ResMLP":
            self.second_net = ResMLP(middle_state_size, s2_hidden_list, output_size)
            self.second_net.load_state_dict(torch.load(second_network_path))
            for param in self.second_net.parameters():
                param.requires_grad = False
        
        if model_choice == "Jacket-MLP":
            self.second_net = Jacket_MLP(middle_state_size, s2_hidden_list, output_size)
            self.second_net.load_state_dict(torch.load(second_network_path))
            # Freeze the second network
            for param in self.second_net.parameters():
                param.requires_grad = False


    def forward(self, x):
        middle_state = self.mlp1(x)
        output = self.second_net(middle_state)
        return output, middle_state

In [3]:
from data_loading import *
from utils import *
import sys

In [24]:

robot_choice = "3DoF-3R"
mode_choice = "IKFK"
test_size = 0.2
batch_size = 32
IKFK_train_loader, IKFK_test_loader, pos_shape, joints_shape = data_loader(robot_choice, mode_choice, test_size, batch_size)
second_network_path = './model_weights/MLP_withoutJ_1.pth'

input_size = pos_shape
middle_state_size = joints_shape
s1_hidden_list = [64, 64, 64]
s2_hidden_list = [64, 128]
second_model_choice = "PreMLP"
output_size = pos_shape

In [25]:
my_network = IK_Network(input_size, 
                        s1_hidden_list, 
                        s2_hidden_list, 
                        middle_state_size, 
                        output_size,
                        second_model_choice,
                        second_network_path)

In [26]:
learning_rate = 0.0001
num_epochs = 100
criterion = nn.MSELoss()
test_criterion = nn.L1Loss()
optimizer = optim.Adam(my_network.parameters(), lr=learning_rate)

In [27]:
def custom_loss(pose_pred, pose_true, middle_state, robot_choice, alpha=0.5, beta=0.5):
    mse_loss = F.mse_loss(pose_pred, pose_true)  # Compute the mean squared error
    # print(mse_loss)
    rec_pose = reconstruct_pose(middle_state, robot_choice)
    # print(rec_pose)
    reg_mse_loss = F.mse_loss(rec_pose, pose_true)  # L2 regularization term
    # print(reg_mse_loss)
    total_loss = alpha * mse_loss + beta * reg_mse_loss  # Combine with a regularization term
    # print(total_loss)
    # sys.exit()
    return total_loss

In [28]:
# Training loop
for epoch in range(num_epochs):
    for batch in IKFK_train_loader:
        inputs = batch['data']
        labels = batch['targets'].squeeze()
        
        # Forward pass
        outputs, middle_state = my_network(inputs)
        rec_pose = reconstruct_pose(middle_state, robot_choice)
#         reg_mse_loss = F.mse_loss(rec_pose, labels)
        
        DH_loss = criterion(outputs,rec_pose)
        pose_loss = criterion(outputs,labels)
        
        total_loss = DH_loss + pose_loss
        
#         loss = custom_loss(outputs, labels, middle_state, robot_choice, alpha=1, beta=0.5)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
#         optimizer.zero_grad()
#         DH_loss.backward()
#         optimizer.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss.item():.4f}')

Epoch [1/100], Loss: 0.5828
Epoch [2/100], Loss: 0.2131
Epoch [3/100], Loss: 0.0705
Epoch [4/100], Loss: 0.0384
Epoch [5/100], Loss: 0.0265
Epoch [6/100], Loss: 0.0187
Epoch [7/100], Loss: 0.0138
Epoch [8/100], Loss: 0.0110
Epoch [9/100], Loss: 0.0091
Epoch [10/100], Loss: 0.0079
Epoch [11/100], Loss: 0.0072
Epoch [12/100], Loss: 0.0066
Epoch [13/100], Loss: 0.0061
Epoch [14/100], Loss: 0.0056
Epoch [15/100], Loss: 0.0053
Epoch [16/100], Loss: 0.0049
Epoch [17/100], Loss: 0.0046
Epoch [18/100], Loss: 0.0043
Epoch [19/100], Loss: 0.0041
Epoch [20/100], Loss: 0.0039
Epoch [21/100], Loss: 0.0037
Epoch [22/100], Loss: 0.0035
Epoch [23/100], Loss: 0.0033
Epoch [24/100], Loss: 0.0032
Epoch [25/100], Loss: 0.0030
Epoch [26/100], Loss: 0.0029
Epoch [27/100], Loss: 0.0027
Epoch [28/100], Loss: 0.0025
Epoch [29/100], Loss: 0.0024
Epoch [30/100], Loss: 0.0023
Epoch [31/100], Loss: 0.0022
Epoch [32/100], Loss: 0.0021
Epoch [33/100], Loss: 0.0021
Epoch [34/100], Loss: 0.0020
Epoch [35/100], Loss: 0

In [29]:

# Testing the model
my_network.eval()
correct = 0
total = 0
mean_loss_dh = []
mean_loss_n = []
with torch.no_grad():
    for batch in IKFK_test_loader:
        inputs = batch['data']
        labels = batch['targets'].squeeze()
        outputs, middle_state = my_network(inputs)
#         _, predicted = torch.max(outputs.data, 1)
        rec_pose = reconstruct_pose(middle_state, robot_choice)
        # print(labels.shape)
        # print(rec_pose.shape)
        loss_dh = test_criterion(rec_pose, labels)
        loss_n = test_criterion(outputs, labels)
        mean_loss_dh.append(loss_dh.item())
        mean_loss_n.append(loss_n.item())
#         correct += (predicted == labels).sum().item()
mean_loss_n = np.mean(np.array(mean_loss_n))
mean_loss_dh = np.mean(np.array(mean_loss_dh))
print(f'DH Test Error: {mean_loss_dh}')
print(f'Network Test Error: {mean_loss_n}')

DH Test Error: 0.02630066125106717
Network Test Error: 0.015906292782534682
