In [1]:
from typing import Any, Dict, List

import os
import re
import json

import pandas as pd

from copy import deepcopy
from collections import defaultdict

In [2]:
graph = {}
with open(f'../../data/mlb/pbp/computes/team_event_graph.json', 'r', encoding='UTF8') as pbp_input:
    graph = json.load(pbp_input)

In [3]:
def compute_likelihoods(graph):
    graph_copy = deepcopy(graph)
    
    for team in graph_copy.keys():
        for out in graph_copy[team].keys():
            for state in graph_copy[team][out].keys():
                runs = graph_copy[team][out][state]['runs']
                total = sum(graph_copy[team][out][state]['types'].values())
                graph_copy[team][out][state] = 0.0 if runs == 0 else round(float(runs) / total, 3)
    
    return graph_copy

likelihood_graph = compute_likelihoods(graph)
likelihood_graph['MIN']

{'0': {'---': 0.028,
  '--3': 0.385,
  '-2-': 0.119,
  '-23': 0.565,
  '1--': 0.107,
  '1-3': 0.759,
  '12-': 0.198,
  '123': 0.95},
 '1': {'---': 0.025,
  '--3': 0.447,
  '-2-': 0.169,
  '-23': 0.688,
  '1--': 0.116,
  '1-3': 0.585,
  '12-': 0.288,
  '123': 0.765},
 '2': {'---': 0.029,
  '--3': 0.186,
  '-2-': 0.146,
  '-23': 0.328,
  '1--': 0.095,
  '1-3': 0.386,
  '12-': 0.214,
  '123': 0.63}}

In [4]:
def flatten_graph(graph):
    records = []
    for team in graph.keys():
        for outs in graph[team].keys():
            row = {
                'team': team,
                'outs': outs,
            }

            row.update(
                graph[team][outs]
            )

            records.append(row)

    return records

def flatten_full_graph(graph):
    records = []
    for team in graph.keys():
        for outs in graph[team].keys():
            for bases in graph[team][outs].keys():
                row = {
                    'team': team,
                    'outs': outs,
                    'bases': bases
                }

                row.update(
                    graph[team][outs][bases]['types']
                )

                records.append(row)

    return records

In [5]:
column_renames = [
    ('Bunt Groundout', 'Groundout'),
    ('Bunt Lineout', 'Lineout'),
    ('Foul Popfly', 'Flyball'),
    ('Foul Bunt Popfly', 'Flyball'),
    ('Bunt Popfly', 'Flyball'),
    ('Popfly', 'Flyball'),
    ('Ground-rule Double', 'Double'),
    ('Intentional Walk', 'Walk'),
    ('Inside-the-park Home Run', 'Home Run'),
]

In [6]:
df = pd.DataFrame(flatten_full_graph(graph))

for f, t in column_renames:
    df[t] += df[f]
    df = df.drop(columns=[f])

df = df.sort_values(['team', 'outs', 'bases'])
df

Unnamed: 0,team,outs,bases,Double,Error,Flyball,Groundout,Hit By Pitch,Home Run,Lineout,...,Double Play,Interference,Steals,Picked off,Bunt Interference,Interference by Batter,Out Advancing,Wild Pitch,Balk,Foul Interference
0,ARI,0,---,,15.0,,338.0,18.0,,79.0,...,,,,,,,,,,
1,ARI,0,--3,,,,,,,,...,,,,,,,,,,
2,ARI,0,-2-,,2.0,,26.0,,,,...,,,,,,,,,,
3,ARI,0,-23,,1.0,,,1.0,,,...,,,,,,,,,,
4,ARI,0,1--,,1.0,,61.0,3.0,,,...,26.0,2.0,2.0,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
715,WSN,2,-23,,,,,1.0,,,...,,,,,,,,,,
716,WSN,2,1--,,5.0,,,4.0,,,...,,2.0,,2.0,,,,,,
717,WSN,2,1-3,,,,,,,,...,,,,,,,,,,
718,WSN,2,12-,,2.0,,,1.0,,,...,,,,,,,,,,


In [7]:
df.columns

Index(['team', 'outs', 'bases', 'Double', 'Error', 'Flyball', 'Groundout',
       'Hit By Pitch', 'Home Run', 'Lineout', 'Single', 'Strikeout', 'Triple',
       'Walk', 'Fielder's Choice', 'Caught Stealing', 'Double Play',
       'Interference', 'Steals', 'Picked off', 'Bunt Interference',
       'Interference by Batter', 'Out Advancing', 'Wild Pitch', 'Balk',
       'Foul Interference'],
      dtype='object')

In [8]:
df = pd.DataFrame(flatten_graph(likelihood_graph)).sort_values(['team', 'outs']).fillna('-')
df

Unnamed: 0,team,outs,---,--3,-2-,-23,1--,1-3,12-,123
0,ARI,0,0.026,0.350,0.123,0.583,0.071,0.615,0.197,0.793
1,ARI,1,0.039,0.442,0.140,0.727,0.078,0.609,0.207,1.055
2,ARI,2,0.035,0.189,0.191,0.544,0.079,0.172,0.323,0.500
3,ATL,0,0.042,0.500,0.183,0.435,0.099,0.652,0.224,1.250
4,ATL,1,0.032,0.442,0.235,0.588,0.096,0.623,0.358,0.903
...,...,...,...,...,...,...,...,...,...,...
85,TOR,1,0.033,0.529,0.134,0.612,0.085,0.656,0.335,0.647
86,TOR,2,0.026,0.190,0.189,0.356,0.079,0.225,0.368,0.853
87,WSN,0,0.024,0.500,0.116,0.250,0.053,0.556,0.188,0.615
88,WSN,1,0.025,0.417,0.112,0.508,0.056,0.609,0.172,0.714
