# Transmission Shift Point Calculator

Adjust vehicle parameters and shift RPM to visualize gear speeds and RPM behavior.  
**Run all cells** (Cell → Run All), then use the controls below.

In [8]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import ipywidgets as widgets
from IPython.display import display, clear_output

In [None]:
MIN_RPM = 1000

def speed_from_rpm(rpm, gear_ratio, final_drive, tire_dia):
    """Calculate speed (mph) from RPM for a given gear."""
    return rpm * np.pi * tire_dia / (gear_ratio * final_drive * 1056)

def rpm_from_speed(speed, gear_ratio, final_drive, tire_dia):
    """Calculate RPM from speed (mph) for a given gear."""
    return speed * gear_ratio * final_drive * 1056 / (np.pi * tire_dia)

def calculate_shift_line(gear_ratios, final_drive, tire_dia, shift_rpm, max_rpm, max_speed=0):
    """Build the connected shift line data.

    Returns:
        segments: list of (speeds, rpms, gear_index) for each gear
        shift_points: list of (speed, rpm_before, rpm_after, gear_from, gear_to)
    """
    segments = []
    shift_points = []
    current_start_rpm = MIN_RPM

    for i, ratio in enumerate(gear_ratios):
        is_last = (i == len(gear_ratios) - 1)
        end_rpm = max_rpm if is_last else min(shift_rpm, max_rpm)

        # Cap last gear at max_speed if set
        if is_last and max_speed > 0:
            rpm_at_limit = rpm_from_speed(max_speed, ratio, final_drive, tire_dia)
            end_rpm = min(end_rpm, rpm_at_limit)

        if current_start_rpm < end_rpm:
            rpms = np.linspace(current_start_rpm, end_rpm, 100)
            speeds = speed_from_rpm(rpms, ratio, final_drive, tire_dia)
            segments.append((speeds, rpms, i))

        if not is_last and shift_rpm <= max_rpm:
            speed_at_shift = speed_from_rpm(shift_rpm, ratio, final_drive, tire_dia)
            next_ratio = gear_ratios[i + 1]
            rpm_after = shift_rpm * next_ratio / ratio
            shift_points.append((speed_at_shift, shift_rpm, rpm_after, i + 1, i + 2))
            current_start_rpm = rpm_after

    return segments, shift_points

In [None]:
# --- Widget Definitions ---
style = {'description_width': '90px'}
narrow = widgets.Layout(width='180px')

tire_w = widgets.FloatText(value=35.0, description='Tire Dia (in):', step=0.5, style=style, layout=narrow)
fd_w = widgets.FloatText(value=4.56, description='Final Drive:', step=0.01, style=style, layout=narrow)
max_rpm_w = widgets.IntText(value=5500, description='Max RPM:', step=100, style=style, layout=narrow)
hwy_w = widgets.FloatText(value=70.0, description='Hwy Speed:', step=1.0, style=style, layout=narrow)
num_gears_w = widgets.IntText(value=5, min=2, max=10, description='# Gears:', style=style, layout=narrow)
pb_min_w = widgets.IntText(value=2500, description='PB Min RPM:', step=100, style=style, layout=narrow)
pb_max_w = widgets.IntText(value=4500, description='PB Max RPM:', step=100, style=style, layout=narrow)
max_spd_w = widgets.FloatText(value=0, description='Max Speed:', step=1.0, style=style, layout=narrow)

shift_w = widgets.IntSlider(
    value=3500, min=1500, max=5500, step=100,
    description='Shift RPM:', continuous_update=False,
    style={'description_width': '90px'},
    layout=widgets.Layout(width='95%')
)

# Sync slider max to max RPM widget
def _sync_max(change):
    new_max = max(change['new'], shift_w.min + shift_w.step)
    shift_w.max = new_max
    if shift_w.value > new_max:
        shift_w.value = new_max
max_rpm_w.observe(_sync_max, names='value')

# --- Dynamic Gear Widgets ---
gear_widgets = []
gear_box = widgets.HBox(layout=widgets.Layout(gap='10px'))
defaults = [3.72, 2.20, 1.50, 1.00, 0.79, 0.63, 0.51, 0.42, 0.35, 0.29]

# --- Output Areas ---
out = widgets.Output()
ratio_out = widgets.HTML()
hwy_out = widgets.HTML()
max_spd_out = widgets.HTML()

# --- Update Function ---
def update(change=None):
    gear_ratios = [w.value for w in gear_widgets]
    fd = fd_w.value
    td = tire_w.value
    mr = max_rpm_w.value
    sr = shift_w.value
    hwy_speed = hwy_w.value
    pb_min = pb_min_w.value
    pb_max = pb_max_w.value
    max_speed = max_spd_w.value

    if not gear_ratios or any(r <= 0 for r in gear_ratios) or fd <= 0 or td <= 0 or mr <= 0:
        return

    # Ratio step display
    ratio_text = ' \u2502 '.join(
        f'{i+1}\u2192{i+2}: {gear_ratios[i] / gear_ratios[i+1]:.3f}'
        for i in range(len(gear_ratios) - 1)
    )
    ratio_out.value = (
        f"<div style='font-family:monospace; font-size:13px; padding:8px; "
        f"background:#f8f8f8; border-radius:4px;'>"
        f"<b>Ratio Steps:</b> {ratio_text}</div>"
    )

    # Highway RPM display (top gear)
    hwy_rpm = rpm_from_speed(hwy_speed, gear_ratios[-1], fd, td)
    hwy_out.value = (
        f"<div style='font-family:monospace; font-size:13px; padding:8px; "
        f"background:#e3f2fd; border-radius:4px; margin-top:4px;'>"
        f"<b>Highway RPM:</b> {hwy_speed:.0f} mph in top gear = <b>{hwy_rpm:.0f} RPM</b></div>"
    )

    segments, shift_points = calculate_shift_line(gear_ratios, fd, td, sr, mr, max_speed=max_speed)

    # Top speed display
    if segments:
        top_speed = segments[-1][0][-1]
        limit_text = " (limited)" if max_speed > 0 else ""
        max_spd_out.value = (
            f"<div style='font-family:monospace; font-size:13px; padding:8px; "
            f"background:#ffebee; border-radius:4px; margin-top:4px;'>"
            f"<b>Top Speed:</b> {top_speed:.1f} mph{limit_text}</div>"
        )
    else:
        max_spd_out.value = ""

    with out:
        clear_output(wait=True)
        fig, ax = plt.subplots(figsize=(14, 7))

        # Gear band highlighting
        band_colors = ['#c8e6c9', '#fff9c4', '#bbdefb', '#f8bbd0', '#e1bee7']
        for speeds, rpms, gi in segments:
            ax.axvspan(speeds[0], speeds[-1], alpha=0.15, color=band_colors[gi % len(band_colors)], zorder=0)
            mid_speed = (speeds[0] + speeds[-1]) / 2
            ax.text(mid_speed, mr * 1.05, f'Gear {gi + 1}', ha='center', va='bottom',
                    fontsize=9, fontweight='bold', color='#666', alpha=0.7)

        # Connected shift line with power band coloring
        for speeds, rpms, gi in segments:
            for j in range(len(speeds) - 1):
                rpm_mid = (rpms[j] + rpms[j + 1]) / 2
                if pb_min <= rpm_mid <= pb_max:
                    c = '#388e3c'   # green — in power band
                else:
                    c = '#f57c00'   # orange — outside power band
                ax.plot(speeds[j:j+2], rpms[j:j+2], color=c, linewidth=3,
                        solid_capstyle='round', zorder=3)

        # Shift point drops and annotations
        for speed, rpm_hi, rpm_lo, gf, gt in shift_points:
            ax.plot([speed, speed], [rpm_hi, rpm_lo], '--', color='#757575',
                    linewidth=1.5, alpha=0.7, zorder=2)
            ax.plot(speed, rpm_hi, 'o', color='black', markersize=7, zorder=4)
            ax.plot(speed, rpm_lo, 'o', markerfacecolor='white',
                    markeredgecolor='black', markeredgewidth=1.5, markersize=7, zorder=4)
            mid = (rpm_hi + rpm_lo) / 2
            ax.annotate(
                f'{gf}\u2192{gt}: {speed:.1f} mph\n(drops to {rpm_lo:.0f} RPM)',
                xy=(speed, mid), xytext=(14, 0), textcoords='offset points',
                fontsize=8.5, ha='left', va='center',
                bbox=dict(boxstyle='round,pad=0.3', fc='#fff8e1', ec='#bcaaa4', alpha=0.9),
                arrowprops=dict(arrowstyle='->', color='#9e9e9e', lw=0.8),
                zorder=5
            )

        # Highway speed indicator
        if hwy_speed > 0:
            ax.axvline(x=hwy_speed, color='#1565c0', ls='-.', lw=1.5, alpha=0.6, zorder=1)
            ax.plot(hwy_speed, hwy_rpm, 's', color='#1565c0', markersize=8, zorder=5)
            ax.annotate(f'{hwy_speed:.0f} mph\n{hwy_rpm:.0f} RPM',
                        xy=(hwy_speed, hwy_rpm), xytext=(-14, 12), textcoords='offset points',
                        fontsize=9, ha='right', color='#1565c0', fontweight='bold',
                        bbox=dict(boxstyle='round,pad=0.3', fc='#e3f2fd', ec='#1565c0', alpha=0.9),
                        zorder=5)

        # Max speed indicator
        if segments:
            max_spd = segments[-1][0][-1]
            ax.axvline(x=max_spd, color='#d32f2f', ls=':', lw=1.5, alpha=0.6, zorder=1)
            ax.annotate(f'Top: {max_spd:.0f} mph', xy=(max_spd, mr * 0.5),
                        xytext=(-12, 0), textcoords='offset points',
                        fontsize=10, ha='right', color='#d32f2f', fontweight='bold', zorder=5)

        # Formatting
        ax.set_xlabel('Speed (mph)', fontsize=12, fontweight='bold')
        ax.set_ylabel('Engine RPM', fontsize=12, fontweight='bold')
        ax.set_title(f'{len(gear_ratios)}-Speed Transmission Shift Points', fontsize=14, fontweight='bold', pad=15)
        ax.grid(True, alpha=0.2)
        ax.set_ylim(0, mr * 1.12)
        if segments:
            ax.set_xlim(0, segments[-1][0][-1] * 1.05)
        else:
            ax.set_xlim(0, None)

        # Legend
        handles = [
            Line2D([0], [0], color='#388e3c', lw=3, label=f'Power band ({pb_min}\u2013{pb_max} RPM)'),
            Line2D([0], [0], color='#f57c00', lw=3, label='Outside power band'),
            Line2D([0], [0], color='#757575', ls='--', lw=1.5, label='RPM drop at shift'),
            Line2D([0], [0], color='#1565c0', ls='-.', lw=1.5, label='Highway cruise'),
        ]
        ax.legend(handles=handles, loc='upper left', fontsize=9, framealpha=0.9)

        plt.tight_layout()
        plt.show()
        plt.close(fig)

# --- Rebuild Gear Widgets ---
def rebuild_gears(change=None):
    n = num_gears_w.value
    n = max(2, min(10, n))
    # Preserve existing values where possible
    old_values = [w.value for w in gear_widgets]
    # Remove old observers
    for w in gear_widgets:
        w.unobserve(update, names='value')
    gear_widgets.clear()
    for i in range(n):
        val = old_values[i] if i < len(old_values) else defaults[i] if i < len(defaults) else defaults[-1] * 0.8
        w = widgets.FloatText(value=val, description=f'Gear {i+1}:', step=0.01, style=style, layout=narrow)
        w.observe(update, names='value')
        gear_widgets.append(w)
    gear_box.children = gear_widgets
    update()

# --- Wire Up Observers ---
for w in [tire_w, fd_w, max_rpm_w, shift_w, hwy_w, pb_min_w, pb_max_w, max_spd_w]:
    w.observe(update, names='value')
num_gears_w.observe(rebuild_gears, names='value')

# --- Layout ---
row1a = widgets.HBox([tire_w, fd_w, max_rpm_w, hwy_w, num_gears_w], layout=widgets.Layout(gap='10px'))
row1b = widgets.HBox([pb_min_w, pb_max_w, max_spd_w], layout=widgets.Layout(gap='10px'))

ui = widgets.VBox([
    widgets.HTML("<h3 style='margin:0 0 5px'>Vehicle Setup</h3>"),
    row1a,
    row1b,
    widgets.HTML("<h3 style='margin:10px 0 5px'>Gear Ratios</h3>"),
    gear_box,
    widgets.HTML("<h3 style='margin:10px 0 5px'>Shift Point</h3>"),
    shift_w,
    ratio_out,
    hwy_out,
    max_spd_out,
    out
], layout=widgets.Layout(padding='10px'))

display(ui)

# Initial build of gear widgets and render
rebuild_gears()