Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Video Features, for example How2Sign #1359

Open
kerolos opened this issue Jun 23, 2024 · 5 comments
Open

Support for Video Features, for example How2Sign #1359

kerolos opened this issue Jun 23, 2024 · 5 comments

Comments

@kerolos
Copy link

kerolos commented Jun 23, 2024

Extend Lhotse to support video features for tasks such as sign language recognition (e.g., How2Sign) and human activity recognition. This enhancement will be useful for the Icefall platform.

Details

With the recent support for video in PR #1151, I am interested in developing a new recipe to handle video data and extract features using tools like MediaPipe.

Objectives

  1. Recipe Addition:

    • Add a new recipe that supports video data in the lhotse/recipes directory.
  2. Feature Extraction:

Implementation Steps

  1. Create Manifest Files:

    • Recordings manifest (recordings.jsonl):

      {
        "id": "-fZc293MpJk_0-1-rgb_front",
        "sources": [
          {
            "type": "file",
            "channels": [0],
            "source": "/mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/-fZc293MpJk_0-1-rgb_front.mp4"
          }
        ],
        "sampling_rate": 24,
        "num_samples": 17,
        "duration": 6.53,
        "end": 6.53,
        "channel_ids": [0],
        "feature_path": "/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/original/raw_features/-fZc293MpJk_0-1-rgb_front.txt"
      }
    • Supervisions manifest (supervisions.jsonl):

      {
        "id": "-fZc293MpJk_0-1-rgb_front",
        "recording_id": "-fZc293MpJk_0-1-rgb_front",
        "start": 0.0,
        "end": 6.53,
        "duration": 6.53,
        "channel": 0,
        "text": "hi",
        "speaker": "-fZc293MpJk"
      }
  2. Feature Extraction Script:

    • Create a script compute_features_sign_language.py:
      import argparse
      import logging
      import os
      from pathlib import Path
      
      import torch
      import numpy as np
      from lhotse import CutSet, NumpyFilesWriter, load_manifest_lazy
      from tqdm import tqdm
      
      # Set the number of threads for torch to avoid performance issues
      torch.set_num_threads(1)
      torch.set_num_interop_threads(1)
      
      def get_args():
          parser = argparse.ArgumentParser(
              description="This script creates ssl features file for sign language dataset"
          )
          parser.add_argument(
              "--src-dir",
              type=str,
              help="Path to the data source",
          )
          parser.add_argument(
              "--output-dir",
              type=str,
              help="Output directory",
          )
          parser.add_argument(
              "--feature-dim",
              type=int,
              default=1662,
              help="Dimension of the feature vectors",
          )
          return parser.parse_args()
      
      def load_raw_features(feature_path, feature_dim):
          with open(feature_path, 'r') as f:
              raw_features = np.loadtxt(f)
          return raw_features.reshape(-1, feature_dim)
      
      def compute_sign_language_features(src_dir, output_dir, feature_dim):
          src_dir = Path(src_dir)
          output_dir = Path(output_dir)
          output_dir.mkdir(parents=True, exist_ok=True)
      
          recordings_manifest = load_manifest_lazy(src_dir / 'recordings.jsonl.gz')
          supervisions_manifest = load_manifest_lazy(src_dir / 'supervisions.jsonl.gz')
      
         #i am not sure how can be implemented 
          with tqdm(total=len(recordings_manifest)) as pbar:
              for recording in recordings_manifest:
                  feature_path = recording["feature_path"]
                  if os.path.exists(feature_path):
                      features = load_raw_features(feature_path, feature_dim)
                  else:
                      # Here we should include the actual feature extraction logic if needed
                      raise FileNotFoundError(f"Feature file not found: {feature_path}")
      
                  output_file = output_dir / f"{recording['id']}.npy"
                  np.save(output_file, features)
                  pbar.update(1)
      
      def main():
          logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s')
      
          args = get_args()
          src_dir = Path(args.src_dir)
          output_dir = Path(args.output_dir)
      
          compute_sign_language_features(src_dir, output_dir, args.feature_dim)
      
      if __name__ == "__main__":
          main()

Questions

  1. Is there a plan to add a recipe that supports video data in Lhotse?
  2. How can I start using customized features, for example, using MediaPipe Pose Estimation tools?
  3. What format should be used to save the extracted features (i.has_features ) and saved as features
    and load them later for training ?
    4- i have also Frames per second, it is not always fixed it in between 24 fps to 50 fps , how can i deal with that ?

I would appreciate any guidance or support on implementing this feature and utilizing it within the Icefall platform @pzelasko .

Thank you!

@pzelasko
Copy link
Collaborator

Hi @kerolos, thanks for opening this discussion! I can help you get your video recipe set up.

Is there a plan to add a recipe that supports video data in Lhotse?

There is one AV recipe currently for GRID AV corpus: https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/grid.py
In general, lhotse recipes download and prepare the manifests for datasets, but actual training is out of lhotse's scope. You may want to set up a separate repository with your experiment's code that imports lhotse.

How can I start using customized features, for example, using MediaPipe Pose Estimation tools?

Once you create a recording, you can load the video, process it with some module, and save + attach as a custom field to the cut. For example:

video_recording = Recording.from_file("/path/to/-fZc293MpJk_0-1-rgb_front.mp4")  # lhotse will auto-construct video recording manifest
video_cut = video_recording.to_cut()

video_frames = video_cut.load_video()  # video frames is a uint8 np.array with shape (T, C, H, W) [or some other permutation, I don't remember off the top of my head]
video_features = compute_some_features(video_frames)  # video_features is np.array with arbitrary shape


# Option 1 -> save to some storage directly
# temporal_dim indicates which dimension in video_features shape corresponds to time; set accordingly.
with NumpyHdf5Writer("video_features.h5") as writer:
    video_cut.video_features = writer.store_array(video_cut.id, video_features, frame_shift=video_recording.video.fps, temporal_dim=0)

# Option 2 -> holds data in memory, write to some storage later (useful if you're going to use Lhotse Shar format):
video_cut = video_cut.attach_tensor("video_features", video_features, frame_shift=video_recording.video.fps, temporal_dim=0)  

If you save the final video_cut, you can then later load video_features with cut.load_video_features() and access the manifest via cut.video_features (special field and method are auto-added for custom fields registered via attach_tensor). You can compute many different features and attach all of them under different names.

What format should be used to save the extracted features (i.has_features ) and saved as features
and load them later for training ?

I would use one of numpy format writers in lhotse (e.g. NumpyHdf5Writer in the example above). Don't use lilcom unless you are sure it makes sense (it is a lossy format optimized for log-domain features). You may also want to explore lhotse shar format which I think should work with video recordings (and definitely works with video features extracted as above). It is better optimized for I/O which might help you process large video data in training.

That said, video features would likely require better compression for very large datasets, which is something we can explore later.

i have also Frames per second, it is not always fixed it in between 24 fps to 50 fps , how can i deal with that ?

You can access the fps via recording.video.fps or cut.video.fps. If you want to resample the video, you have two options: 1) either load the whole thing / cut of a given duration, and downsample/resample then in python; 2) leverage torchaudio ffmpeg bindings to resample the video (you might need to check out their tutorials to learn how to pass specific ffmpeg transform commands and find a way to expose/add it in AudioSource API. For reference, this code loads the video)

Final comment, Recording manifest doesn't support custom fields, so you'd be better off moving feature_path key to supervision as {..., "custom": {"feature_path": ...}}

@kerolos
Copy link
Author

kerolos commented Jul 13, 2024

In the beginning, I tried to use to load the video mp4 format (this mp4 does not have an audio form sign language dataset) :

recording = Recording.from_file("test_rgb_front_clips/raw_videos/_fZbAxSSbX4_0-5-rgb_front.mp4")

I got this error :

  File "/mnt/HD_8TB/training/_icefall_script_/SignRcg/local/_tmp_/readh54test.py", line 76, in <module>
    video_recording = Recording.from_file("/mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/_fZbAxSSbX4_0-5-rgb_front.mp4")  # lhotse will auto-construct video recording manifest
  File "/home/kerolos/anaconda3/envs/icefall-run/lib/python3.8/site-packages/lhotse/audio/recording.py", line 200, in from_file
    audio_info = info(
  File "/home/kerolos/anaconda3/envs/icefall-run/lib/python3.8/site-packages/lhotse/audio/backend.py", line 1494, in info
    return get_current_audio_backend().info(
  File "/home/kerolos/anaconda3/envs/icefall-run/lib/python3.8/site-packages/lhotse/audio/backend.py", line 750, in info
    raise AudioLoadingError(
lhotse.audio.utils.AudioLoadingError: Fetching info about audio from '/mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/_fZbAxSSbX4_0-5-rgb_front.mp4' failed. Details:
Exception #0 (<class 'lhotse.audio.backend.LibsndfileBackend'>): <class 'soundfile.LibsndfileError'>: Error opening '/mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/_fZbAxSSbX4_0-5-rgb_front.mp4': Format not recognised.
Exception #1 (<class 'lhotse.audio.backend.TorchaudioDefaultBackend'>): <class 'RuntimeError'>: Failed to fetch metadata from /mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/_fZbAxSSbX4_0-5-rgb_front.mp4
Set LHOTSE_AUDIO_LOADING_EXCEPTION_VERBOSE=1 environment variable for full stack traces.
  • How can avoid this error from the class Recording ?
  • How can I escape the alignments in load_video() function " method that loads video + audio (and keeps them in sync duration-wise)"

@kerolos
Copy link
Author

kerolos commented Jul 14, 2024

I would like to use the "Lhotse SHAR format" to save SHAR files from manifests jsonl. I have two options:

  1. Save cuts from recordings and supervisions manifests (RecordingSet.from_jsonl, and SupervisionSet.from_jsonl)
  2. Save video features in the Lhotse SHAR format (load_manifest_lazy)
  • Option 1: Saving Cuts from Recordings and Supervisions Manifests (from_jsonl):
from lhotse import RecordingSet, SupervisionSet, CutSet
from lhotse.shar import SharWriter

output_dir = "./data-shar"
recordings_manifest = src_dir / 'recordings.jsonl'
supervisions_manifest = src_dir / 'supervisions.jsonl'

recordings = RecordingSet.from_jsonl(recordings_manifest)
supervisions = SupervisionSet.from_jsonl(supervisions_manifest)
cuts = CutSet.from_manifests(recordings, supervisions).trim_to_supervisions()

try:
    shards = cuts.to_shar(output_dir, fields={"recording": "mp4"}, shard_size=15)
except AssertionError as e:
    print(f"Error: {e}")

Error:

AssertionError: Unknown field type (got: 'mp4', we support only: wav, flac, mp3, opus, lilcom, numpy, jsonl)
  • Option 2: Saving Video Features in Lhotse SHAR Format:
from lhotse import load_manifest_lazy, Recording
from lhotse.shar import ArrayTarWriter
import cv2
import logging
from tqdm import tqdm
import mediapipe as mp

output_dir = "./data-shar"
recordings_manifest_path = src_dir / 'recordings.jsonl'
supervisions_manifest_path = src_dir / 'supervisions.jsonl'

recordings_manifest = load_manifest_lazy(recordings_manifest_path)
supervisions_manifest = load_manifest_lazy(supervisions_manifest_path)

tar_path = output_dir / "video_features.%06d.tar"

with ArrayTarWriter(tar_path, shard_size=15) as writer, tqdm(total=len(recordings_manifest)) as pbar, mp.solutions.holistic.Holistic(
        static_image_mode=False, model_complexity=0, min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
    for recording in recordings_manifest:
        try:
            video_recording = Recording.from_dict(recording.to_dict())
            video_cut = video_recording.to_cut()
            video_frames = video_cut.load_video()
            video_path = video_recording.sources[0].source  # Get the video path from sources
            logging.info(f"Loading video frames from {video_path}")

            # Get FPS using OpenCV
            cap = cv2.VideoCapture(video_path)
            fps = cap.get(cv2.CAP_PROP_FPS)
            cap.release()
            video_features = extract_features_from_video(video_path, holistic)
            if video_features is None:
                logging.error(f"Failed to load video frames for recording ID: {video_recording.id}, video path: {video_path}")
                continue

            # Attach features to video_cut
            video_cut = video_cut.attach_tensor("video_features", video_features, frame_shift=float(1.0 / fps), temporal_dim=0)

            # Store the features using ArrayTarWriter
            writer.write(video_cut.id, video_features, video_cut.video_features)

        except Exception as e:
            logging.error(f"Error processing recording ID {recording.id}: {e}")
        pbar.update(1)

I can save features video_features.000000.tar (inside this folder for each video has two files -fZc293MpJk_0-1-rgb_front.json , and -fZc293MpJk_0-1-rgb_front.npy)
and josn file looks like that :
{"array": {"storage_type": "shar", "storage_path": "", "storage_key": "", "shape": [17, 1662]}, "temporal_dim": 0, "frame_shift": 0.02, "start": 0}
futhermore i checked npy file it has the correct dimention.

Hint: i have not compressed with "lilcom" in ArrayTarWriter and also not saved
feature_shards = writer.output_paths

  • Questions
  1. How can I save the missing parts in Lhotse SHAR format for both options "Cuts (e.g., cuts.000000.jsonl.gz), Recordings (e.g., recording.000000.tar) ?
  2. Is there any plan to support mp4 for cuts.to_shar (/lhotse/cut)
    /set.py) ?

I also want to be able to use the from_shar function and later training DataLoader with Lhotse Shar:

cuts_nodata = CutSet.from_shar(fields={"cuts": shards["cuts"]})
or
cuts = CutSet.from_shar(
    fields={
        "cuts": shards["cuts"],
        "recording": shards["recording"],
    },

In this tutorial (examples: 04-lhotse-shar.ipynb) Implementation note: the use of IterableDataset:
It has been used the features from fbank on the fly without reading the feature from fields ""fbank": feature_shards," .llc or use another array file .npy like in my case ?

How the code be modified in the way to read the existed features from shads "feature_shards" in this DynamicBucketingSampler not extracting a new one from shards recording in (Implementation note: the use of IterableDataset session ) ?

Thanks in advance @pzelasko

@pzelasko
Copy link
Collaborator

  • How can avoid this error from the class Recording ?

Video loading features depend on you having a recent version of pytorch, torchaudio, and compatible ffmpeg version to load videos. Based on the call stack I think maybe you don't have this backend available. Try updating your torch/torchaudio and setting the env var export LHOTSE_AUDIO_BACKEND= FfmpegTorchaudioStreamerBackend to force torchaudio backend for this.

  • How can I escape the alignments in load_video() function " method that loads video + audio (and keeps them in sync duration-wise)"

Try using with_audio=False arg for https://github.com/lhotse-speech/lhotse/blob/master/lhotse/audio/recording.py#L479C9-L479C33

@pzelasko
Copy link
Collaborator

As for your other question:

Option 1: Saving Cuts from Recordings and Supervisions Manifests (from_jsonl):

Yeah we'll need to add mp4 support for AudioTarWriter. I don't have the bandwidth for this right now but I can give help you get started. First we'll need to add save_audio function that can actually save both audio and video, to torchaudio ffmpeg streamer backend here. You can check the read_audio function and implement save_audio analogously to support the same set of features which I hope would be straightforward. Then, we'll need to register the mp4 format in shar writers, here and here. I think this is sufficient to get it working.

How can I save the missing parts in Lhotse SHAR format for both options "Cuts (e.g., cuts.000000.jsonl.gz), Recordings (e.g., recording.000000.tar) ?

You have two options. There is a high-level utility cuts.to_shar() which uses LazySharWriter that saves all the fields together (you have to specify which fields and which format, e.g. {"recording": "mp4", "video_features": "numpy"}). You can use it to write only the cuts first, and then use your existing option 2 code to save video features next to cuts. It will create a valid Shar directory. Exporting recordings to mp4 is discussed above.

How the code be modified in the way to read the existed features from shads "feature_shards" in this DynamicBucketingSampler not extracting a new one from shards recording in (Implementation note: the use of IterableDataset session ) ?

I think what you want is, after executing the suggestions before, this:

cuts = CutSet.from_shar(
    fields={
        "cuts": shards["cuts"],
        "video_features": shards["video_features"],
    },
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants