# Annolid Behavior Video Classification on SlowFast Tutorial

This Colab contains a tutorial on how to perform behavior video classification using the Annolid library and a SlowFast model.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/healthonrails/annolid/blob/main/docs/tutorials/Annolid_behavior_video_classification_on_slowfast.ipynb)
# SlowFast

*Author: FAIR PyTorchVideo*

Modified from: https://pytorch.org/hub/facebookresearch_pytorchvideo_slowfast/

**SlowFast networks pretrained fine-tuning on the custom behavior dataset**

#### Install required packages


In [None]:
!pip install fvcore

In [None]:
!pip install pyav

In [None]:
import torch
import torch.nn as nn
from torch.hub import load

# Load the pre-trained SlowFast model
model = load("facebookresearch/pytorchvideo", "slowfast_r50", pretrained=True)

Import remaining functions:

In [None]:
import json
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import (
    CenterCropVideo,
    NormalizeVideo,
)
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

#### Setup

Set the model to eval mode and move to desired device.

In [None]:
# Set to GPU or CPU
device = "cuda"
model = model.eval()
model = model.to(device)

In [None]:
!gdown --id 15fnvK0KS9rQdoAB1yK2kw5WFVws_BlW5

In [None]:
!unzip behavior_videos.zip

In [None]:
import pandas as pd

video_annotation = "behaivor_videos/amygdala control/CaMKII 1-19 L 2-19-21-Phase 3.csv"
df_anno = pd.read_csv(video_annotation)
df_anno[df_anno.Behavior == "Grooming"]

In [None]:
behaviors = df_anno.Behavior.unique().tolist()

In [None]:
behaviors = behaviors + [
    "FP Mouse Mounting Stimulus Mouse",
    "FP Mouse Snigging Excretions",
    "FP Mouse Tail Rattling",
    "Others",
]

In [None]:
behaviors

In [None]:
import os
import random
import csv
import itertools
from collections import defaultdict
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
from moviepy.editor import VideoFileClip

# Constants for file paths and split ratio
BASE_FOLDER = "behaivor_videos"
OUTPUT_VIDEO_FOLDER = "behavior_video_clips"
TRAIN_JSONL_PATH = "train_video_annotations.jsonl"
TEST_JSONL_PATH = "test_video_annotations.jsonl"
TRAIN_SPLIT_RATIO = 0.97

# Parameters for SlowFast data preparation
SLOW_FPS_REDUCTION_FACTOR = 4  # Factor by which slow pathway FPS is reduced
TARGET_FPS = 30  # Assuming original video FPS is around 30, adjust if needed

# Ensure output folder exists
os.makedirs(OUTPUT_VIDEO_FOLDER, exist_ok=True)


def extract_video_segment(video_file_path, start_time, end_time, output_path):
    """Extracts a video segment from a file between start and end times."""
    ffmpeg_extract_subclip(
        video_file_path, start_time, end_time, targetname=output_path
    )


def create_annotation_entry(
    behavior,
    video_segment_path_slow,
    video_segment_path_fast,
    prompt="<video> What is the behavior in the video?",
):
    """Creates a JSONL entry for a video segment, storing paths for both slow and fast pathways."""
    return {
        "query": prompt,
        "response": behavior,
        "videos": [video_segment_path_slow, video_segment_path_fast],
    }


def parse_behavior_events(csv_path):
    """Parses start and stop events from a behavior CSV file."""
    start_events, stop_events = [], []
    with open(csv_path, "r") as csv_file:
        reader = csv.DictReader(csv_file)
        for row in reader:
            time = float(row["Recording time"])
            behavior = row["Behavior"]
            event = row["Event"]
            if event == "state start":
                start_events.append({"time": time, "behavior": behavior})
            elif event == "state stop":
                stop_events.append({"time": time, "behavior": behavior})
    return start_events, stop_events


def find_gaps(start_events, stop_events, video_duration, gap_duration):
    """Finds gaps between behavior events to sample as 'Others'."""
    gaps = []
    last_end_time = 0

    for start_event in start_events:
        start_time = start_event["time"]
        if start_time - last_end_time >= gap_duration:
            gaps.append((last_end_time, start_time))

        matching_stop = next(
            (
                s
                for s in stop_events
                if s["behavior"] == start_event["behavior"] and s["time"] > start_time
            ),
            None,
        )
        if matching_stop:
            last_end_time = matching_stop["time"]
            stop_events.remove(matching_stop)

    if video_duration - last_end_time >= gap_duration:
        gaps.append((last_end_time, video_duration))

    return gaps


def sample_limited_segments_from_gaps(
    gaps, video_file_path, video_name, gap_duration, max_count, behavior_label="Others"
):
    """Samples segments from the gaps with a limit on the number of segments for 'Others' behavior, creating slow and fast versions."""
    entries = []
    for start, end in gaps:
        num_segments = int((end - start) // gap_duration)
        for i in range(min(num_segments, max_count - len(entries))):
            segment_start = start + i * gap_duration
            segment_end = segment_start + gap_duration

            # Create paths for slow and fast segments (both are initially the same temporal segment)
            segment_path_slow = f"{OUTPUT_VIDEO_FOLDER}/{video_name}_other_slow_{segment_start}-{segment_end}.mp4"
            segment_path_fast = f"{OUTPUT_VIDEO_FOLDER}/{video_name}_other_fast_{segment_start}-{segment_end}.mp4"

            # Extract both segments
            extract_video_segment(
                video_file_path, segment_start, segment_end, segment_path_slow
            )
            extract_video_segment(
                video_file_path, segment_start, segment_end, segment_path_fast
            )

            entries.append(
                create_annotation_entry(
                    behavior_label, segment_path_slow, segment_path_fast
                )
            )
            if len(entries) >= max_count:
                break
        if len(entries) >= max_count:
            break
    return entries


def process_video_file(csv_path, video_file_path, gap_duration=5):
    """Processes a video file by extracting labeled and limited 'Others' segments, creating slow and fast versions."""
    start_events, stop_events = parse_behavior_events(csv_path)
    video_name = os.path.splitext(os.path.basename(video_file_path))[0].replace(
        " ", "_"
    )
    labeled_entries = []
    behavior_counts = defaultdict(int)

    for start_event in start_events:
        start_time = start_event["time"]
        behavior = start_event["behavior"]
        matching_stop = next(
            (
                stop
                for stop in stop_events
                if stop["behavior"] == behavior and stop["time"] > start_time
            ),
            None,
        )

        if matching_stop:
            end_time = matching_stop["time"]
            stop_events.remove(matching_stop)

            # Create paths for slow and fast segments
            segment_path_slow = f"{OUTPUT_VIDEO_FOLDER}/{video_name}_{behavior.replace(' ', '_')}_slow_{start_time}-{end_time}.mp4"
            segment_path_fast = f"{OUTPUT_VIDEO_FOLDER}/{video_name}_{behavior.replace(' ', '_')}_fast_{start_time}-{end_time}.mp4"

            # Extract both segments
            extract_video_segment(
                video_file_path, start_time, end_time, segment_path_slow
            )
            extract_video_segment(
                video_file_path, start_time, end_time, segment_path_fast
            )

            labeled_entries.append(
                create_annotation_entry(behavior, segment_path_slow, segment_path_fast)
            )
            behavior_counts[behavior] += 1

    max_behavior_count = max(behavior_counts.values(), default=0)

    with VideoFileClip(video_file_path) as video:
        video_duration = video.duration
    gaps = find_gaps(start_events, stop_events, video_duration, gap_duration)
    other_entries = sample_limited_segments_from_gaps(
        gaps, video_file_path, video_name, gap_duration, max_behavior_count
    )

    return labeled_entries + other_entries


def stratified_interleaved_split_and_save_annotations(
    entries, train_path, test_path, train_ratio=0.97
):
    """Splits annotations into stratified, interleaved train and test sets by behavior and saves them to JSONL files."""
    behavior_groups = defaultdict(list)
    for entry in entries:
        behavior = entry["response"]
        behavior_groups[behavior].append(entry)

    train_entries, test_entries = [], []
    for behavior, group_entries in behavior_groups.items():
        random.shuffle(group_entries)
        split_index = int(len(group_entries) * train_ratio)
        train_entries.append(group_entries[:split_index])
        test_entries.append(group_entries[split_index:])

    interleaved_train = list(
        itertools.chain.from_iterable(itertools.zip_longest(*train_entries))
    )
    interleaved_test = list(
        itertools.chain.from_iterable(itertools.zip_longest(*test_entries))
    )

    interleaved_train = [entry for entry in interleaved_train if entry is not None]
    interleaved_test = [entry for entry in interleaved_test if entry is not None]

    with open(train_path, "w") as f:
        for entry in interleaved_train:
            f.write(json.dumps(entry) + "\n")

    with open(test_path, "w") as f:
        for entry in interleaved_test:
            f.write(json.dumps(entry) + "\n")


def process_dataset():
    """Processes the entire dataset and creates stratified interleaved train/test JSONL files for SlowFast."""
    all_entries = []
    for subdir in os.listdir(BASE_FOLDER):
        subdir_path = os.path.join(BASE_FOLDER, subdir)
        if os.path.isdir(subdir_path):
            for file in os.listdir(subdir_path):
                if file.endswith(".csv"):
                    csv_path = os.path.join(subdir_path, file)
                    video_file_path = csv_path.replace(".csv", ".mpg")
                    if os.path.exists(video_file_path):
                        entries = process_video_file(csv_path, video_file_path)
                        all_entries.extend(entries)

    stratified_interleaved_split_and_save_annotations(
        all_entries, TRAIN_JSONL_PATH, TEST_JSONL_PATH, TRAIN_SPLIT_RATIO
    )
    print(
        "Conversion complete. Stratified interleaved training and testing datasets created for SlowFast."
    )


if __name__ == "__main__":
    process_dataset()

In [None]:
# Create an id to label name mapping
behaviors_id_to_classname = {}
for i, v in enumerate(behaviors):
    behaviors_id_to_classname[i] = v

In [None]:
behaviors_id_to_classname

In [None]:
# Get the number of classes in your fine-tuning dataset
num_classes = len(behaviors)

# Modify the final classification layer
# The correct attribute name is likely 'projection' within the 'head' module
model.blocks[-1].proj = nn.Linear(model.blocks[-1].proj.in_features, num_classes)

#### Define input transform

In [None]:
import torch.nn as nn
from torch.hub import load
from torch.utils.data import Dataset, DataLoader
import os
import torch.optim as optim

In [None]:
def collate_fn(batch):
    """Custom collate function to handle SlowFast input."""
    batch = [data for data in batch if data is not None]
    if not batch:
        return None

    # Separate the inputs and labels
    inputs_batch = [item[0] for item in batch]  # List of [slow_tensor, fast_tensor]
    labels_batch = torch.stack([item[1] for item in batch])

    # Stack the slow and fast pathways separately
    slow_pathway_batch = torch.stack([item[0] for item in inputs_batch])
    fast_pathway_batch = torch.stack([item[1] for item in inputs_batch])

    inputs = [slow_pathway_batch, fast_pathway_batch]

    return inputs, labels_batch

In [None]:
class PackPathway(torch.nn.Module):
    """
    Transform for converting video frames as a list of tensors.
    """

    def __init__(self):
        super().__init__()

    def forward(self, frames: torch.Tensor):
        fast_pathway = frames
        # Perform temporal sampling from the fast pathway.
        slow_pathway = torch.index_select(
            frames,
            1,
            torch.linspace(
                0,
                frames.shape[1] - 1,
                frames.shape[1] // 4,  # slowfast_alpha is usually 4
            ).long(),
        )
        frame_list = [slow_pathway, fast_pathway]
        return frame_list


class SlowFastDataset(Dataset):
    def __init__(
        self, annotations_file, transform=None, clip_duration=None, class_names=None
    ):
        self.annotations = [json.loads(line) for line in open(annotations_file, "r")]
        self.transform = transform
        self.clip_duration = clip_duration
        self.class_names = class_names  # Add class_names

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        video_path_slow = annotation["videos"][0]
        video_path_fast = annotation["videos"][1]
        behavior_label = annotation["response"]

        label = (
            self.class_names.index(behavior_label)
            if self.class_names
            else behavior_label
        )
        label = torch.tensor(label)

        try:
            video_slow = EncodedVideo.from_path(video_path_slow)
            video_fast = EncodedVideo.from_path(video_path_fast)
        except Exception as e:
            print(f"Error loading video: {e}")
            return None  # Skip this item

        start_sec = 0
        end_sec = self.clip_duration

        clip_slow_data = video_slow.get_clip(start_sec=start_sec, end_sec=end_sec)
        clip_fast_data = video_fast.get_clip(start_sec=start_sec, end_sec=end_sec)

        if clip_slow_data is None or clip_fast_data is None:
            print(f"Error extracting clips: {video_path_slow} or {video_path_fast}")
            return None  # Skip this item

        if self.transform:
            try:
                frames_slow = self.transform["video_slow"](clip_slow_data["video"])
                frames_fast = self.transform["video_fast"](clip_fast_data["video"])
            except Exception as e:
                print(f"Error in transforms: {e}")
                return None  # Skip this item

            return [frames_slow, frames_fast], label
        else:
            return [clip_slow_data["video"], clip_fast_data["video"]], label


# Define your transforms
side_size = 256
mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]
crop_size = 256
num_frames = 32
sampling_rate = 2
frames_per_second = 30
clip_duration = (num_frames * sampling_rate) / frames_per_second

transform_slow = Compose(
    [
        UniformTemporalSubsample(num_frames // 4),
        Lambda(lambda x: x / 255.0),
        NormalizeVideo(mean, std),
        ShortSideScale(size=side_size),
        CenterCropVideo(crop_size=(crop_size, crop_size)),
    ]
)

transform_fast = Compose(
    [
        UniformTemporalSubsample(num_frames),
        Lambda(lambda x: x / 255.0),
        NormalizeVideo(mean, std),
        ShortSideScale(size=side_size),
        CenterCropVideo(crop_size=(crop_size, crop_size)),
    ]
)

train_transform = {"video_slow": transform_slow, "video_fast": transform_fast}
val_transform = {"video_slow": transform_slow, "video_fast": transform_fast}

# Assuming you have a list of your class names
class_names = behaviors

# Instantiate your datasets
train_dataset = SlowFastDataset(
    annotations_file="train_video_annotations.jsonl",
    transform=train_transform,
    clip_duration=clip_duration,
    class_names=class_names,  # Pass class names to the dataset
)

val_dataset = SlowFastDataset(
    annotations_file="test_video_annotations.jsonl",
    transform=val_transform,
    clip_duration=clip_duration,
    class_names=class_names,  # Pass class names here as well
)

In [None]:
# Create data loaders
batch_size = 4
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn,
)

In [None]:
# Define loss function and optimizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        if inputs is None:  # Handle cases where collate_fn returns None
            continue

        # Move the list of input tensors to the device
        inputs = [inp.to(device) for inp in inputs]
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i + 1) % 10 == 0:
            print(
                f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / 10:.4f}"
            )
            running_loss = 0.0

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = [inp.to(device) for inp in inputs]
            labels = labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Validation Accuracy: {100 * correct / total:.2f}%"
    )

In [None]:
print("Finished Training")
torch.save(model.state_dict(), "fine_tuned_slowfast_model.pth")

In [None]:
transform = ApplyTransformToKey(
    key="video",
    transform=Compose(
        [
            UniformTemporalSubsample(num_frames),
            Lambda(lambda x: x / 255.0),
            NormalizeVideo(mean, std),
            ShortSideScale(size=side_size),
            CenterCropVideo(crop_size),
            PackPathway(),
        ]
    ),
)

# The duration of the input clip is also specific to the model.
clip_duration = (num_frames * sampling_rate) / frames_per_second

#### Run Inference

Download an example video.

In [None]:
video_path = "my_example_long_video.mp4"

Load the video and transform it to the input format required by the model.

In [None]:
# Select the duration of the clip to load by specifying the start and end duration
# The start_sec should correspond to where the action occurs in the video
start_sec = 0
end_sec = start_sec + clip_duration

# Initialize an EncodedVideo helper class and load the video
video = EncodedVideo.from_path(video_path)

# Load the desired clip
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)

# Apply a transform to normalize the video input
video_data = transform(video_data)

# Move the inputs to the desired device
inputs = video_data["video"]
inputs = [i.to(device)[None, ...] for i in inputs]

#### Get Predictions

In [None]:
# Pass the input clip through the model
preds = model(inputs)

# Get the predicted classes
post_act = torch.nn.Softmax(dim=1)
preds = post_act(preds)
pred_classes = preds.topk(k=5).indices[0]

# Map the predicted classes to the label names
pred_class_names = [behaviors_id_to_classname[int(i)] for i in pred_classes]
print("Top 5 predicted labels: %s" % ", ".join(pred_class_names))

### Model Description
SlowFast model architectures are based on [1] with pretrained weights using the 8x8 setting
on the Kinetics dataset.

| arch | depth | frame length x sample rate | top 1 | top 5 | Flops (G) | Params (M) |
| --------------- | ----------- | ----------- | ----------- | ----------- | ----------- |  ----------- | ----------- |
| SlowFast | R50   | 8x8                        | 76.94 | 92.69 | 65.71     | 34.57      |
| SlowFast | R101  | 8x8                        | 77.90 | 93.27 | 127.20    | 62.83      |


### References
[1] Christoph Feichtenhofer et al, "SlowFast Networks for Video Recognition"
https://arxiv.org/pdf/1812.03982.pdf