# Jane Street Puzzle 2025 October: Robot Baseball

https://www.janestreet.com/puzzles/current-puzzle/

The Artificial Automaton Athletics Association (Quad-A) is at it again, to compete with postseason baseball they are developing a Robot Baseball competition. Games are composed of a series of independent at-bats in which the batter is trying to maximize expected score and the pitcher is trying to minimize expected score.

An at-bat is a series of pitches with a running count of balls and strikes, both starting at zero. For each pitch, the pitcher decides whether to throw a ball or strike, and the batter decides whether to wait or swing; these decisions are made secretly and simultaneously. The results of these choices are as follows.

- If the pitcher throws a ball and the batter waits, the count of balls is incremented by 1.
- If the pitcher throws a strike and the batter waits, the count of strikes is incremented by 1.
- If the pitcher throws a ball and the batter swings, the count of strikes is incremented by 1.
- If the pitcher throws a strike and the batter swings, with probability p the batter hits a home run and with probability 1-p the count of strikes is incremented by 1.

An at-bat ends when either:

- The count of balls reaches 4, in which case the batter receives 1 point.
- The count of strikes reaches 3, in which case the batter receives 0 points.
- The batter hits a home run, in which case the batter receives 4 points.

By varying the size of the strike zone, Quad-A can adjust the value p, the probability a pitched strike that is swung at results in a home run. They have found that viewers are most excited by at-bats that reach a full count, that is, the at-bats that reach the state of three balls and two strikes. Let q be the probability of at-bats reaching full count; q is dependent on p. Assume the batter and pitcher are both using optimal mixed strategies and Quad-A has chosen the p that maximizes q. Find this q, the maximal probability at-bats reach full count, to ten decimal places.

## Analysis

### Optimal Mixed Strategies

https://www3.nd.edu/~apilking/math10170/information/Lectures/16%20Optimal%20Mixed%20Strategy.pdf

A mixed strategy is when a player has a probability distribution over their possible moves. We can therefore in general define the expected outcome for players A and B:

$
\begin{align}
    E &= P(A=0) P(B=0) E(A=0,B=0) \\
      &+ P(A=0) P(B=1) E(A=0,B=1) \\
      &+ P(A=1) P(B=0) E(A=1,B=0) \\
      &+ P(A=1) P(B=1) E(A=1,B=1) \\
      &= P(A=0) P(B=0) E(A=0,B=0) \\
      &+ P(A=0) \left( 1 - P(B=0) \right) E(A=0,B=1) \\
      &+ \left( 1 - P(A=0) \right) P(B=0) E(A=1,B=0) \\
      &+ \left( 1 - P(A=0) \right) \left( 1 - P(B=0) \right) E(A=1,B=1)
\end{align}
$

Note that a pure strategy can be considered a special case of a mixed stategy in which only one move has probability equal to 1.


A mixed strategy is optimal when both players cannot change their strategies and expect a better outcome: i.e. it is a fixed point.

$
\begin{align} 
    \frac{\partial E}{\partial P(A=0)} &= P(B=0) E(A=0,B=0) \\
                                       &+ \left( 1 - P(B=0) \right) E(A=0,B=1) \\
                                       &- P(B=0) E(A=1,B=0) \\
                                       &- \left( 1 - P(B=0) \right) E(A=1,B=1) \\
                                       &= 0
\end{align}
$

Therefore:

$
P(B=0) = \frac{E(A=1,B=1) - E(A=0,B=1)}{E(A=0,B=0) - E(A=0,B=1) - E(A=1,B=0) + E(A=1,B=1)}
$


### Defining the Robot Baseball Expected Outcome Recursively

We can define the expected outcome recursively as a function of the number of balls and strikes accumulated so far.

$
\begin{align}
    E(b,s) &= P_{\mathrm{ball}} P_{\mathrm{wait}} E(b+1,s) \\
     &+ P_{\mathrm{ball}} \left( 1 - P_{\mathrm{wait}} \right) E(b,s+1) \\
     &+ \left( 1 - P_{\mathrm{ball}} \right) P_{\mathrm{wait}} E(b,s+1) \\
     &+ \left( 1 - P_{\mathrm{ball}} \right) \left( 1 - P_{\mathrm{wait}} \right) (4p + (1-p)E(b,s+1))
\end{align}
$

$
P_{\mathrm{ball}} = \frac{p (4 - E(b,s+1))}{E(b+1,s) - (3 - p) E(b,s+1) + 4p}
$


## Implementation

### Function Definition

We can define a function to output the value of $q$ with input $p$ by solving the optimal mixed strategies for each game state, and then recusively propagating the probabilities to end states, and finally taking the probability of the full-count state, i.e. $q$.

In [None]:
import functools

In [None]:
from collections import deque

In [None]:
bmax = 4
smax = 3

def q(p):
    strat_cache = dict()
    Es = dict()
    @functools.cache
    def solve_strats(b=0, s=0):
        if s == smax:
            Es[(b,s)] = 0
            return 0
        if b == bmax:
            Es[(b,s)] = 1
            return 1
        
        E_ball_wait = solve_strats(b=b+1, s=s)
        E_ball_swing = solve_strats(b=b, s=s+1)
        E_strike_wait = solve_strats(b=b, s=s+1)
        E_strike_swing = 4 * p + (1 - p) * solve_strats(b=b, s=s+1)
        
        P_ball_numerator = E_strike_swing - E_strike_wait
        P_ball_denominator = E_ball_wait - E_strike_wait - E_ball_swing + E_strike_swing
        if P_ball_denominator == 0:
            print('ISSUE P_ball_denominator!')
        P_ball = P_ball_numerator / P_ball_denominator
        
        P_wait_numerator = E_strike_swing - E_ball_swing
        P_wait_denominator = E_ball_wait - E_ball_swing - E_strike_wait + E_strike_swing
        if P_wait_denominator == 0:
            print('ISSUE P_wait_denominator!')
        P_wait = P_wait_numerator / P_wait_denominator
        
        E_opt = P_ball * P_wait * E_ball_wait \
            + P_ball * (1 - P_wait) * E_ball_swing \
            + (1 - P_ball) * P_wait * E_strike_wait \
            + (1 - P_ball) * (1 - P_wait) * E_strike_swing
        
        strat_cache[(b,s)] = (P_ball, P_wait)
        
        Es[(b,s)] = E_opt
        return E_opt
    
    solve_strats()
    #print(strat_cache)
    #print(Es)
    
    # solve_probs
    q = deque([(0,0)])
    Ps = {q[0]: 1}
    
    while q:
        b,s = q.pop()
        P = Ps[(b,s)]
        #if b == 3 and s == 2:
        #    return P
        if (b,s) not in strat_cache:
            continue
        P_ball, P_wait = strat_cache[(b,s)]
        
        if (b+1,s) not in Ps:
            Ps[(b+1,s)] = 0
            q.appendleft((b+1,s))
        Ps[(b+1,s)] += P_ball * P_wait * P
        
        if (b,s+1) not in Ps:
            Ps[(b,s+1)] = 0
            q.appendleft((b,s+1))
        Ps[(b,s+1)] += (P_ball*(1-P_wait) + (1-P_ball)*P_wait + (1-P_ball)*(1-P_wait)*(1 - p)) * P
    #print(Ps)
    
    return Ps[(bmax-1,smax-1)]

### Solving the Optimization Problem

- https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html

In [None]:
from scipy.optimize import minimize_scalar

In [None]:
res = minimize_scalar(lambda x: -q(x), bounds=(0,1), options={'xatol':1e-11})

In [None]:
print(f"p = {res.x}, q = {-res.fun:.10f}")

### Solver-Base Alternate Implementation

Besides taking longer, it seems no matter how low I set the tolerance (only have absolute x-tolerance, not relative f-tolerance available by scipy), I'm not getting very good convergence. I think the propagation of the error must be rather large!

In [None]:
bmax = 4
smax = 3

def q_opt(p):
    strat_cache = dict()
    Es = dict()
    @functools.cache
    def solve_strats(b=0, s=0):
        if s == smax:
            Es[(b,s)] = 0
            return 0
        if b == bmax:
            Es[(b,s)] = 1
            return 1
        E_ball_wait = solve_strats(b=b+1, s=s)
        E_ball_swing = solve_strats(b=b, s=s+1)
        E_strike_wait = solve_strats(b=b, s=s+1)
        E_strike_swing = 4 * p + (1 - p) * solve_strats(b=b, s=s+1)
        
        
        def E_pitcher_min(P_ball):
            def E_batter_max(P_wait):
                return P_ball * P_wait * E_ball_wait \
                    + P_ball * (1 - P_wait) * E_ball_swing \
                    + (1 - P_ball) * P_wait * E_strike_wait \
                    + (1 - P_ball) * (1 - P_wait) * E_strike_swing
            res = minimize_scalar(lambda x: -E_batter_max(x), bounds=(0,1))
            return -res.fun
        res = minimize_scalar(E_pitcher_min, bounds=(0,1), options={'xatol':1e-20})
        P_ball = res.x
        
        def E_batter_max(P_wait):
            def E_pitcher_min(P_ball):
                return P_ball * P_wait * E_ball_wait \
                    + P_ball * (1 - P_wait) * E_ball_swing \
                    + (1 - P_ball) * P_wait * E_strike_wait \
                    + (1 - P_ball) * (1 - P_wait) * E_strike_swing
            res = minimize_scalar(E_pitcher_min, bounds=(0,1), options={'xatol':1e-20})
            return res.fun
        res = minimize_scalar(lambda x: -E_batter_max(x), bounds=(0,1))
        P_wait = res.x
        
        E_opt = P_ball * P_wait * E_ball_wait \
            + P_ball * (1 - P_wait) * E_ball_swing \
            + (1 - P_ball) * P_wait * E_strike_wait \
            + (1 - P_ball) * (1 - P_wait) * E_strike_swing
        
        strat_cache[(b,s)] = (P_ball, P_wait)
        
        Es[(b,s)] = E_opt
        return E_opt
    
    solve_strats()
    #print(strat_cache)
    #print(Es)
    
    # solve_probs
    q = deque([(0,0)])
    Ps = {q[0]: 1}
    
    while q:
        b,s = q.pop()
        P = Ps[(b,s)]
        #if b == 3 and s == 2:
        #    return P
        if (b,s) not in strat_cache:
            continue
        P_ball, P_wait = strat_cache[(b,s)]
        
        if (b+1,s) not in Ps:
            Ps[(b+1,s)] = 0
            q.appendleft((b+1,s))
        Ps[(b+1,s)] += P_ball * P_wait * P
        
        if (b,s+1) not in Ps:
            Ps[(b,s+1)] = 0
            q.appendleft((b,s+1))
        Ps[(b,s+1)] += (P_ball*(1-P_wait) + (1-P_ball)*P_wait + (1-P_ball)*(1-P_wait)*(1 - p)) * P
    #print(Ps)
    return Ps[(bmax-1,smax-1)]

In [None]:
res = minimize_scalar(lambda x: -q_opt(x), bounds=(0,1), options={'xatol':1e-16})

In [None]:
print(f"p = {res.x}, q = {-res.fun:.10f}")

## Visualization

### q(p)

In [None]:
import matplotlib.pyplot as plt

In [None]:
import numpy as np

In [None]:
x = np.linspace(0,1,100)

plt.plot(x, [q(ix) for ix in x])
plt.xlabel("p")
plt.ylabel("q")
plt.show()

### E(P_ball, P_wait) for a single round

In [None]:
p = 0.22697323584760093

E_strike = 0
E_ball = 1

E_ball_wait = E_ball
E_ball_swing = E_strike
E_strike_wait = E_strike
E_strike_swing = 4 * p + (1 - p) * E_strike

P_ball = np.linspace(0, 1, 20)
P_wait = np.linspace(0, 1, 20)
P_ball, P_ball = np.meshgrid(P_ball, P_wait)

E = P_ball * P_wait * E_ball_wait \
    + P_ball * (1 - P_wait) * E_ball_swing \
    + (1 - P_ball) * P_wait * E_strike_wait \
    + (1 - P_ball) * (1 - P_wait) * E_strike_swing

# Plot the surface.
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(P_ball, P_wait, E)

ax.set_xlabel("$P_{ball}$")
ax.set_ylabel("$P_{wait}$")
ax.set_zlabel("$E$")

plt.show()