# Motion Capture System Controller

This Jupyter Notebook is the **complete** control panel for the ArUco marker-based motion capture system. From here, you can:

1.  **Load Credentials:** Automatically load credentials from `server/bootfs/pi_settings.conf`.
2.  **Manage Servers:** Remotely reboot all servers at once.
3.  **Perform Distortion Calibration:** Run the headless camera distortion calibration routine for each Pi.
4.  **Run Live 3D Tracking:** Launch a real-time 3D visualization of the tracked markers and camera poses.

---

## 1. Configuration and Imports

**Instructions:**
1.  Before running, ensure you have created a `pi_settings.conf` file inside the `server/bootfs/` directory from the template and filled in your credentials.
2.  Run this cell to import all necessary libraries and load your configuration.

In [None]:
# --- Imports and Setup ---
import os
import sys
import threading
import time
import numpy as np
import cv2
import socket
import json
import logging
import paramiko
import re
from scp import SCPClient
from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, QStatusBar
from PyQt5.QtCore import QThread, QObject, pyqtSignal, pyqtSlot
import pyqtgraph.opengl as gl
from pyqtgraph.opengl import GLMeshItem
from pyqtgraph.Qt import QtGui

# Add project root to path to allow importing config
if '.' not in sys.path:
    sys.path.insert(0, '.')
import config

# --- Logging Setup ---
LOG_DIR = 'client_logs'
os.makedirs(LOG_DIR, exist_ok=True)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(LOG_DIR, 'motion_capture.log')),
        logging.StreamHandler()
    ]
)

# --- Load SSH Credentials from pi_settings.conf ---
SSH_USERNAME, SSH_PASSWORD = None, None
def parse_conf_file(filepath):
    creds = {}
    try:
        with open(filepath, 'r') as f:
            for line in f:
                if line.strip() and not line.startswith('#'):
                    match = re.match(r'^(\w+)="(.*?)"$', line.strip())
                    if match:
                        key, value = match.groups()
                        creds[key] = value
    except FileNotFoundError:
        return None
    return creds

conf_path = os.path.join('server', 'bootfs', 'pi_settings.conf')
pi_config = parse_conf_file(conf_path)
if pi_config and 'SSH_USERNAME' in pi_config and 'SSH_PASSWORD' in pi_config:
    SSH_USERNAME = pi_config['SSH_USERNAME']
    SSH_PASSWORD = pi_config['SSH_PASSWORD']
    logging.info(f"Successfully loaded credentials from {conf_path}.")
    if SSH_USERNAME == "your_pi_username":
        logging.warning("Using default username. Please update 'pi_settings.conf' with your credentials.")
else:
    logging.error(f"Could not load credentials. Please create '{conf_path}' from the template.")

logging.info("Motion capture notebook initialized.")

---

## 2. Remote Server Management

Use this section to reboot all Raspberry Pi servers simultaneously.

In [None]:
def ssh_command(hostname, username, password, command):
    """Executes a command on a remote host via SSH."""
    client = paramiko.SSHClient()
    client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    try:
        logging.info(f"[{hostname}] Connecting...")
        client.connect(hostname, username=username, password=password, timeout=10)
        logging.info(f"[{hostname}] Connected. Sending command: '{command}'")
        stdin, stdout, stderr = client.exec_command(command)
        error = stderr.read().decode()
        if error and "closed by remote host" not in error:
            logging.error(f"[{hostname}] Error: {error.strip()}")
        else:
            logging.info(f"[{hostname}] Command sent successfully.")
    except Exception as e:
        logging.error(f"[{hostname}] Failed to connect or execute command: {e}")
    finally:
        client.close()

def remote_server_reboot():
    """Initiates the remote reboot command for all servers."""
    if not SSH_USERNAME or not SSH_PASSWORD or SSH_USERNAME == "your_pi_username":
        logging.error("SSH credentials not configured. Please set them in 'server/bootfs/pi_settings.conf' and restart the notebook.")
        return
    command_to_run = "sudo reboot"
    threads = []
    logging.info(f"Preparing to send 'reboot' command to all servers...")
    for host in config.SERVER_HOSTS:
        thread = threading.Thread(
            target=ssh_command,
            args=(host, SSH_USERNAME, SSH_PASSWORD, command_to_run)
        )
        threads.append(thread)
        thread.start()
    for thread in threads:
        thread.join()
    logging.info("Remote reboot script finished.")

**Run the cell below to reboot all servers.**

In [None]:
# To reboot all servers
remote_server_reboot()

---

## 3. Headless Camera Distortion Calibration

This section allows you to perform the one-time camera distortion calibration for a single Raspberry Pi entirely over the network.

**Instructions:**
1.  Set the `TARGET_HOSTNAME` to the Pi you want to calibrate.
2.  Run the setup cell.
3.  Position the chessboard pattern in front of the target camera.
4.  Run the "Capture Image" cell repeatedly, repositioning the chessboard each time. Aim for at least 15 successful captures.
5.  Once you have enough images, run the "Perform Calibration & Download" cell.

In [None]:
# 1. Set the hostname of the Pi you want to calibrate
TARGET_HOSTNAME = "pi-mocap-1.local"

In [None]:
# 2. Run this cell to define the CalibrationController class
class CalibrationController:
    """Manages an SSH connection to a Pi for calibration tasks."""
    def __init__(self, hostname):
        self.hostname = hostname
        if not SSH_USERNAME or not SSH_PASSWORD:
            raise ValueError("SSH Credentials not loaded from pi_settings.conf")
        self.username = SSH_USERNAME
        self.password = SSH_PASSWORD
        self.ssh_client = None
        self.repo_dir = None

    def connect(self):
        logging.info(f"--> Connecting to {self.hostname} for calibration...")
        try:
            self.ssh_client = paramiko.SSHClient()
            self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            self.ssh_client.connect(self.hostname, username=self.username, password=self.password, timeout=15)
            logging.info("--> Connection successful.")
            _, stdout, _ = self.ssh_client.exec_command("find /home -type d -name 'PnP-ArUco-marker-tracking' 2>/dev/null | head -n 1")
            self.repo_dir = stdout.read().decode().strip()
            if not self.repo_dir:
                logging.error("Could not find 'PnP-ArUco-marker-tracking' directory on the remote host.")
                return False
            logging.info(f"--> Remote repository found at: {self.repo_dir}")
            return True
        except Exception as e:
            logging.error(f"Error connecting to {self.hostname}: {e}")
            return False

    def run_remote_command(self, command, timeout=30):
        if not self.ssh_client: return False
        logging.info(f"--> Executing: {command}")
        _, stdout, stderr = self.ssh_client.exec_command(command, timeout=timeout)
        exit_status = stdout.channel.recv_exit_status()
        out, err = stdout.read().decode().strip(), stderr.read().decode().strip()
        if out: logging.info(f"[{self.hostname} STDOUT]:\n{out}")
        if err: logging.error(f"[{self.hostname} STDERR]:\n{err}")
        return exit_status == 0

    def capture_image(self):
        script_path = os.path.join(self.repo_dir, "server/distortion_calibration.py")
        command = f"python3 {script_path} --capture --host {self.hostname}"
        return self.run_remote_command(command)

    def perform_calibration(self):
        script_path = os.path.join(self.repo_dir, "server/distortion_calibration.py")
        command = f"python3 {script_path} --calibrate"
        return self.run_remote_command(command, timeout=60)

    def download_calibration_file(self):
        remote_path = os.path.join(self.repo_dir, config.DISTORTION_DATA_FILE)
        local_path = config.DISTORTION_DATA_FILE
        logging.info(f"--> Attempting to download '{remote_path}' to local '{local_path}'...")
        try:
            with SCPClient(self.ssh_client.get_transport()) as scp:
                scp.get(remote_path, local_path)
            logging.info(f"--> Success! Calibration file saved to '{os.path.abspath(local_path)}'")
            return True
        except Exception as e:
            logging.error(f"Error downloading file: {e}")
            return False

    def close(self):
        if self.ssh_client:
            self.ssh_client.close()
            logging.info("--> Calibration connection closed.")

**3. Capture Image** (Run this cell multiple times)

In [None]:
try:
    controller = CalibrationController(TARGET_HOSTNAME)
    if controller.connect():
        try:
            controller.capture_image()
        finally:
            controller.close()
except ValueError as e:
    logging.error(e)

**4. Perform Calibration & Download** (Run this cell once you have enough images)

In [None]:
try:
    controller = CalibrationController(TARGET_HOSTNAME)
    if controller.connect():
        try:
            if controller.perform_calibration():
                time.sleep(1)
                controller.download_calibration_file()
        finally:
            controller.close()
except ValueError as e:
    logging.error(e)

---

## 4. Live 3D Tracking & Visualization

This section launches the main application for tracking markers in 3D. It connects to all servers, performs continuous camera pose estimation, and displays the results in a real-time 3D plot.

**Instructions:**
1.  Ensure the `distortion_calibration.json` file exists in the project root (by running the calibration step above).
2.  Run the cell below. A separate PyQt5 window will open with the 3D visualization.
3.  To stop the application, simply close the visualization window.

In [None]:
# --- Global Data Structures for Visualization ---
live_marker_data = {} 
camera_poses = {}     
tracked_points_3d = {}
data_lock = threading.Lock()

# --- Class Definitions for Visualization ---
class CameraClient(threading.Thread):
    """Manages the connection and data flow for a single Pi server."""
    def __init__(self, host, port):
        super().__init__()
        self.host, self.port = host, port
        self.sock, self.is_connected = None, False
        self.stop_event = threading.Event()
        self.daemon = True

    def run(self):
        while not self.stop_event.is_set():
            try:
                logging.info(f"[{self.host}] Attempting to connect...")
                self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                self.sock.settimeout(5)
                self.sock.connect((self.host, self.port))
                self.is_connected = True
                logging.info(f"[{self.host}] Connection successful.")
                self.listen_for_data()
            except (ConnectionRefusedError, socket.gaierror, socket.timeout) as e:
                logging.warning(f"[{self.host}] Connection failed: {e}. Retrying in 5s.")
                self.is_connected = False
                time.sleep(5)
            except Exception as e:
                logging.error(f"[{self.host}] Unhandled exception: {e}")
                self.is_connected = False
                time.sleep(5)

    def listen_for_data(self):
        f = self.sock.makefile()
        while not self.stop_event.is_set():
            try:
                line = f.readline()
                if not line: break
                self._process_server_message(line.strip())
            except (IOError, ConnectionResetError): break
        self.is_connected = False
        logging.info(f"[{self.host}] Disconnected.")

    def _process_server_message(self, message):
        global live_marker_data
        try:
            marker_positions = json.loads(message)
            with data_lock:
                for marker in marker_positions:
                    marker_id = marker['id']
                    if marker_id not in live_marker_data: live_marker_data[marker_id] = {}
                    live_marker_data[marker_id][self.host] = tuple(marker['pos'])
        except json.JSONDecodeError: pass

    def stop(self):
        self.stop_event.set()
        if self.sock:
            try: self.sock.shutdown(socket.SHUT_RDWR)
            except OSError: pass
            self.sock.close()

class ProcessingWorker(QObject):
    """Worker that runs PnP and triangulation in a separate thread."""
    new_data = pyqtSignal()

    def __init__(self):
        super().__init__()
        self.stop_event = threading.Event()
        self.camera_matrix, self.dist_coeffs = None, None
        self.pnp_object_points, self.pnp_marker_ids = None, None

    def run(self):
        if not self.load_calibration_data(): return
        while not self.stop_event.is_set():
            self.update_camera_poses()
            self.triangulate_tracked_points()
            self.new_data.emit()
            time.sleep(0.05)

    def load_calibration_data(self):
        try:
            with open(config.DISTORTION_DATA_FILE, 'r') as f:
                calib_data = json.load(f)
                self.camera_matrix = np.array(calib_data['camera_matrix'])
                self.dist_coeffs = np.array(calib_data['distortion_coefficients'])
            self.pnp_marker_ids = list(config.PNP_MARKER_WORLD_COORDINATES.keys())
            self.pnp_object_points = np.array([config.PNP_MARKER_WORLD_COORDINATES[i] for i in self.pnp_marker_ids], dtype=np.float32)
            logging.info("Distortion calibration and PnP data loaded.")
            return True
        except FileNotFoundError:
            logging.error(f"'{config.DISTORTION_DATA_FILE}' not found. Run calibration first.")
            return False

    def update_camera_poses(self):
        global camera_poses
        with data_lock: current_data = {mid: obs.copy() for mid, obs in live_marker_data.items()}
        for host in config.SERVER_HOSTS:
            image_points, object_points = [], []
            for marker_id in self.pnp_marker_ids:
                if marker_id in current_data and host in current_data[marker_id]:
                    image_points.append(current_data[marker_id][host])
                    object_points.append(config.PNP_MARKER_WORLD_COORDINATES[marker_id])
            if len(image_points) >= 4:
                success, rvec, tvec = cv2.solvePnP(np.array(object_points, dtype=np.float32), np.array(image_points, dtype=np.float32), self.camera_matrix, self.dist_coeffs)
                if success:
                    with data_lock: camera_poses[host] = {'rvec': rvec, 'tvec': tvec, 'time': time.time()}

    def triangulate_tracked_points(self):
        global tracked_points_3d
        with data_lock:
            valid_poses = {h: p for h, p in camera_poses.items() if time.time() - p['time'] < 2.0}
            if len(valid_poses) < 2:
                tracked_points_3d = {}
                return
            current_data = {mid: obs.copy() for mid, obs in live_marker_data.items()}
            current_tracked = {}
            for marker_id, observations in current_data.items():
                if marker_id in self.pnp_marker_ids: continue
                valid_observations = {h: p for h, p in observations.items() if h in valid_poses}
                if len(valid_observations) < 2: continue
                pos_3d = self.triangulate_point(valid_observations, valid_poses)
                if pos_3d is not None: current_tracked[marker_id] = pos_3d
            tracked_points_3d = current_tracked

    def triangulate_point(self, observations, poses):
        proj_matrices, points_2d = [], []
        obs_items = list(observations.items())[:2]
        for host, point_2d in obs_items:
            pose = poses[host]
            R, _ = cv2.Rodrigues(pose['rvec'])
            extrinsic_matrix = np.hstack((R, pose['tvec']))
            proj_matrix = self.camera_matrix @ extrinsic_matrix
            proj_matrices.append(proj_matrix)
            points_2d.append(np.array(point_2d, dtype=np.float32))
        points1, points2 = points_2d[0].reshape(2, 1), points_2d[1].reshape(2, 1)
        points_4d_hom = cv2.triangulatePoints(proj_matrices[0], proj_matrices[1], points1, points2)
        return (points_4d_hom[:3] / points_4d_hom[3]).flatten() if points_4d_hom[3] != 0 else None

    def stop(self):
        self.stop_event.set()

class VisualizationWindow(QMainWindow):
    """The main application window with the 3D plot."""
    def __init__(self):
        super().__init__()
        self.setWindowTitle("Real-Time 3D Motion Capture Viewer")
        self.setGeometry(100, 100, 1200, 800)
        central_widget = QWidget()
        self.setCentralWidget(central_widget)
        layout = QVBoxLayout(central_widget)
        self.view = gl.GLViewWidget()
        layout.addWidget(self.view)
        self.view.setCameraPosition(distance=200, elevation=30, azimuth=45)
        grid = gl.GLGridItem(); grid.scale(20, 20, 1); self.view.addItem(grid)
        ref_coords = np.array(list(config.PNP_MARKER_WORLD_COORDINATES.values()))
        self.ref_points = gl.GLScatterPlotItem(pos=ref_coords, color=(1, 0, 0, 1), size=10, pxMode=True)
        self.view.addItem(self.ref_points)
        self.live_points = gl.GLScatterPlotItem(pos=np.empty((0,3)), color=(0, 0, 1, 1), size=15, pxMode=True)
        self.view.addItem(self.live_points)
        self.camera_meshes, self.camera_mesh_template = {}, self.create_camera_pyramid_mesh()
        for host in config.SERVER_HOSTS:
            mesh_item = GLMeshItem(meshdata=self.camera_mesh_template, smooth=False, drawEdges=True, edgeColor=(1,1,0,1), shader='balloon')
            self.camera_meshes[host] = mesh_item
            self.view.addItem(mesh_item)
        self.statusBar = QStatusBar(); self.setStatusBar(self.statusBar)

    def create_camera_pyramid_mesh(self):
        verts = np.array([[0,0,0], [-5,-4,10], [5,-4,10], [5,4,10], [-5,4,10]])
        faces = np.array([[0,1,2], [0,2,3], [0,3,4], [0,4,1], [1,2,3], [1,3,4]])
        return gl.MeshData(vertexes=verts, faces=faces)

    @pyqtSlot()
    def update_plot(self):
        global tracked_points_3d, camera_poses
        with data_lock:
            if tracked_points_3d:
                self.live_points.setData(pos=np.array(list(tracked_points_3d.values())))
            else:
                self.live_points.setData(pos=np.empty((0,3)))
            calibrated_cams = 0
            for host, pose in camera_poses.items():
                if time.time() - pose['time'] < 2.0:
                    calibrated_cams += 1
                    R, _ = cv2.Rodrigues(pose['rvec'])
                    transform = np.eye(4); transform[:3, :3] = R; transform[:3, 3] = pose['tvec'].flatten()
                    flip_yz = np.array([[1,0,0,0], [0,-1,0,0], [0,0,-1,0], [0,0,0,1]])
                    final_transform = QtGui.QMatrix4x4(*(flip_yz @ transform).T.flatten())
                    self.camera_meshes[host].setTransform(final_transform)
                    self.camera_meshes[host].setVisible(True)
                else:
                    self.camera_meshes[host].setVisible(False)
        self.statusBar.showMessage(f"Tracking {len(tracked_points_3d)} markers | Calibrated Cameras: {calibrated_cams}/{len(config.SERVER_HOSTS)}")

    def closeEvent(self, event):
        logging.info("Visualization window closed by user.")
        event.accept()

# --- Main Execution Logic ---
def run_visualization():
    app = QApplication.instance() or QApplication(sys.argv)
    
    clients = [CameraClient(host, config.NETWORK_PORT) for host in config.SERVER_HOSTS]
    for client in clients: client.start()

    processing_thread = QThread()
    processing_worker = ProcessingWorker()
    processing_worker.moveToThread(processing_thread)
    
    window = VisualizationWindow()
    
    processing_thread.started.connect(processing_worker.run)
    processing_worker.new_data.connect(window.update_plot)
    
    processing_thread.start()
    window.show()

    def cleanup():
        logging.info("Application cleanup initiated.")
        processing_worker.stop()
        processing_thread.quit(); processing_thread.wait()
        for client in clients: client.stop(); client.join()
    
    app.aboutToQuit.connect(cleanup)
    app.exec_()

# Run the visualization
run_visualization()