1. Install the Vertex AI SDK: Open a terminal window and enter the command below. You can also [install it in a virtualenv](https://googleapis.dev/python/aiplatform/latest/index.html)

In [1]:
!pip install --upgrade google-cloud-aiplatform openpyxl



In [16]:
# Define ground truth from 10 minute segment

# ground_truth_old contains overlapping segments, meaning each start timestamp overlaps with the previous end timestamp.
# Ideal output does overlap because there are replays in games and the ideal output does not contain replays.
ground_truth_old = [
{"end_time": "00:05", "play_type": "non_play", "start_time": "00:00"},
{"end_time": "00:31", "play_type": "non_play", "start_time": "00:05"},
{"end_time": "00:42", "play_type": "non_play", "start_time": "00:31"},
{"end_time": "00:46", "play_type": "non_play", "start_time": "00:42"},
{"end_time": "01:16", "play_type": "non_play", "start_time": "00:46"},
{"end_time": "01:46", "play_type": "non_play", "start_time": "01:16"},
{"end_time": "02:16", "play_type": "non_play", "start_time": "01:46"},
{"end_time": "02:29", "play_type": "non_play", "start_time": "02:16"},
{"end_time": "02:55", "play_type": "non_play", "start_time": "02:29"},
{"end_time": "03:32", "play_type": "non_play", "start_time": "02:55"},
{"end_time": "03:53", "play_type": "kickoff", "start_time": "03:32"},
{"end_time": "03:58", "play_type": "kickoff_return", "start_time": "03:53"},
{"end_time": "04:56", "play_type": "penalty", "start_time": "03:58"},
{"end_time": "05:34", "play_type": "pass", "start_time": "04:56"},
{"end_time": "06:15", "play_type": "pass", "start_time": "05:34"},
{"end_time": "06:48", "play_type": "pass", "start_time": "06:15"},
{"end_time": "07:20", "play_type": "penalty", "start_time": "06:48"},
{"end_time": "07:29", "play_type": "punt", "start_time": "07:20"},
{"end_time": "07:35", "play_type": "punt_return", "start_time": "07:29"},
{"end_time": "07:59", "play_type": "penalty", "start_time": "07:35"},
{"end_time": "08:49", "play_type": "non_play", "start_time": "07:59"},
{"end_time": "09:30", "play_type": "handoff", "start_time": "08:49"},
{"end_time": "10:10", "play_type": "pass", "start_time": "09:30"},
{"end_time": "10:37", "play_type": "pass", "start_time": "10:10"}
]

ground_truth = [
{"end_time": "00:05", "play_type": "non_play", "start_time": "00:00"},
{"end_time": "00:31", "play_type": "non_play", "start_time": "00:05"},
{"end_time": "00:42", "play_type": "non_play", "start_time": "00:31"},
{"end_time": "00:46", "play_type": "non_play", "start_time": "00:42"},
{"end_time": "01:16", "play_type": "non_play", "start_time": "00:46"},
{"end_time": "01:46", "play_type": "non_play", "start_time": "01:16"},
{"end_time": "02:16", "play_type": "non_play", "start_time": "01:46"},
{"end_time": "02:29", "play_type": "non_play", "start_time": "02:16"},
{"end_time": "02:55", "play_type": "non_play", "start_time": "02:29"},
{"end_time": "03:32", "play_type": "non_play", "start_time": "02:55"},
{"end_time": "03:53", "play_type": "kickoff", "start_time": "03:32"},
{"end_time": "03:58", "play_type": "kickoff_return", "start_time": "03:53"},
{"end_time": "04:22", "play_type": "penalty", "start_time": "04:10"},
{"end_time": "05:08", "play_type": "pass", "start_time": "04:56"},
{"end_time": "05:43", "play_type": "pass", "start_time": "05:34"},
{"end_time": "06:26", "play_type": "pass", "start_time": "06:09"},
{"end_time": "06:56", "play_type": "penalty", "start_time": "06:48"},
{"end_time": "07:29", "play_type": "punt", "start_time": "07:20"},
{"end_time": "07:35", "play_type": "punt_return", "start_time": "07:29"},
{"end_time": "07:59", "play_type": "penalty", "start_time": "07:48"},
{"end_time": "08:49", "play_type": "non_play", "start_time": "07:59"},
{"end_time": "08:59", "play_type": "handoff", "start_time": "08:47"}, # technically handoff is more specific but customer gt lists rush as well
{"end_time": "09:37", "play_type": "pass", "start_time": "09:25"},
{"end_time": "10:19", "play_type": "pass", "start_time": "10:10"}
]

In [22]:
IOU_SCORE = 0.3

In [21]:
import base64
import vertexai
from vertexai.generative_models import GenerativeModel, Part, SafetySetting

# Basic test

def extract_key_moments_basic():
    vertexai.init(project="cloud-llm-preview1", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="gs://lukasgeiger-fubo-demo-data/LIVE_97161119_TEN_MINUTES.mp4",
    )

    textsi_1 = """
    Your job is to extract segments from the sport video.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.
    You are allowed to skip over timestamps if they are non_plays.

    **Non-Play Segments:**
      - Non-play segments are advertisements breaks, promotional content, replays, and anything that is not a play.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play, which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, """Extract moments from the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

In [25]:
import pandas as pd
import json, datetime

def time_to_seconds(time_str):
    """Converts a time string in the format MM:SS to seconds."""
    if not time_str:
        return 0
    minutes, seconds = map(int, time_str.split(':'))
    return minutes * 60 + seconds

def calculate_iou(start_time1, end_time1, start_time2, end_time2):
    """Calculates the Intersection over Union (IoU) of two time intervals."""
    # IoU = (Intersection of the two intervals) / (Union of the two intervals)
    # An IoU of 0 means there's no overlap at all.
    # An IoU of 1 means the two intervals are identical.

    start1 = time_to_seconds(start_time1)
    end1 = time_to_seconds(end_time1)
    start2 = time_to_seconds(start_time2)
    end2 = time_to_seconds(end_time2)

    # Determine the intersection
    intersection_start = max(start1, start2)
    intersection_end = min(end1, end2)

    if intersection_end <= intersection_start:
        return 0.0  # No overlap

    intersection = intersection_end - intersection_start

    # Determine the union
    union = (end1 - start1) + (end2 - start2) - intersection

    return intersection / union if union > 0 else 0.0

def evaluate_play_extraction(llm_response, grounded_truth):
    """
    Evaluates the performance of an LLM in extracting plays from American football videos.

    Args:
        llm_response: A list of dictionaries, where each dictionary represents a play
                      extracted by the LLM and contains 'start_time', 'end_time', and 'play_type'.
        grounded_truth: A list of dictionaries, where each dictionary represents the
                       ground truth play and contains 'start_time', 'end_time', and 'play_type'.

    Returns:
        A dictionary containing the accuracy, precision, recall, and tables of correctly
        and incorrectly identified plays.

    Note:
        Can be an incorrect play for two reasons:
          1. IOU is out of bounds
          2. play_type was incorrectly indentified
    """

    llm_response_filtered = [play for play in llm_response if play['play_type'] != "non_play"]
    grounded_truth_filtered = [play for play in grounded_truth if play['play_type'] != "non_play"]


    true_positives = 0
    false_positives = 0
    false_negatives = 0

    correct_plays = []
    incorrect_plays = []

    matched_gt = [False] * len(grounded_truth)

    for llm_play in llm_response_filtered:
        best_iou = 0
        best_gt_match = -1

        for i, gt_play in enumerate(grounded_truth_filtered):
            iou = calculate_iou(llm_play['start_time'], llm_play['end_time'], gt_play['start_time'], gt_play['end_time'])

            if iou > best_iou:
                best_iou = iou
                best_gt_match = i

        if best_gt_match != -1 and best_iou >= IOU_SCORE:
            if llm_play['play_type'] == grounded_truth_filtered[best_gt_match]['play_type']:
                true_positives += 1
                correct_plays.append({
                    'LLM Start Time': llm_play['start_time'],
                    'LLM End Time': llm_play['end_time'],
                    'LLM Play Type': llm_play['play_type'],
                    'GT Start Time': grounded_truth_filtered[best_gt_match]['start_time'],
                    'GT End Time': grounded_truth_filtered[best_gt_match]['end_time'],
                    'GT Play Type': grounded_truth_filtered[best_gt_match]['play_type'],
                    'IoU': best_iou
                })
            else:
                false_positives += 1
                incorrect_plays.append({
                    'LLM Start Time': llm_play['start_time'],
                    'LLM End Time': llm_play['end_time'],
                    'LLM Play Type': llm_play['play_type'],
                    'GT Start Time': grounded_truth_filtered[best_gt_match]['start_time'],
                    'GT End Time': grounded_truth_filtered[best_gt_match]['end_time'],
                    'GT Play Type': grounded_truth_filtered[best_gt_match]['play_type'],
                    'IoU': best_iou
                })
            matched_gt[best_gt_match] = True

        else:
            false_positives +=1
            incorrect_plays.append({
                    'LLM Start Time': llm_play['start_time'],
                    'LLM End Time': llm_play['end_time'],
                    'LLM Play Type': llm_play['play_type'],
                    'GT Start Time': 'N/A',
                    'GT End Time': 'N/A',
                    'GT Play Type': 'N/A',
                    'IoU': 0
                })

    for i in range(len(grounded_truth_filtered)):
      if not matched_gt[i]:
        false_negatives += 1
        incorrect_plays.append({
            'LLM Start Time': 'N/A',
            'LLM End Time': 'N/A',
            'LLM Play Type': 'N/A',
            'GT Start Time': grounded_truth_filtered[i]['start_time'],
            'GT End Time': grounded_truth_filtered[i]['end_time'],
            'GT Play Type': grounded_truth_filtered[i]['play_type'],
            'IoU': 0
        })

    accuracy = (true_positives) / (len(llm_response_filtered) + len(grounded_truth_filtered) - true_positives) if (len(llm_response_filtered) + len(grounded_truth_filtered) - true_positives) > 0 else 0.0
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'true_positives': true_positives,
        'false_positives': false_positives,
        'false_negatives': false_negatives,
        'correct_plays': pd.DataFrame(correct_plays),
        'incorrect_plays': pd.DataFrame(incorrect_plays)
    }

def aggregate_results(all_results):
    """
    Aggregates the results from multiple evaluations, including the DataFrames.

    Args:
        all_results: A list of dictionaries, where each dictionary is the output
                     from the evaluate_play_extraction function.

    Returns:
        A dictionary containing the average accuracy, precision, recall,
        total true positives, false positives, false negatives,
        and concatenated DataFrames of correct and incorrect plays.
    """

    num_runs = len(all_results)
    total_accuracy = 0
    total_precision = 0
    total_recall = 0
    total_true_positives = 0
    total_false_positives = 0
    total_false_negatives = 0

    all_correct_plays = []
    all_incorrect_plays = []

    for result in all_results:
        total_accuracy += result['accuracy']
        total_precision += result['precision']
        total_recall += result['recall']
        total_true_positives += result['true_positives']
        total_false_positives += result['false_positives']
        total_false_negatives += result['false_negatives']
        all_correct_plays.append(result['correct_plays'])
        all_incorrect_plays.append(result['incorrect_plays'])

    # Concatenate all the correct and incorrect plays DataFrames
    aggregated_correct_plays = pd.concat(all_correct_plays, ignore_index=True)
    aggregated_incorrect_plays = pd.concat(all_incorrect_plays, ignore_index=True)

    return {
        'average_accuracy': total_accuracy / num_runs,
        'average_precision': total_precision / num_runs,
        'average_recall': total_recall / num_runs,
        'total_true_positives': total_true_positives,
        'total_false_positives': total_false_positives,
        'total_false_negatives': total_false_negatives,
        'aggregated_correct_plays': aggregated_correct_plays,
        'aggregated_incorrect_plays': aggregated_incorrect_plays
    }

def save_results_to_csv(aggregated_results, filename_prefix="aggregate_results"):
    """
    Saves the aggregated results to a CSV file with a timestamp in the filename.

    Args:
        aggregate_results: The output dictionary from the aggregate_results function.
        filename_prefix: The prefix of the CSV file name.
    """

    # Get the current timestamp
    now = datetime.datetime.now()
    timestamp_str = now.strftime("%Y%m%d_%H%M%S")

    # Construct the filename with the timestamp
    filename = f"{filename_prefix}_{timestamp_str}.xlsx"

    # Create a DataFrame for the summary statistics
    summary_data = {
        'Metric': ['Average Accuracy', 'Average Precision', 'Average Recall',
                   'Total True Positives', 'Total False Positives', 'Total False Negatives'],
        'Value': [aggregated_results['average_accuracy'], aggregated_results['average_precision'],
                  aggregated_results['average_recall'], aggregated_results['total_true_positives'],
                  aggregated_results['total_false_positives'], aggregated_results['total_false_negatives']]
    }
    summary_df = pd.DataFrame(summary_data)

    # Write the DataFrames to the CSV file
    with pd.ExcelWriter(filename) as writer:
        summary_df.to_excel(writer, sheet_name='Summary', index=False)
        aggregated_results['aggregated_correct_plays'].to_excel(writer, sheet_name='Correct Plays', index=False)
        aggregated_results['aggregated_incorrect_plays'].to_excel(writer, sheet_name='Incorrect Plays', index=False)

In [29]:
# Conduct test multiple (5) times and average results
# First will be plain prompt

result_list = []

for i in range(5):
  response = extract_key_moments_basic()
  result = evaluate_play_extraction(json.loads(response.text), ground_truth)
  result_list.append(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results, filename_prefix="basic")

Average Accuracy: 0.4913
Average Precision: 0.7031
Average Recall: 0.6788
Total True Positives: 40
Total False Positives: 17
Total False Negatives: 19

Aggregated Correct Plays:
   LLM Start Time LLM End Time   LLM Play Type GT Start Time GT End Time  \
0           03:49        03:59  kickoff_return         03:53       03:58   
1           05:00        05:07            pass         04:56       05:08   
2           05:35        05:45            pass         05:34       05:43   
3           06:17        06:24            pass         06:09       06:26   
4           06:47        07:00         penalty         06:48       06:56   
5           07:24        07:37     punt_return         07:29       07:35   
6           09:30        09:36            pass         09:25       09:37   
7           10:09        10:18            pass         10:10       10:19   
8           03:49        04:05  kickoff_return         03:53       03:58   
9           05:00        05:07            pass         04:56  

In [28]:
# Second will be same prompt but with NFL rule book attached

def extract_key_moments_nfl_playbook():
    vertexai.init(project="cloud-llm-preview1", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="gs://lukasgeiger-fubo-demo-data/LIVE_97161119_TEN_MINUTES.mp4",
    )
    rule_book = Part.from_uri(
          uri="gs://lukasgeiger-fubo-demo-data/2024-nfl-rulebook.pdf",
          mime_type="application/pdf",
    )

    textsi_1 = """
    Your job is to extract segments from the sport video.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.
    You are allowed to skip over timestamps if they are non_plays.

    Use the attached rule book to understand what play_types mean.

    **Non-Play Segments:**
      - Non-play segments are advertisements breaks, promotional content, replays, and anything that is not a play.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play as per the external data, which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, rule_book, """Extract moments from the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

result_list = []

for i in range(5):
  response = extract_key_moments_nfl_playbook()
  result = evaluate_play_extraction(json.loads(response.text), ground_truth)
  result_list.append(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results, filename_prefix="nfl_playbook")


Average Accuracy: 0.3734
Average Precision: 0.5061
Average Recall: 0.6691
Total True Positives: 38
Total False Positives: 40
Total False Negatives: 19

Aggregated Correct Plays:
   LLM Start Time LLM End Time   LLM Play Type GT Start Time GT End Time  \
0           03:50        03:57  kickoff_return         03:53       03:58   
1           04:55        05:01            pass         04:56       05:08   
2           05:01        05:07            pass         04:56       05:08   
3           06:18        06:24            pass         06:09       06:26   
4           06:47        06:53         penalty         06:48       06:56   
5           07:29        07:37     punt_return         07:29       07:35   
6           09:30        09:36            pass         09:25       09:37   
7           10:09        10:17            pass         10:10       10:19   
8           03:48        03:59  kickoff_return         03:53       03:58   
9           05:00        05:07            pass         04:56  

In [30]:
# Third, create multi-step prompt. First extract time stamps then analyze plays.

def extract_timestamps():
    vertexai.init(project="cloud-llm-preview1", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="gs://lukasgeiger-fubo-demo-data/LIVE_97161119_TEN_MINUTES.mp4",
    )

    textsi_1 = """
    Your job is to split the video into start and end timestamps by plays.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    A play is when the ball is in play.
    If the video contains non play time just skip the times stamps.
    Replays are considered non play time.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, """Extract timestamps from the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

def identify_plays(timestamps_json):
    vertexai.init(project="cloud-llm-preview1", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="gs://lukasgeiger-fubo-demo-data/LIVE_97161119_TEN_MINUTES.mp4",
    )

    textsi_1 = f"""
    Your job is to assign play_types from the sport video and timestamp file.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.

    Do not modify the given timestamps.

    **Timestamps:**
    {timestamps_json}

    **Non-Play Segments:**
      - Non-play segments are advertisements breaks, promotional content, replays, and anything that is not a play.

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {{
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      }},
      {{
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      }},
      {{
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }}
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, """Extract play types from the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

result_list = []

for i in range(5):
  timestamps = extract_timestamps()
  timestamps_json = json.loads(timestamps.text)
  print(timestamps_json)
  response = identify_plays(timestamps_json)
  result = evaluate_play_extraction(json.loads(response.text), ground_truth)
  result_list.append(result)
  print(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results, filename_prefix="extract_timestampes_then_id_plays")


[{'end_time': '04:05', 'start_time': '03:49'}, {'end_time': '05:07', 'start_time': '05:00'}, {'end_time': '05:18', 'start_time': '05:10'}, {'end_time': '05:43', 'start_time': '05:35'}, {'end_time': '06:24', 'start_time': '06:17'}, {'end_time': '07:16', 'start_time': '07:08'}, {'end_time': '07:37', 'start_time': '07:29'}, {'end_time': '07:46', 'start_time': '07:37'}, {'end_time': '08:17', 'start_time': '08:08'}, {'end_time': '08:20', 'start_time': '08:17'}, {'end_time': '08:58', 'start_time': '08:44'}, {'end_time': '09:16', 'start_time': '09:00'}, {'end_time': '09:36', 'start_time': '09:22'}, {'end_time': '09:47', 'start_time': '09:36'}, {'end_time': '10:00', 'start_time': '10:00'}, {'end_time': '10:19', 'start_time': '10:10'}, {'end_time': '10:21', 'start_time': '10:19'}, {'end_time': '10:30', 'start_time': '10:21'}]
{'accuracy': 0.3333333333333333, 'precision': 0.4666666666666667, 'recall': 0.5833333333333334, 'true_positives': 7, 'false_positives': 8, 'false_negatives': 5, 'correct_p

In [31]:
# Describe each of the play types

def extract_key_moments_plus_description():
    vertexai.init(project="cloud-llm-preview1", location="us-central1")

    video1 = Part.from_uri(
        mime_type="video/mp4",
        uri="gs://lukasgeiger-fubo-demo-data/LIVE_97161119_TEN_MINUTES.mp4",
    )

    textsi_1 = """
    Your job is to extract segments from the sport video.
    Analyze the entire video.
    Be as precise and accurate as possible.

    Sport Type: American Football

    If you are unsure of a segment's play_type, choose 'unsure' as play_type.
    You are allowed to skip over timestamps if they are non_plays.

    **Play_types and their decriptions:**
      - rush: A running play where a player (usually the running back) carries the ball forward to gain yardage.
      - pass: The quarterback throws the ball forward to a receiver downfield.
      - handoff: The quarterback hands the ball off to another player (usually the running back) to start a running play.
      - field_goal: A kick attempt where the team tries to score 3 points by kicking the ball through the upright goal posts.
      - punt: When a team kicks the ball to the opposing team to give up possession (usually on fourth down).
      - extra_point: A kick attempt after a touchdown to score 1 additional point.
      - kickoff: The kick that starts each half and follows every touchdown and successful field goal.
      - kickoff_return: The receiving team attempts to advance the ball after a kickoff.
      - punt_return: The receiving team attempts to advance the ball after a punt.
      - penalty: An infraction of the rules, resulting in a yardage penalty against the offending team. Look for a yellow flag and a referee in a striped shirt announcing the penalty.
      - conversion: An attempt to score 2 points after a touchdown by running or passing the ball into the end zone.
      - non_play: This refers to moments when the ball is not in play, such as timeouts, commercial breaks, between quarters, and replays.
      - touchdown: A scoring play worth 6 points, achieved by carrying or catching the ball in the opponent's end zone.
      - unsure: This would be used when the play type cannot be confidently determined

    **Format the Output:**
    The output should be a list of JSON objects where each object represents a key moment, with the following fields:
        - **start_time:** Timecode (MM:SS format) indicating when the event begins in the video.
        - **end_time:** Timecode (MM:SS format) indicating when the event ends in the video.
        - **play_type:** The type of play as per the explanations above, which should be one of the following:
          ['rush', 'pass', 'punt', 'handoff', 'field_goal', 'penalty', 'extra_point', 'kickoff', 'kickoff_return', 'punt_return', 'conversion', 'non_play', 'touchdown', 'unsure']

    **Example Output Format:**
    [
      {
        "start_time": "01:15",
        "end_time": "01:45",
        "play_type": "rush"
      },
      {
        "start_time": "02:30",
        "end_time": "03:00",
        "play_type": "pass"
      },
      {
        "start_time": "10:20",
        "end_time": "10:40",
        "play_type": "non_play"
      }
    ]"""

    response_schema = {
        "type": "ARRAY",
        "items":{
            "type": "OBJECT",
            "properties": {
                "start_time": {
                  "type": "STRING",
                },
                "end_time": {
                  "type": "STRING",
                },
                "play_type": {
                  "type": "STRING",
                }
              },
              "required": [
                "start_time",
                "end_time",
                "play_type"
              ],
        }
    }

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
        "response_mime_type": "application/json",
        "response_schema": response_schema,
    }

    safety_settings = [
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
        SafetySetting(
            category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=SafetySetting.HarmBlockThreshold.OFF
        ),
    ]

    model = GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=[textsi_1]
    )

    response = model.generate_content(
        [video1, """Extract plays from the video."""],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
    )

    return response

result_list = []

for i in range(5):
  response = extract_key_moments_plus_description()
  result = evaluate_play_extraction(json.loads(response.text), ground_truth)
  result_list.append(result)

aggregated_results = aggregate_results(result_list)

# Print the aggregated results
print(f"Average Accuracy: {aggregated_results['average_accuracy']:.4f}")
print(f"Average Precision: {aggregated_results['average_precision']:.4f}")
print(f"Average Recall: {aggregated_results['average_recall']:.4f}")
print(f"Total True Positives: {aggregated_results['total_true_positives']}")
print(f"Total False Positives: {aggregated_results['total_false_positives']}")
print(f"Total False Negatives: {aggregated_results['total_false_negatives']}")

# Access the aggregated DataFrames
print("\nAggregated Correct Plays:")
print(aggregated_results['aggregated_correct_plays'])
print("\nAggregated Incorrect Plays:")
print(aggregated_results['aggregated_incorrect_plays'])

save_results_to_csv(aggregated_results=aggregated_results,filename_prefix="extract_key_moments_plus_description")


Average Accuracy: 0.4101
Average Precision: 0.5095
Average Recall: 0.7038
Total True Positives: 43
Total False Positives: 43
Total False Negatives: 18

Aggregated Correct Plays:
   LLM Start Time LLM End Time   LLM Play Type GT Start Time GT End Time  \
0           03:50        04:05  kickoff_return         03:53       03:58   
1           04:05        04:22         penalty         04:10       04:22   
2           04:56        05:00            pass         04:56       05:08   
3           05:00        05:07            pass         04:56       05:08   
4           06:18        06:24            pass         06:09       06:26   
5           06:47        06:53         penalty         06:48       06:56   
6           07:23        07:37     punt_return         07:29       07:35   
7           07:37        07:55         penalty         07:48       07:59   
8           09:22        09:36            pass         09:25       09:37   
9           10:01        10:18            pass         10:10  

In [None]:
# Add confidence thresholds
