In [1]:
import sys
sys.path.insert(0, '../')

import torch
import torch.nn as nn
from functions import get_loader, plot_histories, plot_history

from ae_functions import get_latent_dataset
from ae_models import simpleCAE

from mlp_models import simpleMLP
from mlp_functions import train_mlp, validate_mlp

In [2]:
DEVICE = ""
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("CUDA is available. Using GPU...")
else:
    DEVICE = torch.device("cpu")
    print("CUDA is not available. Using CPU...")

BASE_PATH = 'C:/Users/Frank/OneDrive/Bureaublad/ARC/deep-multimodal-learning/fusion'

# Define the tool names and actions
TOOL_NAMES = ['hook', 'ruler', 'spatula', 'sshot']
ACTIONS = ['left_to_right', 'pull', 'push', 'right_to_left']

# All available object names
OBJECTS = ['0_woodenCube', '1_pearToy', '2_yogurtYellowbottle', '3_cowToy', '4_tennisBallYellowGreen',
            '5_blackCoinbag', '6_lemonSodaCan', '7_peperoneGreenToy', '8_boxEgg','9_pumpkinToy',
            '10_tomatoCan', '11_boxMilk', '12_containerNuts', '13_cornCob', '14_yellowFruitToy',
            '15_bottleNailPolisher', '16_boxRealSense', '17_clampOrange', '18_greenRectangleToy', '19_ketchupToy']

sensor_color = "color"
sensor_left = "icub_left"
sensor_right = "icub_right"
sensor_depth = "depthcolormap"

BATCH_SIZE = 8
NUM_EPOCHS = 5
LR_RATE = 1e-3

train_loader = get_loader(BASE_PATH, OBJECTS, TOOL_NAMES, ACTIONS, sensor_color, "training", batch_size=BATCH_SIZE)
val_loader = get_loader(BASE_PATH, OBJECTS, TOOL_NAMES, ACTIONS, sensor_color, "validation", batch_size=BATCH_SIZE)
test_loader = get_loader(BASE_PATH, OBJECTS, TOOL_NAMES, ACTIONS, sensor_color, "testing", batch_size=BATCH_SIZE)

CUDA is available. Using GPU...


In [4]:
model_path = "C:/Users/Frank/OneDrive/Bureaublad/ARC/deep-multimodal-learning/weights_ae/"
weight_name = "simple/simple_cae_ne5_b8_color.pth"
trained_cae = simpleCAE().to(DEVICE)
trained_cae.load_state_dict(torch.load(model_path+weight_name))

# Extract features from the train and validation sets
train_dataset = get_latent_dataset(trained_cae, test_loader, label=1, add_noise=False, is_depth=False, device=DEVICE)
val_dataset = get_latent_dataset(trained_cae, val_loader, label=1, add_noise=False, is_depth=False, device=DEVICE)

input_dim = train_dataset[:][0].size(1)
output_dim = 4 

# Create DataLoaders for the extracted features
mlp_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
mlp_val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Initialize
mlp = simpleMLP(input_dim, output_dim).to(DEVICE)
mlp_optimizer = torch.optim.Adam(mlp.parameters(), lr=LR_RATE)
mlp_lossfunction = nn.CrossEntropyLoss() 

In [5]:
trained_mlp = train_mlp(mlp, mlp_lossfunction, mlp_optimizer, mlp_train_loader, NUM_EPOCHS, 1, DEVICE)
validate_mlp(trained_mlp, mlp_lossfunction, mlp_val_loader, 1, DEVICE)

Epoch [1/5], Train Loss: 5.7222
Training Accuracy: 0.3531
Training Classification Report:
              precision    recall  f1-score   support

           0       0.31      0.33      0.32       160
           1       0.39      0.38      0.38       160
           2       0.35      0.35      0.35       160
           3       0.36      0.36      0.36       160

    accuracy                           0.35       640
   macro avg       0.35      0.35      0.35       640
weighted avg       0.35      0.35      0.35       640

Epoch [2/5], Train Loss: 0.9226
Training Accuracy: 0.6484
Training Classification Report:
              precision    recall  f1-score   support

           0       0.72      0.72      0.72       160
           1       0.71      0.69      0.70       160
           2       0.58      0.57      0.58       160
           3       0.59      0.61      0.60       160

    accuracy                           0.65       640
   macro avg       0.65      0.65      0.65       640
weigh

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
