In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING']="1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import shutil
import json
import pandas as pd
import cv2
from glob import glob
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict

import time
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.cuda.amp.autocast.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*torch.meshgrid.*")

import sys
sys.path.append('/home/lea/trampo/MODELS_2D3D/mmpose')
sys.path.append('/home/lea/trampo/metrics')

from mmpose.apis import init_model as init_pose_estimator
from mmpose.utils import adapt_mmdet_pipeline

try:
    from mmdet.apis import inference_detector, init_detector
    has_mmdet = True
except (ImportError, ModuleNotFoundError):
    has_mmdet = False

from utils import predict_multiview_with_grad, find_best_triangulation, project_points, find_triangulation, show_keypoints_on_im, resize_and_pad_keep_aspect

  from pkg_resources import DistributionNotFound, get_distribution


### Set models parameters

In [2]:
# Set training parameters
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cuda' # FOR TESTING

LEARNING_RATE = 1e-5
BATCH_SIZE = 16
NUM_WORKERS = 4
N = 8 # number of cameras
K = 17 # number of keypoints

# Init detector
det_config = '/home/lea/trampo/MODELS_2D3D/mmpose/demo/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py'
det_checkpoint = 'https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth'
#det_config = "/home/lea/trampo/MODELS_2D3D/rtmpose/RTMPose/rtmdet_nano_320-8xb32_coco-person.py" 
#det_checkpoint = "/home/lea/trampo/MODELS_2D3D/rtmpose/RTMPose/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth"
det_model= init_detector(det_config, det_checkpoint, device=device)
det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)

# Init pose model
pose_config = '/home/lea/trampo/MODELS_2D3D/mmpose/configs/body_2d_keypoint/rtmpose/body8/rtmpose-m_8xb256-420e_body8-256x192.py'
pose_checkpoint = 'https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-body7_pt-body7_420e-256x192-e48f03d0_20230504.pth'
#pose_config = "/home/lea/trampo/MODELS_2D3D/rtmpose/RTMPose/rtmpose-l_8xb256-420e_coco-256x192.py"  
#pose_checkpoint = "/home/lea/trampo/MODELS_2D3D/rtmpose/RTMPose/rtmpose-l_simcc-aic-coco_pt-aic-coco_420e-256x192-f016ffe0_20230126.pth"
pose_model = init_pose_estimator(pose_config, pose_checkpoint, device=device)
pose_model.train()

# Unfreeze all parameters
for param in pose_model.parameters():
    param.requires_grad = True
    
optimizer = torch.optim.Adam(pose_model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)

  _bootstrap._exec(spec, module)


Loads checkpoint by http backend from path: https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth
Loads checkpoint by http backend from path: https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-body7_pt-body7_420e-256x192-e48f03d0_20230504.pth


### Prepare data

In [3]:
root_dir = '/mnt/D494C4CF94C4B4F0/Trampoline_avril2025/Images_trampo_avril2025/20250429'
data_dir = '/home/lea/trampo/MODELS_2D3D/finetuning_multiview/dataset'

sequences = set([str(f).split('-')[0] for f in os.listdir(root_dir)])

sequences = sorted([seq for seq in sequences if seq[0] in ['1', '2']])

cameras = ['Camera1_M11139', 'Camera2_M11140', 'Camera3_M11141', 'Camera4_M11458',
           'Camera5_M11459', 'Camera6_M11461', 'Camera7_M11462', 'Camera8_M11463']

K = np.load('calib/K.npz')['arr_0']
Ks = torch.tensor(K, dtype=torch.float32, device=device)

D = np.load('calib/D.npz')['arr_0']

""" for seq in sequences:
    if not os.path.isdir(os.path.join(data_dir, seq)):
        os.makedirs(os.path.join(data_dir, seq))
    for cam in cameras:
        if not os.path.isdir(os.path.join(data_dir, seq, cam)):
            dest_dir = os.path.join(data_dir, seq, cam)
            source_dir = os.path.join(root_dir, seq+'-'+cam)
            shutil.copytree(source_dir, dest_dir) """

" for seq in sequences:\n    if not os.path.isdir(os.path.join(data_dir, seq)):\n        os.makedirs(os.path.join(data_dir, seq))\n    for cam in cameras:\n        if not os.path.isdir(os.path.join(data_dir, seq, cam)):\n            dest_dir = os.path.join(data_dir, seq, cam)\n            source_dir = os.path.join(root_dir, seq+'-'+cam)\n            shutil.copytree(source_dir, dest_dir) "

In [4]:
# inventaire athlète/frame/caméra/seq
error_thresh = 500
person_dist_thresh = 500  # distance threshold to discard mismatched persons

dataset_path = '/home/lea/trampo/MODELS_2D3D/finetuning_multiview/dataset'
detections = []

""" for seq in os.listdir(dataset_path):
    print(seq)
    person_id = int(str(seq).split('_')[0])

    session = str(seq).split('-')[0].split('_')[2]
    calib_path = os.path.join('calib', f'WorldTCam_{session}.npz')
    world_T_cam = np.load(calib_path)['arr_0']
    projMat = np.stack([np.linalg.inv(mat) for mat in world_T_cam])
    Ts = torch.tensor(projMat, dtype=torch.float32, device=device).unsqueeze(0)
    
    for frame in tqdm(sorted(list(os.listdir(os.path.join(dataset_path, seq, 'Camera1_M11139'))))):
        frame_nb = int(str(frame).split('_')[1].split('.')[0])

        images = []
        for cam in os.listdir(os.path.join(dataset_path, seq)):
            cam_idx = int(str(cam)[6]) - 1
            img_path = os.path.join(dataset_path, seq, cam, frame)
            img = cv2.imread(img_path)
            img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
            images.append(img)
        images = torch.stack(images).unsqueeze(0)

        with torch.no_grad():
            keypoints = predict_multiview_with_grad(det_model, pose_model, images, bbox_thr=0.3, pose_batch_size=8, training=False)
            error, preds_2d, points_3d = find_triangulation(keypoints, Ks, Ts, error_thresh)
        
        if torch.isnan(points_3d).all():
            continue

        # Reproject and add cameras to cams_on if error is below error_thresh
        Rt = Ts[:, :, :3, :]
        P_all = Ks @ Rt
        reproj, valid_mask = project_points(points_3d, P_all)

        dist = torch.norm(preds_2d - reproj, dim=-1).view(8, 17)  # Euclidean distance
        keep_mask = (dist < person_dist_thresh).all(dim=1)
        
        if keep_mask.sum() == 0:
            continue
        cams_on_all = keep_mask
        
        for cam_i, cam_on in enumerate(cams_on_all.squeeze()):
            if cam_on:
                #print((seq, cam_i, frame_nb, person_id))
                detections.append((seq, cam_i, frame_nb, person_id))

df = pd.DataFrame(detections)
df.columns = ["seq", "cam", "frame", "person"]
df.to_pickle("detections.pkl")

print(df) """

' for seq in os.listdir(dataset_path):\n    print(seq)\n    person_id = int(str(seq).split(\'_\')[0])\n\n    session = str(seq).split(\'-\')[0].split(\'_\')[2]\n    calib_path = os.path.join(\'calib\', f\'WorldTCam_{session}.npz\')\n    world_T_cam = np.load(calib_path)[\'arr_0\']\n    projMat = np.stack([np.linalg.inv(mat) for mat in world_T_cam])\n    Ts = torch.tensor(projMat, dtype=torch.float32, device=device).unsqueeze(0)\n    \n    for frame in tqdm(sorted(list(os.listdir(os.path.join(dataset_path, seq, \'Camera1_M11139\'))))):\n        frame_nb = int(str(frame).split(\'_\')[1].split(\'.\')[0])\n\n        images = []\n        for cam in os.listdir(os.path.join(dataset_path, seq)):\n            cam_idx = int(str(cam)[6]) - 1\n            img_path = os.path.join(dataset_path, seq, cam, frame)\n            img = cv2.imread(img_path)\n            img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0\n            images.append(img)\n        images = torch.stack(images).unsquee

In [5]:
def create_cropped_dataset(detector, root_dir, bbox_thr, save_root="datasets/cropped_multiview", model_input_size=(256, 192)):
    os.makedirs(save_root, exist_ok=True)
    detector.eval()

    for seq in tqdm(sorted(os.listdir(root_dir)), desc="Generating cropped dataset"):
        seq_path = os.path.join(save_root, seq)
        os.makedirs(seq_path, exist_ok=True)

        # Skip seq if already processed
        if os.path.isfile(os.path.join(seq_path, f"{seq}_meta.pt")):
            continue

        meta_seq = []

        for cam_dir in sorted(os.listdir(os.path.join(root_dir, seq))):
            cam_path = os.path.join(seq_path, cam_dir)
            os.makedirs(cam_path, exist_ok=True)

            for frame in sorted(os.listdir(os.path.join(root_dir, seq, cam_dir))):
                img = cv2.imread(os.path.join(root_dir, seq, cam_dir, frame))

                det_result = inference_detector(detector, img)
                pred = det_result.pred_instances.cpu().numpy()
                bboxes = pred.bboxes[np.logical_and(pred.labels == 0, pred.scores > bbox_thr)]

                for pid, bbox in enumerate(bboxes):
                    x1, y1, x2, y2 = map(int, bbox)
                    if x2 <= x1 or y2 <= y1:
                        continue

                    crop = img[y1:y2, x1:x2]
                    crop_resized, scale, pads = resize_and_pad_keep_aspect(crop, model_input_size)

                    crop_path = os.path.join(cam_path, frame.split('.')[0] + f"_person{pid:02d}.jpg")
                    cv2.imwrite(crop_path, crop_resized)

                    meta_seq.append({
                        "camera": cam_dir,
                        "frame": frame.split('.')[0],
                        "person": pid,
                        "origin": [x1, y1],
                        "scale": scale,
                        "pads": pads,
                        "crop_path": os.path.relpath(crop_path, save_root)
                    })
        
        torch.save(meta_seq, os.path.join(seq_path, f"{seq}_meta.pt"))

# Create cropped dataset
save_path = '/mnt/D494C4CF94C4B4F0/Trampoline_avril2025/dataset_finetune2d/cropped'
#create_cropped_dataset(det_model, root_dir, bbox_thr=0.3, save_root=save_path, model_input_size=(256, 192))

### Detections stats on dataset

In [6]:
""" # Load your detections
df = pd.read_pickle("detections.pkl")

print("Total detections:", len(df))
print("Unique sequences:", df["seq"].nunique())

# --- Base stats per sequence ---
stats = df.groupby("seq").agg(
    n_frames=("frame", "nunique"),
    n_cameras=("cam", "nunique"),
    n_persons=("person", "nunique"),
    total_detections=("frame", "count"),
).reset_index()

# --- Camera coverage per sequence ---
# Count detections per seq, cam, frame
coverage = df.groupby(["seq", "cam", "frame"]).size().reset_index(name="count")

# For each (seq, cam), count frames with detections
coverage = coverage.groupby(["seq", "cam"]).agg(frames_with_detections=("frame", "nunique")).reset_index()

# Merge with total frames per sequence to get % coverage
coverage = coverage.merge(stats[["seq", "n_frames"]], on="seq", how="left")
coverage["coverage_ratio"] = 100 * coverage["frames_with_detections"] / coverage["n_frames"]

# Pivot for readability (each camera in its own column)
coverage_pivot = coverage.pivot(index="seq", columns="cam", values="coverage_ratio").fillna(0)
coverage_pivot.columns = [f"cam{int(c)+1}_coverage(%)" for c in coverage_pivot.columns]

# Merge coverage with global stats
stats_full = stats.merge(coverage_pivot, on="seq", how="left")

# --- Results ---
print("\n=== Detection Stats per Sequence ===")
print(stats_full.round(2))

# Save if useful
stats_full.to_csv("detections_stats_with_coverage.csv", index=False)
 """

' # Load your detections\ndf = pd.read_pickle("detections.pkl")\n\nprint("Total detections:", len(df))\nprint("Unique sequences:", df["seq"].nunique())\n\n# --- Base stats per sequence ---\nstats = df.groupby("seq").agg(\n    n_frames=("frame", "nunique"),\n    n_cameras=("cam", "nunique"),\n    n_persons=("person", "nunique"),\n    total_detections=("frame", "count"),\n).reset_index()\n\n# --- Camera coverage per sequence ---\n# Count detections per seq, cam, frame\ncoverage = df.groupby(["seq", "cam", "frame"]).size().reset_index(name="count")\n\n# For each (seq, cam), count frames with detections\ncoverage = coverage.groupby(["seq", "cam"]).agg(frames_with_detections=("frame", "nunique")).reset_index()\n\n# Merge with total frames per sequence to get % coverage\ncoverage = coverage.merge(stats[["seq", "n_frames"]], on="seq", how="left")\ncoverage["coverage_ratio"] = 100 * coverage["frames_with_detections"] / coverage["n_frames"]\n\n# Pivot for readability (each camera in its own col

### Create dataset + dataloader

In [7]:
class CroppedMultiViewDataset(Dataset):
    def __init__(self, cropped_dir, raw_dir, K, downsample=1):
        self.cropped_dir = cropped_dir
        self.raw_dir = raw_dir
        self.downsample = max(1, downsample)  # avoid division by zero

        sequences = sorted([d for d in os.listdir(cropped_dir) if os.path.isdir(os.path.join(cropped_dir, d))])
        self.sequences = sorted(sequences)
        
        self.sequence_data = []  # (seq_name, frame_names, calibration)
        self.index_map = []      # global index -> (seq_idx, frame_idx)

        for seq_idx, seq_name in enumerate(self.sequences):
            seq_path = os.path.join(cropped_dir, seq_name)

            # Calibration
            session = seq_name.split('-')[0].split('_')[2]
            calib_path = os.path.join('calib', f'WorldTCam_{session}.npz')

            world_T_cam = np.load(calib_path)['arr_0']
            projMat = np.stack([np.linalg.inv(mat) for mat in world_T_cam])
            Ts = torch.tensor(projMat, dtype=torch.float32)
            Ks = torch.tensor(K, dtype=torch.float32)

            # Cameras
            cam_dirs = sorted([d for d in os.listdir(seq_path) if d.startswith("Cam")])
            cam_dirs = [os.path.join(seq_path, d) for d in cam_dirs]

            # Collect all unique frame base names across all cameras
            frame_names = sorted({
                (f.split('_p')[0], f.split('_')[2])  # (frame, person_id)
                for cam_dir in cam_dirs
                for f in os.listdir(cam_dir)
                if f.endswith('.jpg')
            })
            frame_names = frame_names[::self.downsample]  # downsample by taking every Nth frame

            frame_to_persons = defaultdict(list)
            for cam_dir in cam_dirs:
                for f in os.listdir(cam_dir):
                    if f.endswith('.jpg'):
                        frame_name = f.split('_p')[0]
                        frame_nb = int(f.split('_')[1])  # get 0001 from 'frame_0001_p01.jpg'
                        if frame_nb % downsample == 0:
                            person_id = f.split('_')[2]  # get 1 from 'p01.jpg'
                            if person_id not in frame_to_persons[frame_name]:
                                frame_to_persons[frame_name].append(person_id)
            
            frame_names = sorted(frame_to_persons.keys())
            frame_to_persons = {k: frame_to_persons[k] for k in frame_names}

            self.sequence_data.append({
                "name": seq_name,
                "cam_dirs": cam_dirs,
                "frame_names": frame_to_persons,
                "Ks": Ks,
                "Ts": Ts
            })

            for frame_idx in frame_names:
                self.index_map.append((seq_idx, frame_idx))

        self.num_views = len(self.sequence_data[0]["cam_dirs"])

        self.meta_files = sorted([
            os.path.join(cropped_dir, seq_dir, f)
            for seq_dir in sequences
            for f in os.listdir(os.path.join(cropped_dir, seq_dir))
            if f.endswith('_meta.pt')
        ])

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        seq_idx, frame_name = self.index_map[idx]
        seq_info = self.sequence_data[seq_idx]

        Ks, Ts = seq_info["Ks"], seq_info["Ts"]
        person_ids = seq_info["frame_names"][frame_name]
        frame = frame_name
        seq_name = seq_info["name"]

        N_max = len(person_ids)
        cam_dirs = seq_info["cam_dirs"]
        V = len(cam_dirs)
        crop_images = torch.full(
            (V, N_max, 3, 256, 192),
            float("nan"),
            dtype=torch.float32
        )
        images = torch.full(
            (V, 3, 1080, 1920),
            float("nan"),
            dtype=torch.float32
        )

        metas_all = [[None for _ in range(N_max)] for _ in range(V)]

        for v, cam_dir in enumerate(seq_info["cam_dirs"]):
            valid_i = 0

            img_path = os.path.join(cam_dir.replace('cropped', 'raw'), frame+'.png')
            img = cv2.imread(img_path)
            img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
            images[v] = img
            
            for person_id in person_ids:
                pid = int(person_id.split('person')[-1].split('.')[0])
                crop_path = os.path.join(cam_dir, frame+'_'+person_id)

                if not os.path.isfile(crop_path):
                    continue

                crop = cv2.imread(crop_path)
                crop = torch.from_numpy(crop).permute(2, 0, 1).float() / 255.0
                crop_images[v, valid_i] = crop

                frame_idx = int(str(frame).split('_')[1])

                meta_path = self.meta_files[seq_idx]
                meta_data = torch.load(meta_path)
                meta_frame = next(
                    (m for m in meta_data if m['camera'] == os.path.basename(cam_dir) and m['frame'] == f'{frame_name}.png' and m['person'] == pid),
                    None
                )
                metas_all[v][valid_i] = meta_frame
                valid_i += 1

        return {
            "images": images,  # (V,C,H,W)
            "crops": crop_images,  # (V,C,H,W)
            "metas": metas_all,
            "Ks": Ks,
            "Ts": Ts,
            "seq_name": seq_name,
            "frame_idx": frame_idx,
        }

cropped_dir = '/mnt/D494C4CF94C4B4F0/Trampoline_avril2025/dataset_finetune2d/cropped'
raw_dir = '/mnt/D494C4CF94C4B4F0/Trampoline_avril2025/dataset_finetune2d/raw'
dataset = CroppedMultiViewDataset(cropped_dir, raw_dir, K=K, downsample=5)

print('len =', len(dataset))  # total number of frames across all sequences
sample = dataset[1]
print(sample["crops"].shape)  # (V,C,H,W)
print(sample['metas'])


len = 1094
torch.Size([8, 3, 3, 256, 192])
[[{'camera': 'Camera1_M11139', 'frame': 'frame_00005.png', 'person': 0, 'origin': [114, 0], 'scale': 0.5680473372781065, 'pads': (0, 85), 'crop_path': '1_partie_0429_000/Camera1_M11139/frame_00005_person00.jpg'}, None, None], [{'camera': 'Camera2_M11140', 'frame': 'frame_00005.png', 'person': 0, 'origin': [501, 622], 'scale': 1.136094674556213, 'pads': (0, 87), 'crop_path': '1_partie_0429_000/Camera2_M11140/frame_00005_person00.jpg'}, {'camera': 'Camera2_M11140', 'frame': 'frame_00005.png', 'person': 2, 'origin': [523, 324], 'scale': 1.4545454545454546, 'pads': (0, 88), 'crop_path': '1_partie_0429_000/Camera2_M11140/frame_00005_person02.jpg'}, {'camera': 'Camera2_M11140', 'frame': 'frame_00005.png', 'person': 1, 'origin': [484, 367], 'scale': 2.3132530120481927, 'pads': (0, 10), 'crop_path': '1_partie_0429_000/Camera2_M11140/frame_00005_person01.jpg'}], [{'camera': 'Camera3_M11141', 'frame': 'frame_00005.png', 'person': 0, 'origin': [0, 371], 

In [8]:
def collate_padded(batch):
    """
    Collate function for multiview crops with variable N (persons per frame).
    Pads missing detections with NaN.
    Vectorized for images and crops.
    """
    # === Extract items ===
    images_list = [item["images"] for item in batch]  # (V, C, H, W)
    crops_list  = [item["crops"]  for item in batch]  # (V, N_i, C, H, W)
    Ks_list     = [item["Ks"]     for item in batch]
    Ts_list     = [item["Ts"]     for item in batch]
    metas_list  = [item["metas"]  for item in batch]
    seq_names   = [item["seq_name"] for item in batch]
    frame_idxs  = [item["frame_idx"] for item in batch]

    B = len(batch)
    V = crops_list[0].shape[0]

    # === Pad original images (vectorized) ===
    padded_orig_images = torch.stack(images_list, dim=0)  # (B, V, C, H, W)

    # === Pad crops (efficient, minimal loops) ===
    N_max = max(c.shape[1] for c in crops_list)
    C, H, W = crops_list[0].shape[2:]
    padded_crops = torch.full((B, V, N_max, C, H, W), float("nan"), dtype=torch.float32)
    for i, c in enumerate(crops_list):
        padded_crops[i, :, :c.shape[1]] = c

    # === Pad metas ===
    padded_metas = [
        [[*(metas_b[v]), *([None] * (N_max - len(metas_b[v])))] for v in range(V)]
        for metas_b in metas_list
    ]

    # === Stack Ks and Ts ===
    Ks = torch.stack(Ks_list)
    Ts = torch.stack(Ts_list)

    return {
        "images": padded_orig_images,  # (B, V, C, H, W)
        "crops": padded_crops,         # (B, V, N, C, H, W)
        "metas": padded_metas,
        "Ks": Ks,
        "Ts": Ts,
        "seq_name": seq_names,
        "frame_idx": frame_idxs
    }

# Dataloader
train_size = int(0.8 * len(dataset))
val_size   = int(0.1 * len(dataset))
test_size  = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)  # reproducible split
)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, persistent_workers=True, collate_fn=collate_padded)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, persistent_workers=True, collate_fn=collate_padded)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_padded)

""" for batch in train_loader:
    for el in range(BATCH_SIZE):
        print('\n--- Element', el, '---')
        print(batch["crops"].shape)  # (B, V, C, H, W)
        print(batch["seq_name"][el])
        print(batch["frame_idx"][el])
        #im = batch["crops"][el]
        #print(im.shape)
        #plt.imshow(im[0][0].permute(1, 2, 0).numpy())
        #plt.show()

    break """

' for batch in train_loader:\n    for el in range(BATCH_SIZE):\n        print(\'\n--- Element\', el, \'---\')\n        print(batch["crops"].shape)  # (B, V, C, H, W)\n        print(batch["seq_name"][el])\n        print(batch["frame_idx"][el])\n        #im = batch["crops"][el]\n        #print(im.shape)\n        #plt.imshow(im[0][0].permute(1, 2, 0).numpy())\n        #plt.show()\n\n    break '

### Bone length loss

In [9]:
seg_lengths = [(1, 290, 230, 440, 330),
               (2, 285, 260, 520, 360),
               (3, 310, 230, 520, 370),
               (4, 285, 250, 520, 350),
               (5, 340, 270, 540, 420),
               (6, 390, 270, 530, 370),
               (7, 300, 240, 450, 350)]

map_seg_to_keypoints = [(2, 3), (3, 4), (9, 10), (10, 11), (5, 6), (6, 7), (12, 13), (13, 14)]

df = pd.DataFrame(seg_lengths)
df.columns = ["id", "shoulder-elbow", "elbow-wrist", "hip-knee", "knee-ankle"]
df.to_pickle("seg_lengths.pkl")
print(df)

def get_loss_bones(seg_lengths, keypoints, map):
    loss = 0
    for side in range(2):
        for i, l in enumerate(seg_lengths):
            pair = map[side][i]
            loss += abs(l - torch.linalg.norm(keypoints[pair[0]] - keypoints[pair[1]]))
    return loss/8

def bone_length_loss(keypoints, seg_lengths, map_seg_to_keypoints, device=None):
    """
    keypoints: Tensor [B, N, D]  (D=2 or 3)
    seg_lengths: list of per-bone target lengths (float)
    map_seg_to_keypoints: list of tuples [(i,j), (k,l), ...]
    """
    # Convert to tensors
    map_seg = torch.tensor(map_seg_to_keypoints, device=device)
    target_lengths = seg_lengths #, device=keypoints.device, dtype=torch.float32)

    # Extract coordinates for bone endpoints
    kp1 = keypoints[:, map_seg[:, 0], :]  # [B, n_bones, D]
    kp2 = keypoints[:, map_seg[:, 1], :]  # [B, n_bones, D]

    # Compute bone lengths per sample
    bone_lengths = torch.linalg.norm(kp1 - kp2, dim=-1)  # [B, n_bones]

    # Compute loss (L2 difference)
    loss = torch.nanmean(torch.abs(bone_lengths - target_lengths))
    return loss


   id  shoulder-elbow  elbow-wrist  hip-knee  knee-ankle
0   1             290          230       440         330
1   2             285          260       520         360
2   3             310          230       520         370
3   4             285          250       520         350
4   5             340          270       540         420
5   6             390          270       530         370
6   7             300          240       450         350


### GT from Pose2Sim 8 cam triangulation ViT

In [10]:
from data_analysis import extract_coordinates

vit = ['Nose', 'REye', 'LEye', 'REar', 'LEar', 'RShoulder', 'LShoulder', 'RElbow', 'LElbow', 'RWrist', 'LWrist',
       'RHip', 'LHip', 'RKnee', 'LKnee', 'RAnkle', 'LAnkle']
pose2sim_vit = ['RHip', 'RKnee', 'RAnkle', 'LHip', 'LKnee', 'LAnkle', 'Nose',
                'RShoulder', 'RElbow', 'RWrist', 'LShoulder', 'LElbow', 'LWrist']
common_indices_vit = [j for j in vit if j in pose2sim_vit]
matching_pose2sim_vit = [pose2sim_vit.index(j) for j in common_indices_vit]
matching_vit = [vit.index(j) for j in common_indices_vit]

def get_gt_coords(seqs, frames, rotate=True):
   # define Rotation for Pose2Sim to World
   R = np.array([
      [0, 0, 1],
      [-1, 0, 0],
      [0, 1, 0]
   ])

   gt_coords_all = []

   for seq, frame in zip(seqs, frames):
      path = os.path.join(
         '/home/lea/trampo/MODELS_2D3D/Pose2Sim/pose-3d-vit-multi',seq, 'GT_8cam', f'{seq}.trc')
      gt_coords, gt_frames = extract_coordinates(path)

      frame_val = frame.item() if torch.is_tensor(frame) else frame
      mask = gt_frames == frame_val

      if not np.any(mask):
         # frame not found → create a NaN array of same shape as expected
         n_joints = 13
         nan_coords = np.full((n_joints, 3), np.nan, dtype=np.float32)
         gt_coords_all.append(nan_coords)
         continue

      # frame found
      gt_coords = gt_coords[mask, matching_pose2sim_vit, :]

      if rotate:
         gt_coords = (R @ gt_coords.T).T

      gt_coords_all.append(gt_coords)

   return gt_coords_all


### Train + validate

In [11]:
run_name = 'GTViT_03'
# Create a log directory (TensorBoard will read from this)
writer = SummaryWriter(log_dir=f"runs/{run_name}")
print("TensorBoard log dir:", writer.log_dir)

error_thresh = 500
person_dist_thresh = 500
torch.set_printoptions(precision=4, sci_mode=False)

os.makedirs(f"checkpoints/{run_name}", exist_ok=True)
os.makedirs(f"viz/{run_name}", exist_ok=True)

num_epochs = 5
global_step = 0
lambda_bones = 0.1
best_val_loss = float('inf')

seg_lengths = pd.read_pickle("seg_lengths.pkl")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    pose_model.train()

    for step, batch in enumerate(train_loader):
        start_total = time.perf_counter()
        optimizer.zero_grad(set_to_none=True)

        # --- 0. Data transfer ---
        t0 = time.perf_counter()
        images, crops, metas, Ks, Ts, seq, frames = batch.values()
        images = images.to(device, non_blocking=True)
        crops = crops.to(device, non_blocking=True)
        Ks = Ks.to(device, non_blocking=True)
        Ts = Ts.to(device, non_blocking=True)
        id = 1
        seg_l_id = seg_lengths.loc[seg_lengths['id'] == id].iloc[0, 1:].to_numpy()
        seg_l_tensor = torch.tensor(seg_l_id, dtype=torch.float32, device=device)
        seg_l_tensor = torch.cat([seg_l_tensor, seg_l_tensor])

        # --- 1. Get GT from Pose2Sim ---
        gt_coords = get_gt_coords(seq, frames, rotate=True)
        points_3d = torch.tensor(np.array(gt_coords), dtype=torch.float32, device=device)
        if torch.isnan(points_3d).all():
            del gt_coords, points_3d
            torch.cuda.empty_cache()
            continue

        # --- 2. Predict 2D keypoints ---
        keypoints = predict_multiview_with_grad(pose_model, crops, metas, device=device, training=True)
        
        # --- 3. Triangulate ---
        error, preds_2d, _, _ = find_best_triangulation(keypoints, Ks, Ts, error_thresh)
        if torch.isnan(preds_2d).all():
            del crops, Ks, Ts, keypoints, points_3d, preds_2d
            torch.cuda.empty_cache()
            continue
        
        # --- 4. Reprojection ---
        Rt = Ts[:, :, :3, :]
        P_all = Ks @ Rt
        reproj, valid_mask = project_points(points_3d, P_all)

        # --- 5. Filtering ---
        preds_valid = preds_2d[..., matching_vit, :][valid_mask]
        reproj_valid = reproj[valid_mask]
        dist = torch.norm(preds_valid - reproj_valid, dim=-1)

        keep_mask = dist < person_dist_thresh
        if keep_mask.sum() == 0:
            del images, Ks, Ts, detections, keypoints, preds_2d, points_3d, reproj
            torch.cuda.empty_cache()
            continue
        preds_valid = preds_valid[keep_mask]
        reproj_valid = reproj_valid[keep_mask]
        #print('pred:', preds_valid[0:5])
        #print('reproj:', reproj_valid[0:5])

        # --- 5. Losses ---
        loss = torch.sqrt(torch.nn.functional.mse_loss(preds_valid, reproj_valid))
        #loss_bones = bone_length_loss(points_3d, seg_l_tensor, map_seg_to_keypoints)
        #loss = loss_reproj + lambda_bones * loss_bones

        # --- 6. Backprop ---
        loss.backward()
        optimizer.step()

        if global_step % 1 == 0:
            #writer.add_scalar("Train/Loss_reproj", loss_reproj.item(), global_step)
            #writer.add_scalar("Train/Loss_bones", loss_bones.item(), global_step)
            writer.add_scalar("Train/Loss", loss.item(), global_step)
            writer.add_scalar("Train/Triangulation_error", torch.nanmean(error).item(), global_step)
        global_step += 1

        if step % 20 == 0:
            images_np = images.cpu().detach().numpy()[0]
            show_keypoints_on_im(images_np, preds_2d, reproj, f'viz/{run_name}/epoch{epoch+1}_it{step}', show=False)

        # Memory cleanup after step
        del Ks, Ts, keypoints, preds_2d, reproj, points_3d, preds_valid, reproj_valid, dist, keep_mask, loss
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

        # --- Timing summary ---
        """ total_time = t8 - start_total
        print(
            f"[Step {step}] "
            f"Data: {t1 - t0:.3f}s | "
            f"Predict: {t3 - t2:.3f}s | "
            f"Triangulate: {t4 - t3:.3f}s | "
            f"Reproject: {t5 - t4:.3f}s | "
            f"Filter: {t6 - t5:.3f}s | "
            f"Loss: {t7 - t6:.3f}s | "
            f"Backprop: {t8 - t7:.3f}s | "
            f"Total: {total_time:.3f}s"
        ) """

     # --- VALIDATION ---
    pose_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_batch in val_loader:
            images, crops, Ks, Ts, seq, frames, detections = val_batch.values()
            images = images.to(device, non_blocking=True)
            crops = crops.to(device, non_blocking=True)
            Ks = Ks.to(device, non_blocking=True)
            Ts = Ts.to(device, non_blocking=True)
            detections = detections.to(device, non_blocking=True)
            
            # --- 1. Predict 2D keypoints for detected views ---
            keypoints = predict_multiview_with_grad(pose_model, crops, metas, device=device, training=True)
            
            # --- 2. Triangulate (batched) ---
            error, preds_2d, points_3d, cams_on = find_best_triangulation(keypoints, Ks, Ts, error_thresh)
            if torch.isnan(points_3d).all():
                del images, Ks, Ts, detections, keypoints, points_3d
                torch.cuda.empty_cache()
                continue

            # --- 3. Reproject 3D back into each view ---
            Rt = Ts[:, :, :3, :]
            P_all = Ks @ Rt
            reproj, valid_mask = project_points(points_3d, P_all)
            preds_valid = preds_2d[valid_mask]
            reproj_valid = reproj[valid_mask]

            # --- 4. Remove mismatched persons (dist > 100) ---
            dist = torch.norm(preds_valid - reproj_valid, dim=-1)  # Euclidean distance
            keep_mask = dist < person_dist_thresh
            if keep_mask.sum() == 0:
                del images, Ks, Ts, detections, keypoints, preds_2d, reproj, points_3d
                torch.cuda.empty_cache()
                continue
            preds_valid = preds_valid[keep_mask]
            reproj_valid = reproj_valid[keep_mask]

            val_loss_batch = torch.sqrt(torch.nn.functional.mse_loss(preds_valid, reproj_valid))
            val_loss += val_loss_batch.item()

            # Memory cleanup after batch
            del images, Ks, Ts, detections, keypoints, preds_2d, reproj, points_3d, preds_valid, reproj_valid, dist, keep_mask, val_loss_batch
            torch.cuda.empty_cache()

    val_loss /= len(val_loader)
    writer.add_scalar("Val/Loss", val_loss, epoch)

    # --- Save checkpoint (optional: only best) ---
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': pose_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, f"checkpoints/{run_name}/best_model.pth")

    # Save checkpoint every epoch
    ckpt_path = f"checkpoints/{run_name}/epoch_{epoch+1}.pth"
    torch.save({ 
        'epoch': epoch + 1,
        'model_state_dict': pose_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': val_loss,
    }, ckpt_path)
    print(f"✅ Saved checkpoint: {ckpt_path}")

    # Memory cleanup after epoch
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

writer.close()

# TODO IDEA : per-camera loss for view-specific pose estimator

TensorBoard log dir: runs/GTViT_03

Epoch 1/5


KeyboardInterrupt: 

### Test

In [None]:
checkpoint = torch.load("checkpoints/GTViT_02/best_model.pth", map_location=device)
pose_model.load_state_dict(checkpoint['model_state_dict'])
pose_model.eval()

test_loss = 0.0
with torch.no_grad():
    for batch in test_loader:
        images, Ks, Ts, seq, frames, detections = batch.values()
        images = images.to(device, non_blocking=True)
        Ks = Ks.to(device, non_blocking=True)
        Ts = Ts.to(device, non_blocking=True)
        detections = detections.to(device, non_blocking=True)

        # --- 1. Predict 2D keypoints for detected views ---
        keypoints = predict_multiview_with_grad(
            det_model, pose_model, images, device=device,
            bbox_thr=0.3, training=True, detections=detections)
        
        # --- 2. Triangulate (batched) ---
        error, preds_2d, points_3d, cams_on = find_best_triangulation(keypoints, Ks, Ts, error_thresh)
        if torch.isnan(points_3d).all():
            continue

        # --- 3. Reproject 3D back into each view ---
        Rt = Ts[:, :, :3, :]
        P_all = Ks @ Rt
        reproj, valid_mask = project_points(points_3d, P_all)
        preds_valid = preds_2d[valid_mask]
        reproj_valid = reproj[valid_mask]

        # --- 4. Remove mismatched persons (dist > 100) ---
        dist = torch.norm(preds_valid - reproj_valid, dim=-1)  # Euclidean distance
        keep_mask = dist < person_dist_thresh
        if keep_mask.sum() == 0:
            continue
        preds_valid = preds_valid[keep_mask]
        reproj_valid = reproj_valid[keep_mask]

        loss_batch = torch.sqrt(torch.nn.functional.mse_loss(preds_valid, reproj_valid))
        test_loss += loss_batch.item()

test_loss /= len(test_loader)
print(f"✅ Final test RMSE: {test_loss:.4f}")


✅ Final test RMSE: 47.7418


### Show keypoints on images

In [None]:
b = 0

images_np = images.cpu().detach().numpy()[b]
pts_d = preds_2d.cpu().detach().numpy()[b]
pts_r = reproj.cpu().detach().numpy()[b]

print(images_np.shape, pts_d.shape, pts_r.shape)

show_keypoints_on_im(images_np, pts_d, pts_r, 'viz/epoch1', show=True)

(8, 3, 1080, 1920) (8, 17, 2) (8, 17, 2)


AttributeError: 'numpy.ndarray' object has no attribute 'cpu'

## Test detectors

In [None]:
import cv2
import matplotlib.patches as patches

def plot_bbox_on_image(image, bboxes):
    fig, ax = plt.subplots()
    ax.imshow(image)  # show image
    for bbox in bboxes:
        x, y = bbox[0], bbox[1]
        width = bbox[2] - bbox[0]
        height = bbox[3] - bbox[1]

        # Create a Rectangle patch
        rect = patches.Rectangle(
            (x, y), width, height,
            linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
    plt.show()

In [None]:
model = YOLO("yolo11l.pt")

path = '/home/lea/trampo/MODELS_2D3D/finetuning_multiview/dataset/1_partie_0429_003/Camera2_M11140'
im_files = os.listdir(path)

det = 0
for i, im_file in tqdm(enumerate(im_files)):
    im_name = os.path.join(path, im_file)
    results = model(im_name, classes=[0], conf=0.3, verbose=False)
    bboxes = results[0].boxes.xyxy
    det += len(bboxes)

    if i % 100 == 0:
        plot_bbox_on_image(cv2.imread(im_name), bboxes.cpu().numpy())
    """ 
    if len(bboxes) > 0:
        print(bboxes)

    for result in results:  # peut contenir plusieurs images
        boxes = result.boxes  # toutes les bbox
        for box in boxes:
            cls = int(box.cls[0])          # classe prédite (0 = person dans COCO)
            conf = float(box.conf[0])      # score de confiance
            xyxy = box.xyxy[0].tolist()    # coordonnées [x1, y1, x2, y2] """  

print(det)

## Versions

In [None]:
import mmdet, mmpose, mmengine, mmcv

print("Torch CUDA:", torch.version.cuda)
print("Torch device:", torch.cuda.get_device_name(0))
x = torch.ones(1, device='cuda')
print("CUDA test OK:", x.device)

print("MMCV:", mmcv.__version__)
print("MMDetection:", mmdet.__version__)
print("MMPose:", mmpose.__version__)
print("MMEngine:", mmengine.__version__)

#print(type(img_np), img_np.shape if hasattr(img_np, "shape") else None)
print(det_model.__class__)
print(det_model.cfg.test_dataloader.dataset.pipeline)