In [1]:
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.models import resnet18, ResNet18_Weights
from sklearn.model_selection import KFold
import math
from torchvision.models import resnet50, ResNet50_Weights



class Frame:
    def __init__(self, data_type, room_label, sequence, file_name, color_image_path, pose):
        self.data_type = data_type
        self.room_label = room_label
        self.sequence = sequence
        self.file_name = file_name
        self.color_image_path = color_image_path
        self.pose = pose


def parse_pose_file(pose_file_path):
    with open(pose_file_path, 'r') as file:
        pose = np.array([list(map(float, line.strip().split())) for line in file]).flatten()
    return pose


def create_frame_objects(data_path, room_name, data_type):
    frames = []
    for seq_folder in os.listdir(data_path):
        seq_path = os.path.join(data_path, seq_folder)
        if os.path.isdir(seq_path):
            print(f"Processing sequence: {seq_folder} in {room_name} ({data_type})")
            for frame_file in os.listdir(seq_path):
                if frame_file.endswith('.color.png'):
                    frame_name = frame_file.split('.')[0]
                    color_image_path = os.path.join(seq_path, f"{frame_name}.color.png")
                    pose_file_path = os.path.join(seq_path, f"{frame_name}.pose.txt")
                    if os.path.exists(color_image_path) and os.path.exists(pose_file_path):
                        pose = parse_pose_file(pose_file_path)
                        frame = Frame(data_type, room_name, seq_folder, frame_name, color_image_path, pose)
                        frames.append(frame)
    return frames


def create_data_structure(data_folder):
    local_train_data = []
    local_test_data = []
    room_names = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']
    for room_name in room_names:
        train_path = os.path.join(data_folder, room_name, 'train')
        test_path = os.path.join(data_folder, room_name, 'test')
        local_train_data.extend(create_frame_objects(train_path, room_name, 'train'))
        local_test_data.extend(create_frame_objects(test_path, room_name, 'test'))
    return local_train_data, local_test_data


def create_data_structure_for_each_scene(data_folder, room_name):
    train_path = os.path.join(data_folder, room_name, 'train')
    test_path = os.path.join(data_folder, room_name, 'test')
    local_train_data = create_frame_objects(train_path, room_name, 'train')
    local_test_data = create_frame_objects(test_path, room_name, 'test')
    return local_train_data, local_test_data


your_path_to_data_folder = 'data'
# train_data, test_data = create_data_structure(your_path_to_data_folder)


class CustomDataset(Dataset):
    def __init__(self, frames, transform=None):
        self.frames = frames
        self.transform = transform

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        frame = self.frames[idx]
        image = Image.open(frame.color_image_path).convert('RGB')
        pose_matrix = np.array(frame.pose, dtype=np.float32).reshape(4, 4)
        translation = pose_matrix[:3, 3]
        rotation = pose_matrix[:3, :3]

        if self.transform:
            image = self.transform(image)
        return image, torch.from_numpy(translation), torch.from_numpy(rotation.flatten())


transformations = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class PoseModel(nn.Module):
    def __init__(self):
        super(PoseModel, self).__init__()
        weights = ResNet50_Weights.DEFAULT
        self.backbone = resnet50(weights=weights)
        self.fc_translation = nn.Linear(self.backbone.fc.in_features, 3)
        self.fc_rotation = nn.Linear(self.backbone.fc.in_features, 9)
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        features = self.backbone(x)
        translation = self.fc_translation(features)
        rotation = self.fc_rotation(features)
        return translation, rotation


# pose_model = PoseModel().to(device)
# optimizer = optim.SGD(pose_model.parameters(), lr=0.001, momentum=0.9)  # try adam optimizer
criterion = nn.MSELoss()


def rotation_matrix_to_angle_axis(rotation_matrices):
    """Convert a batch of rotation matrices to angle-axis vectors."""
    # Calculate the trace of each 3x3 rotation matrix in the batch
    traces = torch.einsum('bii->b', rotation_matrices)  # Sum over the diagonal elements in each matrix in the batch
    cos_thetas = (traces - 1) / 2.0
    cos_thetas = torch.clamp(cos_thetas, -1, 1)  # Numerical errors might make cos(theta) slightly out of its range
    thetas = torch.acos(cos_thetas)  # Angles

    # Initialize angle-axis vectors
    angle_axes = torch.zeros_like(rotation_matrices[:, :, 0])

    # Compute sin(theta) for normalization
    sin_thetas = torch.sin(thetas)

    # Find indices where theta is not too small (to avoid division by zero)
    valid = sin_thetas > 1e-5

    # For valid indices where theta is not too small, calculate angle-axis vectors
    angle_axes[valid] = torch.stack([
        rotation_matrices[valid, 2, 1] - rotation_matrices[valid, 1, 2],
        rotation_matrices[valid, 0, 2] - rotation_matrices[valid, 2, 0],
        rotation_matrices[valid, 1, 0] - rotation_matrices[valid, 0, 1]
    ], dim=1) / (2 * sin_thetas[valid].unsqueeze(1)) * thetas[valid].unsqueeze(1)

    return angle_axes


def rotation_error(pred_rot, gt_rot):
    """Calculate the angular distance between two rotation matrices."""
    pred_rot_matrix = pred_rot.view(-1, 3, 3)
    gt_rot_matrix = gt_rot.view(-1, 3, 3)
    r_diff = torch.matmul(pred_rot_matrix, gt_rot_matrix.transpose(1, 2))  # Relative rotation
    angle_axis = rotation_matrix_to_angle_axis(r_diff)
    return torch.norm(angle_axis, dim=1)  # Returns the magnitude of the angle-axis vector


def calculate_translation_error(pred, target):
    return torch.norm(pred - target, dim=1).mean()


scenes = ['chess', 'fire']

for scene in scenes:
    pose_model = PoseModel().to(device)  # model re-initializes for each scene to train separately for each scene
    optimizer = optim.SGD(pose_model.parameters(), lr=0.001, momentum=0.9)

    train_data, test_data = create_data_structure_for_each_scene(your_path_to_data_folder, scene)

    train_dataset = CustomDataset(train_data, transform=transformations)
    test_dataset = CustomDataset(test_data, transform=transformations)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    best_fold = 0
    best_model_state = None

    n_splits = 5
    num_epochs = 15
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    indices = range(len(train_loader.dataset))  # warning is disregarded, since the code works correct

    for fold, (train_idx, val_idx) in enumerate(kfold.split(indices)):
        print("-" * 100)
        print(f"FOLD {fold} for {scene}")
        print("-" * 100)

        train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
        validation_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)

        current_train_loader = torch.utils.data.DataLoader(
            train_loader.dataset, batch_size=64, sampler=train_subsampler)
        current_validation_loader = torch.utils.data.DataLoader(
            train_loader.dataset, batch_size=64, sampler=validation_subsampler)

        best_loss = np.inf
        best_epoch = -1
        for epoch in tqdm(range(num_epochs), desc=f"Epochs Progress for {scene}"):
            pose_model.train()
            total_loss = 0.0
            total_translation_error = 0.0
            total_rotation_error = 0.0

            # Training loop
            for images, translations, rotations in tqdm(current_train_loader,
                                                        desc=f"Training Epoch {epoch + 1} for {scene}", leave=False):
                images = images.to(device)
                translations = translations.to(device)
                rotations = rotations.view(-1, 3, 3).to(device)

                optimizer.zero_grad()
                trans_pred, rot_pred = pose_model(images)
                rot_pred = rot_pred.view(-1, 3, 3)

                loss_translation = criterion(trans_pred, translations)
                loss_rotation = criterion(rot_pred.view(-1, 9), rotations.view(-1, 9))
                loss = loss_translation + loss_rotation
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                translation_error = calculate_translation_error(trans_pred, translations)
                total_translation_error += translation_error.item()
                rotation_error_batch = rotation_error(rot_pred, rotations).mean().item()
                total_rotation_error += rotation_error_batch

            # Validation loop
            pose_model.eval()
            validation_loss = 0.0
            with torch.no_grad():
                for images, translations, rotations in tqdm(current_validation_loader,
                                                            desc=f"Validating Epoch {epoch + 1}",
                                                            leave=False):
                    images = images.to(device)
                    translations = translations.to(device)
                    rotations = rotations.view(-1, 3, 3).to(device)

                    trans_pred, rot_pred = pose_model(images)
                    rot_pred = rot_pred.view(-1, 3, 3)

                    loss_translation = criterion(trans_pred, translations)
                    loss_rotation = criterion(rot_pred.view(-1, 9), rotations.view(-1, 9))
                    loss = loss_translation + loss_rotation
                    validation_loss += loss.item()

            validation_loss /= len(current_validation_loader)

            if validation_loss < best_loss:
                best_loss = validation_loss
                best_model_state = pose_model.state_dict()  # save model state not the model itself
                best_fold = fold
                print()
                print(f"New best model found for {scene} in fold {fold} with validation loss {best_loss:.4f}")

            average_loss = total_loss / len(train_loader)
            average_rotation_error = total_rotation_error / len(train_loader)
            average_translation_error = total_translation_error / len(train_loader)
            print(f"For {scene} in Epoch {epoch + 1}, Validation loss: {validation_loss:.4f}")
            print(f' For {scene} in Epoch {epoch + 1}, Average Loss: {average_loss:.4f}, Average Translation Error: '
                  f'{average_translation_error:.4f}, Average Rotation Error (radians): {average_rotation_error:.4f}')

        print("-" * 50)
        print(f"Training complete for {scene} in Fold {fold}.")
        print(f"(Best Epoch for Fold {fold}: {best_epoch + 1} with Validation Loss: {best_loss:.4f}")

    # after all folds save the best model path
    if best_model_state:
        torch.save(best_model_state, f'best_pose_model_{scene}.pth')
        print(f"Best model for {scene} from Fold {best_fold} saved with loss {best_loss:.4f}")

    # To use the best model
    pose_model.load_state_dict(torch.load(f'best_pose_model_{scene}.pth'))
    pose_model.eval()

    total_translation_error = 0.0
    total_rotation_error = 0.0
    count = 0

    # No gradient needed for test part
    with torch.no_grad():
        for images, translations, rotations in test_loader:
            images = images.to(device)
            translations = translations.to(device)
            rotations = rotations.view(-1, 3, 3).to(device)  # 3x3 rotation matrices

            trans_pred, rot_pred = pose_model(images)
            rot_pred = rot_pred.view(-1, 3, 3)

            # Calculate errors
            translation_error = calculate_translation_error(trans_pred, translations)
            rotation_error_batch = rotation_error(rot_pred, rotations).mean().item()

            total_translation_error += translation_error.item()
            total_rotation_error += rotation_error_batch
            count += 1

        # Calculate average errors
        average_translation_error = total_translation_error / count
        average_rotation_error = total_rotation_error / count
        average_rotation_error_in_degrees = average_rotation_error * (180 / math.pi)

        print(f"Performance of the best model for {scene} on the test data was from Fold {best_fold}:")
        print(f"Average Translation Error for {scene}: {average_translation_error:.4f} meters")
        print(f"Average Rotation Error (degree) for {scene}: {average_rotation_error_in_degrees:.4f}")
        print("-" * 50)


Processing sequence: seq-01 in chess (train)
Processing sequence: seq-02 in chess (train)
Processing sequence: seq-04 in chess (train)
Processing sequence: seq-06 in chess (train)
Processing sequence: seq-03 in chess (test)
Processing sequence: seq-05 in chess (test)
----------------------------------------------------------------------------------------------------
FOLD 0 for chess
----------------------------------------------------------------------------------------------------


Epochs Progress for chess:   7%|▋         | 1/15 [08:16<1:55:51, 496.54s/it]


New best model found for chess in fold 0 with validation loss 0.1614
For chess in Epoch 1, Validation loss: 0.1614
 For chess in Epoch 1, Average Loss: 0.1480, Average Translation Error: 0.3020, Average Rotation Error (radians): 0.1452


Epochs Progress for chess:  13%|█▎        | 2/15 [16:33<1:47:39, 496.92s/it]


New best model found for chess in fold 0 with validation loss 0.0904
For chess in Epoch 2, Validation loss: 0.0904
 For chess in Epoch 2, Average Loss: 0.0502, Average Translation Error: 0.1594, Average Rotation Error (radians): 0.1409


Epochs Progress for chess:  20%|██        | 3/15 [24:49<1:39:19, 496.62s/it]


New best model found for chess in fold 0 with validation loss 0.0622
For chess in Epoch 3, Validation loss: 0.0622
 For chess in Epoch 3, Average Loss: 0.0306, Average Translation Error: 0.1176, Average Rotation Error (radians): 0.1086


Epochs Progress for chess:  27%|██▋       | 4/15 [33:06<1:31:01, 496.46s/it]


New best model found for chess in fold 0 with validation loss 0.0488
For chess in Epoch 4, Validation loss: 0.0488
 For chess in Epoch 4, Average Loss: 0.0225, Average Translation Error: 0.0998, Average Rotation Error (radians): 0.0880


Epochs Progress for chess:  33%|███▎      | 5/15 [41:22<1:22:45, 496.55s/it]


New best model found for chess in fold 0 with validation loss 0.0411
For chess in Epoch 5, Validation loss: 0.0411
 For chess in Epoch 5, Average Loss: 0.0176, Average Translation Error: 0.0883, Average Rotation Error (radians): 0.0732


Epochs Progress for chess:  40%|████      | 6/15 [49:40<1:14:32, 496.95s/it]


New best model found for chess in fold 0 with validation loss 0.0365
For chess in Epoch 6, Validation loss: 0.0365
 For chess in Epoch 6, Average Loss: 0.0148, Average Translation Error: 0.0805, Average Rotation Error (radians): 0.0636


Epochs Progress for chess:  47%|████▋     | 7/15 [58:03<1:06:30, 498.75s/it]


New best model found for chess in fold 0 with validation loss 0.0331
For chess in Epoch 7, Validation loss: 0.0331
 For chess in Epoch 7, Average Loss: 0.0132, Average Translation Error: 0.0757, Average Rotation Error (radians): 0.0576


Epochs Progress for chess:  53%|█████▎    | 8/15 [1:06:21<58:09, 498.52s/it]


New best model found for chess in fold 0 with validation loss 0.0306
For chess in Epoch 8, Validation loss: 0.0306
 For chess in Epoch 8, Average Loss: 0.0120, Average Translation Error: 0.0727, Average Rotation Error (radians): 0.0535


Epochs Progress for chess:  60%|██████    | 9/15 [1:14:37<49:47, 497.94s/it]


New best model found for chess in fold 0 with validation loss 0.0286
For chess in Epoch 9, Validation loss: 0.0286
 For chess in Epoch 9, Average Loss: 0.0106, Average Translation Error: 0.0679, Average Rotation Error (radians): 0.0499


Epochs Progress for chess:  67%|██████▋   | 10/15 [1:22:56<41:30, 498.13s/it]


New best model found for chess in fold 0 with validation loss 0.0266
For chess in Epoch 10, Validation loss: 0.0266
 For chess in Epoch 10, Average Loss: 0.0099, Average Translation Error: 0.0657, Average Rotation Error (radians): 0.0469


Epochs Progress for chess:  73%|███████▎  | 11/15 [1:31:14<33:12, 498.05s/it]


New best model found for chess in fold 0 with validation loss 0.0258
For chess in Epoch 11, Validation loss: 0.0258
 For chess in Epoch 11, Average Loss: 0.0094, Average Translation Error: 0.0655, Average Rotation Error (radians): 0.0441


Epochs Progress for chess:  80%|████████  | 12/15 [1:39:32<24:54, 498.26s/it]


New best model found for chess in fold 0 with validation loss 0.0240
For chess in Epoch 12, Validation loss: 0.0240
 For chess in Epoch 12, Average Loss: 0.0087, Average Translation Error: 0.0622, Average Rotation Error (radians): 0.0426


Epochs Progress for chess:  87%|████████▋ | 13/15 [1:47:49<16:35, 497.78s/it]


New best model found for chess in fold 0 with validation loss 0.0235
For chess in Epoch 13, Validation loss: 0.0235
 For chess in Epoch 13, Average Loss: 0.0082, Average Translation Error: 0.0603, Average Rotation Error (radians): 0.0407


Epochs Progress for chess:  93%|█████████▎| 14/15 [1:56:06<08:17, 497.51s/it]


New best model found for chess in fold 0 with validation loss 0.0225
For chess in Epoch 14, Validation loss: 0.0225
 For chess in Epoch 14, Average Loss: 0.0074, Average Translation Error: 0.0558, Average Rotation Error (radians): 0.0382


Epochs Progress for chess: 100%|██████████| 15/15 [2:04:24<00:00, 497.60s/it]



New best model found for chess in fold 0 with validation loss 0.0216
For chess in Epoch 15, Validation loss: 0.0216
 For chess in Epoch 15, Average Loss: 0.0073, Average Translation Error: 0.0565, Average Rotation Error (radians): 0.0379
--------------------------------------------------
Training complete for chess in Fold 0.
(Best Epoch for Fold 0: 0 with Validation Loss: 0.0216
----------------------------------------------------------------------------------------------------
FOLD 1 for chess
----------------------------------------------------------------------------------------------------


Epochs Progress for chess:   7%|▋         | 1/15 [08:27<1:58:21, 507.26s/it]


New best model found for chess in fold 1 with validation loss 0.0131
For chess in Epoch 1, Validation loss: 0.0131
 For chess in Epoch 1, Average Loss: 0.0077, Average Translation Error: 0.0595, Average Rotation Error (radians): 0.0378


Epochs Progress for chess:  13%|█▎        | 2/15 [16:42<1:48:25, 500.44s/it]


New best model found for chess in fold 1 with validation loss 0.0129
For chess in Epoch 2, Validation loss: 0.0129
 For chess in Epoch 2, Average Loss: 0.0073, Average Translation Error: 0.0579, Average Rotation Error (radians): 0.0348


Epochs Progress for chess:  20%|██        | 3/15 [25:01<1:39:53, 499.50s/it]


New best model found for chess in fold 1 with validation loss 0.0127
For chess in Epoch 3, Validation loss: 0.0127
 For chess in Epoch 3, Average Loss: 0.0069, Average Translation Error: 0.0557, Average Rotation Error (radians): 0.0354


Epochs Progress for chess:  27%|██▋       | 4/15 [33:21<1:31:35, 499.62s/it]


New best model found for chess in fold 1 with validation loss 0.0120
For chess in Epoch 4, Validation loss: 0.0120
 For chess in Epoch 4, Average Loss: 0.0065, Average Translation Error: 0.0531, Average Rotation Error (radians): 0.0340


Epochs Progress for chess:  33%|███▎      | 5/15 [41:41<1:23:17, 499.78s/it]

For chess in Epoch 5, Validation loss: 0.0121
 For chess in Epoch 5, Average Loss: 0.0067, Average Translation Error: 0.0557, Average Rotation Error (radians): 0.0335


Epochs Progress for chess:  40%|████      | 6/15 [49:59<1:14:52, 499.15s/it]


New best model found for chess in fold 1 with validation loss 0.0119
For chess in Epoch 6, Validation loss: 0.0119
 For chess in Epoch 6, Average Loss: 0.0061, Average Translation Error: 0.0523, Average Rotation Error (radians): 0.0325


Epochs Progress for chess:  47%|████▋     | 7/15 [58:19<1:06:35, 499.42s/it]


New best model found for chess in fold 1 with validation loss 0.0118
For chess in Epoch 7, Validation loss: 0.0118
 For chess in Epoch 7, Average Loss: 0.0060, Average Translation Error: 0.0520, Average Rotation Error (radians): 0.0312


Epochs Progress for chess:  53%|█████▎    | 8/15 [1:06:37<58:13, 499.04s/it]


New best model found for chess in fold 1 with validation loss 0.0113
For chess in Epoch 8, Validation loss: 0.0113
 For chess in Epoch 8, Average Loss: 0.0058, Average Translation Error: 0.0513, Average Rotation Error (radians): 0.0306


Epochs Progress for chess:  60%|██████    | 9/15 [1:14:54<49:50, 498.46s/it]


New best model found for chess in fold 1 with validation loss 0.0113
For chess in Epoch 9, Validation loss: 0.0113
 For chess in Epoch 9, Average Loss: 0.0056, Average Translation Error: 0.0507, Average Rotation Error (radians): 0.0296


Epochs Progress for chess:  67%|██████▋   | 10/15 [1:23:15<41:36, 499.38s/it]


New best model found for chess in fold 1 with validation loss 0.0109
For chess in Epoch 10, Validation loss: 0.0109
 For chess in Epoch 10, Average Loss: 0.0052, Average Translation Error: 0.0481, Average Rotation Error (radians): 0.0293


Epochs Progress for chess:  73%|███████▎  | 11/15 [1:31:34<33:16, 499.19s/it]

For chess in Epoch 11, Validation loss: 0.0113
 For chess in Epoch 11, Average Loss: 0.0054, Average Translation Error: 0.0499, Average Rotation Error (radians): 0.0290


Epochs Progress for chess:  80%|████████  | 12/15 [1:39:53<24:57, 499.13s/it]


New best model found for chess in fold 1 with validation loss 0.0107
For chess in Epoch 12, Validation loss: 0.0107
 For chess in Epoch 12, Average Loss: 0.0051, Average Translation Error: 0.0485, Average Rotation Error (radians): 0.0286


Epochs Progress for chess:  87%|████████▋ | 13/15 [1:48:12<16:37, 498.92s/it]


New best model found for chess in fold 1 with validation loss 0.0106
For chess in Epoch 13, Validation loss: 0.0106
 For chess in Epoch 13, Average Loss: 0.0047, Average Translation Error: 0.0460, Average Rotation Error (radians): 0.0280


Epochs Progress for chess:  93%|█████████▎| 14/15 [1:59:24<09:11, 551.18s/it]


New best model found for chess in fold 1 with validation loss 0.0105
For chess in Epoch 14, Validation loss: 0.0105
 For chess in Epoch 14, Average Loss: 0.0049, Average Translation Error: 0.0475, Average Rotation Error (radians): 0.0274


Epochs Progress for chess: 100%|██████████| 15/15 [2:09:55<00:00, 519.68s/it]



New best model found for chess in fold 1 with validation loss 0.0103
For chess in Epoch 15, Validation loss: 0.0103
 For chess in Epoch 15, Average Loss: 0.0048, Average Translation Error: 0.0481, Average Rotation Error (radians): 0.0270
--------------------------------------------------
Training complete for chess in Fold 1.
(Best Epoch for Fold 1: 0 with Validation Loss: 0.0103
----------------------------------------------------------------------------------------------------
FOLD 2 for chess
----------------------------------------------------------------------------------------------------


Epochs Progress for chess:   7%|▋         | 1/15 [10:59<2:33:57, 659.80s/it]


New best model found for chess in fold 2 with validation loss 0.0076
For chess in Epoch 1, Validation loss: 0.0076
 For chess in Epoch 1, Average Loss: 0.0050, Average Translation Error: 0.0492, Average Rotation Error (radians): 0.0279


Epochs Progress for chess:  13%|█▎        | 2/15 [21:43<2:20:55, 650.39s/it]


New best model found for chess in fold 2 with validation loss 0.0074
For chess in Epoch 2, Validation loss: 0.0074
 For chess in Epoch 2, Average Loss: 0.0047, Average Translation Error: 0.0463, Average Rotation Error (radians): 0.0268


Epochs Progress for chess:  20%|██        | 3/15 [32:20<2:08:48, 644.03s/it]

For chess in Epoch 3, Validation loss: 0.0074
 For chess in Epoch 3, Average Loss: 0.0046, Average Translation Error: 0.0458, Average Rotation Error (radians): 0.0265


Epochs Progress for chess:  27%|██▋       | 4/15 [43:23<1:59:30, 651.88s/it]


New best model found for chess in fold 2 with validation loss 0.0072
For chess in Epoch 4, Validation loss: 0.0072
 For chess in Epoch 4, Average Loss: 0.0045, Average Translation Error: 0.0466, Average Rotation Error (radians): 0.0252


Epochs Progress for chess:  33%|███▎      | 5/15 [53:23<1:45:31, 633.17s/it]

For chess in Epoch 5, Validation loss: 0.0072
 For chess in Epoch 5, Average Loss: 0.0045, Average Translation Error: 0.0462, Average Rotation Error (radians): 0.0256


Epochs Progress for chess:  40%|████      | 6/15 [1:03:14<1:32:48, 618.69s/it]

For chess in Epoch 6, Validation loss: 0.0075
 For chess in Epoch 6, Average Loss: 0.0045, Average Translation Error: 0.0464, Average Rotation Error (radians): 0.0257


Epochs Progress for chess:  47%|████▋     | 7/15 [1:12:56<1:20:54, 606.76s/it]

For chess in Epoch 7, Validation loss: 0.0073
 For chess in Epoch 7, Average Loss: 0.0044, Average Translation Error: 0.0458, Average Rotation Error (radians): 0.0250


Epochs Progress for chess:  53%|█████▎    | 8/15 [1:22:37<1:09:49, 598.43s/it]


New best model found for chess in fold 2 with validation loss 0.0070
For chess in Epoch 8, Validation loss: 0.0070
 For chess in Epoch 8, Average Loss: 0.0043, Average Translation Error: 0.0450, Average Rotation Error (radians): 0.0244


Epochs Progress for chess:  60%|██████    | 9/15 [1:32:16<59:14, 592.45s/it]  

For chess in Epoch 9, Validation loss: 0.0071
 For chess in Epoch 9, Average Loss: 0.0040, Average Translation Error: 0.0431, Average Rotation Error (radians): 0.0245


Epochs Progress for chess:  67%|██████▋   | 10/15 [1:41:50<48:54, 586.84s/it]

For chess in Epoch 10, Validation loss: 0.0070
 For chess in Epoch 10, Average Loss: 0.0041, Average Translation Error: 0.0447, Average Rotation Error (radians): 0.0235


Epochs Progress for chess:  73%|███████▎  | 11/15 [1:51:29<38:57, 584.26s/it]

For chess in Epoch 11, Validation loss: 0.0071
 For chess in Epoch 11, Average Loss: 0.0039, Average Translation Error: 0.0429, Average Rotation Error (radians): 0.0235


Epochs Progress for chess:  80%|████████  | 12/15 [2:00:24<28:27, 569.25s/it]


New best model found for chess in fold 2 with validation loss 0.0068
For chess in Epoch 12, Validation loss: 0.0068
 For chess in Epoch 12, Average Loss: 0.0038, Average Translation Error: 0.0425, Average Rotation Error (radians): 0.0229


Epochs Progress for chess:  87%|████████▋ | 13/15 [2:09:00<18:26, 553.34s/it]

For chess in Epoch 13, Validation loss: 0.0068
 For chess in Epoch 13, Average Loss: 0.0036, Average Translation Error: 0.0405, Average Rotation Error (radians): 0.0221


Epochs Progress for chess:  93%|█████████▎| 14/15 [2:17:37<09:02, 542.30s/it]

For chess in Epoch 14, Validation loss: 0.0070
 For chess in Epoch 14, Average Loss: 0.0036, Average Translation Error: 0.0408, Average Rotation Error (radians): 0.0226


Epochs Progress for chess: 100%|██████████| 15/15 [2:26:15<00:00, 585.03s/it]



New best model found for chess in fold 2 with validation loss 0.0067
For chess in Epoch 15, Validation loss: 0.0067
 For chess in Epoch 15, Average Loss: 0.0035, Average Translation Error: 0.0400, Average Rotation Error (radians): 0.0229
--------------------------------------------------
Training complete for chess in Fold 2.
(Best Epoch for Fold 2: 0 with Validation Loss: 0.0067
----------------------------------------------------------------------------------------------------
FOLD 3 for chess
----------------------------------------------------------------------------------------------------


Epochs Progress for chess:   7%|▋         | 1/15 [08:37<2:00:40, 517.14s/it]


New best model found for chess in fold 3 with validation loss 0.0054
For chess in Epoch 1, Validation loss: 0.0054
 For chess in Epoch 1, Average Loss: 0.0038, Average Translation Error: 0.0420, Average Rotation Error (radians): 0.0224


Epochs Progress for chess:  13%|█▎        | 2/15 [17:14<1:52:07, 517.50s/it]


New best model found for chess in fold 3 with validation loss 0.0052
For chess in Epoch 2, Validation loss: 0.0052
 For chess in Epoch 2, Average Loss: 0.0036, Average Translation Error: 0.0411, Average Rotation Error (radians): 0.0222


Epochs Progress for chess:  20%|██        | 3/15 [25:51<1:43:22, 516.91s/it]

For chess in Epoch 3, Validation loss: 0.0053
 For chess in Epoch 3, Average Loss: 0.0035, Average Translation Error: 0.0406, Average Rotation Error (radians): 0.0223


Epochs Progress for chess:  27%|██▋       | 4/15 [34:27<1:34:45, 516.84s/it]

For chess in Epoch 4, Validation loss: 0.0054
 For chess in Epoch 4, Average Loss: 0.0035, Average Translation Error: 0.0400, Average Rotation Error (radians): 0.0226


Epochs Progress for chess:  33%|███▎      | 5/15 [43:04<1:26:08, 516.87s/it]

For chess in Epoch 5, Validation loss: 0.0055
 For chess in Epoch 5, Average Loss: 0.0035, Average Translation Error: 0.0406, Average Rotation Error (radians): 0.0218


Epochs Progress for chess:  40%|████      | 6/15 [51:42<1:17:34, 517.19s/it]

For chess in Epoch 6, Validation loss: 0.0053
 For chess in Epoch 6, Average Loss: 0.0032, Average Translation Error: 0.0381, Average Rotation Error (radians): 0.0211


Epochs Progress for chess:  47%|████▋     | 7/15 [1:00:19<1:08:56, 517.11s/it]

For chess in Epoch 7, Validation loss: 0.0053
 For chess in Epoch 7, Average Loss: 0.0034, Average Translation Error: 0.0393, Average Rotation Error (radians): 0.0213


Epochs Progress for chess:  53%|█████▎    | 8/15 [1:08:57<1:00:21, 517.30s/it]

For chess in Epoch 8, Validation loss: 0.0054
 For chess in Epoch 8, Average Loss: 0.0032, Average Translation Error: 0.0385, Average Rotation Error (radians): 0.0214


Epochs Progress for chess:  60%|██████    | 9/15 [1:17:34<51:43, 517.32s/it]  


New best model found for chess in fold 3 with validation loss 0.0052
For chess in Epoch 9, Validation loss: 0.0052
 For chess in Epoch 9, Average Loss: 0.0033, Average Translation Error: 0.0388, Average Rotation Error (radians): 0.0209


Epochs Progress for chess:  67%|██████▋   | 10/15 [1:26:14<43:10, 518.05s/it]


New best model found for chess in fold 3 with validation loss 0.0052
For chess in Epoch 10, Validation loss: 0.0052
 For chess in Epoch 10, Average Loss: 0.0033, Average Translation Error: 0.0393, Average Rotation Error (radians): 0.0213


Epochs Progress for chess:  73%|███████▎  | 11/15 [1:34:52<34:32, 518.25s/it]


New best model found for chess in fold 3 with validation loss 0.0052
For chess in Epoch 11, Validation loss: 0.0052
 For chess in Epoch 11, Average Loss: 0.0032, Average Translation Error: 0.0388, Average Rotation Error (radians): 0.0212


Epochs Progress for chess:  80%|████████  | 12/15 [1:43:30<25:53, 517.95s/it]

For chess in Epoch 12, Validation loss: 0.0052
 For chess in Epoch 12, Average Loss: 0.0031, Average Translation Error: 0.0383, Average Rotation Error (radians): 0.0203


Epochs Progress for chess:  87%|████████▋ | 13/15 [1:52:05<17:14, 517.25s/it]


New best model found for chess in fold 3 with validation loss 0.0051
For chess in Epoch 13, Validation loss: 0.0051
 For chess in Epoch 13, Average Loss: 0.0030, Average Translation Error: 0.0370, Average Rotation Error (radians): 0.0198


Epochs Progress for chess:  93%|█████████▎| 14/15 [2:00:44<08:37, 517.77s/it]


New best model found for chess in fold 3 with validation loss 0.0051
For chess in Epoch 14, Validation loss: 0.0051
 For chess in Epoch 14, Average Loss: 0.0029, Average Translation Error: 0.0364, Average Rotation Error (radians): 0.0198


Epochs Progress for chess: 100%|██████████| 15/15 [2:09:23<00:00, 517.55s/it]


For chess in Epoch 15, Validation loss: 0.0052
 For chess in Epoch 15, Average Loss: 0.0030, Average Translation Error: 0.0374, Average Rotation Error (radians): 0.0209
--------------------------------------------------
Training complete for chess in Fold 3.
(Best Epoch for Fold 3: 0 with Validation Loss: 0.0051
----------------------------------------------------------------------------------------------------
FOLD 4 for chess
----------------------------------------------------------------------------------------------------


Epochs Progress for chess:   7%|▋         | 1/15 [08:37<2:00:39, 517.09s/it]


New best model found for chess in fold 4 with validation loss 0.0040
For chess in Epoch 1, Validation loss: 0.0040
 For chess in Epoch 1, Average Loss: 0.0030, Average Translation Error: 0.0368, Average Rotation Error (radians): 0.0195


Epochs Progress for chess:  13%|█▎        | 2/15 [17:14<1:52:01, 517.03s/it]


New best model found for chess in fold 4 with validation loss 0.0040
For chess in Epoch 2, Validation loss: 0.0040
 For chess in Epoch 2, Average Loss: 0.0028, Average Translation Error: 0.0354, Average Rotation Error (radians): 0.0195


Epochs Progress for chess:  20%|██        | 3/15 [25:50<1:43:22, 516.88s/it]

For chess in Epoch 3, Validation loss: 0.0040
 For chess in Epoch 3, Average Loss: 0.0030, Average Translation Error: 0.0377, Average Rotation Error (radians): 0.0195


Epochs Progress for chess:  27%|██▋       | 4/15 [34:31<1:35:00, 518.22s/it]

For chess in Epoch 4, Validation loss: 0.0040
 For chess in Epoch 4, Average Loss: 0.0029, Average Translation Error: 0.0366, Average Rotation Error (radians): 0.0188


Epochs Progress for chess:  33%|███▎      | 5/15 [43:14<1:26:39, 519.97s/it]

For chess in Epoch 5, Validation loss: 0.0041
 For chess in Epoch 5, Average Loss: 0.0029, Average Translation Error: 0.0370, Average Rotation Error (radians): 0.0192


Epochs Progress for chess:  40%|████      | 6/15 [51:49<1:17:46, 518.45s/it]


New best model found for chess in fold 4 with validation loss 0.0039
For chess in Epoch 6, Validation loss: 0.0039
 For chess in Epoch 6, Average Loss: 0.0028, Average Translation Error: 0.0361, Average Rotation Error (radians): 0.0190


Epochs Progress for chess:  47%|████▋     | 7/15 [1:00:26<1:09:03, 517.96s/it]

For chess in Epoch 7, Validation loss: 0.0042
 For chess in Epoch 7, Average Loss: 0.0028, Average Translation Error: 0.0361, Average Rotation Error (radians): 0.0194


Epochs Progress for chess:  53%|█████▎    | 8/15 [1:09:03<1:00:24, 517.78s/it]

For chess in Epoch 8, Validation loss: 0.0041
 For chess in Epoch 8, Average Loss: 0.0028, Average Translation Error: 0.0365, Average Rotation Error (radians): 0.0196


Epochs Progress for chess:  60%|██████    | 9/15 [1:17:41<51:46, 517.79s/it]  

For chess in Epoch 9, Validation loss: 0.0039
 For chess in Epoch 9, Average Loss: 0.0027, Average Translation Error: 0.0360, Average Rotation Error (radians): 0.0194


Epochs Progress for chess:  67%|██████▋   | 10/15 [1:26:19<43:09, 517.89s/it]


New best model found for chess in fold 4 with validation loss 0.0039
For chess in Epoch 10, Validation loss: 0.0039
 For chess in Epoch 10, Average Loss: 0.0027, Average Translation Error: 0.0365, Average Rotation Error (radians): 0.0188


Epochs Progress for chess:  73%|███████▎  | 11/15 [1:34:58<34:32, 518.01s/it]

For chess in Epoch 11, Validation loss: 0.0041
 For chess in Epoch 11, Average Loss: 0.0027, Average Translation Error: 0.0356, Average Rotation Error (radians): 0.0180


Epochs Progress for chess:  80%|████████  | 12/15 [1:43:37<25:55, 518.35s/it]

For chess in Epoch 12, Validation loss: 0.0043
 For chess in Epoch 12, Average Loss: 0.0027, Average Translation Error: 0.0363, Average Rotation Error (radians): 0.0185


Epochs Progress for chess:  87%|████████▋ | 13/15 [1:52:13<17:15, 517.60s/it]

For chess in Epoch 13, Validation loss: 0.0041
 For chess in Epoch 13, Average Loss: 0.0026, Average Translation Error: 0.0345, Average Rotation Error (radians): 0.0180


Epochs Progress for chess:  93%|█████████▎| 14/15 [2:00:49<08:37, 517.19s/it]


New best model found for chess in fold 4 with validation loss 0.0039
For chess in Epoch 14, Validation loss: 0.0039
 For chess in Epoch 14, Average Loss: 0.0025, Average Translation Error: 0.0339, Average Rotation Error (radians): 0.0183


Epochs Progress for chess: 100%|██████████| 15/15 [2:09:26<00:00, 517.75s/it]


For chess in Epoch 15, Validation loss: 0.0039
 For chess in Epoch 15, Average Loss: 0.0025, Average Translation Error: 0.0338, Average Rotation Error (radians): 0.0180
--------------------------------------------------
Training complete for chess in Fold 4.
(Best Epoch for Fold 4: 0 with Validation Loss: 0.0039
Best model for chess from Fold 4 saved with loss 0.0039
Performance of the best model for chess on the test data was from Fold 4:
Average Translation Error for chess: 0.2632 meters
Average Rotation Error (degree) for chess: 7.4473
--------------------------------------------------
Processing sequence: seq-01 in fire (train)
Processing sequence: seq-02 in fire (train)
Processing sequence: seq-03 in fire (test)
Processing sequence: seq-04 in fire (test)
----------------------------------------------------------------------------------------------------
FOLD 0 for fire
----------------------------------------------------------------------------------------------------


Epochs Progress for fire:   7%|▋         | 1/15 [02:07<29:49, 127.84s/it]


New best model found for fire in fold 0 with validation loss 0.3811
For fire in Epoch 1, Validation loss: 0.3811
 For fire in Epoch 1, Average Loss: 0.2931, Average Translation Error: 0.4224, Average Rotation Error (radians): 0.1180


Epochs Progress for fire:  13%|█▎        | 2/15 [04:17<27:59, 129.18s/it]


New best model found for fire in fold 0 with validation loss 0.1514
For fire in Epoch 2, Validation loss: 0.1514
 For fire in Epoch 2, Average Loss: 0.1134, Average Translation Error: 0.2534, Average Rotation Error (radians): 0.1604


Epochs Progress for fire:  20%|██        | 3/15 [06:28<25:55, 129.62s/it]


New best model found for fire in fold 0 with validation loss 0.1274
For fire in Epoch 3, Validation loss: 0.1274
 For fire in Epoch 3, Average Loss: 0.0533, Average Translation Error: 0.1623, Average Rotation Error (radians): 0.1462


Epochs Progress for fire:  27%|██▋       | 4/15 [08:38<23:48, 129.89s/it]


New best model found for fire in fold 0 with validation loss 0.0832
For fire in Epoch 4, Validation loss: 0.0832
 For fire in Epoch 4, Average Loss: 0.0368, Average Translation Error: 0.1205, Average Rotation Error (radians): 0.1148


Epochs Progress for fire:  33%|███▎      | 5/15 [10:48<21:40, 130.03s/it]


New best model found for fire in fold 0 with validation loss 0.0586
For fire in Epoch 5, Validation loss: 0.0586
 For fire in Epoch 5, Average Loss: 0.0260, Average Translation Error: 0.0990, Average Rotation Error (radians): 0.1024


Epochs Progress for fire:  40%|████      | 6/15 [12:59<19:32, 130.24s/it]


New best model found for fire in fold 0 with validation loss 0.0476
For fire in Epoch 6, Validation loss: 0.0476
 For fire in Epoch 6, Average Loss: 0.0200, Average Translation Error: 0.0866, Average Rotation Error (radians): 0.0903


Epochs Progress for fire:  47%|████▋     | 7/15 [15:10<17:23, 130.39s/it]

For fire in Epoch 7, Validation loss: 0.0477
 For fire in Epoch 7, Average Loss: 0.0159, Average Translation Error: 0.0776, Average Rotation Error (radians): 0.0776


Epochs Progress for fire:  53%|█████▎    | 8/15 [17:20<15:11, 130.26s/it]


New best model found for fire in fold 0 with validation loss 0.0340
For fire in Epoch 8, Validation loss: 0.0340
 For fire in Epoch 8, Average Loss: 0.0135, Average Translation Error: 0.0738, Average Rotation Error (radians): 0.0641


Epochs Progress for fire:  60%|██████    | 9/15 [19:30<13:02, 130.36s/it]


New best model found for fire in fold 0 with validation loss 0.0314
For fire in Epoch 9, Validation loss: 0.0314
 For fire in Epoch 9, Average Loss: 0.0114, Average Translation Error: 0.0672, Average Rotation Error (radians): 0.0574


Epochs Progress for fire:  67%|██████▋   | 10/15 [21:41<10:52, 130.44s/it]


New best model found for fire in fold 0 with validation loss 0.0295
For fire in Epoch 10, Validation loss: 0.0295
 For fire in Epoch 10, Average Loss: 0.0102, Average Translation Error: 0.0651, Average Rotation Error (radians): 0.0503


Epochs Progress for fire:  73%|███████▎  | 11/15 [23:51<08:41, 130.28s/it]


New best model found for fire in fold 0 with validation loss 0.0258
For fire in Epoch 11, Validation loss: 0.0258
 For fire in Epoch 11, Average Loss: 0.0096, Average Translation Error: 0.0642, Average Rotation Error (radians): 0.0451


Epochs Progress for fire:  80%|████████  | 12/15 [26:01<06:30, 130.30s/it]

For fire in Epoch 12, Validation loss: 0.0268
 For fire in Epoch 12, Average Loss: 0.0080, Average Translation Error: 0.0576, Average Rotation Error (radians): 0.0426


Epochs Progress for fire:  87%|████████▋ | 13/15 [28:11<04:20, 130.26s/it]


New best model found for fire in fold 0 with validation loss 0.0230
For fire in Epoch 13, Validation loss: 0.0230
 For fire in Epoch 13, Average Loss: 0.0073, Average Translation Error: 0.0548, Average Rotation Error (radians): 0.0373


Epochs Progress for fire:  93%|█████████▎| 14/15 [30:21<02:10, 130.10s/it]


New best model found for fire in fold 0 with validation loss 0.0222
For fire in Epoch 14, Validation loss: 0.0222
 For fire in Epoch 14, Average Loss: 0.0074, Average Translation Error: 0.0569, Average Rotation Error (radians): 0.0366


Epochs Progress for fire: 100%|██████████| 15/15 [32:30<00:00, 130.06s/it]



New best model found for fire in fold 0 with validation loss 0.0212
For fire in Epoch 15, Validation loss: 0.0212
 For fire in Epoch 15, Average Loss: 0.0068, Average Translation Error: 0.0538, Average Rotation Error (radians): 0.0328
--------------------------------------------------
Training complete for fire in Fold 0.
(Best Epoch for Fold 0: 0 with Validation Loss: 0.0212
----------------------------------------------------------------------------------------------------
FOLD 1 for fire
----------------------------------------------------------------------------------------------------


Epochs Progress for fire:   7%|▋         | 1/15 [02:09<30:18, 129.90s/it]


New best model found for fire in fold 1 with validation loss 0.0121
For fire in Epoch 1, Validation loss: 0.0121
 For fire in Epoch 1, Average Loss: 0.0076, Average Translation Error: 0.0597, Average Rotation Error (radians): 0.0346


Epochs Progress for fire:  13%|█▎        | 2/15 [04:19<28:09, 129.98s/it]


New best model found for fire in fold 1 with validation loss 0.0120
For fire in Epoch 2, Validation loss: 0.0120
 For fire in Epoch 2, Average Loss: 0.0070, Average Translation Error: 0.0558, Average Rotation Error (radians): 0.0348


Epochs Progress for fire:  20%|██        | 3/15 [06:29<25:59, 129.97s/it]


New best model found for fire in fold 1 with validation loss 0.0109
For fire in Epoch 3, Validation loss: 0.0109
 For fire in Epoch 3, Average Loss: 0.0059, Average Translation Error: 0.0509, Average Rotation Error (radians): 0.0328


Epochs Progress for fire:  27%|██▋       | 4/15 [08:40<23:53, 130.31s/it]


New best model found for fire in fold 1 with validation loss 0.0103
For fire in Epoch 4, Validation loss: 0.0103
 For fire in Epoch 4, Average Loss: 0.0063, Average Translation Error: 0.0537, Average Rotation Error (radians): 0.0307


Epochs Progress for fire:  33%|███▎      | 5/15 [10:51<21:43, 130.30s/it]


New best model found for fire in fold 1 with validation loss 0.0102
For fire in Epoch 5, Validation loss: 0.0102
 For fire in Epoch 5, Average Loss: 0.0059, Average Translation Error: 0.0511, Average Rotation Error (radians): 0.0323


Epochs Progress for fire:  40%|████      | 6/15 [13:01<19:33, 130.38s/it]

For fire in Epoch 6, Validation loss: 0.0104
 For fire in Epoch 6, Average Loss: 0.0058, Average Translation Error: 0.0510, Average Rotation Error (radians): 0.0304


Epochs Progress for fire:  47%|████▋     | 7/15 [15:12<17:23, 130.43s/it]

For fire in Epoch 7, Validation loss: 0.0105
 For fire in Epoch 7, Average Loss: 0.0052, Average Translation Error: 0.0470, Average Rotation Error (radians): 0.0264


Epochs Progress for fire:  53%|█████▎    | 8/15 [17:22<15:13, 130.50s/it]

For fire in Epoch 8, Validation loss: 0.0104
 For fire in Epoch 8, Average Loss: 0.0057, Average Translation Error: 0.0525, Average Rotation Error (radians): 0.0281


Epochs Progress for fire:  60%|██████    | 9/15 [19:32<13:02, 130.38s/it]


New best model found for fire in fold 1 with validation loss 0.0097
For fire in Epoch 9, Validation loss: 0.0097
 For fire in Epoch 9, Average Loss: 0.0050, Average Translation Error: 0.0479, Average Rotation Error (radians): 0.0251


Epochs Progress for fire:  67%|██████▋   | 10/15 [21:43<10:52, 130.45s/it]

For fire in Epoch 10, Validation loss: 0.0099
 For fire in Epoch 10, Average Loss: 0.0044, Average Translation Error: 0.0439, Average Rotation Error (radians): 0.0254


Epochs Progress for fire:  73%|███████▎  | 11/15 [23:53<08:41, 130.38s/it]


New best model found for fire in fold 1 with validation loss 0.0092
For fire in Epoch 11, Validation loss: 0.0092
 For fire in Epoch 11, Average Loss: 0.0055, Average Translation Error: 0.0517, Average Rotation Error (radians): 0.0263


Epochs Progress for fire:  80%|████████  | 12/15 [26:03<06:30, 130.29s/it]


New best model found for fire in fold 1 with validation loss 0.0088
For fire in Epoch 12, Validation loss: 0.0088
 For fire in Epoch 12, Average Loss: 0.0049, Average Translation Error: 0.0488, Average Rotation Error (radians): 0.0252


Epochs Progress for fire:  87%|████████▋ | 13/15 [28:14<04:20, 130.36s/it]


New best model found for fire in fold 1 with validation loss 0.0088
For fire in Epoch 13, Validation loss: 0.0088
 For fire in Epoch 13, Average Loss: 0.0042, Average Translation Error: 0.0442, Average Rotation Error (radians): 0.0226


Epochs Progress for fire:  93%|█████████▎| 14/15 [30:24<02:10, 130.33s/it]

For fire in Epoch 14, Validation loss: 0.0096
 For fire in Epoch 14, Average Loss: 0.0046, Average Translation Error: 0.0476, Average Rotation Error (radians): 0.0234


Epochs Progress for fire: 100%|██████████| 15/15 [32:35<00:00, 130.37s/it]


For fire in Epoch 15, Validation loss: 0.0091
 For fire in Epoch 15, Average Loss: 0.0043, Average Translation Error: 0.0445, Average Rotation Error (radians): 0.0235
--------------------------------------------------
Training complete for fire in Fold 1.
(Best Epoch for Fold 1: 0 with Validation Loss: 0.0088
----------------------------------------------------------------------------------------------------
FOLD 2 for fire
----------------------------------------------------------------------------------------------------


Epochs Progress for fire:   7%|▋         | 1/15 [02:09<30:16, 129.74s/it]


New best model found for fire in fold 2 with validation loss 0.0065
For fire in Epoch 1, Validation loss: 0.0065
 For fire in Epoch 1, Average Loss: 0.0041, Average Translation Error: 0.0441, Average Rotation Error (radians): 0.0224


Epochs Progress for fire:  13%|█▎        | 2/15 [04:19<28:08, 129.85s/it]


New best model found for fire in fold 2 with validation loss 0.0059
For fire in Epoch 2, Validation loss: 0.0059
 For fire in Epoch 2, Average Loss: 0.0045, Average Translation Error: 0.0469, Average Rotation Error (radians): 0.0217


Epochs Progress for fire:  20%|██        | 3/15 [06:30<26:01, 130.15s/it]

For fire in Epoch 3, Validation loss: 0.0064
 For fire in Epoch 3, Average Loss: 0.0041, Average Translation Error: 0.0449, Average Rotation Error (radians): 0.0238


Epochs Progress for fire:  27%|██▋       | 4/15 [08:40<23:52, 130.24s/it]

For fire in Epoch 4, Validation loss: 0.0064
 For fire in Epoch 4, Average Loss: 0.0038, Average Translation Error: 0.0420, Average Rotation Error (radians): 0.0229


Epochs Progress for fire:  33%|███▎      | 5/15 [10:51<21:43, 130.32s/it]

For fire in Epoch 5, Validation loss: 0.0060
 For fire in Epoch 5, Average Loss: 0.0036, Average Translation Error: 0.0412, Average Rotation Error (radians): 0.0218


Epochs Progress for fire:  40%|████      | 6/15 [13:01<19:32, 130.26s/it]

For fire in Epoch 6, Validation loss: 0.0063
 For fire in Epoch 6, Average Loss: 0.0037, Average Translation Error: 0.0418, Average Rotation Error (radians): 0.0195


Epochs Progress for fire:  47%|████▋     | 7/15 [15:11<17:22, 130.31s/it]

For fire in Epoch 7, Validation loss: 0.0065
 For fire in Epoch 7, Average Loss: 0.0039, Average Translation Error: 0.0442, Average Rotation Error (radians): 0.0203


Epochs Progress for fire:  53%|█████▎    | 8/15 [17:22<15:13, 130.48s/it]

For fire in Epoch 8, Validation loss: 0.0062
 For fire in Epoch 8, Average Loss: 0.0037, Average Translation Error: 0.0416, Average Rotation Error (radians): 0.0222


Epochs Progress for fire:  60%|██████    | 9/15 [19:32<13:02, 130.48s/it]

For fire in Epoch 9, Validation loss: 0.0062
 For fire in Epoch 9, Average Loss: 0.0040, Average Translation Error: 0.0451, Average Rotation Error (radians): 0.0206


Epochs Progress for fire:  67%|██████▋   | 10/15 [21:43<10:52, 130.53s/it]

For fire in Epoch 10, Validation loss: 0.0064
 For fire in Epoch 10, Average Loss: 0.0039, Average Translation Error: 0.0440, Average Rotation Error (radians): 0.0218


Epochs Progress for fire:  73%|███████▎  | 11/15 [23:53<08:41, 130.37s/it]


New best model found for fire in fold 2 with validation loss 0.0058
For fire in Epoch 11, Validation loss: 0.0058
 For fire in Epoch 11, Average Loss: 0.0035, Average Translation Error: 0.0407, Average Rotation Error (radians): 0.0201


Epochs Progress for fire:  80%|████████  | 12/15 [26:03<06:30, 130.24s/it]


New best model found for fire in fold 2 with validation loss 0.0055
For fire in Epoch 12, Validation loss: 0.0055
 For fire in Epoch 12, Average Loss: 0.0031, Average Translation Error: 0.0377, Average Rotation Error (radians): 0.0193


Epochs Progress for fire:  87%|████████▋ | 13/15 [28:13<04:20, 130.22s/it]

For fire in Epoch 13, Validation loss: 0.0060
 For fire in Epoch 13, Average Loss: 0.0031, Average Translation Error: 0.0379, Average Rotation Error (radians): 0.0177


Epochs Progress for fire:  93%|█████████▎| 14/15 [30:24<02:10, 130.30s/it]

For fire in Epoch 14, Validation loss: 0.0058
 For fire in Epoch 14, Average Loss: 0.0035, Average Translation Error: 0.0411, Average Rotation Error (radians): 0.0208


Epochs Progress for fire: 100%|██████████| 15/15 [32:34<00:00, 130.28s/it]


For fire in Epoch 15, Validation loss: 0.0063
 For fire in Epoch 15, Average Loss: 0.0032, Average Translation Error: 0.0385, Average Rotation Error (radians): 0.0196
--------------------------------------------------
Training complete for fire in Fold 2.
(Best Epoch for Fold 2: 0 with Validation Loss: 0.0055
----------------------------------------------------------------------------------------------------
FOLD 3 for fire
----------------------------------------------------------------------------------------------------


Epochs Progress for fire:   7%|▋         | 1/15 [02:10<30:25, 130.40s/it]


New best model found for fire in fold 3 with validation loss 0.0046
For fire in Epoch 1, Validation loss: 0.0046
 For fire in Epoch 1, Average Loss: 0.0031, Average Translation Error: 0.0382, Average Rotation Error (radians): 0.0186


Epochs Progress for fire:  13%|█▎        | 2/15 [04:20<28:14, 130.32s/it]


New best model found for fire in fold 3 with validation loss 0.0045
For fire in Epoch 2, Validation loss: 0.0045
 For fire in Epoch 2, Average Loss: 0.0033, Average Translation Error: 0.0397, Average Rotation Error (radians): 0.0184


Epochs Progress for fire:  20%|██        | 3/15 [06:30<26:03, 130.30s/it]


New best model found for fire in fold 3 with validation loss 0.0042
For fire in Epoch 3, Validation loss: 0.0042
 For fire in Epoch 3, Average Loss: 0.0031, Average Translation Error: 0.0387, Average Rotation Error (radians): 0.0175


Epochs Progress for fire:  27%|██▋       | 4/15 [08:41<23:53, 130.27s/it]

For fire in Epoch 4, Validation loss: 0.0044
 For fire in Epoch 4, Average Loss: 0.0031, Average Translation Error: 0.0372, Average Rotation Error (radians): 0.0195


Epochs Progress for fire:  33%|███▎      | 5/15 [10:50<21:41, 130.10s/it]

For fire in Epoch 5, Validation loss: 0.0045
 For fire in Epoch 5, Average Loss: 0.0028, Average Translation Error: 0.0360, Average Rotation Error (radians): 0.0159


Epochs Progress for fire:  40%|████      | 6/15 [13:01<19:31, 130.21s/it]

For fire in Epoch 6, Validation loss: 0.0044
 For fire in Epoch 6, Average Loss: 0.0031, Average Translation Error: 0.0390, Average Rotation Error (radians): 0.0176


Epochs Progress for fire:  47%|████▋     | 7/15 [15:11<17:21, 130.21s/it]

For fire in Epoch 7, Validation loss: 0.0044
 For fire in Epoch 7, Average Loss: 0.0027, Average Translation Error: 0.0360, Average Rotation Error (radians): 0.0164


Epochs Progress for fire:  53%|█████▎    | 8/15 [17:22<15:12, 130.42s/it]

For fire in Epoch 8, Validation loss: 0.0046
 For fire in Epoch 8, Average Loss: 0.0031, Average Translation Error: 0.0393, Average Rotation Error (radians): 0.0178


Epochs Progress for fire:  60%|██████    | 9/15 [19:32<13:01, 130.21s/it]

For fire in Epoch 9, Validation loss: 0.0042
 For fire in Epoch 9, Average Loss: 0.0029, Average Translation Error: 0.0369, Average Rotation Error (radians): 0.0189


Epochs Progress for fire:  67%|██████▋   | 10/15 [21:42<10:51, 130.29s/it]

For fire in Epoch 10, Validation loss: 0.0045
 For fire in Epoch 10, Average Loss: 0.0026, Average Translation Error: 0.0343, Average Rotation Error (radians): 0.0166


Epochs Progress for fire:  73%|███████▎  | 11/15 [23:52<08:41, 130.29s/it]

For fire in Epoch 11, Validation loss: 0.0044
 For fire in Epoch 11, Average Loss: 0.0023, Average Translation Error: 0.0314, Average Rotation Error (radians): 0.0162


Epochs Progress for fire:  80%|████████  | 12/15 [26:03<06:31, 130.35s/it]

For fire in Epoch 12, Validation loss: 0.0042
 For fire in Epoch 12, Average Loss: 0.0025, Average Translation Error: 0.0329, Average Rotation Error (radians): 0.0194


Epochs Progress for fire:  87%|████████▋ | 13/15 [28:14<04:20, 130.49s/it]

For fire in Epoch 13, Validation loss: 0.0043
 For fire in Epoch 13, Average Loss: 0.0031, Average Translation Error: 0.0391, Average Rotation Error (radians): 0.0149


Epochs Progress for fire:  93%|█████████▎| 14/15 [30:24<02:10, 130.53s/it]


New best model found for fire in fold 3 with validation loss 0.0042
For fire in Epoch 14, Validation loss: 0.0042
 For fire in Epoch 14, Average Loss: 0.0026, Average Translation Error: 0.0359, Average Rotation Error (radians): 0.0163


Epochs Progress for fire: 100%|██████████| 15/15 [32:35<00:00, 130.35s/it]


For fire in Epoch 15, Validation loss: 0.0044
 For fire in Epoch 15, Average Loss: 0.0024, Average Translation Error: 0.0327, Average Rotation Error (radians): 0.0169
--------------------------------------------------
Training complete for fire in Fold 3.
(Best Epoch for Fold 3: 0 with Validation Loss: 0.0042
----------------------------------------------------------------------------------------------------
FOLD 4 for fire
----------------------------------------------------------------------------------------------------


Epochs Progress for fire:   7%|▋         | 1/15 [02:10<30:24, 130.33s/it]


New best model found for fire in fold 4 with validation loss 0.0032
For fire in Epoch 1, Validation loss: 0.0032
 For fire in Epoch 1, Average Loss: 0.0026, Average Translation Error: 0.0355, Average Rotation Error (radians): 0.0155


Epochs Progress for fire:  13%|█▎        | 2/15 [04:21<28:17, 130.59s/it]


New best model found for fire in fold 4 with validation loss 0.0031
For fire in Epoch 2, Validation loss: 0.0031
 For fire in Epoch 2, Average Loss: 0.0023, Average Translation Error: 0.0322, Average Rotation Error (radians): 0.0152


Epochs Progress for fire:  20%|██        | 3/15 [06:31<26:05, 130.45s/it]

For fire in Epoch 3, Validation loss: 0.0032
 For fire in Epoch 3, Average Loss: 0.0024, Average Translation Error: 0.0338, Average Rotation Error (radians): 0.0160


Epochs Progress for fire:  27%|██▋       | 4/15 [08:42<23:58, 130.75s/it]


New best model found for fire in fold 4 with validation loss 0.0031
For fire in Epoch 4, Validation loss: 0.0031
 For fire in Epoch 4, Average Loss: 0.0023, Average Translation Error: 0.0309, Average Rotation Error (radians): 0.0141


Epochs Progress for fire:  33%|███▎      | 5/15 [10:53<21:48, 130.84s/it]

For fire in Epoch 5, Validation loss: 0.0032
 For fire in Epoch 5, Average Loss: 0.0028, Average Translation Error: 0.0365, Average Rotation Error (radians): 0.0161


Epochs Progress for fire:  40%|████      | 6/15 [13:05<19:39, 131.05s/it]

For fire in Epoch 6, Validation loss: 0.0032
 For fire in Epoch 6, Average Loss: 0.0024, Average Translation Error: 0.0327, Average Rotation Error (radians): 0.0170


Epochs Progress for fire:  47%|████▋     | 7/15 [15:16<17:29, 131.25s/it]

For fire in Epoch 7, Validation loss: 0.0032
 For fire in Epoch 7, Average Loss: 0.0022, Average Translation Error: 0.0304, Average Rotation Error (radians): 0.0150


Epochs Progress for fire:  53%|█████▎    | 8/15 [17:28<15:19, 131.37s/it]


New best model found for fire in fold 4 with validation loss 0.0031
For fire in Epoch 8, Validation loss: 0.0031
 For fire in Epoch 8, Average Loss: 0.0023, Average Translation Error: 0.0329, Average Rotation Error (radians): 0.0153


Epochs Progress for fire:  60%|██████    | 9/15 [19:40<13:09, 131.57s/it]

For fire in Epoch 9, Validation loss: 0.0031
 For fire in Epoch 9, Average Loss: 0.0021, Average Translation Error: 0.0311, Average Rotation Error (radians): 0.0133


Epochs Progress for fire:  67%|██████▋   | 10/15 [21:50<10:55, 131.02s/it]

For fire in Epoch 10, Validation loss: 0.0032
 For fire in Epoch 10, Average Loss: 0.0019, Average Translation Error: 0.0279, Average Rotation Error (radians): 0.0156


Epochs Progress for fire:  73%|███████▎  | 11/15 [24:01<08:45, 131.25s/it]

For fire in Epoch 11, Validation loss: 0.0031
 For fire in Epoch 11, Average Loss: 0.0022, Average Translation Error: 0.0314, Average Rotation Error (radians): 0.0156


Epochs Progress for fire:  80%|████████  | 12/15 [26:14<06:34, 131.62s/it]

For fire in Epoch 12, Validation loss: 0.0032
 For fire in Epoch 12, Average Loss: 0.0023, Average Translation Error: 0.0328, Average Rotation Error (radians): 0.0155


Epochs Progress for fire:  87%|████████▋ | 13/15 [28:25<04:23, 131.52s/it]

For fire in Epoch 13, Validation loss: 0.0032
 For fire in Epoch 13, Average Loss: 0.0027, Average Translation Error: 0.0350, Average Rotation Error (radians): 0.0162


Epochs Progress for fire:  93%|█████████▎| 14/15 [30:36<02:11, 131.32s/it]


New best model found for fire in fold 4 with validation loss 0.0030
For fire in Epoch 14, Validation loss: 0.0030
 For fire in Epoch 14, Average Loss: 0.0025, Average Translation Error: 0.0350, Average Rotation Error (radians): 0.0171


Epochs Progress for fire: 100%|██████████| 15/15 [32:47<00:00, 131.17s/it]


For fire in Epoch 15, Validation loss: 0.0033
 For fire in Epoch 15, Average Loss: 0.0025, Average Translation Error: 0.0350, Average Rotation Error (radians): 0.0148
--------------------------------------------------
Training complete for fire in Fold 4.
(Best Epoch for Fold 4: 0 with Validation Loss: 0.0030
Best model for fire from Fold 4 saved with loss 0.0030
Performance of the best model for fire on the test data was from Fold 4:
Average Translation Error for fire: 0.3868 meters
Average Rotation Error (degree) for fire: 18.5901
--------------------------------------------------


In [1]:
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.models import resnet18, ResNet18_Weights
from sklearn.model_selection import KFold
import math
from torchvision.models import resnet50, ResNet50_Weights



class Frame:
    def __init__(self, data_type, room_label, sequence, file_name, color_image_path, pose):
        self.data_type = data_type
        self.room_label = room_label
        self.sequence = sequence
        self.file_name = file_name
        self.color_image_path = color_image_path
        self.pose = pose


def parse_pose_file(pose_file_path):
    with open(pose_file_path, 'r') as file:
        pose = np.array([list(map(float, line.strip().split())) for line in file]).flatten()
    return pose


def create_frame_objects(data_path, room_name, data_type):
    frames = []
    for seq_folder in os.listdir(data_path):
        seq_path = os.path.join(data_path, seq_folder)
        if os.path.isdir(seq_path):
            print(f"Processing sequence: {seq_folder} in {room_name} ({data_type})")
            for frame_file in os.listdir(seq_path):
                if frame_file.endswith('.color.png'):
                    frame_name = frame_file.split('.')[0]
                    color_image_path = os.path.join(seq_path, f"{frame_name}.color.png")
                    pose_file_path = os.path.join(seq_path, f"{frame_name}.pose.txt")
                    if os.path.exists(color_image_path) and os.path.exists(pose_file_path):
                        pose = parse_pose_file(pose_file_path)
                        frame = Frame(data_type, room_name, seq_folder, frame_name, color_image_path, pose)
                        frames.append(frame)
    return frames


def create_data_structure(data_folder):
    local_train_data = []
    local_test_data = []
    room_names = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']
    for room_name in room_names:
        train_path = os.path.join(data_folder, room_name, 'train')
        test_path = os.path.join(data_folder, room_name, 'test')
        local_train_data.extend(create_frame_objects(train_path, room_name, 'train'))
        local_test_data.extend(create_frame_objects(test_path, room_name, 'test'))
    return local_train_data, local_test_data


def create_data_structure_for_each_scene(data_folder, room_name):
    train_path = os.path.join(data_folder, room_name, 'train')
    test_path = os.path.join(data_folder, room_name, 'test')
    local_train_data = create_frame_objects(train_path, room_name, 'train')
    local_test_data = create_frame_objects(test_path, room_name, 'test')
    return local_train_data, local_test_data


your_path_to_data_folder = 'data'
# train_data, test_data = create_data_structure(your_path_to_data_folder)


class CustomDataset(Dataset):
    def __init__(self, frames, transform=None):
        self.frames = frames
        self.transform = transform

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        frame = self.frames[idx]
        image = Image.open(frame.color_image_path).convert('RGB')
        pose_matrix = np.array(frame.pose, dtype=np.float32).reshape(4, 4)
        translation = pose_matrix[:3, 3]
        rotation = pose_matrix[:3, :3]

        if self.transform:
            image = self.transform(image)
        return image, torch.from_numpy(translation), torch.from_numpy(rotation.flatten())


transformations = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class PoseModel(nn.Module):
    def __init__(self):
        super(PoseModel, self).__init__()
        weights = ResNet50_Weights.DEFAULT
        self.backbone = resnet50(weights=weights)
        self.fc_translation = nn.Linear(self.backbone.fc.in_features, 3)
        self.fc_rotation = nn.Linear(self.backbone.fc.in_features, 9)
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        features = self.backbone(x)
        translation = self.fc_translation(features)
        rotation = self.fc_rotation(features)
        return translation, rotation


# pose_model = PoseModel().to(device)
# optimizer = optim.SGD(pose_model.parameters(), lr=0.001, momentum=0.9)  # try adam optimizer
criterion = nn.MSELoss()


def rotation_matrix_to_angle_axis(rotation_matrices):
    """Convert a batch of rotation matrices to angle-axis vectors."""
    # Calculate the trace of each 3x3 rotation matrix in the batch
    traces = torch.einsum('bii->b', rotation_matrices)  # Sum over the diagonal elements in each matrix in the batch
    cos_thetas = (traces - 1) / 2.0
    cos_thetas = torch.clamp(cos_thetas, -1, 1)  # Numerical errors might make cos(theta) slightly out of its range
    thetas = torch.acos(cos_thetas)  # Angles

    # Initialize angle-axis vectors
    angle_axes = torch.zeros_like(rotation_matrices[:, :, 0])

    # Compute sin(theta) for normalization
    sin_thetas = torch.sin(thetas)

    # Find indices where theta is not too small (to avoid division by zero)
    valid = sin_thetas > 1e-5

    # For valid indices where theta is not too small, calculate angle-axis vectors
    angle_axes[valid] = torch.stack([
        rotation_matrices[valid, 2, 1] - rotation_matrices[valid, 1, 2],
        rotation_matrices[valid, 0, 2] - rotation_matrices[valid, 2, 0],
        rotation_matrices[valid, 1, 0] - rotation_matrices[valid, 0, 1]
    ], dim=1) / (2 * sin_thetas[valid].unsqueeze(1)) * thetas[valid].unsqueeze(1)

    return angle_axes


def rotation_error(pred_rot, gt_rot):
    """Calculate the angular distance between two rotation matrices."""
    pred_rot_matrix = pred_rot.view(-1, 3, 3)
    gt_rot_matrix = gt_rot.view(-1, 3, 3)
    r_diff = torch.matmul(pred_rot_matrix, gt_rot_matrix.transpose(1, 2))  # Relative rotation
    angle_axis = rotation_matrix_to_angle_axis(r_diff)
    return torch.norm(angle_axis, dim=1)  # Returns the magnitude of the angle-axis vector


def calculate_translation_error(pred, target):
    return torch.norm(pred - target, dim=1).mean()


scenes = [ 'heads', 'office']

for scene in scenes:
    pose_model = PoseModel().to(device)  # model re-initializes for each scene to train separately for each scene
    optimizer = optim.SGD(pose_model.parameters(), lr=0.001, momentum=0.9)

    train_data, test_data = create_data_structure_for_each_scene(your_path_to_data_folder, scene)

    train_dataset = CustomDataset(train_data, transform=transformations)
    test_dataset = CustomDataset(test_data, transform=transformations)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    best_fold = 0
    best_model_state = None

    n_splits = 5
    num_epochs = 15
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    indices = range(len(train_loader.dataset))  # warning is disregarded, since the code works correct

    for fold, (train_idx, val_idx) in enumerate(kfold.split(indices)):
        print("-" * 100)
        print(f"FOLD {fold} for {scene}")
        print("-" * 100)

        train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
        validation_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)

        current_train_loader = torch.utils.data.DataLoader(
            train_loader.dataset, batch_size=64, sampler=train_subsampler)
        current_validation_loader = torch.utils.data.DataLoader(
            train_loader.dataset, batch_size=64, sampler=validation_subsampler)

        best_loss = np.inf
        best_epoch = -1
        for epoch in tqdm(range(num_epochs), desc=f"Epochs Progress for {scene}"):
            pose_model.train()
            total_loss = 0.0
            total_translation_error = 0.0
            total_rotation_error = 0.0

            # Training loop
            for images, translations, rotations in tqdm(current_train_loader,
                                                        desc=f"Training Epoch {epoch + 1} for {scene}", leave=False):
                images = images.to(device)
                translations = translations.to(device)
                rotations = rotations.view(-1, 3, 3).to(device)

                optimizer.zero_grad()
                trans_pred, rot_pred = pose_model(images)
                rot_pred = rot_pred.view(-1, 3, 3)

                loss_translation = criterion(trans_pred, translations)
                loss_rotation = criterion(rot_pred.view(-1, 9), rotations.view(-1, 9))
                loss = loss_translation + loss_rotation
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                translation_error = calculate_translation_error(trans_pred, translations)
                total_translation_error += translation_error.item()
                rotation_error_batch = rotation_error(rot_pred, rotations).mean().item()
                total_rotation_error += rotation_error_batch

            # Validation loop
            pose_model.eval()
            validation_loss = 0.0
            with torch.no_grad():
                for images, translations, rotations in tqdm(current_validation_loader,
                                                            desc=f"Validating Epoch {epoch + 1}",
                                                            leave=False):
                    images = images.to(device)
                    translations = translations.to(device)
                    rotations = rotations.view(-1, 3, 3).to(device)

                    trans_pred, rot_pred = pose_model(images)
                    rot_pred = rot_pred.view(-1, 3, 3)

                    loss_translation = criterion(trans_pred, translations)
                    loss_rotation = criterion(rot_pred.view(-1, 9), rotations.view(-1, 9))
                    loss = loss_translation + loss_rotation
                    validation_loss += loss.item()

            validation_loss /= len(current_validation_loader)

            if validation_loss < best_loss:
                best_loss = validation_loss
                best_model_state = pose_model.state_dict()  # save model state not the model itself
                best_fold = fold
                print()
                print(f"New best model found for {scene} in fold {fold} with validation loss {best_loss:.4f}")

            average_loss = total_loss / len(train_loader)
            average_rotation_error = total_rotation_error / len(train_loader)
            average_translation_error = total_translation_error / len(train_loader)
            print(f"For {scene} in Epoch {epoch + 1}, Validation loss: {validation_loss:.4f}")
            print(f' For {scene} in Epoch {epoch + 1}, Average Loss: {average_loss:.4f}, Average Translation Error: '
                  f'{average_translation_error:.4f}, Average Rotation Error (radians): {average_rotation_error:.4f}')

        print("-" * 50)
        print(f"Training complete for {scene} in Fold {fold}.")
        print(f"(Best Epoch for Fold {fold}: {best_epoch + 1} with Validation Loss: {best_loss:.4f}")

    # after all folds save the best model path
    if best_model_state:
        torch.save(best_model_state, f'best_pose_model_{scene}.pth')
        print(f"Best model for {scene} from Fold {best_fold} saved with loss {best_loss:.4f}")

    # To use the best model
    pose_model.load_state_dict(torch.load(f'best_pose_model_{scene}.pth'))
    pose_model.eval()

    total_translation_error = 0.0
    total_rotation_error = 0.0
    count = 0

    # No gradient needed for test part
    with torch.no_grad():
        for images, translations, rotations in test_loader:
            images = images.to(device)
            translations = translations.to(device)
            rotations = rotations.view(-1, 3, 3).to(device)  # 3x3 rotation matrices

            trans_pred, rot_pred = pose_model(images)
            rot_pred = rot_pred.view(-1, 3, 3)

            # Calculate errors
            translation_error = calculate_translation_error(trans_pred, translations)
            rotation_error_batch = rotation_error(rot_pred, rotations).mean().item()

            total_translation_error += translation_error.item()
            total_rotation_error += rotation_error_batch
            count += 1

        # Calculate average errors
        average_translation_error = total_translation_error / count
        average_rotation_error = total_rotation_error / count
        average_rotation_error_in_degrees = average_rotation_error * (180 / math.pi)

        print(f"Performance of the best model for {scene} on the test data was from Fold {best_fold}:")
        print(f"Average Translation Error for {scene}: {average_translation_error:.4f} meters")
        print(f"Average Rotation Error (degree) for {scene}: {average_rotation_error_in_degrees:.4f}")
        print("-" * 50)


Processing sequence: seq-02 in heads (train)
Processing sequence: seq-01 in heads (test)
----------------------------------------------------------------------------------------------------
FOLD 0 for heads
----------------------------------------------------------------------------------------------------


Epochs Progress for heads:   7%|▋         | 1/15 [04:46<1:06:57, 286.99s/it]


New best model found for heads in fold 0 with validation loss 0.2088
For heads in Epoch 1, Validation loss: 0.2088
 For heads in Epoch 1, Average Loss: 0.1319, Average Translation Error: 0.1768, Average Rotation Error (radians): 0.0975


Epochs Progress for heads:  13%|█▎        | 2/15 [07:25<45:49, 211.46s/it]  


New best model found for heads in fold 0 with validation loss 0.0879
For heads in Epoch 2, Validation loss: 0.0879
 For heads in Epoch 2, Average Loss: 0.0575, Average Translation Error: 0.1416, Average Rotation Error (radians): 0.1182


Epochs Progress for heads:  20%|██        | 3/15 [10:01<37:10, 185.91s/it]


New best model found for heads in fold 0 with validation loss 0.0647
For heads in Epoch 3, Validation loss: 0.0647
 For heads in Epoch 3, Average Loss: 0.0321, Average Translation Error: 0.1099, Average Rotation Error (radians): 0.0924


Epochs Progress for heads:  27%|██▋       | 4/15 [12:35<31:46, 173.29s/it]


New best model found for heads in fold 0 with validation loss 0.0543
For heads in Epoch 4, Validation loss: 0.0543
 For heads in Epoch 4, Average Loss: 0.0249, Average Translation Error: 0.0894, Average Rotation Error (radians): 0.0839


Epochs Progress for heads:  33%|███▎      | 5/15 [15:07<27:38, 165.89s/it]


New best model found for heads in fold 0 with validation loss 0.0481
For heads in Epoch 5, Validation loss: 0.0481
 For heads in Epoch 5, Average Loss: 0.0197, Average Translation Error: 0.0784, Average Rotation Error (radians): 0.0990


Epochs Progress for heads:  40%|████      | 6/15 [17:39<24:10, 161.11s/it]


New best model found for heads in fold 0 with validation loss 0.0410
For heads in Epoch 6, Validation loss: 0.0410
 For heads in Epoch 6, Average Loss: 0.0168, Average Translation Error: 0.0708, Average Rotation Error (radians): 0.0949


Epochs Progress for heads:  47%|████▋     | 7/15 [20:11<21:05, 158.16s/it]


New best model found for heads in fold 0 with validation loss 0.0358
For heads in Epoch 7, Validation loss: 0.0358
 For heads in Epoch 7, Average Loss: 0.0150, Average Translation Error: 0.0662, Average Rotation Error (radians): 0.0867


Epochs Progress for heads:  53%|█████▎    | 8/15 [22:22<17:25, 149.40s/it]


New best model found for heads in fold 0 with validation loss 0.0325
For heads in Epoch 8, Validation loss: 0.0325
 For heads in Epoch 8, Average Loss: 0.0131, Average Translation Error: 0.0610, Average Rotation Error (radians): 0.0773


Epochs Progress for heads:  60%|██████    | 9/15 [24:28<14:13, 142.19s/it]


New best model found for heads in fold 0 with validation loss 0.0298
For heads in Epoch 9, Validation loss: 0.0298
 For heads in Epoch 9, Average Loss: 0.0116, Average Translation Error: 0.0572, Average Rotation Error (radians): 0.0729


Epochs Progress for heads:  67%|██████▋   | 10/15 [26:35<11:27, 137.42s/it]


New best model found for heads in fold 0 with validation loss 0.0285
For heads in Epoch 10, Validation loss: 0.0285
 For heads in Epoch 10, Average Loss: 0.0108, Average Translation Error: 0.0555, Average Rotation Error (radians): 0.0684


Epochs Progress for heads:  73%|███████▎  | 11/15 [28:42<08:57, 134.26s/it]


New best model found for heads in fold 0 with validation loss 0.0256
For heads in Epoch 11, Validation loss: 0.0256
 For heads in Epoch 11, Average Loss: 0.0098, Average Translation Error: 0.0521, Average Rotation Error (radians): 0.0630


Epochs Progress for heads:  80%|████████  | 12/15 [30:49<06:35, 131.93s/it]


New best model found for heads in fold 0 with validation loss 0.0234
For heads in Epoch 12, Validation loss: 0.0234
 For heads in Epoch 12, Average Loss: 0.0090, Average Translation Error: 0.0506, Average Rotation Error (radians): 0.0599


Epochs Progress for heads:  87%|████████▋ | 13/15 [32:56<04:21, 130.58s/it]


New best model found for heads in fold 0 with validation loss 0.0223
For heads in Epoch 13, Validation loss: 0.0223
 For heads in Epoch 13, Average Loss: 0.0084, Average Translation Error: 0.0487, Average Rotation Error (radians): 0.0572


Epochs Progress for heads:  93%|█████████▎| 14/15 [35:02<02:09, 129.04s/it]


New best model found for heads in fold 0 with validation loss 0.0209
For heads in Epoch 14, Validation loss: 0.0209
 For heads in Epoch 14, Average Loss: 0.0080, Average Translation Error: 0.0478, Average Rotation Error (radians): 0.0550


Epochs Progress for heads: 100%|██████████| 15/15 [37:07<00:00, 148.52s/it]



New best model found for heads in fold 0 with validation loss 0.0195
For heads in Epoch 15, Validation loss: 0.0195
 For heads in Epoch 15, Average Loss: 0.0070, Average Translation Error: 0.0447, Average Rotation Error (radians): 0.0480
--------------------------------------------------
Training complete for heads in Fold 0.
(Best Epoch for Fold 0: 0 with Validation Loss: 0.0195
----------------------------------------------------------------------------------------------------
FOLD 1 for heads
----------------------------------------------------------------------------------------------------


Epochs Progress for heads:   7%|▋         | 1/15 [02:06<29:25, 126.09s/it]


New best model found for heads in fold 1 with validation loss 0.0148
For heads in Epoch 1, Validation loss: 0.0148
 For heads in Epoch 1, Average Loss: 0.0074, Average Translation Error: 0.0488, Average Rotation Error (radians): 0.0465


Epochs Progress for heads:  13%|█▎        | 2/15 [04:11<27:14, 125.72s/it]


New best model found for heads in fold 1 with validation loss 0.0147
For heads in Epoch 2, Validation loss: 0.0147
 For heads in Epoch 2, Average Loss: 0.0067, Average Translation Error: 0.0463, Average Rotation Error (radians): 0.0440


Epochs Progress for heads:  20%|██        | 3/15 [06:17<25:09, 125.82s/it]


New best model found for heads in fold 1 with validation loss 0.0141
For heads in Epoch 3, Validation loss: 0.0141
 For heads in Epoch 3, Average Loss: 0.0064, Average Translation Error: 0.0452, Average Rotation Error (radians): 0.0429


Epochs Progress for heads:  27%|██▋       | 4/15 [08:23<23:03, 125.78s/it]

For heads in Epoch 4, Validation loss: 0.0144
 For heads in Epoch 4, Average Loss: 0.0061, Average Translation Error: 0.0437, Average Rotation Error (radians): 0.0414


Epochs Progress for heads:  33%|███▎      | 5/15 [10:30<21:02, 126.23s/it]


New best model found for heads in fold 1 with validation loss 0.0124
For heads in Epoch 5, Validation loss: 0.0124
 For heads in Epoch 5, Average Loss: 0.0057, Average Translation Error: 0.0424, Average Rotation Error (radians): 0.0395


Epochs Progress for heads:  40%|████      | 6/15 [12:36<18:55, 126.11s/it]


New best model found for heads in fold 1 with validation loss 0.0120
For heads in Epoch 6, Validation loss: 0.0120
 For heads in Epoch 6, Average Loss: 0.0054, Average Translation Error: 0.0417, Average Rotation Error (radians): 0.0374


Epochs Progress for heads:  47%|████▋     | 7/15 [14:42<16:50, 126.26s/it]


New best model found for heads in fold 1 with validation loss 0.0119
For heads in Epoch 7, Validation loss: 0.0119
 For heads in Epoch 7, Average Loss: 0.0051, Average Translation Error: 0.0405, Average Rotation Error (radians): 0.0357


Epochs Progress for heads:  53%|█████▎    | 8/15 [16:47<14:41, 125.94s/it]

For heads in Epoch 8, Validation loss: 0.0124
 For heads in Epoch 8, Average Loss: 0.0051, Average Translation Error: 0.0408, Average Rotation Error (radians): 0.0358


Epochs Progress for heads:  60%|██████    | 9/15 [18:53<12:35, 125.93s/it]


New best model found for heads in fold 1 with validation loss 0.0111
For heads in Epoch 9, Validation loss: 0.0111
 For heads in Epoch 9, Average Loss: 0.0050, Average Translation Error: 0.0415, Average Rotation Error (radians): 0.0337


Epochs Progress for heads:  67%|██████▋   | 10/15 [20:59<10:29, 125.90s/it]

For heads in Epoch 10, Validation loss: 0.0116
 For heads in Epoch 10, Average Loss: 0.0045, Average Translation Error: 0.0380, Average Rotation Error (radians): 0.0310


Epochs Progress for heads:  73%|███████▎  | 11/15 [23:06<08:24, 126.10s/it]


New best model found for heads in fold 1 with validation loss 0.0107
For heads in Epoch 11, Validation loss: 0.0107
 For heads in Epoch 11, Average Loss: 0.0047, Average Translation Error: 0.0400, Average Rotation Error (radians): 0.0324


Epochs Progress for heads:  80%|████████  | 12/15 [25:12<06:18, 126.04s/it]

For heads in Epoch 12, Validation loss: 0.0108
 For heads in Epoch 12, Average Loss: 0.0043, Average Translation Error: 0.0380, Average Rotation Error (radians): 0.0313


Epochs Progress for heads:  87%|████████▋ | 13/15 [27:17<04:11, 125.85s/it]

For heads in Epoch 13, Validation loss: 0.0112
 For heads in Epoch 13, Average Loss: 0.0040, Average Translation Error: 0.0364, Average Rotation Error (radians): 0.0279


Epochs Progress for heads:  93%|█████████▎| 14/15 [29:23<02:05, 125.82s/it]


New best model found for heads in fold 1 with validation loss 0.0107
For heads in Epoch 14, Validation loss: 0.0107
 For heads in Epoch 14, Average Loss: 0.0040, Average Translation Error: 0.0366, Average Rotation Error (radians): 0.0292


Epochs Progress for heads: 100%|██████████| 15/15 [31:28<00:00, 125.91s/it]



New best model found for heads in fold 1 with validation loss 0.0096
For heads in Epoch 15, Validation loss: 0.0096
 For heads in Epoch 15, Average Loss: 0.0039, Average Translation Error: 0.0355, Average Rotation Error (radians): 0.0293
--------------------------------------------------
Training complete for heads in Fold 1.
(Best Epoch for Fold 1: 0 with Validation Loss: 0.0096
----------------------------------------------------------------------------------------------------
FOLD 2 for heads
----------------------------------------------------------------------------------------------------


Epochs Progress for heads:   7%|▋         | 1/15 [02:05<29:11, 125.08s/it]


New best model found for heads in fold 2 with validation loss 0.0072
For heads in Epoch 1, Validation loss: 0.0072
 For heads in Epoch 1, Average Loss: 0.0043, Average Translation Error: 0.0370, Average Rotation Error (radians): 0.0316


Epochs Progress for heads:  13%|█▎        | 2/15 [04:10<27:10, 125.40s/it]


New best model found for heads in fold 2 with validation loss 0.0067
For heads in Epoch 2, Validation loss: 0.0067
 For heads in Epoch 2, Average Loss: 0.0040, Average Translation Error: 0.0365, Average Rotation Error (radians): 0.0284


Epochs Progress for heads:  20%|██        | 3/15 [06:15<25:04, 125.35s/it]

For heads in Epoch 3, Validation loss: 0.0070
 For heads in Epoch 3, Average Loss: 0.0038, Average Translation Error: 0.0351, Average Rotation Error (radians): 0.0259


Epochs Progress for heads:  27%|██▋       | 4/15 [08:21<22:58, 125.33s/it]

For heads in Epoch 4, Validation loss: 0.0069
 For heads in Epoch 4, Average Loss: 0.0036, Average Translation Error: 0.0346, Average Rotation Error (radians): 0.0264


Epochs Progress for heads:  33%|███▎      | 5/15 [10:26<20:53, 125.38s/it]


New best model found for heads in fold 2 with validation loss 0.0064
For heads in Epoch 5, Validation loss: 0.0064
 For heads in Epoch 5, Average Loss: 0.0035, Average Translation Error: 0.0337, Average Rotation Error (radians): 0.0272


Epochs Progress for heads:  40%|████      | 6/15 [12:32<18:48, 125.42s/it]


New best model found for heads in fold 2 with validation loss 0.0064
For heads in Epoch 6, Validation loss: 0.0064
 For heads in Epoch 6, Average Loss: 0.0036, Average Translation Error: 0.0346, Average Rotation Error (radians): 0.0271


Epochs Progress for heads:  47%|████▋     | 7/15 [14:37<16:43, 125.46s/it]

For heads in Epoch 7, Validation loss: 0.0065
 For heads in Epoch 7, Average Loss: 0.0037, Average Translation Error: 0.0346, Average Rotation Error (radians): 0.0275


Epochs Progress for heads:  53%|█████▎    | 8/15 [16:43<14:37, 125.39s/it]


New best model found for heads in fold 2 with validation loss 0.0063
For heads in Epoch 8, Validation loss: 0.0063
 For heads in Epoch 8, Average Loss: 0.0034, Average Translation Error: 0.0342, Average Rotation Error (radians): 0.0244


Epochs Progress for heads:  60%|██████    | 9/15 [18:48<12:32, 125.34s/it]

For heads in Epoch 9, Validation loss: 0.0066
 For heads in Epoch 9, Average Loss: 0.0033, Average Translation Error: 0.0322, Average Rotation Error (radians): 0.0244


Epochs Progress for heads:  67%|██████▋   | 10/15 [20:54<10:28, 125.70s/it]

For heads in Epoch 10, Validation loss: 0.0063
 For heads in Epoch 10, Average Loss: 0.0032, Average Translation Error: 0.0322, Average Rotation Error (radians): 0.0244


Epochs Progress for heads:  73%|███████▎  | 11/15 [23:00<08:23, 125.82s/it]

For heads in Epoch 11, Validation loss: 0.0066
 For heads in Epoch 11, Average Loss: 0.0032, Average Translation Error: 0.0330, Average Rotation Error (radians): 0.0237


Epochs Progress for heads:  80%|████████  | 12/15 [25:06<06:17, 125.78s/it]

For heads in Epoch 12, Validation loss: 0.0066
 For heads in Epoch 12, Average Loss: 0.0032, Average Translation Error: 0.0329, Average Rotation Error (radians): 0.0224


Epochs Progress for heads:  87%|████████▋ | 13/15 [27:11<04:11, 125.66s/it]

For heads in Epoch 13, Validation loss: 0.0068
 For heads in Epoch 13, Average Loss: 0.0031, Average Translation Error: 0.0318, Average Rotation Error (radians): 0.0241


Epochs Progress for heads:  93%|█████████▎| 14/15 [29:17<02:05, 125.59s/it]


New best model found for heads in fold 2 with validation loss 0.0059
For heads in Epoch 14, Validation loss: 0.0059
 For heads in Epoch 14, Average Loss: 0.0031, Average Translation Error: 0.0324, Average Rotation Error (radians): 0.0221


Epochs Progress for heads: 100%|██████████| 15/15 [31:23<00:00, 125.56s/it]


For heads in Epoch 15, Validation loss: 0.0062
 For heads in Epoch 15, Average Loss: 0.0029, Average Translation Error: 0.0304, Average Rotation Error (radians): 0.0218
--------------------------------------------------
Training complete for heads in Fold 2.
(Best Epoch for Fold 2: 0 with Validation Loss: 0.0059
----------------------------------------------------------------------------------------------------
FOLD 3 for heads
----------------------------------------------------------------------------------------------------


Epochs Progress for heads:   7%|▋         | 1/15 [02:06<29:24, 126.03s/it]


New best model found for heads in fold 3 with validation loss 0.0051
For heads in Epoch 1, Validation loss: 0.0051
 For heads in Epoch 1, Average Loss: 0.0030, Average Translation Error: 0.0315, Average Rotation Error (radians): 0.0212


Epochs Progress for heads:  13%|█▎        | 2/15 [04:11<27:17, 125.94s/it]


New best model found for heads in fold 3 with validation loss 0.0048
For heads in Epoch 2, Validation loss: 0.0048
 For heads in Epoch 2, Average Loss: 0.0030, Average Translation Error: 0.0324, Average Rotation Error (radians): 0.0212


Epochs Progress for heads:  20%|██        | 3/15 [06:17<25:08, 125.68s/it]

For heads in Epoch 3, Validation loss: 0.0051
 For heads in Epoch 3, Average Loss: 0.0029, Average Translation Error: 0.0314, Average Rotation Error (radians): 0.0221


Epochs Progress for heads:  27%|██▋       | 4/15 [08:22<23:01, 125.58s/it]

For heads in Epoch 4, Validation loss: 0.0048
 For heads in Epoch 4, Average Loss: 0.0028, Average Translation Error: 0.0306, Average Rotation Error (radians): 0.0217


Epochs Progress for heads:  33%|███▎      | 5/15 [10:27<20:54, 125.47s/it]

For heads in Epoch 5, Validation loss: 0.0048
 For heads in Epoch 5, Average Loss: 0.0030, Average Translation Error: 0.0314, Average Rotation Error (radians): 0.0222


Epochs Progress for heads:  40%|████      | 6/15 [12:34<18:51, 125.67s/it]

For heads in Epoch 6, Validation loss: 0.0053
 For heads in Epoch 6, Average Loss: 0.0028, Average Translation Error: 0.0317, Average Rotation Error (radians): 0.0193


Epochs Progress for heads:  47%|████▋     | 7/15 [14:39<16:45, 125.70s/it]

For heads in Epoch 7, Validation loss: 0.0048
 For heads in Epoch 7, Average Loss: 0.0025, Average Translation Error: 0.0290, Average Rotation Error (radians): 0.0196


Epochs Progress for heads:  53%|█████▎    | 8/15 [16:45<14:39, 125.64s/it]


New best model found for heads in fold 3 with validation loss 0.0047
For heads in Epoch 8, Validation loss: 0.0047
 For heads in Epoch 8, Average Loss: 0.0028, Average Translation Error: 0.0318, Average Rotation Error (radians): 0.0206


Epochs Progress for heads:  60%|██████    | 9/15 [18:51<12:34, 125.76s/it]

For heads in Epoch 9, Validation loss: 0.0050
 For heads in Epoch 9, Average Loss: 0.0026, Average Translation Error: 0.0278, Average Rotation Error (radians): 0.0203


Epochs Progress for heads:  67%|██████▋   | 10/15 [20:56<10:28, 125.70s/it]


New best model found for heads in fold 3 with validation loss 0.0044
For heads in Epoch 10, Validation loss: 0.0044
 For heads in Epoch 10, Average Loss: 0.0027, Average Translation Error: 0.0301, Average Rotation Error (radians): 0.0202


Epochs Progress for heads:  73%|███████▎  | 11/15 [23:03<08:23, 125.86s/it]

For heads in Epoch 11, Validation loss: 0.0048
 For heads in Epoch 11, Average Loss: 0.0025, Average Translation Error: 0.0291, Average Rotation Error (radians): 0.0201


Epochs Progress for heads:  80%|████████  | 12/15 [25:08<06:17, 125.67s/it]

For heads in Epoch 12, Validation loss: 0.0048
 For heads in Epoch 12, Average Loss: 0.0025, Average Translation Error: 0.0293, Average Rotation Error (radians): 0.0195


Epochs Progress for heads:  87%|████████▋ | 13/15 [27:14<04:11, 125.67s/it]

For heads in Epoch 13, Validation loss: 0.0044
 For heads in Epoch 13, Average Loss: 0.0024, Average Translation Error: 0.0280, Average Rotation Error (radians): 0.0208


Epochs Progress for heads:  93%|█████████▎| 14/15 [29:19<02:05, 125.64s/it]

For heads in Epoch 14, Validation loss: 0.0047
 For heads in Epoch 14, Average Loss: 0.0023, Average Translation Error: 0.0272, Average Rotation Error (radians): 0.0187


Epochs Progress for heads: 100%|██████████| 15/15 [31:25<00:00, 125.69s/it]



New best model found for heads in fold 3 with validation loss 0.0043
For heads in Epoch 15, Validation loss: 0.0043
 For heads in Epoch 15, Average Loss: 0.0024, Average Translation Error: 0.0284, Average Rotation Error (radians): 0.0200
--------------------------------------------------
Training complete for heads in Fold 3.
(Best Epoch for Fold 3: 0 with Validation Loss: 0.0043
----------------------------------------------------------------------------------------------------
FOLD 4 for heads
----------------------------------------------------------------------------------------------------


Epochs Progress for heads:   7%|▋         | 1/15 [02:05<29:16, 125.46s/it]


New best model found for heads in fold 4 with validation loss 0.0035
For heads in Epoch 1, Validation loss: 0.0035
 For heads in Epoch 1, Average Loss: 0.0027, Average Translation Error: 0.0302, Average Rotation Error (radians): 0.0196


Epochs Progress for heads:  13%|█▎        | 2/15 [04:10<27:08, 125.26s/it]

For heads in Epoch 2, Validation loss: 0.0036
 For heads in Epoch 2, Average Loss: 0.0022, Average Translation Error: 0.0262, Average Rotation Error (radians): 0.0185


Epochs Progress for heads:  20%|██        | 3/15 [06:15<25:02, 125.22s/it]

For heads in Epoch 3, Validation loss: 0.0035
 For heads in Epoch 3, Average Loss: 0.0025, Average Translation Error: 0.0285, Average Rotation Error (radians): 0.0191


Epochs Progress for heads:  27%|██▋       | 4/15 [08:20<22:57, 125.22s/it]

For heads in Epoch 4, Validation loss: 0.0035
 For heads in Epoch 4, Average Loss: 0.0024, Average Translation Error: 0.0277, Average Rotation Error (radians): 0.0185


Epochs Progress for heads:  33%|███▎      | 5/15 [10:27<20:58, 125.83s/it]


New best model found for heads in fold 4 with validation loss 0.0034
For heads in Epoch 5, Validation loss: 0.0034
 For heads in Epoch 5, Average Loss: 0.0021, Average Translation Error: 0.0259, Average Rotation Error (radians): 0.0172


Epochs Progress for heads:  40%|████      | 6/15 [12:33<18:52, 125.81s/it]


New best model found for heads in fold 4 with validation loss 0.0033
For heads in Epoch 6, Validation loss: 0.0033
 For heads in Epoch 6, Average Loss: 0.0025, Average Translation Error: 0.0302, Average Rotation Error (radians): 0.0182


Epochs Progress for heads:  47%|████▋     | 7/15 [14:38<16:45, 125.64s/it]

For heads in Epoch 7, Validation loss: 0.0036
 For heads in Epoch 7, Average Loss: 0.0022, Average Translation Error: 0.0271, Average Rotation Error (radians): 0.0164


Epochs Progress for heads:  53%|█████▎    | 8/15 [16:45<14:41, 125.96s/it]

For heads in Epoch 8, Validation loss: 0.0035
 For heads in Epoch 8, Average Loss: 0.0020, Average Translation Error: 0.0258, Average Rotation Error (radians): 0.0161


Epochs Progress for heads:  60%|██████    | 9/15 [18:51<12:35, 125.89s/it]


New best model found for heads in fold 4 with validation loss 0.0033
For heads in Epoch 9, Validation loss: 0.0033
 For heads in Epoch 9, Average Loss: 0.0022, Average Translation Error: 0.0274, Average Rotation Error (radians): 0.0180


Epochs Progress for heads:  67%|██████▋   | 10/15 [20:57<10:29, 125.92s/it]

For heads in Epoch 10, Validation loss: 0.0035
 For heads in Epoch 10, Average Loss: 0.0023, Average Translation Error: 0.0268, Average Rotation Error (radians): 0.0207


Epochs Progress for heads:  73%|███████▎  | 11/15 [23:02<08:23, 125.83s/it]

For heads in Epoch 11, Validation loss: 0.0034
 For heads in Epoch 11, Average Loss: 0.0021, Average Translation Error: 0.0262, Average Rotation Error (radians): 0.0168


Epochs Progress for heads:  80%|████████  | 12/15 [25:08<06:17, 125.71s/it]

For heads in Epoch 12, Validation loss: 0.0035
 For heads in Epoch 12, Average Loss: 0.0021, Average Translation Error: 0.0268, Average Rotation Error (radians): 0.0163


Epochs Progress for heads:  87%|████████▋ | 13/15 [27:14<04:11, 125.77s/it]

For heads in Epoch 13, Validation loss: 0.0035
 For heads in Epoch 13, Average Loss: 0.0021, Average Translation Error: 0.0263, Average Rotation Error (radians): 0.0167


Epochs Progress for heads:  93%|█████████▎| 14/15 [29:19<02:05, 125.74s/it]

For heads in Epoch 14, Validation loss: 0.0034
 For heads in Epoch 14, Average Loss: 0.0020, Average Translation Error: 0.0244, Average Rotation Error (radians): 0.0157


Epochs Progress for heads: 100%|██████████| 15/15 [31:25<00:00, 125.73s/it]


For heads in Epoch 15, Validation loss: 0.0033
 For heads in Epoch 15, Average Loss: 0.0021, Average Translation Error: 0.0268, Average Rotation Error (radians): 0.0181
--------------------------------------------------
Training complete for heads in Fold 4.
(Best Epoch for Fold 4: 0 with Validation Loss: 0.0033
Best model for heads from Fold 4 saved with loss 0.0033
Performance of the best model for heads on the test data was from Fold 4:
Average Translation Error for heads: 0.2727 meters
Average Rotation Error (degree) for heads: 14.2431
--------------------------------------------------
Processing sequence: seq-01 in office (train)
Processing sequence: seq-03 in office (train)
Processing sequence: seq-04 in office (train)
Processing sequence: seq-05 in office (train)
Processing sequence: seq-08 in office (train)
Processing sequence: seq-10 in office (train)
Processing sequence: seq-02 in office (test)
Processing sequence: seq-06 in office (test)
Processing sequence: seq-07 in office

Epochs Progress for office:   7%|▋         | 1/15 [12:49<2:59:30, 769.36s/it]


New best model found for office in fold 0 with validation loss 0.2250
For office in Epoch 1, Validation loss: 0.2250
 For office in Epoch 1, Average Loss: 0.1759, Average Translation Error: 0.3159, Average Rotation Error (radians): 0.1700


Epochs Progress for office:  13%|█▎        | 2/15 [25:43<2:47:20, 772.36s/it]


New best model found for office in fold 0 with validation loss 0.1231
For office in Epoch 2, Validation loss: 0.1231
 For office in Epoch 2, Average Loss: 0.0672, Average Translation Error: 0.1776, Average Rotation Error (radians): 0.1478


Epochs Progress for office:  20%|██        | 3/15 [38:35<2:34:27, 772.26s/it]


New best model found for office in fold 0 with validation loss 0.0806
For office in Epoch 3, Validation loss: 0.0806
 For office in Epoch 3, Average Loss: 0.0393, Average Translation Error: 0.1329, Average Rotation Error (radians): 0.1077


Epochs Progress for office:  27%|██▋       | 4/15 [51:30<2:21:44, 773.10s/it]


New best model found for office in fold 0 with validation loss 0.0624
For office in Epoch 4, Validation loss: 0.0624
 For office in Epoch 4, Average Loss: 0.0275, Average Translation Error: 0.1116, Average Rotation Error (radians): 0.0836


Epochs Progress for office:  33%|███▎      | 5/15 [1:04:24<2:08:53, 773.38s/it]


New best model found for office in fold 0 with validation loss 0.0509
For office in Epoch 5, Validation loss: 0.0509
 For office in Epoch 5, Average Loss: 0.0211, Average Translation Error: 0.0978, Average Rotation Error (radians): 0.0700


Epochs Progress for office:  40%|████      | 6/15 [1:17:18<1:56:03, 773.78s/it]


New best model found for office in fold 0 with validation loss 0.0436
For office in Epoch 6, Validation loss: 0.0436
 For office in Epoch 6, Average Loss: 0.0179, Average Translation Error: 0.0899, Average Rotation Error (radians): 0.0621


Epochs Progress for office:  47%|████▋     | 7/15 [1:30:12<1:43:10, 773.84s/it]


New best model found for office in fold 0 with validation loss 0.0401
For office in Epoch 7, Validation loss: 0.0401
 For office in Epoch 7, Average Loss: 0.0153, Average Translation Error: 0.0834, Average Rotation Error (radians): 0.0553


Epochs Progress for office:  53%|█████▎    | 8/15 [1:43:07<1:30:19, 774.21s/it]


New best model found for office in fold 0 with validation loss 0.0363
For office in Epoch 8, Validation loss: 0.0363
 For office in Epoch 8, Average Loss: 0.0136, Average Translation Error: 0.0779, Average Rotation Error (radians): 0.0504


Epochs Progress for office:  60%|██████    | 9/15 [1:56:01<1:17:24, 774.01s/it]


New best model found for office in fold 0 with validation loss 0.0327
For office in Epoch 9, Validation loss: 0.0327
 For office in Epoch 9, Average Loss: 0.0121, Average Translation Error: 0.0735, Average Rotation Error (radians): 0.0470


Epochs Progress for office:  67%|██████▋   | 10/15 [2:08:58<1:04:35, 775.01s/it]


New best model found for office in fold 0 with validation loss 0.0319
For office in Epoch 10, Validation loss: 0.0319
 For office in Epoch 10, Average Loss: 0.0112, Average Translation Error: 0.0715, Average Rotation Error (radians): 0.0443


Epochs Progress for office:  73%|███████▎  | 11/15 [2:21:58<51:46, 776.63s/it]  


New best model found for office in fold 0 with validation loss 0.0290
For office in Epoch 11, Validation loss: 0.0290
 For office in Epoch 11, Average Loss: 0.0103, Average Translation Error: 0.0688, Average Rotation Error (radians): 0.0413


Epochs Progress for office:  80%|████████  | 12/15 [2:34:49<38:44, 774.77s/it]


New best model found for office in fold 0 with validation loss 0.0284
For office in Epoch 12, Validation loss: 0.0284
 For office in Epoch 12, Average Loss: 0.0099, Average Translation Error: 0.0684, Average Rotation Error (radians): 0.0398


Epochs Progress for office:  87%|████████▋ | 13/15 [2:47:42<25:48, 774.37s/it]


New best model found for office in fold 0 with validation loss 0.0268
For office in Epoch 13, Validation loss: 0.0268
 For office in Epoch 13, Average Loss: 0.0090, Average Translation Error: 0.0645, Average Rotation Error (radians): 0.0377


Epochs Progress for office:  93%|█████████▎| 14/15 [3:00:28<12:51, 771.65s/it]

For office in Epoch 14, Validation loss: 0.0271
 For office in Epoch 14, Average Loss: 0.0085, Average Translation Error: 0.0621, Average Rotation Error (radians): 0.0365


Epochs Progress for office: 100%|██████████| 15/15 [3:13:20<00:00, 773.37s/it]



New best model found for office in fold 0 with validation loss 0.0248
For office in Epoch 15, Validation loss: 0.0248
 For office in Epoch 15, Average Loss: 0.0082, Average Translation Error: 0.0609, Average Rotation Error (radians): 0.0364
--------------------------------------------------
Training complete for office in Fold 0.
(Best Epoch for Fold 0: 0 with Validation Loss: 0.0248
----------------------------------------------------------------------------------------------------
FOLD 1 for office
----------------------------------------------------------------------------------------------------


Epochs Progress for office:   7%|▋         | 1/15 [12:47<2:59:07, 767.64s/it]


New best model found for office in fold 1 with validation loss 0.0170
For office in Epoch 1, Validation loss: 0.0170
 For office in Epoch 1, Average Loss: 0.0086, Average Translation Error: 0.0644, Average Rotation Error (radians): 0.0347


Epochs Progress for office:  13%|█▎        | 2/15 [25:32<2:45:55, 765.78s/it]


New best model found for office in fold 1 with validation loss 0.0158
For office in Epoch 2, Validation loss: 0.0158
 For office in Epoch 2, Average Loss: 0.0078, Average Translation Error: 0.0608, Average Rotation Error (radians): 0.0330


Epochs Progress for office:  20%|██        | 3/15 [38:17<2:33:08, 765.69s/it]


New best model found for office in fold 1 with validation loss 0.0145
For office in Epoch 3, Validation loss: 0.0145
 For office in Epoch 3, Average Loss: 0.0076, Average Translation Error: 0.0604, Average Rotation Error (radians): 0.0332


Epochs Progress for office:  27%|██▋       | 4/15 [51:05<2:20:32, 766.63s/it]

For office in Epoch 4, Validation loss: 0.0154
 For office in Epoch 4, Average Loss: 0.0073, Average Translation Error: 0.0596, Average Rotation Error (radians): 0.0312


Epochs Progress for office:  33%|███▎      | 5/15 [1:03:52<2:07:47, 766.78s/it]

For office in Epoch 5, Validation loss: 0.0150
 For office in Epoch 5, Average Loss: 0.0070, Average Translation Error: 0.0576, Average Rotation Error (radians): 0.0315


Epochs Progress for office:  40%|████      | 6/15 [1:16:42<1:55:10, 767.80s/it]

For office in Epoch 6, Validation loss: 0.0145
 For office in Epoch 6, Average Loss: 0.0069, Average Translation Error: 0.0576, Average Rotation Error (radians): 0.0308


Epochs Progress for office:  47%|████▋     | 7/15 [1:29:27<1:42:13, 766.74s/it]


New best model found for office in fold 1 with validation loss 0.0134
For office in Epoch 7, Validation loss: 0.0134
 For office in Epoch 7, Average Loss: 0.0064, Average Translation Error: 0.0549, Average Rotation Error (radians): 0.0291


Epochs Progress for office:  53%|█████▎    | 8/15 [1:42:14<1:29:29, 767.06s/it]


New best model found for office in fold 1 with validation loss 0.0133
For office in Epoch 8, Validation loss: 0.0133
 For office in Epoch 8, Average Loss: 0.0063, Average Translation Error: 0.0543, Average Rotation Error (radians): 0.0293


Epochs Progress for office:  60%|██████    | 9/15 [1:55:02<1:16:43, 767.31s/it]

For office in Epoch 9, Validation loss: 0.0145
 For office in Epoch 9, Average Loss: 0.0061, Average Translation Error: 0.0534, Average Rotation Error (radians): 0.0286


Epochs Progress for office:  67%|██████▋   | 10/15 [2:07:46<1:03:51, 766.35s/it]

For office in Epoch 10, Validation loss: 0.0143
 For office in Epoch 10, Average Loss: 0.0059, Average Translation Error: 0.0534, Average Rotation Error (radians): 0.0273


Epochs Progress for office:  73%|███████▎  | 11/15 [2:20:36<51:09, 767.27s/it]  


New best model found for office in fold 1 with validation loss 0.0125
For office in Epoch 11, Validation loss: 0.0125
 For office in Epoch 11, Average Loss: 0.0056, Average Translation Error: 0.0517, Average Rotation Error (radians): 0.0269


Epochs Progress for office:  80%|████████  | 12/15 [2:33:22<38:20, 766.85s/it]

For office in Epoch 12, Validation loss: 0.0131
 For office in Epoch 12, Average Loss: 0.0056, Average Translation Error: 0.0521, Average Rotation Error (radians): 0.0265


Epochs Progress for office:  87%|████████▋ | 13/15 [2:46:09<25:33, 766.95s/it]

For office in Epoch 13, Validation loss: 0.0140
 For office in Epoch 13, Average Loss: 0.0053, Average Translation Error: 0.0495, Average Rotation Error (radians): 0.0269


Epochs Progress for office:  93%|█████████▎| 14/15 [2:58:57<12:47, 767.18s/it]

For office in Epoch 14, Validation loss: 0.0138
 For office in Epoch 14, Average Loss: 0.0052, Average Translation Error: 0.0495, Average Rotation Error (radians): 0.0262


Epochs Progress for office: 100%|██████████| 15/15 [3:11:46<00:00, 767.10s/it]



New best model found for office in fold 1 with validation loss 0.0125
For office in Epoch 15, Validation loss: 0.0125
 For office in Epoch 15, Average Loss: 0.0050, Average Translation Error: 0.0490, Average Rotation Error (radians): 0.0248
--------------------------------------------------
Training complete for office in Fold 1.
(Best Epoch for Fold 1: 0 with Validation Loss: 0.0125
----------------------------------------------------------------------------------------------------
FOLD 2 for office
----------------------------------------------------------------------------------------------------


Epochs Progress for office:   7%|▋         | 1/15 [12:43<2:58:09, 763.54s/it]


New best model found for office in fold 2 with validation loss 0.0089
For office in Epoch 1, Validation loss: 0.0089
 For office in Epoch 1, Average Loss: 0.0052, Average Translation Error: 0.0505, Average Rotation Error (radians): 0.0260


Epochs Progress for office:  13%|█▎        | 2/15 [25:31<2:46:02, 766.38s/it]

For office in Epoch 2, Validation loss: 0.0094
 For office in Epoch 2, Average Loss: 0.0050, Average Translation Error: 0.0487, Average Rotation Error (radians): 0.0256


Epochs Progress for office:  20%|██        | 3/15 [38:16<2:33:04, 765.34s/it]


New best model found for office in fold 2 with validation loss 0.0086
For office in Epoch 3, Validation loss: 0.0086
 For office in Epoch 3, Average Loss: 0.0051, Average Translation Error: 0.0497, Average Rotation Error (radians): 0.0253


Epochs Progress for office:  27%|██▋       | 4/15 [51:03<2:20:28, 766.25s/it]

For office in Epoch 4, Validation loss: 0.0092
 For office in Epoch 4, Average Loss: 0.0050, Average Translation Error: 0.0492, Average Rotation Error (radians): 0.0246


Epochs Progress for office:  33%|███▎      | 5/15 [1:03:52<2:07:52, 767.23s/it]


New best model found for office in fold 2 with validation loss 0.0080
For office in Epoch 5, Validation loss: 0.0080
 For office in Epoch 5, Average Loss: 0.0047, Average Translation Error: 0.0472, Average Rotation Error (radians): 0.0239


Epochs Progress for office:  40%|████      | 6/15 [1:16:40<1:55:08, 767.59s/it]

For office in Epoch 6, Validation loss: 0.0092
 For office in Epoch 6, Average Loss: 0.0048, Average Translation Error: 0.0481, Average Rotation Error (radians): 0.0239


Epochs Progress for office:  47%|████▋     | 7/15 [1:29:29<1:42:22, 767.86s/it]

For office in Epoch 7, Validation loss: 0.0094
 For office in Epoch 7, Average Loss: 0.0047, Average Translation Error: 0.0482, Average Rotation Error (radians): 0.0238


Epochs Progress for office:  53%|█████▎    | 8/15 [1:42:19<1:29:39, 768.56s/it]

For office in Epoch 8, Validation loss: 0.0090
 For office in Epoch 8, Average Loss: 0.0047, Average Translation Error: 0.0471, Average Rotation Error (radians): 0.0234


Epochs Progress for office:  60%|██████    | 9/15 [1:55:07<1:16:49, 768.30s/it]

For office in Epoch 9, Validation loss: 0.0089
 For office in Epoch 9, Average Loss: 0.0046, Average Translation Error: 0.0475, Average Rotation Error (radians): 0.0235


Epochs Progress for office:  67%|██████▋   | 10/15 [2:07:53<1:03:58, 767.74s/it]

For office in Epoch 10, Validation loss: 0.0082
 For office in Epoch 10, Average Loss: 0.0044, Average Translation Error: 0.0463, Average Rotation Error (radians): 0.0232


Epochs Progress for office:  73%|███████▎  | 11/15 [2:20:43<51:13, 768.28s/it]  

For office in Epoch 11, Validation loss: 0.0085
 For office in Epoch 11, Average Loss: 0.0043, Average Translation Error: 0.0451, Average Rotation Error (radians): 0.0230


Epochs Progress for office:  80%|████████  | 12/15 [2:33:30<38:23, 767.89s/it]

For office in Epoch 12, Validation loss: 0.0083
 For office in Epoch 12, Average Loss: 0.0040, Average Translation Error: 0.0429, Average Rotation Error (radians): 0.0219


Epochs Progress for office:  87%|████████▋ | 13/15 [2:46:18<25:36, 768.05s/it]

For office in Epoch 13, Validation loss: 0.0086
 For office in Epoch 13, Average Loss: 0.0042, Average Translation Error: 0.0450, Average Rotation Error (radians): 0.0221


Epochs Progress for office:  93%|█████████▎| 14/15 [2:58:51<12:43, 763.35s/it]

For office in Epoch 14, Validation loss: 0.0093
 For office in Epoch 14, Average Loss: 0.0039, Average Translation Error: 0.0429, Average Rotation Error (radians): 0.0216


Epochs Progress for office: 100%|██████████| 15/15 [3:11:14<00:00, 764.95s/it]


For office in Epoch 15, Validation loss: 0.0098
 For office in Epoch 15, Average Loss: 0.0040, Average Translation Error: 0.0439, Average Rotation Error (radians): 0.0210
--------------------------------------------------
Training complete for office in Fold 2.
(Best Epoch for Fold 2: 0 with Validation Loss: 0.0080
----------------------------------------------------------------------------------------------------
FOLD 3 for office
----------------------------------------------------------------------------------------------------


Epochs Progress for office:   7%|▋         | 1/15 [12:24<2:53:45, 744.65s/it]


New best model found for office in fold 3 with validation loss 0.0055
For office in Epoch 1, Validation loss: 0.0055
 For office in Epoch 1, Average Loss: 0.0041, Average Translation Error: 0.0447, Average Rotation Error (radians): 0.0223


Epochs Progress for office:  13%|█▎        | 2/15 [24:50<2:41:26, 745.09s/it]

For office in Epoch 2, Validation loss: 0.0059
 For office in Epoch 2, Average Loss: 0.0040, Average Translation Error: 0.0438, Average Rotation Error (radians): 0.0221


Epochs Progress for office:  20%|██        | 3/15 [37:13<2:28:53, 744.49s/it]


New best model found for office in fold 3 with validation loss 0.0055
For office in Epoch 3, Validation loss: 0.0055
 For office in Epoch 3, Average Loss: 0.0039, Average Translation Error: 0.0424, Average Rotation Error (radians): 0.0213


Epochs Progress for office:  27%|██▋       | 4/15 [49:41<2:16:41, 745.61s/it]

For office in Epoch 4, Validation loss: 0.0059
 For office in Epoch 4, Average Loss: 0.0037, Average Translation Error: 0.0425, Average Rotation Error (radians): 0.0210


Epochs Progress for office:  33%|███▎      | 5/15 [1:02:02<2:04:02, 744.23s/it]


New best model found for office in fold 3 with validation loss 0.0054
For office in Epoch 5, Validation loss: 0.0054
 For office in Epoch 5, Average Loss: 0.0037, Average Translation Error: 0.0417, Average Rotation Error (radians): 0.0205


Epochs Progress for office:  40%|████      | 6/15 [1:14:28<1:51:41, 744.59s/it]

For office in Epoch 6, Validation loss: 0.0069
 For office in Epoch 6, Average Loss: 0.0037, Average Translation Error: 0.0421, Average Rotation Error (radians): 0.0213


Epochs Progress for office:  47%|████▋     | 7/15 [1:26:51<1:39:14, 744.31s/it]

For office in Epoch 7, Validation loss: 0.0059
 For office in Epoch 7, Average Loss: 0.0036, Average Translation Error: 0.0410, Average Rotation Error (radians): 0.0203


Epochs Progress for office:  53%|█████▎    | 8/15 [1:39:18<1:26:55, 745.14s/it]


New best model found for office in fold 3 with validation loss 0.0053
For office in Epoch 8, Validation loss: 0.0053
 For office in Epoch 8, Average Loss: 0.0035, Average Translation Error: 0.0407, Average Rotation Error (radians): 0.0201


Epochs Progress for office:  60%|██████    | 9/15 [1:51:42<1:14:27, 744.61s/it]

For office in Epoch 9, Validation loss: 0.0054
 For office in Epoch 9, Average Loss: 0.0035, Average Translation Error: 0.0410, Average Rotation Error (radians): 0.0203


Epochs Progress for office:  67%|██████▋   | 10/15 [2:04:09<1:02:06, 745.29s/it]

For office in Epoch 10, Validation loss: 0.0058
 For office in Epoch 10, Average Loss: 0.0036, Average Translation Error: 0.0419, Average Rotation Error (radians): 0.0199


Epochs Progress for office:  73%|███████▎  | 11/15 [2:16:33<49:40, 745.12s/it]  

For office in Epoch 11, Validation loss: 0.0058
 For office in Epoch 11, Average Loss: 0.0036, Average Translation Error: 0.0418, Average Rotation Error (radians): 0.0198


Epochs Progress for office:  80%|████████  | 12/15 [2:29:00<37:16, 745.61s/it]

For office in Epoch 12, Validation loss: 0.0057
 For office in Epoch 12, Average Loss: 0.0035, Average Translation Error: 0.0407, Average Rotation Error (radians): 0.0205


Epochs Progress for office:  87%|████████▋ | 13/15 [2:41:26<24:51, 745.66s/it]

For office in Epoch 13, Validation loss: 0.0054
 For office in Epoch 13, Average Loss: 0.0034, Average Translation Error: 0.0398, Average Rotation Error (radians): 0.0190


Epochs Progress for office:  93%|█████████▎| 14/15 [2:53:49<12:25, 745.01s/it]

For office in Epoch 14, Validation loss: 0.0054
 For office in Epoch 14, Average Loss: 0.0032, Average Translation Error: 0.0389, Average Rotation Error (radians): 0.0193


Epochs Progress for office: 100%|██████████| 15/15 [3:06:17<00:00, 745.16s/it]


For office in Epoch 15, Validation loss: 0.0059
 For office in Epoch 15, Average Loss: 0.0032, Average Translation Error: 0.0387, Average Rotation Error (radians): 0.0195
--------------------------------------------------
Training complete for office in Fold 3.
(Best Epoch for Fold 3: 0 with Validation Loss: 0.0053
----------------------------------------------------------------------------------------------------
FOLD 4 for office
----------------------------------------------------------------------------------------------------


Epochs Progress for office:   7%|▋         | 1/15 [12:22<2:53:14, 742.43s/it]


New best model found for office in fold 4 with validation loss 0.0040
For office in Epoch 1, Validation loss: 0.0040
 For office in Epoch 1, Average Loss: 0.0034, Average Translation Error: 0.0402, Average Rotation Error (radians): 0.0200


Epochs Progress for office:  13%|█▎        | 2/15 [24:49<2:41:25, 745.07s/it]

For office in Epoch 2, Validation loss: 0.0040
 For office in Epoch 2, Average Loss: 0.0032, Average Translation Error: 0.0388, Average Rotation Error (radians): 0.0188


Epochs Progress for office:  20%|██        | 3/15 [37:25<2:29:59, 749.92s/it]


New best model found for office in fold 4 with validation loss 0.0040
For office in Epoch 3, Validation loss: 0.0040
 For office in Epoch 3, Average Loss: 0.0033, Average Translation Error: 0.0400, Average Rotation Error (radians): 0.0190


Epochs Progress for office:  27%|██▋       | 4/15 [50:15<2:18:56, 757.87s/it]


New best model found for office in fold 4 with validation loss 0.0040
For office in Epoch 4, Validation loss: 0.0040
 For office in Epoch 4, Average Loss: 0.0032, Average Translation Error: 0.0388, Average Rotation Error (radians): 0.0189


Epochs Progress for office:  33%|███▎      | 5/15 [1:03:01<2:06:48, 760.87s/it]

For office in Epoch 5, Validation loss: 0.0040
 For office in Epoch 5, Average Loss: 0.0031, Average Translation Error: 0.0376, Average Rotation Error (radians): 0.0195


Epochs Progress for office:  40%|████      | 6/15 [1:15:54<1:54:45, 765.06s/it]

For office in Epoch 6, Validation loss: 0.0041
 For office in Epoch 6, Average Loss: 0.0032, Average Translation Error: 0.0395, Average Rotation Error (radians): 0.0185


Epochs Progress for office:  47%|████▋     | 7/15 [1:28:40<1:42:03, 765.38s/it]

For office in Epoch 7, Validation loss: 0.0040
 For office in Epoch 7, Average Loss: 0.0030, Average Translation Error: 0.0378, Average Rotation Error (radians): 0.0188


Epochs Progress for office:  53%|█████▎    | 8/15 [1:41:49<1:30:09, 772.78s/it]

For office in Epoch 8, Validation loss: 0.0040
 For office in Epoch 8, Average Loss: 0.0031, Average Translation Error: 0.0379, Average Rotation Error (radians): 0.0192


Epochs Progress for office:  60%|██████    | 9/15 [1:54:38<1:17:11, 771.84s/it]


New best model found for office in fold 4 with validation loss 0.0039
For office in Epoch 9, Validation loss: 0.0039
 For office in Epoch 9, Average Loss: 0.0030, Average Translation Error: 0.0364, Average Rotation Error (radians): 0.0191


Epochs Progress for office:  67%|██████▋   | 10/15 [2:07:28<1:04:15, 771.09s/it]

For office in Epoch 10, Validation loss: 0.0041
 For office in Epoch 10, Average Loss: 0.0031, Average Translation Error: 0.0385, Average Rotation Error (radians): 0.0187


Epochs Progress for office:  73%|███████▎  | 11/15 [2:20:16<51:20, 770.21s/it]  

For office in Epoch 11, Validation loss: 0.0043
 For office in Epoch 11, Average Loss: 0.0030, Average Translation Error: 0.0377, Average Rotation Error (radians): 0.0189


Epochs Progress for office:  80%|████████  | 12/15 [2:33:07<38:31, 770.44s/it]

For office in Epoch 12, Validation loss: 0.0040
 For office in Epoch 12, Average Loss: 0.0030, Average Translation Error: 0.0379, Average Rotation Error (radians): 0.0183


Epochs Progress for office:  87%|████████▋ | 13/15 [2:45:56<25:39, 769.91s/it]

For office in Epoch 13, Validation loss: 0.0041
 For office in Epoch 13, Average Loss: 0.0029, Average Translation Error: 0.0366, Average Rotation Error (radians): 0.0181


Epochs Progress for office:  93%|█████████▎| 14/15 [2:59:05<12:55, 775.76s/it]

For office in Epoch 14, Validation loss: 0.0041
 For office in Epoch 14, Average Loss: 0.0029, Average Translation Error: 0.0373, Average Rotation Error (radians): 0.0184


Epochs Progress for office: 100%|██████████| 15/15 [3:12:01<00:00, 768.08s/it]


For office in Epoch 15, Validation loss: 0.0040
 For office in Epoch 15, Average Loss: 0.0028, Average Translation Error: 0.0369, Average Rotation Error (radians): 0.0180
--------------------------------------------------
Training complete for office in Fold 4.
(Best Epoch for Fold 4: 0 with Validation Loss: 0.0039
Best model for office from Fold 4 saved with loss 0.0039
Performance of the best model for office on the test data was from Fold 4:
Average Translation Error for office: 0.3402 meters
Average Rotation Error (degree) for office: 8.8737
--------------------------------------------------


In [2]:
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.models import resnet18, ResNet18_Weights
from sklearn.model_selection import KFold
import math
from torchvision.models import resnet50, ResNet50_Weights



class Frame:
    def __init__(self, data_type, room_label, sequence, file_name, color_image_path, pose):
        self.data_type = data_type
        self.room_label = room_label
        self.sequence = sequence
        self.file_name = file_name
        self.color_image_path = color_image_path
        self.pose = pose


def parse_pose_file(pose_file_path):
    with open(pose_file_path, 'r') as file:
        pose = np.array([list(map(float, line.strip().split())) for line in file]).flatten()
    return pose


def create_frame_objects(data_path, room_name, data_type):
    frames = []
    for seq_folder in os.listdir(data_path):
        seq_path = os.path.join(data_path, seq_folder)
        if os.path.isdir(seq_path):
            print(f"Processing sequence: {seq_folder} in {room_name} ({data_type})")
            for frame_file in os.listdir(seq_path):
                if frame_file.endswith('.color.png'):
                    frame_name = frame_file.split('.')[0]
                    color_image_path = os.path.join(seq_path, f"{frame_name}.color.png")
                    pose_file_path = os.path.join(seq_path, f"{frame_name}.pose.txt")
                    if os.path.exists(color_image_path) and os.path.exists(pose_file_path):
                        pose = parse_pose_file(pose_file_path)
                        frame = Frame(data_type, room_name, seq_folder, frame_name, color_image_path, pose)
                        frames.append(frame)
    return frames


def create_data_structure(data_folder):
    local_train_data = []
    local_test_data = []
    room_names = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']
    for room_name in room_names:
        train_path = os.path.join(data_folder, room_name, 'train')
        test_path = os.path.join(data_folder, room_name, 'test')
        local_train_data.extend(create_frame_objects(train_path, room_name, 'train'))
        local_test_data.extend(create_frame_objects(test_path, room_name, 'test'))
    return local_train_data, local_test_data


def create_data_structure_for_each_scene(data_folder, room_name):
    train_path = os.path.join(data_folder, room_name, 'train')
    test_path = os.path.join(data_folder, room_name, 'test')
    local_train_data = create_frame_objects(train_path, room_name, 'train')
    local_test_data = create_frame_objects(test_path, room_name, 'test')
    return local_train_data, local_test_data


your_path_to_data_folder = 'data'
# train_data, test_data = create_data_structure(your_path_to_data_folder)


class CustomDataset(Dataset):
    def __init__(self, frames, transform=None):
        self.frames = frames
        self.transform = transform

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        frame = self.frames[idx]
        image = Image.open(frame.color_image_path).convert('RGB')
        pose_matrix = np.array(frame.pose, dtype=np.float32).reshape(4, 4)
        translation = pose_matrix[:3, 3]
        rotation = pose_matrix[:3, :3]

        if self.transform:
            image = self.transform(image)
        return image, torch.from_numpy(translation), torch.from_numpy(rotation.flatten())


transformations = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class PoseModel(nn.Module):
    def __init__(self):
        super(PoseModel, self).__init__()
        weights = ResNet50_Weights.DEFAULT
        self.backbone = resnet50(weights=weights)
        self.fc_translation = nn.Linear(self.backbone.fc.in_features, 3)
        self.fc_rotation = nn.Linear(self.backbone.fc.in_features, 9)
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        features = self.backbone(x)
        translation = self.fc_translation(features)
        rotation = self.fc_rotation(features)
        return translation, rotation


# pose_model = PoseModel().to(device)
# optimizer = optim.SGD(pose_model.parameters(), lr=0.001, momentum=0.9)  # try adam optimizer
criterion = nn.MSELoss()


def rotation_matrix_to_angle_axis(rotation_matrices):
    """Convert a batch of rotation matrices to angle-axis vectors."""
    # Calculate the trace of each 3x3 rotation matrix in the batch
    traces = torch.einsum('bii->b', rotation_matrices)  # Sum over the diagonal elements in each matrix in the batch
    cos_thetas = (traces - 1) / 2.0
    cos_thetas = torch.clamp(cos_thetas, -1, 1)  # Numerical errors might make cos(theta) slightly out of its range
    thetas = torch.acos(cos_thetas)  # Angles

    # Initialize angle-axis vectors
    angle_axes = torch.zeros_like(rotation_matrices[:, :, 0])

    # Compute sin(theta) for normalization
    sin_thetas = torch.sin(thetas)

    # Find indices where theta is not too small (to avoid division by zero)
    valid = sin_thetas > 1e-5

    # For valid indices where theta is not too small, calculate angle-axis vectors
    angle_axes[valid] = torch.stack([
        rotation_matrices[valid, 2, 1] - rotation_matrices[valid, 1, 2],
        rotation_matrices[valid, 0, 2] - rotation_matrices[valid, 2, 0],
        rotation_matrices[valid, 1, 0] - rotation_matrices[valid, 0, 1]
    ], dim=1) / (2 * sin_thetas[valid].unsqueeze(1)) * thetas[valid].unsqueeze(1)

    return angle_axes


def rotation_error(pred_rot, gt_rot):
    """Calculate the angular distance between two rotation matrices."""
    pred_rot_matrix = pred_rot.view(-1, 3, 3)
    gt_rot_matrix = gt_rot.view(-1, 3, 3)
    r_diff = torch.matmul(pred_rot_matrix, gt_rot_matrix.transpose(1, 2))  # Relative rotation
    angle_axis = rotation_matrix_to_angle_axis(r_diff)
    return torch.norm(angle_axis, dim=1)  # Returns the magnitude of the angle-axis vector


def calculate_translation_error(pred, target):
    return torch.norm(pred - target, dim=1).mean()


scenes = ['redkitchen']

for scene in scenes:
    pose_model = PoseModel().to(device)  # model re-initializes for each scene to train separately for each scene
    optimizer = optim.SGD(pose_model.parameters(), lr=0.001, momentum=0.9)

    train_data, test_data = create_data_structure_for_each_scene(your_path_to_data_folder, scene)

    train_dataset = CustomDataset(train_data, transform=transformations)
    test_dataset = CustomDataset(test_data, transform=transformations)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    best_fold = 0
    best_model_state = None

    n_splits = 5
    num_epochs = 15
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    indices = range(len(train_loader.dataset))  # warning is disregarded, since the code works correct

    for fold, (train_idx, val_idx) in enumerate(kfold.split(indices)):
        print("-" * 100)
        print(f"FOLD {fold} for {scene}")
        print("-" * 100)

        train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
        validation_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)

        current_train_loader = torch.utils.data.DataLoader(
            train_loader.dataset, batch_size=64, sampler=train_subsampler)
        current_validation_loader = torch.utils.data.DataLoader(
            train_loader.dataset, batch_size=64, sampler=validation_subsampler)

        best_loss = np.inf
        best_epoch = -1
        for epoch in tqdm(range(num_epochs), desc=f"Epochs Progress for {scene}"):
            pose_model.train()
            total_loss = 0.0
            total_translation_error = 0.0
            total_rotation_error = 0.0

            # Training loop
            for images, translations, rotations in tqdm(current_train_loader,
                                                        desc=f"Training Epoch {epoch + 1} for {scene}", leave=False):
                images = images.to(device)
                translations = translations.to(device)
                rotations = rotations.view(-1, 3, 3).to(device)

                optimizer.zero_grad()
                trans_pred, rot_pred = pose_model(images)
                rot_pred = rot_pred.view(-1, 3, 3)

                loss_translation = criterion(trans_pred, translations)
                loss_rotation = criterion(rot_pred.view(-1, 9), rotations.view(-1, 9))
                loss = loss_translation + loss_rotation
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                translation_error = calculate_translation_error(trans_pred, translations)
                total_translation_error += translation_error.item()
                rotation_error_batch = rotation_error(rot_pred, rotations).mean().item()
                total_rotation_error += rotation_error_batch

            # Validation loop
            pose_model.eval()
            validation_loss = 0.0
            with torch.no_grad():
                for images, translations, rotations in tqdm(current_validation_loader,
                                                            desc=f"Validating Epoch {epoch + 1}",
                                                            leave=False):
                    images = images.to(device)
                    translations = translations.to(device)
                    rotations = rotations.view(-1, 3, 3).to(device)

                    trans_pred, rot_pred = pose_model(images)
                    rot_pred = rot_pred.view(-1, 3, 3)

                    loss_translation = criterion(trans_pred, translations)
                    loss_rotation = criterion(rot_pred.view(-1, 9), rotations.view(-1, 9))
                    loss = loss_translation + loss_rotation
                    validation_loss += loss.item()

            validation_loss /= len(current_validation_loader)

            if validation_loss < best_loss:
                best_loss = validation_loss
                best_model_state = pose_model.state_dict()  # save model state not the model itself
                best_fold = fold
                print()
                print(f"New best model found for {scene} in fold {fold} with validation loss {best_loss:.4f}")

            average_loss = total_loss / len(train_loader)
            average_rotation_error = total_rotation_error / len(train_loader)
            average_translation_error = total_translation_error / len(train_loader)
            print(f"For {scene} in Epoch {epoch + 1}, Validation loss: {validation_loss:.4f}")
            print(f' For {scene} in Epoch {epoch + 1}, Average Loss: {average_loss:.4f}, Average Translation Error: '
                  f'{average_translation_error:.4f}, Average Rotation Error (radians): {average_rotation_error:.4f}')

        print("-" * 50)
        print(f"Training complete for {scene} in Fold {fold}.")
        print(f"(Best Epoch for Fold {fold}: {best_epoch + 1} with Validation Loss: {best_loss:.4f}")

    # after all folds save the best model path
    if best_model_state:
        torch.save(best_model_state, f'best_pose_model_{scene}.pth')
        print(f"Best model for {scene} from Fold {best_fold} saved with loss {best_loss:.4f}")

    # To use the best model
    pose_model.load_state_dict(torch.load(f'best_pose_model_{scene}.pth'))
    pose_model.eval()

    total_translation_error = 0.0
    total_rotation_error = 0.0
    count = 0

    # No gradient needed for test part
    with torch.no_grad():
        for images, translations, rotations in test_loader:
            images = images.to(device)
            translations = translations.to(device)
            rotations = rotations.view(-1, 3, 3).to(device)  # 3x3 rotation matrices

            trans_pred, rot_pred = pose_model(images)
            rot_pred = rot_pred.view(-1, 3, 3)

            # Calculate errors
            translation_error = calculate_translation_error(trans_pred, translations)
            rotation_error_batch = rotation_error(rot_pred, rotations).mean().item()

            total_translation_error += translation_error.item()
            total_rotation_error += rotation_error_batch
            count += 1

        # Calculate average errors
        average_translation_error = total_translation_error / count
        average_rotation_error = total_rotation_error / count
        average_rotation_error_in_degrees = average_rotation_error * (180 / math.pi)

        print(f"Performance of the best model for {scene} on the test data was from Fold {best_fold}:")
        print(f"Average Translation Error for {scene}: {average_translation_error:.4f} meters")
        print(f"Average Rotation Error (degree) for {scene}: {average_rotation_error_in_degrees:.4f}")
        print("-" * 50)


Processing sequence: seq-01 in redkitchen (train)
Processing sequence: seq-02 in redkitchen (train)
Processing sequence: seq-05 in redkitchen (train)
Processing sequence: seq-07 in redkitchen (train)
Processing sequence: seq-08 in redkitchen (train)
Processing sequence: seq-11 in redkitchen (train)
Processing sequence: seq-13 in redkitchen (train)
Processing sequence: seq-03 in redkitchen (test)
Processing sequence: seq-04 in redkitchen (test)
Processing sequence: seq-06 in redkitchen (test)
Processing sequence: seq-12 in redkitchen (test)
Processing sequence: seq-14 in redkitchen (test)
----------------------------------------------------------------------------------------------------
FOLD 0 for redkitchen
----------------------------------------------------------------------------------------------------


Epochs Progress for redkitchen:   7%|▋         | 1/15 [15:07<3:31:39, 907.11s/it]


New best model found for redkitchen in fold 0 with validation loss 0.2374
For redkitchen in Epoch 1, Validation loss: 0.2374
 For redkitchen in Epoch 1, Average Loss: 0.2121, Average Translation Error: 0.3795, Average Rotation Error (radians): 0.1765


Epochs Progress for redkitchen:  13%|█▎        | 2/15 [30:04<3:15:19, 901.51s/it]


New best model found for redkitchen in fold 0 with validation loss 0.1057
For redkitchen in Epoch 2, Validation loss: 0.1057
 For redkitchen in Epoch 2, Average Loss: 0.0642, Average Translation Error: 0.1775, Average Rotation Error (radians): 0.1408


Epochs Progress for redkitchen:  20%|██        | 3/15 [45:08<3:00:29, 902.49s/it]


New best model found for redkitchen in fold 0 with validation loss 0.0714
For redkitchen in Epoch 3, Validation loss: 0.0714
 For redkitchen in Epoch 3, Average Loss: 0.0346, Average Translation Error: 0.1274, Average Rotation Error (radians): 0.0965


Epochs Progress for redkitchen:  27%|██▋       | 4/15 [1:00:12<2:45:35, 903.24s/it]


New best model found for redkitchen in fold 0 with validation loss 0.0553
For redkitchen in Epoch 4, Validation loss: 0.0553
 For redkitchen in Epoch 4, Average Loss: 0.0249, Average Translation Error: 0.1104, Average Rotation Error (radians): 0.0733


Epochs Progress for redkitchen:  33%|███▎      | 5/15 [1:15:16<2:30:33, 903.32s/it]


New best model found for redkitchen in fold 0 with validation loss 0.0467
For redkitchen in Epoch 5, Validation loss: 0.0467
 For redkitchen in Epoch 5, Average Loss: 0.0194, Average Translation Error: 0.0973, Average Rotation Error (radians): 0.0631


Epochs Progress for redkitchen:  40%|████      | 6/15 [1:30:23<2:15:43, 904.82s/it]


New best model found for redkitchen in fold 0 with validation loss 0.0413
For redkitchen in Epoch 6, Validation loss: 0.0413
 For redkitchen in Epoch 6, Average Loss: 0.0167, Average Translation Error: 0.0909, Average Rotation Error (radians): 0.0560


Epochs Progress for redkitchen:  47%|████▋     | 7/15 [1:45:29<2:00:40, 905.10s/it]


New best model found for redkitchen in fold 0 with validation loss 0.0368
For redkitchen in Epoch 7, Validation loss: 0.0368
 For redkitchen in Epoch 7, Average Loss: 0.0146, Average Translation Error: 0.0855, Average Rotation Error (radians): 0.0510


Epochs Progress for redkitchen:  53%|█████▎    | 8/15 [2:00:30<1:45:27, 903.86s/it]


New best model found for redkitchen in fold 0 with validation loss 0.0343
For redkitchen in Epoch 8, Validation loss: 0.0343
 For redkitchen in Epoch 8, Average Loss: 0.0133, Average Translation Error: 0.0822, Average Rotation Error (radians): 0.0479


Epochs Progress for redkitchen:  60%|██████    | 9/15 [2:15:36<1:30:26, 904.38s/it]


New best model found for redkitchen in fold 0 with validation loss 0.0320
For redkitchen in Epoch 9, Validation loss: 0.0320
 For redkitchen in Epoch 9, Average Loss: 0.0120, Average Translation Error: 0.0776, Average Rotation Error (radians): 0.0448


Epochs Progress for redkitchen:  67%|██████▋   | 10/15 [2:30:43<1:15:26, 905.28s/it]


New best model found for redkitchen in fold 0 with validation loss 0.0299
For redkitchen in Epoch 10, Validation loss: 0.0299
 For redkitchen in Epoch 10, Average Loss: 0.0113, Average Translation Error: 0.0754, Average Rotation Error (radians): 0.0430


Epochs Progress for redkitchen:  73%|███████▎  | 11/15 [2:45:46<1:00:18, 904.64s/it]


New best model found for redkitchen in fold 0 with validation loss 0.0286
For redkitchen in Epoch 11, Validation loss: 0.0286
 For redkitchen in Epoch 11, Average Loss: 0.0103, Average Translation Error: 0.0722, Average Rotation Error (radians): 0.0403


Epochs Progress for redkitchen:  80%|████████  | 12/15 [3:00:49<45:11, 903.98s/it]  


New best model found for redkitchen in fold 0 with validation loss 0.0269
For redkitchen in Epoch 12, Validation loss: 0.0269
 For redkitchen in Epoch 12, Average Loss: 0.0097, Average Translation Error: 0.0699, Average Rotation Error (radians): 0.0393


Epochs Progress for redkitchen:  87%|████████▋ | 13/15 [3:15:58<30:11, 905.62s/it]


New best model found for redkitchen in fold 0 with validation loss 0.0259
For redkitchen in Epoch 13, Validation loss: 0.0259
 For redkitchen in Epoch 13, Average Loss: 0.0089, Average Translation Error: 0.0662, Average Rotation Error (radians): 0.0379


Epochs Progress for redkitchen:  93%|█████████▎| 14/15 [3:31:11<15:07, 907.70s/it]


New best model found for redkitchen in fold 0 with validation loss 0.0247
For redkitchen in Epoch 14, Validation loss: 0.0247
 For redkitchen in Epoch 14, Average Loss: 0.0087, Average Translation Error: 0.0662, Average Rotation Error (radians): 0.0373


Epochs Progress for redkitchen: 100%|██████████| 15/15 [3:46:20<00:00, 905.40s/it]



New best model found for redkitchen in fold 0 with validation loss 0.0239
For redkitchen in Epoch 15, Validation loss: 0.0239
 For redkitchen in Epoch 15, Average Loss: 0.0084, Average Translation Error: 0.0653, Average Rotation Error (radians): 0.0350
--------------------------------------------------
Training complete for redkitchen in Fold 0.
(Best Epoch for Fold 0: 0 with Validation Loss: 0.0239
----------------------------------------------------------------------------------------------------
FOLD 1 for redkitchen
----------------------------------------------------------------------------------------------------


Epochs Progress for redkitchen:   7%|▋         | 1/15 [15:08<3:32:03, 908.83s/it]


New best model found for redkitchen in fold 1 with validation loss 0.0135
For redkitchen in Epoch 1, Validation loss: 0.0135
 For redkitchen in Epoch 1, Average Loss: 0.0087, Average Translation Error: 0.0669, Average Rotation Error (radians): 0.0357


Epochs Progress for redkitchen:  13%|█▎        | 2/15 [30:12<3:16:18, 906.06s/it]


New best model found for redkitchen in fold 1 with validation loss 0.0133
For redkitchen in Epoch 2, Validation loss: 0.0133
 For redkitchen in Epoch 2, Average Loss: 0.0084, Average Translation Error: 0.0657, Average Rotation Error (radians): 0.0351


Epochs Progress for redkitchen:  20%|██        | 3/15 [45:27<3:01:58, 909.86s/it]


New best model found for redkitchen in fold 1 with validation loss 0.0130
For redkitchen in Epoch 3, Validation loss: 0.0130
 For redkitchen in Epoch 3, Average Loss: 0.0078, Average Translation Error: 0.0635, Average Rotation Error (radians): 0.0332


Epochs Progress for redkitchen:  27%|██▋       | 4/15 [1:00:22<2:45:44, 904.02s/it]

For redkitchen in Epoch 4, Validation loss: 0.0133
 For redkitchen in Epoch 4, Average Loss: 0.0079, Average Translation Error: 0.0639, Average Rotation Error (radians): 0.0333


Epochs Progress for redkitchen:  33%|███▎      | 5/15 [1:15:17<2:30:08, 900.88s/it]


New best model found for redkitchen in fold 1 with validation loss 0.0125
For redkitchen in Epoch 5, Validation loss: 0.0125
 For redkitchen in Epoch 5, Average Loss: 0.0075, Average Translation Error: 0.0621, Average Rotation Error (radians): 0.0322


Epochs Progress for redkitchen:  40%|████      | 6/15 [1:30:10<2:14:44, 898.29s/it]


New best model found for redkitchen in fold 1 with validation loss 0.0124
For redkitchen in Epoch 6, Validation loss: 0.0124
 For redkitchen in Epoch 6, Average Loss: 0.0070, Average Translation Error: 0.0597, Average Rotation Error (radians): 0.0306


Epochs Progress for redkitchen:  47%|████▋     | 7/15 [1:45:03<1:59:30, 896.29s/it]


New best model found for redkitchen in fold 1 with validation loss 0.0121
For redkitchen in Epoch 7, Validation loss: 0.0121
 For redkitchen in Epoch 7, Average Loss: 0.0067, Average Translation Error: 0.0587, Average Rotation Error (radians): 0.0302


Epochs Progress for redkitchen:  53%|█████▎    | 8/15 [2:00:52<1:46:32, 913.26s/it]


New best model found for redkitchen in fold 1 with validation loss 0.0117
For redkitchen in Epoch 8, Validation loss: 0.0117
 For redkitchen in Epoch 8, Average Loss: 0.0067, Average Translation Error: 0.0588, Average Rotation Error (radians): 0.0305


Epochs Progress for redkitchen:  60%|██████    | 9/15 [2:15:47<1:30:44, 907.35s/it]

For redkitchen in Epoch 9, Validation loss: 0.0120
 For redkitchen in Epoch 9, Average Loss: 0.0066, Average Translation Error: 0.0589, Average Rotation Error (radians): 0.0302


Epochs Progress for redkitchen:  67%|██████▋   | 10/15 [2:30:41<1:15:16, 903.26s/it]


New best model found for redkitchen in fold 1 with validation loss 0.0116
For redkitchen in Epoch 10, Validation loss: 0.0116
 For redkitchen in Epoch 10, Average Loss: 0.0061, Average Translation Error: 0.0557, Average Rotation Error (radians): 0.0287


Epochs Progress for redkitchen:  73%|███████▎  | 11/15 [2:45:36<1:00:03, 900.83s/it]


New best model found for redkitchen in fold 1 with validation loss 0.0112
For redkitchen in Epoch 11, Validation loss: 0.0112
 For redkitchen in Epoch 11, Average Loss: 0.0059, Average Translation Error: 0.0554, Average Rotation Error (radians): 0.0278


Epochs Progress for redkitchen:  80%|████████  | 12/15 [3:00:34<44:59, 899.93s/it]  


New best model found for redkitchen in fold 1 with validation loss 0.0111
For redkitchen in Epoch 12, Validation loss: 0.0111
 For redkitchen in Epoch 12, Average Loss: 0.0057, Average Translation Error: 0.0543, Average Rotation Error (radians): 0.0279


Epochs Progress for redkitchen:  87%|████████▋ | 13/15 [3:15:28<29:56, 898.16s/it]

For redkitchen in Epoch 13, Validation loss: 0.0112
 For redkitchen in Epoch 13, Average Loss: 0.0057, Average Translation Error: 0.0537, Average Rotation Error (radians): 0.0269


Epochs Progress for redkitchen:  93%|█████████▎| 14/15 [3:30:22<14:56, 896.82s/it]


New best model found for redkitchen in fold 1 with validation loss 0.0107
For redkitchen in Epoch 14, Validation loss: 0.0107
 For redkitchen in Epoch 14, Average Loss: 0.0057, Average Translation Error: 0.0542, Average Rotation Error (radians): 0.0268


Epochs Progress for redkitchen: 100%|██████████| 15/15 [3:45:18<00:00, 901.20s/it]


For redkitchen in Epoch 15, Validation loss: 0.0111
 For redkitchen in Epoch 15, Average Loss: 0.0055, Average Translation Error: 0.0541, Average Rotation Error (radians): 0.0258
--------------------------------------------------
Training complete for redkitchen in Fold 1.
(Best Epoch for Fold 1: 0 with Validation Loss: 0.0107
----------------------------------------------------------------------------------------------------
FOLD 2 for redkitchen
----------------------------------------------------------------------------------------------------


Epochs Progress for redkitchen:   7%|▋         | 1/15 [14:49<3:27:35, 889.67s/it]


New best model found for redkitchen in fold 2 with validation loss 0.0083
For redkitchen in Epoch 1, Validation loss: 0.0083
 For redkitchen in Epoch 1, Average Loss: 0.0056, Average Translation Error: 0.0546, Average Rotation Error (radians): 0.0264


Epochs Progress for redkitchen:  13%|█▎        | 2/15 [29:41<3:13:01, 890.86s/it]


New best model found for redkitchen in fold 2 with validation loss 0.0079
For redkitchen in Epoch 2, Validation loss: 0.0079
 For redkitchen in Epoch 2, Average Loss: 0.0058, Average Translation Error: 0.0560, Average Rotation Error (radians): 0.0268


Epochs Progress for redkitchen:  20%|██        | 3/15 [44:35<2:58:25, 892.16s/it]

For redkitchen in Epoch 3, Validation loss: 0.0082
 For redkitchen in Epoch 3, Average Loss: 0.0051, Average Translation Error: 0.0517, Average Rotation Error (radians): 0.0255


Epochs Progress for redkitchen:  27%|██▋       | 4/15 [59:36<2:44:15, 895.99s/it]


New best model found for redkitchen in fold 2 with validation loss 0.0075
For redkitchen in Epoch 4, Validation loss: 0.0075
 For redkitchen in Epoch 4, Average Loss: 0.0052, Average Translation Error: 0.0526, Average Rotation Error (radians): 0.0251


Epochs Progress for redkitchen:  33%|███▎      | 5/15 [1:14:36<2:29:32, 897.23s/it]


New best model found for redkitchen in fold 2 with validation loss 0.0075
For redkitchen in Epoch 5, Validation loss: 0.0075
 For redkitchen in Epoch 5, Average Loss: 0.0051, Average Translation Error: 0.0518, Average Rotation Error (radians): 0.0251


Epochs Progress for redkitchen:  40%|████      | 6/15 [1:31:19<2:19:57, 933.09s/it]


New best model found for redkitchen in fold 2 with validation loss 0.0073
For redkitchen in Epoch 6, Validation loss: 0.0073
 For redkitchen in Epoch 6, Average Loss: 0.0050, Average Translation Error: 0.0511, Average Rotation Error (radians): 0.0249


Epochs Progress for redkitchen:  47%|████▋     | 7/15 [1:46:18<2:02:55, 921.97s/it]

For redkitchen in Epoch 7, Validation loss: 0.0081
 For redkitchen in Epoch 7, Average Loss: 0.0047, Average Translation Error: 0.0491, Average Rotation Error (radians): 0.0246


Epochs Progress for redkitchen:  53%|█████▎    | 8/15 [2:01:16<1:46:42, 914.57s/it]

For redkitchen in Epoch 8, Validation loss: 0.0077
 For redkitchen in Epoch 8, Average Loss: 0.0048, Average Translation Error: 0.0495, Average Rotation Error (radians): 0.0246


Epochs Progress for redkitchen:  60%|██████    | 9/15 [2:16:07<1:30:42, 907.10s/it]

For redkitchen in Epoch 9, Validation loss: 0.0074
 For redkitchen in Epoch 9, Average Loss: 0.0048, Average Translation Error: 0.0504, Average Rotation Error (radians): 0.0241


Epochs Progress for redkitchen:  67%|██████▋   | 10/15 [2:30:56<1:15:06, 901.39s/it]

For redkitchen in Epoch 10, Validation loss: 0.0077
 For redkitchen in Epoch 10, Average Loss: 0.0048, Average Translation Error: 0.0504, Average Rotation Error (radians): 0.0241


Epochs Progress for redkitchen:  73%|███████▎  | 11/15 [2:45:44<59:49, 897.25s/it]  


New best model found for redkitchen in fold 2 with validation loss 0.0073
For redkitchen in Epoch 11, Validation loss: 0.0073
 For redkitchen in Epoch 11, Average Loss: 0.0044, Average Translation Error: 0.0472, Average Rotation Error (radians): 0.0242


Epochs Progress for redkitchen:  80%|████████  | 12/15 [3:00:35<44:46, 895.38s/it]


New best model found for redkitchen in fold 2 with validation loss 0.0070
For redkitchen in Epoch 12, Validation loss: 0.0070
 For redkitchen in Epoch 12, Average Loss: 0.0042, Average Translation Error: 0.0468, Average Rotation Error (radians): 0.0227


Epochs Progress for redkitchen:  87%|████████▋ | 13/15 [3:15:22<29:45, 892.89s/it]

For redkitchen in Epoch 13, Validation loss: 0.0072
 For redkitchen in Epoch 13, Average Loss: 0.0043, Average Translation Error: 0.0483, Average Rotation Error (radians): 0.0224


Epochs Progress for redkitchen:  93%|█████████▎| 14/15 [3:30:08<14:50, 890.93s/it]

For redkitchen in Epoch 14, Validation loss: 0.0071
 For redkitchen in Epoch 14, Average Loss: 0.0041, Average Translation Error: 0.0457, Average Rotation Error (radians): 0.0220


Epochs Progress for redkitchen: 100%|██████████| 15/15 [3:44:55<00:00, 899.67s/it]



New best model found for redkitchen in fold 2 with validation loss 0.0070
For redkitchen in Epoch 15, Validation loss: 0.0070
 For redkitchen in Epoch 15, Average Loss: 0.0042, Average Translation Error: 0.0472, Average Rotation Error (radians): 0.0223
--------------------------------------------------
Training complete for redkitchen in Fold 2.
(Best Epoch for Fold 2: 0 with Validation Loss: 0.0070
----------------------------------------------------------------------------------------------------
FOLD 3 for redkitchen
----------------------------------------------------------------------------------------------------


Epochs Progress for redkitchen:   7%|▋         | 1/15 [14:49<3:27:30, 889.35s/it]


New best model found for redkitchen in fold 3 with validation loss 0.0051
For redkitchen in Epoch 1, Validation loss: 0.0051
 For redkitchen in Epoch 1, Average Loss: 0.0044, Average Translation Error: 0.0482, Average Rotation Error (radians): 0.0217


Epochs Progress for redkitchen:  13%|█▎        | 2/15 [29:36<3:12:24, 888.03s/it]

For redkitchen in Epoch 2, Validation loss: 0.0052
 For redkitchen in Epoch 2, Average Loss: 0.0044, Average Translation Error: 0.0486, Average Rotation Error (radians): 0.0228


Epochs Progress for redkitchen:  20%|██        | 3/15 [44:17<2:56:55, 884.63s/it]

For redkitchen in Epoch 3, Validation loss: 0.0053
 For redkitchen in Epoch 3, Average Loss: 0.0040, Average Translation Error: 0.0457, Average Rotation Error (radians): 0.0223


Epochs Progress for redkitchen:  27%|██▋       | 4/15 [58:50<2:41:21, 880.16s/it]

For redkitchen in Epoch 4, Validation loss: 0.0053
 For redkitchen in Epoch 4, Average Loss: 0.0039, Average Translation Error: 0.0447, Average Rotation Error (radians): 0.0219


Epochs Progress for redkitchen:  33%|███▎      | 5/15 [1:13:23<2:26:16, 877.69s/it]


New best model found for redkitchen in fold 3 with validation loss 0.0051
For redkitchen in Epoch 5, Validation loss: 0.0051
 For redkitchen in Epoch 5, Average Loss: 0.0038, Average Translation Error: 0.0442, Average Rotation Error (radians): 0.0210


Epochs Progress for redkitchen:  40%|████      | 6/15 [1:27:57<2:11:27, 876.38s/it]

For redkitchen in Epoch 6, Validation loss: 0.0053
 For redkitchen in Epoch 6, Average Loss: 0.0040, Average Translation Error: 0.0463, Average Rotation Error (radians): 0.0211


Epochs Progress for redkitchen:  47%|████▋     | 7/15 [1:42:31<1:56:45, 875.64s/it]

For redkitchen in Epoch 7, Validation loss: 0.0056
 For redkitchen in Epoch 7, Average Loss: 0.0038, Average Translation Error: 0.0448, Average Rotation Error (radians): 0.0207


Epochs Progress for redkitchen:  53%|█████▎    | 8/15 [1:57:05<1:42:05, 875.11s/it]


New best model found for redkitchen in fold 3 with validation loss 0.0050
For redkitchen in Epoch 8, Validation loss: 0.0050
 For redkitchen in Epoch 8, Average Loss: 0.0037, Average Translation Error: 0.0440, Average Rotation Error (radians): 0.0212


Epochs Progress for redkitchen:  60%|██████    | 9/15 [2:11:39<1:27:27, 874.66s/it]

For redkitchen in Epoch 9, Validation loss: 0.0052
 For redkitchen in Epoch 9, Average Loss: 0.0037, Average Translation Error: 0.0441, Average Rotation Error (radians): 0.0212


Epochs Progress for redkitchen:  67%|██████▋   | 10/15 [2:26:14<1:12:53, 874.75s/it]


New best model found for redkitchen in fold 3 with validation loss 0.0050
For redkitchen in Epoch 10, Validation loss: 0.0050
 For redkitchen in Epoch 10, Average Loss: 0.0036, Average Translation Error: 0.0426, Average Rotation Error (radians): 0.0204


Epochs Progress for redkitchen:  73%|███████▎  | 11/15 [2:40:47<58:17, 874.43s/it]  

For redkitchen in Epoch 11, Validation loss: 0.0051
 For redkitchen in Epoch 11, Average Loss: 0.0035, Average Translation Error: 0.0429, Average Rotation Error (radians): 0.0200


Epochs Progress for redkitchen:  80%|████████  | 12/15 [2:55:23<43:44, 874.78s/it]


New best model found for redkitchen in fold 3 with validation loss 0.0049
For redkitchen in Epoch 12, Validation loss: 0.0049
 For redkitchen in Epoch 12, Average Loss: 0.0035, Average Translation Error: 0.0426, Average Rotation Error (radians): 0.0205


Epochs Progress for redkitchen:  87%|████████▋ | 13/15 [3:09:59<29:10, 875.20s/it]

For redkitchen in Epoch 13, Validation loss: 0.0050
 For redkitchen in Epoch 13, Average Loss: 0.0035, Average Translation Error: 0.0430, Average Rotation Error (radians): 0.0197


Epochs Progress for redkitchen:  93%|█████████▎| 14/15 [3:24:32<14:34, 874.45s/it]


New best model found for redkitchen in fold 3 with validation loss 0.0048
For redkitchen in Epoch 14, Validation loss: 0.0048
 For redkitchen in Epoch 14, Average Loss: 0.0036, Average Translation Error: 0.0442, Average Rotation Error (radians): 0.0203


Epochs Progress for redkitchen: 100%|██████████| 15/15 [3:39:09<00:00, 876.66s/it]


For redkitchen in Epoch 15, Validation loss: 0.0049
 For redkitchen in Epoch 15, Average Loss: 0.0036, Average Translation Error: 0.0440, Average Rotation Error (radians): 0.0203
--------------------------------------------------
Training complete for redkitchen in Fold 3.
(Best Epoch for Fold 3: 0 with Validation Loss: 0.0048
----------------------------------------------------------------------------------------------------
FOLD 4 for redkitchen
----------------------------------------------------------------------------------------------------


Epochs Progress for redkitchen:   7%|▋         | 1/15 [14:34<3:24:05, 874.66s/it]


New best model found for redkitchen in fold 4 with validation loss 0.0041
For redkitchen in Epoch 1, Validation loss: 0.0041
 For redkitchen in Epoch 1, Average Loss: 0.0037, Average Translation Error: 0.0445, Average Rotation Error (radians): 0.0202


Epochs Progress for redkitchen:  13%|█▎        | 2/15 [29:09<3:09:30, 874.66s/it]


New best model found for redkitchen in fold 4 with validation loss 0.0041
For redkitchen in Epoch 2, Validation loss: 0.0041
 For redkitchen in Epoch 2, Average Loss: 0.0034, Average Translation Error: 0.0420, Average Rotation Error (radians): 0.0194


Epochs Progress for redkitchen:  20%|██        | 3/15 [43:42<2:54:46, 873.89s/it]

For redkitchen in Epoch 3, Validation loss: 0.0042
 For redkitchen in Epoch 3, Average Loss: 0.0035, Average Translation Error: 0.0428, Average Rotation Error (radians): 0.0190


Epochs Progress for redkitchen:  27%|██▋       | 4/15 [58:15<2:40:08, 873.53s/it]

For redkitchen in Epoch 4, Validation loss: 0.0041
 For redkitchen in Epoch 4, Average Loss: 0.0034, Average Translation Error: 0.0419, Average Rotation Error (radians): 0.0197


Epochs Progress for redkitchen:  33%|███▎      | 5/15 [1:12:48<2:25:32, 873.27s/it]

For redkitchen in Epoch 5, Validation loss: 0.0042
 For redkitchen in Epoch 5, Average Loss: 0.0034, Average Translation Error: 0.0428, Average Rotation Error (radians): 0.0190


Epochs Progress for redkitchen:  40%|████      | 6/15 [1:27:22<2:11:03, 873.71s/it]


New best model found for redkitchen in fold 4 with validation loss 0.0040
For redkitchen in Epoch 6, Validation loss: 0.0040
 For redkitchen in Epoch 6, Average Loss: 0.0033, Average Translation Error: 0.0420, Average Rotation Error (radians): 0.0197


Epochs Progress for redkitchen:  47%|████▋     | 7/15 [1:41:57<1:56:31, 873.92s/it]


New best model found for redkitchen in fold 4 with validation loss 0.0040
For redkitchen in Epoch 7, Validation loss: 0.0040
 For redkitchen in Epoch 7, Average Loss: 0.0034, Average Translation Error: 0.0428, Average Rotation Error (radians): 0.0191


Epochs Progress for redkitchen:  53%|█████▎    | 8/15 [1:56:30<1:41:56, 873.76s/it]

For redkitchen in Epoch 8, Validation loss: 0.0041
 For redkitchen in Epoch 8, Average Loss: 0.0032, Average Translation Error: 0.0409, Average Rotation Error (radians): 0.0186


Epochs Progress for redkitchen:  60%|██████    | 9/15 [2:11:04<1:27:23, 873.97s/it]


New best model found for redkitchen in fold 4 with validation loss 0.0039
For redkitchen in Epoch 9, Validation loss: 0.0039
 For redkitchen in Epoch 9, Average Loss: 0.0032, Average Translation Error: 0.0410, Average Rotation Error (radians): 0.0186


Epochs Progress for redkitchen:  67%|██████▋   | 10/15 [2:25:38<1:12:48, 873.77s/it]


New best model found for redkitchen in fold 4 with validation loss 0.0039
For redkitchen in Epoch 10, Validation loss: 0.0039
 For redkitchen in Epoch 10, Average Loss: 0.0030, Average Translation Error: 0.0391, Average Rotation Error (radians): 0.0186


Epochs Progress for redkitchen:  73%|███████▎  | 11/15 [2:40:10<58:13, 873.31s/it]  


New best model found for redkitchen in fold 4 with validation loss 0.0039
For redkitchen in Epoch 11, Validation loss: 0.0039
 For redkitchen in Epoch 11, Average Loss: 0.0031, Average Translation Error: 0.0398, Average Rotation Error (radians): 0.0189


Epochs Progress for redkitchen:  80%|████████  | 12/15 [2:54:42<43:38, 872.94s/it]

For redkitchen in Epoch 12, Validation loss: 0.0041
 For redkitchen in Epoch 12, Average Loss: 0.0031, Average Translation Error: 0.0404, Average Rotation Error (radians): 0.0189


Epochs Progress for redkitchen:  87%|████████▋ | 13/15 [3:09:15<29:05, 872.98s/it]

For redkitchen in Epoch 13, Validation loss: 0.0040
 For redkitchen in Epoch 13, Average Loss: 0.0031, Average Translation Error: 0.0408, Average Rotation Error (radians): 0.0184


Epochs Progress for redkitchen:  93%|█████████▎| 14/15 [3:23:47<14:32, 872.67s/it]

For redkitchen in Epoch 14, Validation loss: 0.0041
 For redkitchen in Epoch 14, Average Loss: 0.0030, Average Translation Error: 0.0395, Average Rotation Error (radians): 0.0187


Epochs Progress for redkitchen: 100%|██████████| 15/15 [3:38:19<00:00, 873.32s/it]


For redkitchen in Epoch 15, Validation loss: 0.0040
 For redkitchen in Epoch 15, Average Loss: 0.0028, Average Translation Error: 0.0380, Average Rotation Error (radians): 0.0177
--------------------------------------------------
Training complete for redkitchen in Fold 4.
(Best Epoch for Fold 4: 0 with Validation Loss: 0.0039
Best model for redkitchen from Fold 4 saved with loss 0.0039
Performance of the best model for redkitchen on the test data was from Fold 4:
Average Translation Error for redkitchen: 0.4521 meters
Average Rotation Error (degree) for redkitchen: 10.2388
--------------------------------------------------


In [5]:
import os
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.nn as nn

# Define Frame class
class Frame:
    def __init__(self, data_type, room_label, sequence, file_name, depth_image_path, color_image_path, pose):
        self.data_type = data_type
        self.room_label = room_label
        self.sequence = sequence
        self.file_name = file_name
        self.depth_image_path = depth_image_path
        self.color_image_path = color_image_path
        self.pose = pose

def parse_pose_file(pose_file_path):
    with open(pose_file_path, 'r') as file:
        pose = [list(map(float, line.strip().split())) for line in file]
    return pose

def create_frame_objects(data_path, room_name, data_type):
    frames = []
    for seq_folder in os.listdir(data_path):
        seq_path = os.path.join(data_path, seq_folder)
        if os.path.isdir(seq_path):
            for frame_file in os.listdir(seq_path):
                if frame_file.endswith('.color.png'):
                    frame_name = frame_file.split('.')[0]
                    depth_image_path = os.path.join(seq_path, f"{frame_name}.depth.png")
                    color_image_path = os.path.join(seq_path, f"{frame_name}.color.png")
                    pose_file_path = os.path.join(seq_path, f"{frame_name}.pose.txt")

                    if os.path.exists(depth_image_path) and os.path.exists(pose_file_path):
                        pose = parse_pose_file(pose_file_path)
                        frame = Frame(data_type, room_name, seq_folder, frame_name, depth_image_path, color_image_path, pose)
                        frames.append(frame)
    return frames

def create_data_structure(data_folder):
    train_data = []
    test_data = []
    room_names = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']
    
    for room_name in room_names:
        test_path = os.path.join(data_folder, room_name, 'test')
        train_path = os.path.join(data_folder, room_name, 'train')
        
        train_data.extend(create_frame_objects(train_path, room_name, 'train'))
        test_data.extend(create_frame_objects(test_path, room_name, 'test'))
        
    return train_data, test_data

# Path to data folder
your_path_to_data_folder = 'data'
train_data, test_data = create_data_structure(your_path_to_data_folder)

# Define Custom Dataset for AlexNet
class CustomDataset(Dataset):
    def __init__(self, frames, label_map, transform=None):
        self.frames = frames
        self.label_map = label_map
        self.transform = transform

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        frame = self.frames[idx]
        image = Image.open(frame.color_image_path).convert('RGB')
        label = self.label_map[frame.room_label]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Transformations for AlexNet
transformations = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

label_map = {label: index for index, label in enumerate(['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs'])}
num_classes = len(label_map) 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

test_dataset = CustomDataset(test_data, label_map, transform=transformations)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load the pre-trained AlexNet model
alexnet_model = models.alexnet(pretrained=False)
alexnet_model.classifier[6] = nn.Linear(alexnet_model.classifier[6].in_features, num_classes)
alexnet_model.load_state_dict(torch.load('best_model_fold_2.pth'))
alexnet_model.to(device)
alexnet_model.eval()

# Define PoseNet Model
class PoseNetModel(nn.Module):
    def __init__(self):
        super(PoseNetModel, self).__init__()
        self.backbone = models.resnet50(pretrained=True) # Do not load pretrained weights here
        self.fc_translation = nn.Linear(self.backbone.fc.in_features, 3)
        self.fc_rotation = nn.Linear(self.backbone.fc.in_features, 9)
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        features = self.backbone(x)
        translation = self.fc_translation(features)
        rotation = self.fc_rotation(features)
        return translation, rotation

# Custom Dataset for PoseNet
class PoseNetDataset(Dataset):
    def __init__(self, frames, transform=None):
        self.frames = frames
        self.transform = transform

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        frame = self.frames[idx]
        image = Image.open(frame.color_image_path).convert('RGB')
        pose_matrix = np.array(frame.pose, dtype=np.float32).reshape(4, 4)
        translation = pose_matrix[:3, 3]
        rotation = pose_matrix[:3, :3]

        if self.transform:
            image = self.transform(image)
        return image, torch.from_numpy(translation), torch.from_numpy(rotation.flatten())

pose_transformations = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load and evaluate models
room_names = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']
pose_models = {room: PoseNetModel().to(device) for room in room_names}
for room in room_names:
    pose_models[room].load_state_dict(torch.load(f'best_pose_model_{room}.pth'))
    pose_models[room].eval()

# Define the criterion for PoseNet
criterion = nn.MSELoss()

def rotation_matrix_to_angle_axis(rotation_matrices):
    """Convert a batch of rotation matrices to angle-axis vectors."""
    # Calculate the trace of each 3x3 rotation matrix in the batch
    traces = torch.einsum('bii->b', rotation_matrices)  # Sum over the diagonal elements in each matrix in the batch
    cos_thetas = (traces - 1) / 2.0
    cos_thetas = torch.clamp(cos_thetas, -1, 1)  # Numerical errors might make cos(theta) slightly out of its range
    thetas = torch.acos(cos_thetas)  # Angles

    # Initialize angle-axis vectors
    angle_axes = torch.zeros_like(rotation_matrices[:, :, 0])

    # Compute sin(theta) for normalization
    sin_thetas = torch.sin(thetas)

    # Find indices where theta is not too small (to avoid division by zero)
    valid = sin_thetas > 1e-5

    # For valid indices where theta is not too small, calculate angle-axis vectors
    angle_axes[valid] = torch.stack([
        rotation_matrices[valid, 2, 1] - rotation_matrices[valid, 1, 2],
        rotation_matrices[valid, 0, 2] - rotation_matrices[valid, 2, 0],
        rotation_matrices[valid, 1, 0] - rotation_matrices[valid, 0, 1]
    ], dim=1) / (2 * sin_thetas[valid].unsqueeze(1)) * thetas[valid].unsqueeze(1)

    return angle_axes

# Function to calculate errors

def calculate_errors(pred_translations, gt_translations, pred_rotations, gt_rotations):
    translation_errors = torch.norm(pred_translations - gt_translations, dim=1)
    
    pred_rot_matrix = pred_rotations.view(-1, 3, 3)
    gt_rot_matrix = gt_rotations.view(-1, 3, 3)
    r_diff = torch.matmul(pred_rot_matrix, gt_rot_matrix.transpose(1, 2))  # Relative rotation
    angle_axis = rotation_matrix_to_angle_axis(r_diff)
    rotation_errors = torch.norm(angle_axis, dim=1)  # Returns the magnitude of the angle-axis vector

    return translation_errors.mean().item(), rotation_errors.mean().item()


    

# Evaluate models
total_translation_error = 0.0
total_rotation_error = 0.0
count = 0

with torch.no_grad():
    for images, _ in test_loader:
        images = images.to(device)
        alexnet_outputs = alexnet_model(images)
        _, predicted_rooms = torch.max(alexnet_outputs.data, 1)
        
        for i in range(images.size(0)):
            predicted_room = predicted_rooms[i].item()
            room_name = room_names[predicted_room]
            
            # Load the PoseNet model for the predicted room
            posenet_model = pose_models[room_name]
            
            # Predict localization using the PoseNet model
            image = images[i].unsqueeze(0)
            frame = test_data[count]
            test_dataset_posenet = PoseNetDataset([frame], transform=pose_transformations)
            test_loader_posenet = DataLoader(test_dataset_posenet, batch_size=1, shuffle=False)
            
            for img, gt_translation, gt_rotation in test_loader_posenet:
                img = img.to(device)
                gt_translation = gt_translation.to(device)
                gt_rotation = gt_rotation.to(device)
                
                pred_translation, pred_rotation = posenet_model(img)
                translation_error, rotation_error = calculate_errors(pred_translation, gt_translation, pred_rotation, gt_rotation)
                total_translation_error += translation_error
                total_rotation_error += rotation_error
                count += 1

average_translation_error = total_translation_error / count
average_rotation_error_radians = total_rotation_error / count
average_rotation_error_degrees = average_rotation_error_radians * (180.0 / np.pi)

print(f"Average Translation Error: {average_translation_error:.4f} meters")
print(f"Average Rotation Error: {average_rotation_error_degrees:.4f} degrees")

Average Translation Error: 0.4215 meters
Average Rotation Error: 11.5806 degrees
