Try out Picklebot, an AI umpire trained to call balls and strikes! This interactive demo allows you to pick pitch types and see Picklebot's predictions live from MLB game footage.

To use it, simply select a pitch from the dropdown and run the code cell. Picklebot will analyze the video clip and display the ball/strike call.

Under the hood, Picklebot is a deep learning model trained on over 50,000 labeled pitches from MLB games. It uses adapted architectures from mobile neural networks like MobileNetV3 and MoViNet reconfigured to work on video instead of photos. Focusing these efficient architectures on pitch classification achieves 80% accuracy in calling balls and strikes.

The training dataset was created by downloading archival MLB footage and manually labeling each pitch classification.

This demo lets you evaluate Picklebot's pitch calling abilities first-hand. Check out the code and model details to learn more about how AI can replicate human umpire decisions! Let me know if you have any other questions.

In [None]:
#collapse

!git clone -q https://github.com/hbfreed/Picklebot
import torch
from IPython.display import HTML
from Picklebot.mobilenet import MobileNetSmall3D
from mobilenet import MobileNetSmall3D
!pip install -q av
from torchvision.io import read_video
!pip install -q mediapy
import mediapy

#@title Run this cell (press the 'play' button) to run Picklebot for your choice of video

def classify_pitch(confidence_scores):
    if torch.argmax(confidence_scores) == 0:
        call = 'Ball'
    elif torch.argmax(confidence_scores) == 1:
        call = 'Strike'
    else:
        print("that's odd, something is wrong")
        pass
    return call


ball_video = 'Picklebot/demo_files/clip_7765.mp4'
strike_video = 'Picklebot/demo_files/clip_53102.mp4'

pitch_choice = 'Ball' #@param ['Ball', 'Strike']

choice_map = {
    'Ball':ball_video,
    'Strike':strike_video,
}

pitch = choice_map[pitch_choice]

video = mediapy.read_video(pitch)
mediapy.show_video(video,width=600)

pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255
# load the model, load the model's weights
model = MobileNetSmall3D()
model.load_state_dict(torch.load('models/mobilenet_small.pth',map_location=torch.device('cpu')))

# run the model
model.eval()
output = model(pitch_tensor)
call = classify_pitch(output)

HTML(f"<h1>{call}!<h1>")