## Dependencies

In [1]:
import os
os.chdir('../')
%pwd

'c:\\Projects\\python\\echoframe'

In [2]:
import os
import shutil
import pandas as pd
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset, DataLoader

# ----------------------------------------
# 1) Class to split data into train/val/test
# ----------------------------------------

class DatasetSplitter:
    """
    Splits raw dataset into ./data/train, ./data/val, ./data/test
    according to FileList.csv. Each row in FileList.csv is assumed to have:
        - 'FileName': e.g. '0X100009310A3BD7FC' (no .avi in the CSV)
        - 'Split': 'train' or 'val' or 'test'
    """
    def __init__(self, 
                 filelist_csv='./data/FileList.csv', 
                 source_dir='./data/raw_videos',
                 target_dir='./data'):
        """
        filelist_csv: path to FileList.csv
        source_dir: directory where raw video files currently reside
        target_dir: main directory where train/val/test subfolders will be created
        """
        self.filelist_csv = filelist_csv
        self.source_dir = source_dir
        self.target_dir = target_dir

        # Read CSV
        if not os.path.exists(self.filelist_csv):
            raise FileNotFoundError(f"FileList CSV not found: {self.filelist_csv}")
        self.df = pd.read_csv(self.filelist_csv)

        # Ensure mandatory columns exist
        for col in ['FileName', 'Split']:
            if col not in self.df.columns:
                raise ValueError(f"Missing required column '{col}' in {self.filelist_csv}")

    def split_and_store(self):
        """Creates train/val/test folders and copies each file into the correct folder."""
        splits = ['train', 'val', 'test']
        # Create subdirectories if they don't exist
        for sp in splits:
            split_dir = os.path.join(self.target_dir, sp)
            if not os.path.exists(split_dir):
                os.makedirs(split_dir)

        # Go through each row in CSV, copy the file to the appropriate split folder
        for idx, row in self.df.iterrows():
            filename_raw = str(row['FileName'])  # e.g. '0X100009310A3BD7FC'
            split = str(row['Split']).lower()
            if split not in splits:
                print(f"Skipping file {filename_raw} with unknown split {split}")
                continue

            # Ensure .avi extension in case the CSV has none
            if not filename_raw.lower().endswith('.avi'):
                filename_avi = filename_raw + '.avi'
            else:
                filename_avi = filename_raw

            src_path = os.path.join(self.source_dir, filename_avi)
            dst_path = os.path.join(self.target_dir, split, filename_avi)

            if not os.path.exists(src_path):
                print(f"Source file not found: {src_path}, skipping.")
                continue

            shutil.copy2(src_path, dst_path)

        print("Data splitting & storing complete.")


# ----------------------------------------
# 2) Dataset class for frames + binary mask
# ----------------------------------------

class EchoVolumeDataset(Dataset):
    """
    Reads frames from videos stored under ./data/<split> 
    and builds a binary mask from VolumeTracings.csv for each frame.

    NOTE:
      - In FileList.csv, the 'FileName' may be without '.avi'.
      - In VolumeTracings.csv, the 'FileName' includes '.avi'.
      - After we split/copy, the actual video files in ./data/<split> will end with '.avi'.
      - We unify everything by removing extensions (os.path.splitext) when building keys.
    """
    def __init__(self,
                 split='train',
                 data_dir='./data',
                 volume_csv='./data/VolumeTracings.csv',
                 resize=(112, 112),
                 mean=(0.0, 0.0, 0.0),
                 std=(1.0, 1.0, 1.0)):
        """
        Args:
            split: 'train', 'val', or 'test'
            data_dir: base directory holding subdirectories 'train', 'val', 'test'
            volume_csv: path to VolumeTracings.csv
            resize: tuple (height, width) for resizing frame/mask
            mean, std: used to normalize frames
        """
        super().__init__()
        self.split = split
        self.data_dir = data_dir
        self.volume_csv = volume_csv
        self.resize = resize
        self.mean = mean
        self.std = std

        # 1) Gather a list of .avi (or .mp4) under data_dir/split
        self.video_dir = os.path.join(self.data_dir, self.split)
        if not os.path.exists(self.video_dir):
            raise FileNotFoundError(f"Directory not found: {self.video_dir}")

        self.video_files = sorted([
            f for f in os.listdir(self.video_dir)
            if f.lower().endswith('.avi') or f.lower().endswith('.mp4')
        ])

        # 2) Read the volume tracings CSV
        if not os.path.exists(self.volume_csv):
            raise FileNotFoundError(f"VolumeTracings CSV not found: {self.volume_csv}")
        self.tracings_df = pd.read_csv(self.volume_csv)

        # Build a dictionary of polygons keyed by (base_name, frame_idx)
        self.polygons_dict = self._build_polygons_dict()

        # Create an index mapping so each item is (video_file, frame_idx)
        self.index_map = []
        for vid in self.video_files:
            video_path = os.path.join(self.video_dir, vid)
            frame_count = self._get_num_frames(video_path)
            for frame_idx in range(frame_count):
                # Example: ("0X100009310A3BD7FC.avi", 0), ...
                self.index_map.append((vid, frame_idx))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        video_name, frame_idx = self.index_map[idx]
        video_path = os.path.join(self.video_dir, video_name)

        # Load that single frame
        frame = self._read_frame(video_path, frame_idx)
        if frame is None:
            frame = np.zeros((self.resize[0], self.resize[1], 3), dtype=np.float32)

        # Create the mask for this frame
        mask = self._create_mask(video_name, frame_idx, frame.shape[:2])

        # Convert frame (H,W,3) -> (3,H,W), normalize
        frame = frame.astype(np.float32) / 255.0
        for c in range(3):
            frame[:, :, c] = (frame[:, :, c] - self.mean[c]) / self.std[c]
        frame = np.transpose(frame, (2,0,1))  # => (3, H, W)

        # Mask => (1, H, W)
        mask = np.expand_dims(mask, axis=0).astype(np.float32)

        # Convert to torch tensors
        frame_tensor = torch.from_numpy(frame)
        mask_tensor = torch.from_numpy(mask)
        return frame_tensor, mask_tensor

    # ---------------------------------------------
    # Utility methods
    # ---------------------------------------------

    def _build_polygons_dict(self):
        """
        Parses VolumeTracings.csv to build a dict keyed by (base_name, frame_number),
        each mapping to a list of polygons (each polygon is a list of (x,y) points).
        This unifies .avi vs. no extension by always stripping the file extension.
        """
        polygons_dict = {}

        for _, row in self.tracings_df.iterrows():
            raw_file = str(row['FileName'])    # e.g. '0X100009310A3BD7FC.avi'
            frame_num = int(row['Frame'])      # e.g. 10
            # Extract polygon coordinates (x1, y1, x2, y2, ...)
            coords = []
            cols = row.index.tolist()
            # Skip the first 2 columns (FileName, Frame)
            for col_idx in range(2, len(cols), 2):
                x_col = cols[col_idx]
                y_col = cols[col_idx+1]
                if pd.isna(row[x_col]) or pd.isna(row[y_col]):
                    break
                x_val = float(row[x_col])
                y_val = float(row[y_col])
                coords.append((x_val, y_val))

            # Remove extension so the key is consistent with the dataset video name
            base_name = os.path.splitext(raw_file)[0]  # => '0X100009310A3BD7FC'
            key = (base_name, frame_num)

            if key not in polygons_dict:
                polygons_dict[key] = []
            polygons_dict[key].append(coords)

        return polygons_dict

    def _get_num_frames(self, video_path):
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            return 0
        length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        cap.release()
        return length

    def _read_frame(self, video_path, frame_idx):
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            return None
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        cap.release()
        if not ret or frame is None:
            return None

        # Convert BGR -> RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        # Resize
        frame = cv2.resize(frame, self.resize, interpolation=cv2.INTER_AREA)
        return frame

    def _create_mask(self, video_name, frame_idx, hw_shape):
        """
        Creates a binary mask using polygon fill. 
        'video_name' is something like '0X100009310A3BD7FC.avi',
        so we strip the extension to get the dictionary key.
        """
        base_name = os.path.splitext(video_name)[0]  # e.g. '0X100009310A3BD7FC'
        key = (base_name, frame_idx)

        mask = np.zeros(hw_shape, dtype=np.uint8)
        if key not in self.polygons_dict:
            return mask  # blank

        # Fill each polygon
        for polygon_coords in self.polygons_dict[key]:
            pts = np.array(polygon_coords, dtype=np.int32).reshape((-1,1,2))
            cv2.fillPoly(mask, [pts], 255)

        # Convert 0..255 to 0..1
        mask = (mask > 127).astype(np.uint8)
        return mask

# ----------------------------------------
# 3) Similar approach for test data
# ----------------------------------------
# If your test data also has filenames that differ in extension, 
# the code below still works because we unify them with os.path.splitext.

class EchoVolumeTestDataset(EchoVolumeDataset):
    def __init__(self, data_dir='./data', volume_csv='./data/VolumeTracings.csv',
                 resize=(112,112), mean=(0,0,0), std=(1,1,1)):
        super().__init__(
            split='test',
            data_dir=data_dir,
            volume_csv=volume_csv,
            resize=resize,
            mean=mean,
            std=std
        )



In [None]:

# # 1) Split the dataset according to FileList.csv
# splitter = DatasetSplitter(
#     filelist_csv='.\data\EchoNet-Dynamic\EchoNet-Dynamic\FileList.csv',
#     source_dir='.\data\EchoNet-Dynamic\EchoNet-Dynamic\Videos',  # location of original .avi or .mp4 files
#     target_dir='./data'
# )
# splitter.split_and_store()

# 2) Create a dataset for 'train'
train_dataset = EchoVolumeDataset(
    split='train',
    data_dir='./data',
    volume_csv='.\data\EchoNet-Dynamic\EchoNet-Dynamic\VolumeTracings.csv',
    resize=(112, 112),
    mean=(0.0, 0.0, 0.0),
    std=(1.0, 1.0, 1.0)
)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

print("Train dataset size:", len(train_dataset))
# Inspect a single batch
for frames, masks in train_loader:
    # frames shape: (B, 3, H, W)
    # masks shape:  (B, 1, H, W)
    print("Frames batch shape:", frames.shape)
    print("Masks batch shape:", masks.shape)
    break

# 3) Test dataset
test_dataset = EchoVolumeTestDataset(
    data_dir='./data',
    volume_csv='.\data\EchoNet-Dynamic\EchoNet-Dynamic\VolumeTracings.csv',
    resize=(112, 112),
    mean=(0.0, 0.0, 0.0),
    std=(1.0, 1.0, 1.0)
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
print("Test dataset size:", len(test_dataset))

# Example iteration over test data
for frames, masks in test_loader:
    # frames shape: (1, 3, H, W)
    # masks shape:  (1, 1, H, W)
    print("Test frame shape:", frames.shape)
    print("Test mask shape:", masks.shape)
    # Typically you'd run inference on these frames
    break


Data splitting & storing complete.
Train dataset size: 1315340
Frames batch shape: torch.Size([2, 3, 112, 112])
Masks batch shape: torch.Size([2, 1, 112, 112])
Test dataset size: 226460
Test frame shape: torch.Size([1, 3, 112, 112])
Test mask shape: torch.Size([1, 1, 112, 112])
