Skip to content

Commit

Permalink
Merge pull request #33 from johndpope/feat/32-cycle-consistency
Browse files Browse the repository at this point in the history
DRAFT - Feat/32 cycle consistency
  • Loading branch information
johndpope committed Jun 4, 2024
2 parents ace2981 + 9eb3a11 commit ff9cf22
Show file tree
Hide file tree
Showing 26 changed files with 6,022 additions and 271 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@ models/gaze_model_pytorch_vgg16_prl_mpii_allsubjects1.model
*.dat
*.pth
# output_images/*.*
*.png
junk/-2KGPYEFnsU_11_nobg.mp4
*.png
*.png
*.npz
108 changes: 108 additions & 0 deletions EmoDataset.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
## EMODataset Class Summary

### Overview
The `EMODataset` class is a PyTorch dataset for processing and augmenting video frames, with functionalities to remove backgrounds, warp and crop faces, and save/load processed frames efficiently. The class is designed to handle large video datasets and includes methods to streamline the preprocessing pipeline.

### Dependencies
The class relies on the following libraries:
- `moviepy.editor`: Video editing and processing.
- `PIL.Image`: Image processing.
- `torch`: PyTorch for tensor operations and model support.
- `torchvision.transforms`: Image transformations.
- `decord`: Efficient video reading.
- `rembg`: Background removal.
- `face_recognition`: Face detection.
- `skimage.transform`: Image warping.
- `cv2`: Video writing with OpenCV.
- `numpy`: Array operations.
- `io`, `os`, `json`, `Path`, `subprocess`, `tqdm`: Standard libraries for file handling, I/O operations, and progress visualization.

### Initialization
The `__init__` method sets up the dataset with various parameters:
- `use_gpu`, `sample_rate`, `n_sample_frames`, `width`, `height`, `img_scale`, `img_ratio`, `video_dir`, `drop_ratio`, `json_file`, `stage`, `transform`, `remove_background`, `use_greenscreen`, `apply_crop_warping`
- Loads video metadata from the provided JSON file.
- Initializes decord for video reading with PyTorch tensor output.

### Methods

#### `__len__`
Returns the length of the dataset, determined by the number of video IDs.

#### `warp_and_crop_face`
Processes an image tensor to detect, warp, and crop the face region:
- Converts tensor to PIL image.
- Removes background.
- Detects face locations.
- Crops the face region.
- Optionally applies thin-plate-spline warping.
- Converts the processed image back to a tensor and returns it.

#### `load_and_process_video`
Loads and processes video frames:
- Checks if processed tensor file exists; if so, loads tensors.
- If not, processes video frames, applies augmentation, and saves frames as PNG images and tensors.
- Saves processed tensors as compressed numpy arrays for efficient loading.

#### `augmentation`
Applies transformations and optional background removal to the provided images:
- Supports both single images and lists of images.
- Returns transformed tensors.

#### `remove_bg`
Removes the background from the provided image using `rembg`:
- Optionally applies a green screen background.
- Converts image to RGB format and returns it.

#### `save_video`
Saves a list of frames as a video file:
- Uses OpenCV to write frames to a video file.

#### `process_video`
Processes all frames of a video:
- Uses the `process_video_frames` method to process frames.

#### `process_video_frames`
Processes frames of a video using decord:
- Reads frames using decord and applies augmentation.
- Returns processed frames.

#### `__getitem__`
Returns a sample from the dataset:
- Loads and processes source and driving videos.
- Returns a dictionary containing video IDs and frames.

### Usage
To use the `EMODataset` class:
1. Initialize the dataset with appropriate parameters.
2. Use PyTorch DataLoader to iterate over the dataset and retrieve samples.
3. Process the frames as needed for training or inference in a machine learning model.

### Example
```python
from torchvision import transforms

transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
])

dataset = EMODataset(
use_gpu=False,
sample_rate=5,
n_sample_frames=16,
width=512,
height=512,
img_scale=(0.9, 1.0),
video_dir="path/to/videos",
json_file="path/to/metadata.json",
transform=transform,
remove_background=True,
use_greenscreen=False,
apply_crop_warping=True
)

for sample in dataset:
print(sample)
```

This class provides a comprehensive pipeline for processing video data, making it suitable for tasks such as training deep learning models on video datasets.
Loading

0 comments on commit ff9cf22

Please sign in to comment.