# Demo: Arbitrary Piecewise Functions
## Example with x² from 2017-2019, then cos(x) from 2019-2022


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Year→Compute mapping
y0 = 2013
C0 = 1e16
k  = np.log(4.1) / 1.0

def year_to_compute(year):
    return C0 * np.exp(k * (year - y0))

# Define arbitrary piecewise functions
# Format: (start_year, end_year, function, label)
segments = [
    (2017, 2019, lambda C: (C/1e17)**2 + 2, "Quadratic x²"),
    (2019, 2022, lambda C: 10 * np.cos(C/1e17) + 12, "Cosine"),
]

# Optional baseline for comparison
def baseline_func(C):
    return 35.5 * (C ** (-0.064)) + 1.8

# Plot settings
plot_year_start, plot_year_end = 2017.0, 2022.0
segment_color = "#5ec8ff"
baseline_color = "black"
transition_color = "blue"


In [1]:
# Create the plot
compute_min = year_to_compute(plot_year_start)
compute_max = year_to_compute(plot_year_end)

fig, ax = plt.subplots(figsize=(10, 5.5))
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlim(compute_min, compute_max)
ax.set_xlabel("Compute (FLOPs)")
ax.set_ylabel("Loss")

# Plot baseline (optional)
C_grid = np.logspace(np.log10(compute_min), np.log10(compute_max), 1200)
ax.plot(C_grid, baseline_func(C_grid), color=baseline_color, linewidth=2, label='Baseline', alpha=0.5)

# Plot each segment with its arbitrary function
for (ys, ye, func, label) in segments:
    ys_c, ye_c = max(ys, plot_year_start), min(ye, plot_year_end)
    if ye_c <= ys_c:
        continue
    
    years = np.linspace(ys_c, ye_c, 200)
    C_seg = year_to_compute(years)
    L_seg = func(C_seg)
    ax.plot(C_seg, L_seg, color=segment_color, linewidth=2.5, label=label)

# Mark transition between segments
if len(segments) >= 2:
    _, ye, func1, _ = segments[0]
    _, _, func2, _ = segments[1]
    C_t = year_to_compute(ye)
    L1 = func1(C_t)
    L2 = func2(C_t)
    ax.plot([C_t, C_t], [min(L1, L2), max(L1, L2)],
            linestyle=":", linewidth=2, color=transition_color)
    y_mid = (L1 * L2) ** 0.5
    ax.text(C_t * 1.05, y_mid, "Transition:\nx² → cos(x)", 
            color=transition_color, va='center', fontweight='bold')

# Add year labels on top axis
ax_top = ax.twiny()
ax_top.set_xscale('log')
ax_top.set_xlim(ax.get_xlim())
years = np.arange(int(np.floor(plot_year_start)), int(np.ceil(plot_year_end)) + 1)
tick_positions = year_to_compute(years)
mask = (tick_positions >= compute_min) & (tick_positions <= compute_max)
ax_top.set_xticks(tick_positions[mask])
ax_top.set_xticklabels([str(y) for y in years[mask]])
ax_top.set_xlabel("Year")

ax.legend(loc='best')
ax.grid(True, which="both", linestyle=":", linewidth=0.5, alpha=0.6)
plt.tight_layout()

out_path = "Figures/demo_arbitrary_functions.png"
plt.savefig(out_path, dpi=200, bbox_inches="tight")
print(f"Saved to: {out_path}")
plt.show()


NameError: name 'year_to_compute' is not defined