### Keypoint Augmentation

In [8]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import os
import re
from pathlib import Path
from tqdm import tqdm

def get_keypoint_coords(row):
    """Helper function to reshape coordinates into (N, K, 2) format"""
    coords = []
    for i in range(133):
        x = row[f'x{i}']
        y = row[f'y{i}']
        coords.append([x, y])
    return np.array(coords)

def transform_frame(row, flip=False, angle=0, scale=1.0):
    """
    Apply sequence of transformations to a single frame.
    """
    transformed = row.copy()
    coords = get_keypoint_coords(row)
    
    # Convert bbox to 4 corner points
    bbox_corners = np.array([
        [row['xmin'], row['ymin']],  # top-left
        [row['xmin'], row['ymax']],  # bottom-left
        [row['xmax'], row['ymax']],  # bottom-right
        [row['xmax'], row['ymin']]   # top-right
    ])
    
    # Get center points
    center = np.mean(coords, axis=0)
    bbox_center = np.mean(bbox_corners, axis=0)
    
    # Center coordinates
    centered_coords = coords - center
    centered_bbox = bbox_corners - bbox_center
    
    # 1. Apply flip
    if flip:
        centered_coords[:, 0] = -centered_coords[:, 0]
        centered_bbox[:, 0] = -centered_bbox[:, 0]
    
    # 2. Apply rotation
    if angle != 0:
        theta = np.radians(angle)
        rot_matrix = np.array([[np.cos(theta), -np.sin(theta)],
                             [np.sin(theta), np.cos(theta)]])
        centered_coords = np.dot(centered_coords, rot_matrix.T)
        centered_bbox = np.dot(centered_bbox, rot_matrix.T)
    
    # 3. Apply scaling
    if scale != 1.0:
        centered_coords *= scale
        centered_bbox *= scale
    
    # Move back to original position
    final_coords = centered_coords + center
    final_bbox = centered_bbox + bbox_center
    
    # Update keypoint coordinates
    for i in range(133):
        transformed[f'x{i}'] = final_coords[i, 0]
        transformed[f'y{i}'] = final_coords[i, 1]
    
    # Update bbox coordinates - find min/max from transformed corners
    transformed['xmin'] = np.min(final_bbox[:, 0])
    transformed['ymin'] = np.min(final_bbox[:, 1])
    transformed['xmax'] = np.max(final_bbox[:, 0])
    transformed['ymax'] = np.max(final_bbox[:, 1])
    
    return transformed

def augment_keypoints(df, flip=True, angle=15, scale=.7):
    """
    Augment keypoint data by applying consistent transformations to all frames.
    """
    augmented = df.copy()
    
    # Apply same transformations to each frame
    for idx, row in df.iterrows():
        augmented.iloc[idx] = transform_frame(row, flip=flip, angle=angle, scale=scale)
    
    return augmented

def plot_keypoints(orig_df, aug_df, n_samples=3):
    """
    Plot random frames from original and augmented data side by side.
    """
    # Select random frame indices
    random_frames = random.sample(range(len(orig_df)), n_samples)
    
    # Create subplots
    fig = plt.figure(figsize=(12, 4*n_samples))
    
    for i, frame_idx in enumerate(random_frames):
        # Plot original
        ax = plt.subplot(n_samples, 2, 2*i + 1)
        frame_data = orig_df.iloc[frame_idx]
        
        # Plot keypoints
        x_coords = [frame_data[f'x{k}'] for k in range(133)]
        y_coords = [frame_data[f'y{k}'] for k in range(133)]
        ax.scatter(x_coords, y_coords, c='blue', alpha=0.5, s=10)
        
        # Plot bounding box
        bbox = plt.Rectangle((frame_data['xmin'], frame_data['ymin']),
                           frame_data['xmax'] - frame_data['xmin'],
                           frame_data['ymax'] - frame_data['ymin'],
                           fill=False, color='red')
        ax.add_patch(bbox)
        ax.set_aspect('equal')
        ax.set_title(f'Original Frame {frame_idx}')
        
        # Set consistent limits
        all_x = x_coords + [frame_data['xmin'], frame_data['xmax']]
        all_y = y_coords + [frame_data['ymin'], frame_data['ymax']]
        margin = 0.1 * (max(all_x) - min(all_x))
        xlim = [min(all_x) - margin, max(all_x) + margin]
        ylim = [min(all_y) - margin, max(all_y) + margin]
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        
        # Plot augmented
        ax = plt.subplot(n_samples, 2, 2*i + 2)
        frame_data = aug_df.iloc[frame_idx]
        
        # Plot keypoints
        x_coords = [frame_data[f'x{k}'] for k in range(133)]
        y_coords = [frame_data[f'y{k}'] for k in range(133)]
        ax.scatter(x_coords, y_coords, c='blue', alpha=0.5, s=10)
        
        # Plot bounding box
        bbox = plt.Rectangle((frame_data['xmin'], frame_data['ymin']),
                           frame_data['xmax'] - frame_data['xmin'],
                           frame_data['ymax'] - frame_data['ymin'],
                           fill=False, color='red')
        ax.add_patch(bbox)
        ax.set_aspect('equal')
        ax.set_title(f'Augmented Frame {frame_idx}\nflip + rotate(15°) + scale(0.7)')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
    
    plt.tight_layout()
    plt.show()

def augment_and_save(input_file):
    """
    Read input CSV, perform augmentation, and save the results.
    """
    # Read the CSV file
    df = pd.read_csv(input_file)
    
    # Generate output filename by adding _aug before the extension
    output_file = input_file.rsplit('.', 1)[0] + '_aug.' + input_file.rsplit('.', 1)[1]
    
    # Apply augmentations
    augmented_df = augment_keypoints(df)
    
    # Save augmented data
    augmented_df.to_csv(output_file, index=False)
    
    print(f"Original samples: {len(df)}")
    print(f"Augmented samples: {len(augmented_df)}")
    print(f"Augmented data saved to: {output_file}")
    
    return df, augmented_df

def process_folder(folder_path):
    """
    Process all files in the folder matching the pattern XXXXXXXX_dbx.csv
    Skip already augmented files and any other files
    """
    folder = Path(folder_path)
    pattern = r'^[A-Za-z0-9]{8}_dbx\.csv$'
    
    # Statistics
    processed = 0
    skipped = 0
    errors = 0
    
    # Get list of eligible files first
    csv_files = [f for f in folder.glob('*.csv') 
                 if re.match(pattern, f.name) and 
                 not (f.parent / f"{f.stem}_aug.csv").exists()]
    
    # Use tqdm for the main processing loop
    with tqdm(total=len(csv_files), desc="Processing files", unit="file") as pbar:
        for file_path in csv_files:
            try:
                # Process the file
                original_df = pd.read_csv(str(file_path))
                augmented_df = augment_keypoints(original_df)
                
                # Generate output filename
                output_file = str(file_path.parent / f"{file_path.stem}_aug.csv")
                
                # Save augmented data
                augmented_df.to_csv(output_file, index=False)
                
                # Update statistics
                processed += 1
                
                # Update progress bar description with current file
                pbar.set_postfix_str(f"Current: {file_path.name}")
                pbar.update(1)
                
            except Exception as e:
                errors += 1
                tqdm.write(f"\nError processing {file_path.name}: {str(e)}")
    
    # Print summary after the progress bar is complete
    print("\nProcessing Summary:")
    print(f"Files processed: {processed}")
    print(f"Files skipped: {skipped}")
    print(f"Errors encountered: {errors}")

if __name__ == "__main__":
    # Use current directory as default, or specify your folder path
    folder_path = "data/pose"  
    process_folder(folder_path)

# Try single file
'''if __name__ == "__main__":
    input_file = "data/pose/7963TF00_dbx.csv"  # replace with your input file
    original_df, augmented_df = augment_and_save(input_file)
    plot_keypoints(original_df, augmented_df, n_samples=3)'''

Processing files: 100%|████████████████████████████████████| 73/73 [32:28<00:00, 26.69s/file, Current: 7972OT00_dbx.csv]


Processing Summary:
Files processed: 73
Files skipped: 0
Errors encountered: 0





'if __name__ == "__main__":\n    input_file = "data/pose/7963TF00_dbx.csv"  # replace with your input file\n    original_df, augmented_df = augment_and_save(input_file)\n    plot_keypoints(original_df, augmented_df, n_samples=3)'