# Sport video classification evaluation
Reach out to lukasgeiger@google.com for any questions. 

In this notebook we evaluate Gemini's capability to classify play types across video modality. 
Our goal is to have Gemini output start_time, end_time, and play_type (touchdown, kickoff, rush, pass, etc.) over a given section of video. 

Experiment setup: 
- Use Intersection over Union (IoU, aka Jaccard Index) to calculate overlap between ground truth start/end timestamps and llm start/end timestamps. IoU was set to 0.3 for all experiments. Read more here: https://en.wikipedia.org/wiki/Jaccard_index
- Conduct tests across human labelled 10 minute video 
- For all results above IoU threshold, compare ground truth play_type to llm response play_type. Record true positive, false negative, and false positive. Note that we do not track true negatives because true negatives does not make sense for this use case (would it be every second that there isn't a play?)
- Make 5 separate API calls to Gemini to get a good sample size 
- Calculate accuracy, precision, recall from tp, fn, fp (slightly modified accuracy calculation to accound for no tn)

We conduct multiple experiments using prompt engineering and video modification (oversampling). All using Gemini 1.5 Pro 002. Here are a few of the things we tried.
1. Basic Prompt
2. Adding NFL rulebook to context 
3. Two sequential Gemini calls: first to extract timestamps, then to ID plays
4. Adding descriptions of each play type 
5. Slow down video 
6. Speed up video 
7. Sample video at 1 fps 
8. Oversample video at 1fps. Oversampling means first sample each video at 1fps and then duplicate each frame. Essentially increasing the amount of time the model sees each frame. 

### Setup

In [4]:
!pip install -q --upgrade google-cloud-aiplatform openpyxl ffmpeg-python

In [1]:
# Define ground truth from 10 minute segment

# ground_truth_old contains overlapping segments, meaning each start timestamp overlaps with the previous end timestamp.
# Ideal output does overlap because there are replays in games and the ideal output does not contain replays.

# Human labeled

ground_truth_old = [
{"end_time": "00:05", "play_type": "non_play", "start_time": "00:00"},
{"end_time": "00:31", "play_type": "non_play", "start_time": "00:05"},
{"end_time": "00:42", "play_type": "non_play", "start_time": "00:31"},
{"end_time": "00:46", "play_type": "non_play", "start_time": "00:42"},
{"end_time": "01:16", "play_type": "non_play", "start_time": "00:46"},
{"end_time": "01:46", "play_type": "non_play", "start_time": "01:16"},
{"end_time": "02:16", "play_type": "non_play", "start_time": "01:46"},
{"end_time": "02:29", "play_type": "non_play", "start_time": "02:16"},
{"end_time": "02:55", "play_type": "non_play", "start_time": "02:29"},
{"end_time": "03:32", "play_type": "non_play", "start_time": "02:55"},
{"end_time": "03:53", "play_type": "kickoff", "start_time": "03:32"},
{"end_time": "03:58", "play_type": "kickoff_return", "start_time": "03:53"},
{"end_time": "04:56", "play_type": "penalty", "start_time": "03:58"},
{"end_time": "05:34", "play_type": "pass", "start_time": "04:56"},
{"end_time": "06:15", "play_type": "pass", "start_time": "05:34"},
{"end_time": "06:48", "play_type": "pass", "start_time": "06:15"},
{"end_time": "07:20", "play_type": "penalty", "start_time": "06:48"},
{"end_time": "07:29", "play_type": "punt", "start_time": "07:20"},
{"end_time": "07:35", "play_type": "punt_return", "start_time": "07:29"},
{"end_time": "07:59", "play_type": "penalty", "start_time": "07:35"},
{"end_time": "08:49", "play_type": "non_play", "start_time": "07:59"},
{"end_time": "09:30", "play_type": "handoff", "start_time": "08:49"},
{"end_time": "10:10", "play_type": "pass", "start_time": "09:30"},
{"end_time": "10:37", "play_type": "pass", "start_time": "10:10"}
]

ground_truth = [
{"end_time": "00:05", "play_type": "non_play", "start_time": "00:00"},
{"end_time": "00:31", "play_type": "non_play", "start_time": "00:05"},
{"end_time": "00:42", "play_type": "non_play", "start_time": "00:31"},
{"end_time": "00:46", "play_type": "non_play", "start_time": "00:42"},
{"end_time": "01:16", "play_type": "non_play", "start_time": "00:46"},
{"end_time": "01:46", "play_type": "non_play", "start_time": "01:16"},
{"end_time": "02:16", "play_type": "non_play", "start_time": "01:46"},
{"end_time": "02:29", "play_type": "non_play", "start_time": "02:16"},
{"end_time": "02:55", "play_type": "non_play", "start_time": "02:29"},
{"end_time": "03:32", "play_type": "non_play", "start_time": "02:55"},
{"end_time": "03:53", "play_type": "kickoff", "start_time": "03:32"},
{"end_time": "03:58", "play_type": "kickoff_return", "start_time": "03:53"},
{"end_time": "04:22", "play_type": "penalty", "start_time": "04:10"},
{"end_time": "05:08", "play_type": "pass", "start_time": "04:56"},
{"end_time": "05:43", "play_type": "pass", "start_time": "05:34"},
{"end_time": "06:26", "play_type": "pass", "start_time": "06:09"},
{"end_time": "06:56", "play_type": "penalty", "start_time": "06:48"},
{"end_time": "07:29", "play_type": "punt", "start_time": "07:20"},
{"end_time": "07:35", "play_type": "punt_return", "start_time": "07:29"},
{"end_time": "07:59", "play_type": "penalty", "start_time": "07:48"},
{"end_time": "08:49", "play_type": "non_play", "start_time": "07:59"},
{"end_time": "08:59", "play_type": "handoff", "start_time": "08:47"}, # technically handoff is more specific but could also be considered a rush
{"end_time": "09:37", "play_type": "pass", "start_time": "09:25"},
{"end_time": "10:19", "play_type": "pass", "start_time": "10:10"}
]

### Helper functions for evaluation tests

In [2]:
IOU_SCORE = 0.3

In [3]:
import pandas as pd
import json, datetime

def time_to_seconds(time_str):
    """Converts a time string in the format MM:SS to seconds."""
    if not time_str:
        return 0
    minutes, seconds = map(int, time_str.split(':'))
    return minutes * 60 + seconds

def calculate_iou(start_time1, end_time1, start_time2, end_time2):
    """Calculates the Intersection over Union (IoU) of two time intervals."""
    # IoU = (Intersection of the two intervals) / (Union of the two intervals)
    # An IoU of 0 means there's no overlap at all.
    # An IoU of 1 means the two intervals are identical.

    start1 = time_to_seconds(start_time1)
    end1 = time_to_seconds(end_time1)
    start2 = time_to_seconds(start_time2)
    end2 = time_to_seconds(end_time2)

    # Determine the intersection
    intersection_start = max(start1, start2)
    intersection_end = min(end1, end2)

    if intersection_end <= intersection_start:
        return 0.0  # No overlap

    intersection = intersection_end - intersection_start

    # Determine the union
    union = (end1 - start1) + (end2 - start2) - intersection

    return intersection / union if union > 0 else 0.0

def evaluate_play_extraction(llm_response, grounded_truth):
    """
    Evaluates the performance of an LLM in extracting plays from American football videos.

    Args:
        llm_response: A list of dictionaries, where each dictionary represents a play
                      extracted by the LLM and contains 'start_time', 'end_time', and 'play_type'.
        grounded_truth: A list of dictionaries, where each dictionary represents the
                       ground truth play and contains 'start_time', 'end_time', and 'play_type'.

    Returns:
        A dictionary containing the accuracy, precision, recall, and tables of correctly
        and incorrectly identified plays.

    Note:
        True Positive:
        - An LLM-predicted play has an IoU (Intersection over Union) greater than or equal to IOU_SCORE with a ground truth play.
        - The play_type predicted by the LLM matches the play_type of the corresponding ground truth play.
        False Positive:
        - An LLM-predicted play has an IoU greater than or equal to IOU_SCORE with a ground truth play, but the play_type is incorrect.
        - An LLM-predicted play has no corresponding ground truth play with an IoU greater than or equal to IOU_SCORE (meaning it's a spurious prediction).
        False Negative:
        - When a ground truth play does not have a corresponding LLM-predicted play with an IoU greater than or equal to IOU_SCORE and matching play_type (meaning the LLM missed a play).
        True Negative:
        - Doesn't make sense for the use case. What would you count as a true negative? Every second that isn't identified?
    """
    # Exclude non_play types
    llm_response_filtered = [play for play in llm_response if play['play_type'] != "non_play"]
    grounded_truth_filtered = [play for play in grounded_truth if play['play_type'] != "non_play"]

    # Ignore True Negatives because they don't make sense for a video stream
    # i.e. do you count all the seconds that aren't identified as true negatives?
    true_positives = 0
    false_positives = 0
    false_negatives = 0

    correct_plays = []
    incorrect_plays = []

    matched_gt = [False] * len(grounded_truth)

    for llm_play in llm_response_filtered:
        best_iou = 0
        best_gt_match = -1

        for i, gt_play in enumerate(grounded_truth_filtered):
            iou = calculate_iou(llm_play['start_time'], llm_play['end_time'], gt_play['start_time'], gt_play['end_time'])

            if iou > best_iou:
                best_iou = iou
                best_gt_match = i

        if best_iou >= IOU_SCORE:
            if llm_play['play_type'] == grounded_truth_filtered[best_gt_match]['play_type']:
                true_positives += 1
                correct_plays.append({
                    'LLM Start Time': llm_play['start_time'],
                    'LLM End Time': llm_play['end_time'],
                    'LLM Play Type': llm_play['play_type'],
                    'GT Start Time': grounded_truth_filtered[best_gt_match]['start_time'],
                    'GT End Time': grounded_truth_filtered[best_gt_match]['end_time'],
                    'GT Play Type': grounded_truth_filtered[best_gt_match]['play_type'],
                    'IoU': best_iou
                })
            else:
                false_positives += 1
                incorrect_plays.append({
                    'LLM Start Time': llm_play['start_time'],
                    'LLM End Time': llm_play['end_time'],
                    'LLM Play Type': llm_play['play_type'],
                    'GT Start Time': grounded_truth_filtered[best_gt_match]['start_time'],
                    'GT End Time': grounded_truth_filtered[best_gt_match]['end_time'],
                    'GT Play Type': grounded_truth_filtered[best_gt_match]['play_type'],
                    'IoU': best_iou
                })
            matched_gt[best_gt_match] = True

        elif best_gt_match != -1:
            false_positives +=1
            incorrect_plays.append({
                    'LLM Start Time': llm_play['start_time'],
                    'LLM End Time': llm_play['end_time'],
                    'LLM Play Type': llm_play['play_type'],
                    'GT Start Time': grounded_truth_filtered[best_gt_match]['start_time'],
                    'GT End Time': grounded_truth_filtered[best_gt_match]['end_time'],
                    'GT Play Type': grounded_truth_filtered[best_gt_match]['play_type'],
                    'IoU': best_iou
                })
        else:
            false_positives +=1
            incorrect_plays.append({
                    'LLM Start Time': llm_play['start_time'],
                    'LLM End Time': llm_play['end_time'],
                    'LLM Play Type': llm_play['play_type'],
                    'GT Start Time': "N/A",
                    'GT End Time': "N/A",
                    'GT Play Type': "N/A",
                    'IoU': 0
                })

    for i in range(len(grounded_truth_filtered)):
      if not matched_gt[i]:
        false_negatives += 1
        incorrect_plays.append({
            'LLM Start Time': 'N/A',
            'LLM End Time': 'N/A',
            'LLM Play Type': 'N/A',
            'GT Start Time': grounded_truth_filtered[i]['start_time'],
            'GT End Time': grounded_truth_filtered[i]['end_time'],
            'GT Play Type': grounded_truth_filtered[i]['play_type'],
            'IoU': 0
        })
    # Accuracy: TP / (TP + FP + FN)
    accuracy = true_positives / (true_positives + false_positives + false_negatives) if (true_positives + false_positives + false_negatives) > 0 else 0.0
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'true_positives': true_positives,
        'false_positives': false_positives,
        'false_negatives': false_negatives,
        'correct_plays': pd.DataFrame(correct_plays),
        'incorrect_plays': pd.DataFrame(incorrect_plays)
    }

def aggregate_results(all_results):
    """
    Aggregates the results from multiple evaluations, including the DataFrames.

    Args:
        all_results: A list of dictionaries, where each dictionary is the output
                     from the evaluate_play_extraction function.

    Returns:
        A dictionary containing the average accuracy, precision, recall,
        total true positives, false positives, false negatives,
        and concatenated DataFrames of correct and incorrect plays.
    """

    num_runs = len(all_results)
    total_accuracy = 0
    total_precision = 0
    total_recall = 0
    total_true_positives = 0
    total_false_positives = 0
    total_false_negatives = 0

    all_correct_plays = []
    all_incorrect_plays = []

    for result in all_results:
        total_accuracy += result['accuracy']
        total_precision += result['precision']
        total_recall += result['recall']
        total_true_positives += result['true_positives']
        total_false_positives += result['false_positives']
        total_false_negatives += result['false_negatives']
        all_correct_plays.append(result['correct_plays'])
        all_incorrect_plays.append(result['incorrect_plays'])

    # Concatenate all the correct and incorrect plays DataFrames
    aggregated_correct_plays = pd.concat(all_correct_plays, ignore_index=True)
    aggregated_incorrect_plays = pd.concat(all_incorrect_plays, ignore_index=True)

    return {
        'average_accuracy': total_accuracy / num_runs,
        'average_precision': total_precision / num_runs,
        'average_recall': total_recall / num_runs,
        'total_true_positives': total_true_positives,
        'total_false_positives': total_false_positives,
        'total_false_negatives': total_false_negatives,
        'aggregated_correct_plays': aggregated_correct_plays,
        'aggregated_incorrect_plays': aggregated_incorrect_plays
    }

def save_results_to_csv(aggregated_results, filename_prefix="aggregate_results"):
    """
    Saves the aggregated results to a CSV file with a timestamp in the filename.

    Args:
        aggregate_results: The output dictionary from the aggregate_results function.
        filename_prefix: The prefix of the CSV file name.
    """

    # Get the current timestamp
    now = datetime.datetime.now()
    timestamp_str = now.strftime("%Y%m%d_%H%M%S")

    # Construct the filename with the timestamp
    filename = f"{filename_prefix}_{timestamp_str}.xlsx"

    # Create a DataFrame for the summary statistics
    summary_data = {
        'Metric': ['Average Accuracy', 'Average Precision', 'Average Recall',
                   'Total True Positives', 'Total False Positives', 'Total False Negatives'],
        'Value': [aggregated_results['average_accuracy'], aggregated_results['average_precision'],
                  aggregated_results['average_recall'], aggregated_results['total_true_positives'],
                  aggregated_results['total_false_positives'], aggregated_results['total_false_negatives']]
    }
    summary_df = pd.DataFrame(summary_data)

    # Write the DataFrames to the CSV file
    with pd.ExcelWriter(filename) as writer:
        summary_df.to_excel(writer, sheet_name='Summary', index=False)
        aggregated_results['aggregated_correct_plays'].to_excel(writer, sheet_name='Correct Plays', index=False)
        aggregated_results['aggregated_incorrect_plays'].to_excel(writer, sheet_name='Incorrect Plays', index=False)

### Basic prompt - Gemini 1.5 Pro

In [4]:
import base64
import vertexai
from vertexai.generative_models import GenerativeModel, Part, SafetySetting

# Basic test

def extract_key_moments_basic():
    vertexai.init(project="<project-id>", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="<gcs-video-file-path>/LIVE_97161119_TEN_MINUTES.mp4",
    )

    textsi_1 = """
    Your job is to extract segments from the sport video.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.
    You are allowed to skip over timestamps if they are non_plays.

    **Non-Play Segments:**
      - Non-play segments are advertisements breaks, promotional content, replays, and anything that is not a play.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play, which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, """Classify plays in the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

In [5]:
# Conduct test multiple (5) times and average results
# First will be basic prompt

result_list = []

for i in range(5):
  response = extract_key_moments_basic()
  result = evaluate_play_extraction(json.loads(response.text), ground_truth)
  result_list.append(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results, filename_prefix="basic-1.5-pro-002")

Average Accuracy: 0.4312
Average Precision: 0.6109
Average Recall: 0.5915
Total True Positives: 39
Total False Positives: 25
Total False Negatives: 26

Aggregated Correct Plays:
   LLM Start Time LLM End Time   LLM Play Type GT Start Time GT End Time  \
0           03:48        03:59  kickoff_return         03:53       03:58   
1           05:00        05:07            pass         04:56       05:08   
2           06:17        06:24            pass         06:09       06:26   
3           07:17        07:24            punt         07:20       07:29   
4           08:51        08:58         handoff         08:47       08:59   
5           09:29        09:36            pass         09:25       09:37   
6           10:09        10:17            pass         10:10       10:19   
7           03:48        03:59  kickoff_return         03:53       03:58   
8           05:00        05:07            pass         04:56       05:08   
9           06:17        06:24            pass         06:09  

### NFL rule book 

In [28]:
# Second will be same prompt but with NFL rule book attached

def extract_key_moments_nfl_playbook():
    vertexai.init(project="<project-id>", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="<gcs-video-file-path>/LIVE_97161119_TEN_MINUTES.mp4",
    )
    rule_book = Part.from_uri(
          uri="<gcs-video-file-path>/2024-nfl-rulebook.pdf",
          mime_type="application/pdf",
    )

    textsi_1 = """
    Your job is to extract segments from the sport video.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.
    You are allowed to skip over timestamps if they are non_plays.

    Use the attached rule book to understand what play_types mean.

    **Non-Play Segments:**
      - Non-play segments are advertisements breaks, promotional content, replays, and anything that is not a play.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play as per the external data, which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, rule_book, """Extract moments from the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

result_list = []

for i in range(5):
  response = extract_key_moments_nfl_playbook()
  result = evaluate_play_extraction(json.loads(response.text), ground_truth)
  result_list.append(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results, filename_prefix="nfl_playbook")


Average Accuracy: 0.3734
Average Precision: 0.5061
Average Recall: 0.6691
Total True Positives: 38
Total False Positives: 40
Total False Negatives: 19

Aggregated Correct Plays:
   LLM Start Time LLM End Time   LLM Play Type GT Start Time GT End Time  \
0           03:50        03:57  kickoff_return         03:53       03:58   
1           04:55        05:01            pass         04:56       05:08   
2           05:01        05:07            pass         04:56       05:08   
3           06:18        06:24            pass         06:09       06:26   
4           06:47        06:53         penalty         06:48       06:56   
5           07:29        07:37     punt_return         07:29       07:35   
6           09:30        09:36            pass         09:25       09:37   
7           10:09        10:17            pass         10:10       10:19   
8           03:48        03:59  kickoff_return         03:53       03:58   
9           05:00        05:07            pass         04:56  

### Multi-step prompt

In [30]:
# Third, create multi-step prompt. First extract time stamps then analyze plays.

def extract_timestamps():
    vertexai.init(project="<project-id>", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="<gcs-video-file-path>/LIVE_97161119_TEN_MINUTES.mp4",
    )

    textsi_1 = """
    Your job is to split the video into start and end timestamps by plays.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    A play is when the ball is in play.
    If the video contains non play time just skip the times stamps.
    Replays are considered non play time.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, """Extract timestamps from the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

def identify_plays(timestamps_json):
    vertexai.init(project="<project-id>", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="<gcs-video-file-path>/LIVE_97161119_TEN_MINUTES.mp4",
    )

    textsi_1 = f"""
    Your job is to assign play_types from the sport video and timestamp file.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.

    Do not modify the given timestamps.

    **Timestamps:**
    {timestamps_json}

    **Non-Play Segments:**
      - Non-play segments are advertisements breaks, promotional content, replays, and anything that is not a play.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {{
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      }},
      {{
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      }},
      {{
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }}
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, """Extract play types from the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

result_list = []

for i in range(5):
  timestamps = extract_timestamps()
  timestamps_json = json.loads(timestamps.text)
  print(timestamps_json)
  response = identify_plays(timestamps_json)
  result = evaluate_play_extraction(json.loads(response.text), ground_truth)
  result_list.append(result)
  print(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results, filename_prefix="extract_timestampes_then_id_plays")


[{'end_time': '04:05', 'start_time': '03:49'}, {'end_time': '05:07', 'start_time': '05:00'}, {'end_time': '05:18', 'start_time': '05:10'}, {'end_time': '05:43', 'start_time': '05:35'}, {'end_time': '06:24', 'start_time': '06:17'}, {'end_time': '07:16', 'start_time': '07:08'}, {'end_time': '07:37', 'start_time': '07:29'}, {'end_time': '07:46', 'start_time': '07:37'}, {'end_time': '08:17', 'start_time': '08:08'}, {'end_time': '08:20', 'start_time': '08:17'}, {'end_time': '08:58', 'start_time': '08:44'}, {'end_time': '09:16', 'start_time': '09:00'}, {'end_time': '09:36', 'start_time': '09:22'}, {'end_time': '09:47', 'start_time': '09:36'}, {'end_time': '10:00', 'start_time': '10:00'}, {'end_time': '10:19', 'start_time': '10:10'}, {'end_time': '10:21', 'start_time': '10:19'}, {'end_time': '10:30', 'start_time': '10:21'}]
{'accuracy': 0.3333333333333333, 'precision': 0.4666666666666667, 'recall': 0.5833333333333334, 'true_positives': 7, 'false_positives': 8, 'false_negatives': 5, 'correct_p

### Describe each play type

In [38]:
# Describe each of the play types

def extract_key_moments_plus_description():
    vertexai.init(project="<project-id>", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="<gcs-video-file-path>/LIVE_97161119_TEN_MINUTES.mp4",
    )

    textsi_1 = """
    Your job is to extract segments from the sport video.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.
    You are allowed to skip over timestamps if they are non_plays.

    **Play_types and their decriptions:**
      - rush: A running play where a player (usually the running back) carries the ball forward to gain yardage.
      - pass: The quarterback throws the ball forward to a receiver downfield.
      - handoff: The quarterback hands the ball off to another player (usually the running back) to start a running play.
      - field_goal: A kick attempt where the team tries to score 3 points by kicking the ball through the upright goal posts.
      - punt: When a team kicks the ball to the opposing team to give up possession (usually on fourth down).
      - extra_point: A kick attempt after a touchdown to score 1 additional point.
      - kickoff: The kick that starts each half and follows every touchdown and successful field goal.
      - kickoff_return: The receiving team attempts to advance the ball after a kickoff.
      - punt_return: The receiving team attempts to advance the ball after a punt.
      - penalty: An infraction of the rules, resulting in a yardage penalty against the offending team. Look for a yellow flag and a referee in a striped shirt announcing the penalty.
      - conversion: An attempt to score 2 points after a touchdown by running or passing the ball into the end zone.
      - non_play: This refers to moments when the ball is not in play, such as timeouts, commercial breaks, between quarters, and replays.
      - touchdown: A scoring play worth 6 points, achieved by carrying or catching the ball in the opponent's end zone.
      - unsure: This would be used when the play type cannot be confidently determined

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play as per the explanations above, which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, """Extract plays from the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

result_list = []

for i in range(5):
  response = extract_key_moments_plus_description()
  result = evaluate_play_extraction(json.loads(response.text), ground_truth)
  result_list.append(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results,filename_prefix="extract_key_moments_plus_description")


KeyboardInterrupt: 

### Slow down video

In [None]:
# Slow down video (0.5x)
!ffmpeg -i /Users/lukasgeiger/Desktop/VertexGenAISamples/private/utils/videos/LIVE_97161119_TEN_MINUTES.mp4 -filter_complex "[0:v]setpts=2*PTS[v];[0:a]atempo=0.5[a]" -map "[v]" -map "[a]" /Users/lukasgeiger/Desktop/VertexGenAISamples/private/utils/videos/LIVE_97161119_TEN_MINUTES_slowed_05x.mp4

In [12]:
# Have to adjust ground truth because video is now twice as long
# Double each ground truth timestamp.

def convert_to_seconds(timestamp):
  """Converts a timestamp in MM:SS format to seconds."""
  minutes, seconds = map(int, timestamp.split(':'))
  return minutes * 60 + seconds

def convert_to_timestamp(seconds):
  """Converts seconds to a timestamp in MM:SS format."""
  minutes = seconds // 60
  seconds = seconds % 60
  return f"{minutes:02d}:{seconds:02d}"

ground_truth_slowed = []

for event in ground_truth:
  start_time_seconds = convert_to_seconds(event['start_time'])
  end_time_seconds = convert_to_seconds(event['end_time'])

  new_start_time_seconds = start_time_seconds * 2
  new_end_time_seconds = end_time_seconds * 2

  ground_truth_slowed.append({
      "start_time": convert_to_timestamp(new_start_time_seconds),
      "end_time": convert_to_timestamp(new_end_time_seconds),
      "play_type": event['play_type']
  })

print(ground_truth_slowed)

[{'start_time': '00:00', 'end_time': '00:10', 'play_type': 'non_play'}, {'start_time': '00:10', 'end_time': '01:02', 'play_type': 'non_play'}, {'start_time': '01:02', 'end_time': '01:24', 'play_type': 'non_play'}, {'start_time': '01:24', 'end_time': '01:32', 'play_type': 'non_play'}, {'start_time': '01:32', 'end_time': '02:32', 'play_type': 'non_play'}, {'start_time': '02:32', 'end_time': '03:32', 'play_type': 'non_play'}, {'start_time': '03:32', 'end_time': '04:32', 'play_type': 'non_play'}, {'start_time': '04:32', 'end_time': '04:58', 'play_type': 'non_play'}, {'start_time': '04:58', 'end_time': '05:50', 'play_type': 'non_play'}, {'start_time': '05:50', 'end_time': '07:04', 'play_type': 'non_play'}, {'start_time': '07:04', 'end_time': '07:46', 'play_type': 'kickoff'}, {'start_time': '07:46', 'end_time': '07:56', 'play_type': 'kickoff_return'}, {'start_time': '08:20', 'end_time': '08:44', 'play_type': 'penalty'}, {'start_time': '09:52', 'end_time': '10:16', 'play_type': 'pass'}, {'sta

In [22]:
# Generate LLM response for slow down video.

import base64
import vertexai
from vertexai.generative_models import GenerativeModel, Part, SafetySetting

def extract_key_moments_slowed():
    vertexai.init(project="<project-id>", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="<gcs-video-file-path>/LIVE_97161119_TEN_MINUTES_slowed_05x.mp4",
    )

    textsi_1 = """
    Your job is to extract segments from the sport video.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.
    You are allowed to skip over timestamps if they are non_plays.

    **Non-Play Segments:**
      - Non-play segments are advertisements breaks, promotional content, replays, and anything that is not a play.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play, which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, """Extract moments from the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

result_list = []

for i in range(5):
  response = extract_key_moments_slowed()
  result = evaluate_play_extraction(json.loads(response.text), ground_truth_slowed)
  result_list.append(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results,filename_prefix="slowed")

Average Accuracy: 0.1583
Average Precision: 0.1606
Average Recall: 0.9278
Total True Positives: 34
Total False Positives: 180
Total False Negatives: 3

Aggregated Correct Plays:
   LLM Start Time LLM End Time LLM Play Type GT Start Time GT End Time  \
0           07:26        07:40       kickoff         07:04       07:46   
1           10:01        10:23          pass         09:52       10:16   
2           11:05        11:20          pass         11:08       11:26   
3           12:13        12:33          pass         12:18       12:52   
4           12:33        12:50          pass         12:18       12:52   
5           18:50        19:13          pass         18:50       19:14   
6           20:02        20:31          pass         20:20       20:38   
7           09:45        10:14          pass         09:52       10:16   
8           12:13        12:33          pass         12:18       12:52   
9           12:33        12:50          pass         12:18       12:52   
10      

### Speed up video

In [None]:
# Speed up video (2x)
!ffmpeg -i /Users/lukasgeiger/Desktop/VertexGenAISamples/private/utils/videos/LIVE_97161119_TEN_MINUTES.mp4 -filter_complex "[0:v]setpts=0.5*PTS[v];[0:a]atempo=2[a]" -map "[v]" -map "[a]" /Users/lukasgeiger/Desktop/VertexGenAISamples/private/utils/videos/LIVE_97161119_TEN_MINUTES_sped_up_2x.mp4

In [23]:
# Cut ground truth in half

ground_truth_sped_up = []

for event in ground_truth:
  start_time_seconds = convert_to_seconds(event['start_time'])
  end_time_seconds = convert_to_seconds(event['end_time'])

  new_start_time_seconds = start_time_seconds // 2 # Use // for integer division
  new_end_time_seconds = end_time_seconds // 2 # Use // for integer division

  ground_truth_sped_up.append({
      "start_time": convert_to_timestamp(new_start_time_seconds),
      "end_time": convert_to_timestamp(new_end_time_seconds),
      "play_type": event['play_type']
  })

print(ground_truth_sped_up)

[{'start_time': '00:00', 'end_time': '00:02', 'play_type': 'non_play'}, {'start_time': '00:02', 'end_time': '00:15', 'play_type': 'non_play'}, {'start_time': '00:15', 'end_time': '00:21', 'play_type': 'non_play'}, {'start_time': '00:21', 'end_time': '00:23', 'play_type': 'non_play'}, {'start_time': '00:23', 'end_time': '00:38', 'play_type': 'non_play'}, {'start_time': '00:38', 'end_time': '00:53', 'play_type': 'non_play'}, {'start_time': '00:53', 'end_time': '01:08', 'play_type': 'non_play'}, {'start_time': '01:08', 'end_time': '01:14', 'play_type': 'non_play'}, {'start_time': '01:14', 'end_time': '01:27', 'play_type': 'non_play'}, {'start_time': '01:27', 'end_time': '01:46', 'play_type': 'non_play'}, {'start_time': '01:46', 'end_time': '01:56', 'play_type': 'kickoff'}, {'start_time': '01:56', 'end_time': '01:59', 'play_type': 'kickoff_return'}, {'start_time': '02:05', 'end_time': '02:11', 'play_type': 'penalty'}, {'start_time': '02:28', 'end_time': '02:34', 'play_type': 'pass'}, {'sta

In [40]:
# Generate response for sped up video

import base64
import vertexai
from vertexai.generative_models import GenerativeModel, Part, SafetySetting

# Basic test

def extract_key_moments_sped_up():
    vertexai.init(project="<project-id>", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="<gcs-video-file-path>/LIVE_97161119_TEN_MINUTES_sped_up_2x.mp4",
    )

    textsi_1 = """
    Your job is to extract segments from the sport video.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.
    You are allowed to skip over timestamps if they are non_plays.

    **Non-Play Segments:**
      - Non-play segments are advertisements breaks, promotional content, replays, and anything that is not a play.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play, which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, """Extract moments from the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

result_list = []

for i in range(5):
  response = extract_key_moments_sped_up()
  result = evaluate_play_extraction(json.loads(response.text), ground_truth_sped_up)
  result_list.append(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results,filename_prefix="sped_up")

KeyboardInterrupt: 

### 1 fps video

In [None]:
# Instead of slowing down.
!ffmpeg -i /Users/lukasgeiger/Desktop/VertexGenAISamples/private/utils/videos/LIVE_97161119_TEN_MINUTES.mp4 -vf "fps=1" -c:v libx264 -crf 18 /Users/lukasgeiger/Desktop/VertexGenAISamples/private/utils/videos/LIVE_97161119_TEN_MINUTES_1fps.mp4

In [9]:
# 1fps

import base64
import vertexai
from vertexai.generative_models import GenerativeModel, Part, SafetySetting

# Basic test

def extract_key_moments_1fps():
    vertexai.init(project="<project-id>", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="<gcs-video-file-path>/LIVE_97161119_TEN_MINUTES_1fps.mp4",
    )

    textsi_1 = """
    Your job is to extract segments from the sport video.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.
    You are allowed to skip over timestamps if they are non_plays.

    **Non-Play Segments:**
      - Non-play segments are advertisements breaks, promotional content, replays, and anything that is not a play.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play, which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1],
    )

    response = model.generate_content(
        ["""Extract moments from the video.""", video1],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

result_list = []

for i in range(5):
  response = extract_key_moments_1fps()
  result = evaluate_play_extraction(json.loads(response.text), ground_truth)
  result_list.append(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results,filename_prefix="1fps")

Average Accuracy: 0.3295
Average Precision: 0.4327
Average Recall: 0.6024
Total True Positives: 33
Total False Positives: 50
Total False Negatives: 22

Aggregated Correct Plays:
   LLM Start Time LLM End Time   LLM Play Type GT Start Time GT End Time  \
0           03:50        04:06  kickoff_return         03:53       03:58   
1           04:56        05:01            pass         04:56       05:08   
2           05:01        05:16            pass         04:56       05:08   
3           06:07        06:25            pass         06:09       06:26   
4           09:22        09:47            pass         09:25       09:37   
5           10:01        10:18            pass         10:10       10:19   
6           04:55        05:01            pass         04:56       05:08   
7           07:17        07:29            punt         07:20       07:29   
8           10:00        10:18            pass         10:10       10:19   
9           04:06        04:21         penalty         04:10  

### Oversample 1fps video 

In [None]:
# Instead of slowing down. convert to 1fps and then hold each frame for 2 seconds.
!ffmpeg -i /Users/lukasgeiger/Desktop/VertexGenAISamples/private/utils/videos/LIVE_97161119_TEN_MINUTES.mp4 -vf "fps=1,setpts=2*PTS" -af "atempo=0.5" -c:v libx264 -crf 18 /Users/lukasgeiger/Desktop/VertexGenAISamples/private/utils/videos/LIVE_97161119_TEN_MINUTES_1fps_oversampled.mp4

In [15]:
# 1fps oversampled

import base64
import vertexai
from vertexai.generative_models import GenerativeModel, Part, SafetySetting

# Basic test

def extract_key_moments_1fps_oversampled():
    vertexai.init(project="<project-id>", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="<gcs-video-file-path>/LIVE_97161119_TEN_MINUTES_1fps_oversampled.mp4",
    )

    textsi_1 = """
    Your job is to extract segments from the sport video.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.
    You are allowed to skip over timestamps if they are non_plays.

    **Non-Play Segments:**
      - Non-play segments are advertisements breaks, promotional content, replays, and anything that is not a play.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play, which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1],
    )

    response = model.generate_content(
        ["""Extract moments from the video.""", video1],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

result_list = []

for i in range(5):
  response = extract_key_moments_1fps_oversampled()
  result = evaluate_play_extraction(json.loads(response.text), ground_truth_slowed)
  result_list.append(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results,filename_prefix="1fps_oversampled")

Average Accuracy: 0.1223
Average Precision: 0.1227
Average Recall: 0.9750
Total True Positives: 34
Total False Positives: 244
Total False Negatives: 1

Aggregated Correct Plays:
   LLM Start Time LLM End Time LLM Play Type GT Start Time GT End Time  \
0           08:27        08:47       penalty         08:20       08:44   
1           11:06        11:24          pass         11:08       11:26   
2           13:26        13:47       penalty         13:36       13:52   
3           15:34        15:57       penalty         15:36       15:58   
4           18:43        19:01          pass         18:50       19:14   
5           19:01        19:13          pass         18:50       19:14   
6           20:25        20:36          pass         20:20       20:38   
7           09:51        10:13          pass         09:52       10:16   
8           11:06        11:23          pass         11:08       11:26   
9           13:26        13:47       penalty         13:36       13:52   
10      

# Conclusion

None of the results improved on the baseline result of ~40-50% accuracy. 

Interestingly the more we added to the prompt the worst the model performed. 
Video modification similarly did not increase and in fact degraded model performance significantly