In [1]:
%matplotlib widget

# Graphing
import matplotlib.pyplot as plt
import numpy as np
from abc import ABC, abstractmethod
from scipy.stats import norm
from typing import Literal, Callable

# Game and Graph
import ipywidgets as widgets
from IPython.display import display, clear_output
import json
import os
import threading
import time
import random


In [2]:
INDEPENDENT_VARIABLES = ["strike_price", "underlying_price", "time", "interest_rate", "volatility"] # Independent variables for the Black-Scholes model
DEPENDENT_VARIABLES = ["option_price", "delta", "gamma", "theta", "vega", "rho", "vanna", "charm", "vomma"] # Greeks

N = norm.cdf
phi = norm.pdf
at_the_forward_pricing = False  # Whether the option is at-the-forward (S = K * exp(-r * T))

def d1(S, K, T, r, sigma):
    return (np.log(S/K) + (r + 0.5 * sigma**2) * T) / (sigma * np.sqrt(T))

def d2(S, K, T, r, sigma):
    return d1(S, K, T, r, sigma) - sigma * np.sqrt(T)

def bs_call(S, K, T, r, sigma):
    return S * N(d1(S, K, T, r, sigma)) - K * np.exp(-r*T)* N(d2(S, K, T, r, sigma))

def bs_put(S, K, T, r, sigma):
    return K*np.exp(-r*T)*N(-d2(S, K, T, r, sigma)) - S*N(-d1(S, K, T, r, sigma))

def delta_call(S, K, T, r, sigma):
    return N(d1(S, K, T, r, sigma))

def delta_put(S, K, T, r, sigma):
    return N(d1(S, K, T, r, sigma)) - 1

def gamma(S, K, T, r, sigma):
    return phi(d1(S, K, T, r, sigma)) / (S * sigma * np.sqrt(T))

def theta_call(S, K, T, r, sigma):
    term1 = - (S * phi(d1(S, K, T, r, sigma)) * sigma) / (2 * np.sqrt(T))
    term2 = - r * K * np.exp(-r*T) * N(d2(S, K, T, r, sigma))
    return term1 + term2

def theta_put(S, K, T, r, sigma):
    term1 = - (S * phi(d1(S, K, T, r, sigma)) * sigma) / (2 * np.sqrt(T))
    term2 = r * K * np.exp(-r*T) * N(-d2(S, K, T, r, sigma))
    return term1 + term2

def vega(S, K, T, r, sigma):
    return S * phi(d1(S, K, T, r, sigma)) * np.sqrt(T)

def rho_call(S, K, T, r, sigma):
    return K * T * np.exp(-r*T) * N(d2(S, K, T, r, sigma))

def rho_put(S, K, T, r, sigma):
    return -K * T * np.exp(-r*T) * N(-d2(S, K, T, r, sigma))

def vanna(S, K, T, r, sigma):
    return phi(d1(S, K, T, r, sigma)) * d2(S, K, T, r, sigma) / sigma

def charm_call(S, K, T, r, sigma):
    return -phi(d1(S, K, T, r, sigma)) * (2 * r * T - d2(S, K, T, r, sigma) * sigma * np.sqrt(T)) / (2 * T * sigma * np.sqrt(T))

def charm_put(S, K, T, r, sigma):
    return -phi(d1(S, K, T, r, sigma)) * (2 * r * T + d2(S, K, T, r, sigma) * sigma * np.sqrt(T)) / (2 * T * sigma * np.sqrt(T))

def vomma(S, K, T, r, sigma):
    return vega(S, K, T, r, sigma) * d1(S, K, T, r, sigma) * d2(S, K, T, r, sigma) / sigma

def greek_factory(dv, option_type: Literal["call", "put"]) -> Callable:
    if dv == "option_price":
        if option_type == "call":
            return bs_call
        elif option_type == "put":
            return bs_put
        else:
            raise ValueError("option_type must be 'call' or 'put'")
    elif dv == "delta":
        if option_type == "call":
            return delta_call
        elif option_type == "put":
            return delta_put
        else:
            raise ValueError("option_type must be 'call' or 'put'")
    elif dv == "gamma":
        return gamma
    elif dv == "theta":
        if option_type == "call":
            return theta_call
        elif option_type == "put":
            return theta_put
        else:
            raise ValueError("option_type must be 'call' or 'put'")
    elif dv == "vega":
        if option_type == "call":
            return vega
        elif option_type == "put":
            return vega
        else:
            raise ValueError("option_type must be 'call' or 'put'")
    elif dv == "rho":
        if option_type == "call":
            return rho_call
        elif option_type == "put":
            return rho_put
        else:
            raise ValueError("option_type must be 'call' or 'put'")
    elif dv == "vanna":
        return vanna
    elif dv == "charm":
        if option_type == "call":
            return charm_call
        elif option_type == "put":
            return charm_put
        else:
            raise ValueError("option_type must be 'call' or 'put'")
    elif dv == "vomma":
        return vomma
    else:
        raise ValueError(f"Unsupported dependent variable: {dv}")

In [3]:

class Experiment(ABC):
    """
    Abstract base class for experiments.
    """
    IV: str
    DV: str
    equation: Callable
    x_values: np.ndarray
    y_values: np.ndarray

    option_type: Literal["call", "put"]

    at_the_forward = at_the_forward_pricing  # Whether the option is at-the-forward (S = K * exp(-r * T))

    K = [85, 100, 115]  # Strike prices for different experiments
    T = 1.0  # Time to expiration in years
    r = 0.05  # Interest rate
    sigma = 0.3  # Volatility
    S = 100 * np.exp(-r * T) if at_the_forward else 100  # Underlying asset price

    legend: np.ndarray

    def __init__(self, option_type: Literal["call", "put"],  dependent_variable: str, independent_variable: str):
        """
        Initialize the experiment with the given independent and dependent variables.
        """

        # Validate independent and dependent variables
        if independent_variable not in INDEPENDENT_VARIABLES:
            raise ValueError(f"Invalid independent variable: {independent_variable}. Must be one of {INDEPENDENT_VARIABLES}.")
        if dependent_variable not in DEPENDENT_VARIABLES:
            raise ValueError(f"Invalid dependent variable: {dependent_variable}. Must be one of {DEPENDENT_VARIABLES}.")
        
        self.IV = independent_variable
        self.DV = dependent_variable
        self.option_type = option_type
        self.equation = greek_factory(dependent_variable, option_type)
        self.legend = ["ITM", "ATM", "OTM"] if option_type == "call" else ["OTM", "ATM", "ITM"]
        if independent_variable == "strike_price":
            self.legend = ["Option Price"]

    @abstractmethod
    def solve(self):
        """
        Solve the experiment with the given independent and dependent variables. Store the results in x_values and y_values.
        """
        pass

    def get_curve(self):
        """
        Return x and a single y curve.
        """
        return self.x_values, self.y_values[1] if len(self.y_values) > 1 else self.y_values[0]

    def plot_results(self):
        """
        Plot the results of the experiment.
        """
        plt.figure(figsize=(10, 6))
        
        for i in range(len(self.y_values)):
            plt.plot(self.x_values, self.y_values[i], label=self.legend[i])
    
        iv_label = self.IV.replace('_', ' ').title()
        dv_label = self.DV.replace('_', ' ').capitalize()
    
        plt.xlabel(iv_label)
        plt.ylabel(dv_label)
    
        # Determine if it's a multi-leg or single-leg plot
        is_multi = len(self.y_values) > 1
        opt_label = str(self.option_type).capitalize()
        title_base = f"{dv_label} vs {iv_label}"
        if is_multi:
            title = f"{title_base} for {opt_label} Combo"
        else:
            title = f"{title_base} for {opt_label} Option"
    
        plt.title(title)
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()


In [4]:
class StrikePriceExperiment(Experiment): # Vary strike price to see how it affects the Greeks
    def __init__(self, option_type: Literal["call", "put"], dependent_variable: str, independent_variable: str = "strike_price"):
        super().__init__(option_type, dependent_variable, independent_variable)
        self.x_values = np.linspace(1, 200, 199)  # Strike prices from 50 to 150
        self.y_values = []

    def solve(self):
        self.y_values.append([])
        for K in self.x_values:
            dependent_variable_sample = self.equation(self.S, K, self.T, self.r, self.sigma)
            self.y_values[-1].append(dependent_variable_sample)

class UnderlyingPriceExperiment(Experiment): # Vary underlying price to see how it affects the Greeks
    def __init__(self, option_type: Literal["call", "put"], dependent_variable: str, independent_variable: str = "underlying_price"):
        super().__init__(option_type, dependent_variable, independent_variable)
        self.x_values = np.linspace(1, 200, 199)  # Underlying prices from 50 to 150
        self.y_values = []

    def solve(self):
        for K in self.K:
            self.y_values.append([])
            for S in self.x_values:
                dependent_variable_sample = self.equation(S, K, self.T, self.r, self.sigma)
                self.y_values[-1].append(dependent_variable_sample)

class TimeExperiment(Experiment): # Vary time to see how it affects the Greeks
    def __init__(self, option_type: Literal["call", "put"], dependent_variable, independent_variable: str = "time"):
        super().__init__(option_type, dependent_variable, independent_variable)
        self.x_values = np.linspace(0.01, 1, 100)  # Time from 0.01 to 1 year
        self.y_values = []

    def solve(self):
        for K in self.K:
            self.y_values.append([])
            for T in self.x_values:
                dependent_variable_sample = self.equation(self.S, K, T, self.r, self.sigma)
                self.y_values[-1].append(dependent_variable_sample)

class InterestRateExperiment(Experiment): # Vary interest rate to see how it affects the Greeks
    def __init__(self, option_type: Literal["call", "put"], dependent_variable, independent_variable: str = "interest_rate"):
        super().__init__(option_type, dependent_variable, independent_variable)
        self.x_values = np.linspace(0.01, 1, 100)  # Interest rates from 0.01 to 0.1
        self.y_values = []

    def solve(self):
        for K in self.K:
            self.y_values.append([])
            for r in self.x_values:
                dependent_variable_sample = self.equation(self.S, K, self.T, r, self.sigma)
                self.y_values[-1].append(dependent_variable_sample)

class VolatilityExperiment(Experiment): # Vary volatility to see how it affects the Greeks
    def __init__(self, option_type: Literal["call", "put"], dependent_variable, independent_variable: str = "volatility"):
        super().__init__(option_type, dependent_variable, independent_variable)
        self.x_values = np.linspace(0.01, 1, 100)  # Volatility from 0.01 to 1
        self.y_values = []

    def solve(self):
        for K in self.K:
            self.y_values.append([])
            for sigma in self.x_values:
                dependent_variable_sample = self.equation(self.S, K, self.T, self.r, sigma)
                self.y_values[-1].append(dependent_variable_sample)


In [5]:
#############################################
#         INTERFACE FOR EXPERIMENTS         #
#############################################

# Static label to show the setting
forward_label = widgets.Label(
    value=f"at_the_forward_pricing: {at_the_forward_pricing}, current ATF value: 100"
)

# Dropdowns
option_type_widget = widgets.Dropdown(
    options=["call", "put"],
    value="call",
    description="Option Type:"
)

iv_widget = widgets.Dropdown(
    options=INDEPENDENT_VARIABLES,
    value="strike_price",
    description="Independent:"
)

dv_widget = widgets.Dropdown(
    options=DEPENDENT_VARIABLES,
    value="option_price",
    description="Dependent:"
)

generate_button = widgets.Button(
    description="Generate Plot",
    button_style='success'
)

output = widgets.Output()

# Function to run when button is clicked
def on_generate_click(b):
    with output:
        clear_output()
        
        option_type = option_type_widget.value
        iv = iv_widget.value
        dv = dv_widget.value
        
        # Choose correct experiment class
        experiment_class = {
            "strike_price": StrikePriceExperiment,
            "underlying_price": UnderlyingPriceExperiment,
            "time": TimeExperiment,
            "interest_rate": InterestRateExperiment,
            "volatility": VolatilityExperiment
        }[iv]
        
        try:
            experiment = experiment_class(option_type = option_type, dependent_variable=dv)
            experiment.solve()
            experiment.plot_results()
        except Exception as e:
            print(f"Error: {e}")

generate_button.on_click(on_generate_click)

# Display widgets
display(widgets.VBox([
    forward_label,
    option_type_widget,
    iv_widget,
    dv_widget,
    generate_button,
    output
]))


VBox(children=(Label(value='at_the_forward_pricing: False, current ATF value: 100'), Dropdown(description='Opt…

In [7]:
import os, json, random, threading, time
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

# --- Stats setup ---
stats_file = "greek_game_stats.json"
if not os.path.exists(stats_file):
    stats = {"total_attempted": 0, "total_correct": 0, "combo_stats": {}}
    for iv in INDEPENDENT_VARIABLES:
        for dv in DEPENDENT_VARIABLES:
            stats["combo_stats"][f"{iv}->{dv}"] = {"correct": 0, "total": 0}
    with open(stats_file, "w") as f:
        json.dump(stats, f, indent=2)

with open(stats_file, "r") as f:
    stats = json.load(f)

# --- Game state ---
user_points = []
answer_x, answer_y = None, None
current_iv, current_dv, current_ot = None, None, None
instrument_label, strike_display = "", ""

# --- Widgets & outputs ---
plot_output = widgets.Output()
details_output = widgets.Output()
feedback_output = widgets.Output()
timer_label = widgets.Label(value="⏱️ Time remaining: 60s")

# Buttons for gameplay
submit_button = widgets.Button(description="Submit", button_style="primary")
next_button = widgets.Button(description="Next", button_style="info")

# Countdown control
countdown_thread = None
countdown_cancelled_flag = threading.Event()

# --- Define combos once ---
combo_defs = {
    'Straddle'             : [('call', 100,  1), ('put', 100,  1)],
    'Strangle'             : [('put', 80,  1), ('call', 120,  1)],
    'Vertical Call Spread' : [('call', 80,  1), ('call', 100, -1)],
    'Put Spread'           : [('put', 100, 1), ('put', 80, -1)],
    'Call Spread 1x2'      : [('call', 80,  1), ('call', 100, -2)],
    'Put Spread 1x2'       : [('put', 100, 1), ('put', 80, -2)],
    'Call Spread 1x3'      : [('call', 80,  1), ('call', 100, -3)],
    'Put Spread 1x3'       : [('put', 100, 1), ('put', 80, -3)],
    'Call Tree'            : [('call', 80,  1), ('call', 100, -2), ('call', 120, 1)],
    'Put Tree'             : [('put', 100, 1), ('put', 90, -2), ('put', 80, 1)],
    'Call Fly'             : [('call', 80,  1), ('call', 100, -2), ('call', 120, 1)],
    'Put Fly'              : [('put', 120, 1), ('put', 100, -2), ('put', 80, 1)],
    'Condor'               : [('call', 80,  1), ('call', 90,  -1), ('call', 110, -1), ('call', 120, 1)],
    'Iron Fly'             : [('call', 100, -1), ('put', 100, -1), ('call', 120, 1), ('put', 80, 1)],
    'Iron Condor'          : [('put', 80, 1), ('put', 90, -1), ('call', 110, -1), ('call', 120, 1)],
    'Fence'                : [('put',  80, 1), ('put',  100, -1), ('call', 100, -1), ('call', 120, 1)],
    'Risk Reversal'        : [('put',  80, -1), ('call', 120, 1)],
    'Box'                  : [('call', 80,  1),  ('call', 120, -1), ('put',  80, -1), ('put', 120, 1)],
}

# --- Settings panel ---
iv_selector      = widgets.SelectMultiple(
    options=INDEPENDENT_VARIABLES,
    value=tuple(INDEPENDENT_VARIABLES),
    description='IVs:'
)
dv_selector      = widgets.SelectMultiple(
    options=DEPENDENT_VARIABLES,
    value=tuple(DEPENDENT_VARIABLES),
    description='DVs:'
)
include_singles  = widgets.Checkbox(
    value=True,
    description='Include regular calls/puts'
)
combo_selector   = widgets.SelectMultiple(
    options=list(combo_defs.keys()),
    value=list(combo_defs.keys()),
    description='Combos:'
)
settings_panel   = widgets.VBox([
    widgets.Label('--- Settings Panel ---'),
    iv_selector,
    dv_selector,
    include_singles,
    combo_selector
])
display(settings_panel)

# --- Utility functions ---
def show_experiment_details():
    with details_output:
        clear_output()
        experiment_constants = {
            "strike_price": experiment.K,
            "underlying_price": round(experiment.S, 2),
            "volatility": experiment.sigma,
            "interest_rate": experiment.r,
            "time": experiment.T,
        }
        print("📋 Fixed Parameters:")
        print(f"• IV: {current_iv}")
        print(f"• DV: {current_dv}")
        print(f"• Option Type: {current_ot.capitalize()}")
        for key, value in experiment_constants.items():
            if key != current_iv:
                label = key.replace('_',' ').title()
                val = value if isinstance(value, list) else round(value, 4)
                print(f"• {label}: {val}")
        total   = stats.get("total_attempted", 0)
        correct = stats.get("total_correct", 0)
        pct     = (correct / total * 100) if total else 0
        print(f"\n📈 Your Stats: {correct}/{total} ({pct:.1f}%)")


def update_stats(iv, dv, correct):
    stats["total_attempted"] = stats.get("total_attempted", 0) + 1
    if correct:
        stats["total_correct"] = stats.get("total_correct", 0) + 1
    key = f"{iv}->{dv}"
    combo = stats["combo_stats"].get(key)
    if combo:
        combo["total"]   += 1
        if correct:
            combo["correct"] += 1
    with open(stats_file, 'w') as f:
        json.dump(stats, f, indent=2)


def reset_plot():
    global user_points
    user_points = []
    plot_output.clear_output()

# --- Main functions ---
def new_question():
    plt.close('all')
    global current_iv, current_dv, current_ot, answer_x, answer_y
    global instrument_label, strike_display, experiment

    reset_plot()
    user_points.clear()

    # Filter IVs and DVs from settings
    ivs = list(iv_selector.value)
    dvs = list(dv_selector.value)
    if not ivs or not dvs:
        return
    current_iv = random.choice(ivs)
    current_dv = random.choice(dvs)
    current_ot = random.choice(['call', 'put'])

    # Build instruments based on settings
    instruments = []
    if include_singles.value:
        for ot in ('call', 'put'):
            for K in [80, 100, 120]:
                instruments.append({
                    'label': f"{ot.capitalize()} {K}",
                    'leg_defs': [(ot, K, 1)]
                })
    for name in combo_selector.value:
        instruments.append({
            'label': name,
            'leg_defs': combo_defs[name]
        })
    if not instruments:
        return

    choice = random.choice(instruments)
    instrument_label = choice['label']
    leg_defs = choice['leg_defs']
    strike_display = "(" + "/".join(str(K) for (_ot, K, _qty) in leg_defs) + ")"

    # Solve each leg and weight by quantity
    exp_map = {
        "strike_price": StrikePriceExperiment,
        "underlying_price": UnderlyingPriceExperiment,
        "time": TimeExperiment,
        "interest_rate": InterestRateExperiment,
        "volatility": VolatilityExperiment
    }
    experiment_class = exp_map[current_iv]
    curves = []
    for leg_ot, leg_strike, leg_qty in leg_defs:
        experiment = experiment_class(option_type=leg_ot, dependent_variable=current_dv)
        experiment.K = [leg_strike]
        experiment.solve()
        x, y = experiment.get_curve()
        curves.append((x, leg_qty * np.array(y)))

    # Sum curves
    answer_x = curves[0][0]
    answer_y = np.sum([yy for (_xx, yy) in curves], axis=0)

    show_experiment_details()

    # Plot and interactive draw
    with plot_output:
        fig, ax = plt.subplots(figsize=(8, 4))
        ax.set_title(
            f"{instrument_label} {strike_display}: "
            f"{current_dv.capitalize()} vs {current_iv.replace('_', ' ').capitalize()}"
        )
        ax.set_xlabel(current_iv.replace('_', ' ').title())
        ax.set_ylabel(current_dv.capitalize())
        ax.set_xlim(answer_x.min(), answer_x.max())
        pad = (answer_y.max() - answer_y.min()) * 0.1
        ax.set_ylim(answer_y.min() - pad, answer_y.max() + pad)
        ax.grid(True)

        line, = ax.plot([], [], 'b-', linewidth=2)
        drawing = {'active': False}

        def on_press(event):
            if event.inaxes != ax:
                return
            drawing['active'] = True
            user_points.clear()
            user_points.append((event.xdata, event.ydata))
            line.set_data(*zip(*user_points))
            fig.canvas.draw_idle()

        def on_move(event):
            if drawing['active'] and event.xdata and event.ydata:
                user_points.append((event.xdata, event.ydata))
                line.set_data(*zip(*user_points))
                fig.canvas.draw_idle()

        def on_release(event):
            drawing['active'] = False

        fig.canvas.mpl_connect('button_press_event', on_press)
        fig.canvas.mpl_connect('motion_notify_event', on_move)
        fig.canvas.mpl_connect('button_release_event', on_release)
        plt.show()

    start_countdown()


def evaluate_guess(b):
    if not user_points:
        with feedback_output:
            clear_output()
            print("⚠️ Please draw a curve first.")
        return
    user_x, user_y = zip(*user_points)
    user_interp = np.interp(answer_x, user_x, user_y, left=0, right=0)
    mse = np.mean((user_interp - answer_y) ** 2)
    y_range = np.max(answer_y) - np.min(answer_y) + 0.2
    rmse = mse / ((y_range ** 2) + 1e-8)
    correct = rmse < 0.1  # looser threshold

    update_stats(current_iv, current_dv, correct)

    with feedback_output:
        clear_output()
        print(f"{'✅ Correct!' if correct else '❌ Try again.'} Normalized RMSE = {rmse:.4f}")
        print("Explain below or click Next.")

    # Overlay correct curve
    with plot_output:
        clear_output()
        fig, ax = plt.subplots(figsize=(8, 4))
        ax.set_title(
            f"{instrument_label} {strike_display}: "
            f"{current_dv.capitalize()} vs {current_iv.replace('_', ' ').capitalize()}"
        )
        ax.set_xlabel(current_iv.replace('_', ' ').title())
        ax.set_ylabel(current_dv.capitalize())
        ax.set_xlim(answer_x.min(), answer_x.max())
        pad = (answer_y.max() - answer_y.min()) * 0.1
        ax.set_ylim(min(answer_y.min(), min(user_y)) - pad, max(answer_y.max(), max(user_y)) + pad)
        ax.grid(True)
        ax.plot(answer_x, answer_y, 'r--', label='Correct Answer')
        ax.plot(user_x, user_y, 'b-', label='Your Guess')
        ax.legend()
        plt.show()

# --- Hookups ---
submit_button.on_click(evaluate_guess)
next_button.on_click(lambda b: new_question())


def start_countdown():
    global countdown_thread
    timer_label.value = "⏱️ Time remaining: 60s"
    countdown_cancelled_flag.set()
    if countdown_thread and countdown_thread.is_alive():
        countdown_thread.join(timeout=1)
    countdown_cancelled_flag.clear()
    def countdown():
        for remaining in range(60, -1, -1):
            if countdown_cancelled_flag.is_set():
                return
            timer_label.value = f"⏱️ Time remaining: {remaining}s"
            time.sleep(1)
    countdown_thread = threading.Thread(target=countdown)
    countdown_thread.start()

# --- Display UI ---
display(widgets.HBox([
    widgets.VBox([widgets.HBox([submit_button, next_button]), timer_label, plot_output, details_output, feedback_output]),
    widgets.VBox([settings_panel])
]))

# Start first question
new_question()


VBox(children=(Label(value='--- Settings Panel ---'), SelectMultiple(description='IVs:', index=(0, 1, 2, 3, 4)…

HBox(children=(VBox(children=(HBox(children=(Button(button_style='primary', description='Submit', style=Button…