In [None]:
"""
QA fine-tune → user enters one sentence → parse {tax, year, prstp, a2base, elasmu, gA1, gsigma1, pback}
→ run DICE → export Excel
"""

import importlib.util, subprocess, sys

def ensure(pkgs):
    missing = [p for p in pkgs if importlib.util.find_spec(p) is None]
    if missing:
        subprocess.check_call([sys.executable, "-m", "pip", "install", *missing])

ensure([
    "transformers",
    "datasets",
    "fuzzysearch",
    "numpy",
    "pandas",
    "scikit-learn",
    "openpyxl",
    "scipy",
])

import torch

from transformers import (
    DistilBertTokenizerFast,
    DistilBertForQuestionAnswering,
    TrainingArguments,
    Trainer,
    pipeline,
)

from fuzzysearch import find_near_matches
from datasets import Dataset as DS

import os, random, re, subprocess
import numpy as np
import pandas as pd
from math import exp, log
from sklearn.model_selection import train_test_split
from scipy.optimize import minimize, Bounds, LinearConstraint


# ------------------------- Training config -------------------------
FAST_MODE  = True
EPOCHS     = 3
N_SAMPLES  = 2000
BATCH_SIZE = 32
MAX_LEN    = 128
USE_FP16   = torch.cuda.is_available()
NUM_WORKERS= 2

# Small smoothness penalty for numerical stability (on ΔMIU²)
SMOOTH_LAMBDA = 1e-4

# Enforce MIU to be non-decreasing (recommended True to better match official path)
ENFORCE_MONO_MIU = True

# ------------------------- Device -------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device_index = 0 if device == 'cuda' else -1
print("Using device:", device)

# ------------------------- Parameters to extract -------------------------
PARAMS = ["tax","year","prstp","a2base","elasmu","gA1","gsigma1","pback"]

REGEXES = {
    "tax": [
        r"(?:tax|price|fee|charge)\s*(?:of|@|=)?\s*\$?\s*(\d{1,4}(?:,\d{3})*(?:\.\d+)?)\b",
        r"\$(\d{1,4}(?:,\d{3})*(?:\.\d+)?)\s*(?:per\s*(?:ton|t|tonne))?\b"
    ],
    "year": [
        r"\b(?:year|in|by|target|starting)\s*'?(\d{4})\b",
        r"\b(20[2-9]\d|2100)\b"
    ],
    "prstp":[
        r"\bprstp\b\s*(?:to|is|=)?\s*(-?\d*\.?\d+)",
        r"pure time pref.*?(-?\d*\.?\d+)"
    ],
    "a2base":[
        r"\ba2base\b\s*(?:to|is|=)?\s*(-?\d*\.?\d+)",
        r"damage quad coeff.*?(-?\d*\.?\d+)"
    ],
    "elasmu":[
        r"\belasmu\b\s*(?:to|is|=)?\s*(-?\d*\.?\d+)",
        r"util elasticity.*?(-?\d*\.?\d+)"
    ],
    "gA1":[
        r"\bgA1\b\s*(?:to|is|=)?\s*(-?\d*\.?\d+)",
        r"tfp growth init.*?(-?\d*\.?\d+)"
    ],
    "gsigma1":[
        r"\bgsigma1\b\s*(?:to|is|=)?\s*(-?\d*\.?\d+)",
        r"sigma decline init.*?(-?\d*\.?\d+)"
    ],
    "pback": [
        r"\bpback\b\s*(?:=|to|is|at)?\s*\$?\s*(\d{2,4}(?:\.\d+)?)\b",
        r"backstop.*?\$?\s*(\d{2,4}(?:\.\d+)?)\b"
    ],
}

RANGES = {
    "tax": (5, 1000),      # Carbon price $/tCO2
    "year": (2020, 2100),    # Policy target year
    "prstp": (0.0, 0.03),    # Pure rate of social time preference
    "elasmu": (0.5, 2.5),    # Elasticity of marginal utility
    "a2base": (0.0005, 0.008),  # Damage quadratic term
    "gA1": (0.04, 0.12),     # Initial TFP growth per 5y
    "gsigma1": (-0.03, -0.003),  # Initial decarbonization rate per year
    "pback": (150, 1500)     # Backstop cost in 2019$ per tCO2
}

PARAM_DESCRIPTIONS = {
    "tax":    "Carbon price level in $/tCO2 for the specified target year.",
    "year":   "Target year for the carbon price (calendar year).",
    "prstp":  "Pure rate of social time preference (annual).",
    "a2base": "Quadratic damage coefficient in the damage function.",
    "elasmu": "Elasticity of marginal utility of consumption.",
    "gA1":    "Initial growth rate of TFP (per 5 years).",
    "gsigma1":"Initial growth of decarbonization (per year, negative).",
    "pback":  "Backstop technology cost (reference level, 2019$ per tCO2)."
}

# ------------------------- Synthetic data for QA training -------------------------
def gen(n=N_SAMPLES, seed=42, typo_rate=0.005, filler_rate=0.01):
    random.seed(seed)
    rows = []
    keys_extra = [k for k in PARAMS if k not in ["tax","year"]]
    for _ in range(n):
        vals = {}
        for k in PARAMS:
            if k == "tax":
                vals[k] = str(random.choice(range(10,201,5)))
            elif k == "year":
                vals[k] = str(random.choice(range(2025,2036)))
            else:
                lo, hi = RANGES[k]
                vals[k] = f"{random.uniform(lo,hi):.6f}"
        main = f"In {vals['year']}, set a carbon tax of ${vals['tax']}/ton."
        extras = ", ".join(f"{k} to {vals[k]}" for k in random.sample(keys_extra, len(keys_extra)))
        txt = main + " Also set " + extras + "."
        # Light noise
        chars = list(txt); i = 0
        while i < len(chars):
            if random.random() < typo_rate:
                op = random.choice(["swap","del","rep"])
                if op == "swap" and i+1 < len(chars):
                    chars[i], chars[i+1] = chars[i+1], chars[i]; i += 1
                elif op == "del":
                    chars.pop(i); i -= 1
                else:
                    chars[i] = random.choice("abcdefghijklmnopqrstuvwxyz")
            i += 1
        if random.random() < filler_rate:
            ins = random.choice([" like"," um"," well"])
            pos = random.randint(0,len(chars))
            chars.insert(pos, ins)
        rows.append({"query": "".join(chars), **vals})
    return pd.DataFrame(rows)

df = gen()
df_train, df_val = train_test_split(df, test_size=0.3, random_state=42)

def make_qa_records(df, keys=PARAMS):
    recs = []
    for idx, r in df.iterrows():
        for k in keys:
            orig = str(r[k])
            m = find_near_matches(orig, r.query, max_l_dist=2)
            if m:
                recs.append({
                    "id": f"{idx}-{k}",
                    "context": r.query,
                    "question": f"What is the {k}?",
                    "answers": {"text":[orig], "answer_start":[m[0].start]}
                })
    return DS.from_list(recs)

train_raw = make_qa_records(df_train)
eval_raw  = make_qa_records(df_val)

tok = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

def prepare_features(examples):
    enc = tok(
        examples["question"], examples["context"],
        truncation="only_second", max_length=MAX_LEN, stride=50,
        return_overflowing_tokens=True, return_offsets_mapping=True,
        padding="max_length",
    )
    sample_map = enc.pop("overflow_to_sample_mapping")
    offset_map = enc.pop("offset_mapping")
    starts, ends = [], []
    for i, offsets in enumerate(offset_map):
        si = sample_map[i]
        ans = examples["answers"][si]
        if not ans["answer_start"]:
            starts.append(0); ends.append(0); continue
        sc = ans["answer_start"][0]; ec = sc + len(ans["text"][0])
        seq_ids = enc.sequence_ids(i)
        c0 = seq_ids.index(1); c1 = len(seq_ids)-1 - seq_ids[::-1].index(1)
        if not (offsets[c0][0] <= sc < offsets[c1][1] and offsets[c0][0] < ec <= offsets[c1][1]):
            starts.append(0); ends.append(0)
        else:
            s, e = c0, c1
            while s <= c1 and offsets[s][0] <= sc: s += 1
            while e >= c0 and offsets[e][1] >= ec: e -= 1
            starts.append(s-1); ends.append(e+1)
    enc["start_positions"] = starts
    enc["end_positions"]   = ends
    return enc

try:
    train_ds = train_raw.map(prepare_features, batched=True, remove_columns=train_raw.column_names, num_proc=NUM_WORKERS)
    val_ds   = eval_raw.map(prepare_features,  batched=True, remove_columns=eval_raw.column_names,  num_proc=NUM_WORKERS)
except TypeError:
    train_ds = train_raw.map(prepare_features, batched=True, remove_columns=train_raw.column_names)
    val_ds   = eval_raw.map(prepare_features,  batched=True, remove_columns=eval_raw.column_names)

model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
if device == 'cuda':
    model = model.to('cuda')

# === Training_Arguments as requested ===
args = TrainingArguments(
    output_dir="qa_ft_out",
    num_train_epochs=3,
    per_device_train_batch_size=32,
    learning_rate=5e-5,
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_steps=100,
    save_total_limit=1,
    report_to=['none'],
    no_cuda=(device == 'cpu'),
    fp16=USE_FP16,
    dataloader_num_workers=NUM_WORKERS,
    disable_tqdm=True,
    save_steps=10**9
)

trainer = Trainer(model=model, args=args, train_dataset=train_ds, eval_dataset=None, processing_class=tok)
print(f"\n[Training] samples={len(train_ds)} epochs={EPOCHS}")
trainer.train()

qa_pipe = pipeline("question-answering", model=model, tokenizer=tok, device=device_index)
num_pat = re.compile(r"-?\d+(?:\.\d+)?")

def clean_number(s):
    if not s:
        return None
    s = s.strip()
    if s.startswith("."):
        s = "0" + s
    s = re.sub(r"[,$%°]", "", s)
    return s if re.fullmatch(r"-?\d*\.?\d+", s) else None

def extract(q: str):
    out = {}
    for k in PARAMS:
        val = None
        for pat in REGEXES.get(k, []):
            m = re.search(pat, q, flags=re.I)
            if m:
                val = m.group(1)
                break
        if not val:
            m = re.search(rf"\b{k}\b[^\d\-\.]{{0,16}}(-?\d*\.?\d+)", q, flags=re.I)
            if m:
                val = m.group(1)
        if not val and k == "year":
            m = re.search(r"\b(20[2-9]\d|2100)\b", q)
            if m: val = m.group(1)
        if not val and k == "tax":
            m = re.search(r"\$\s*([0-9]{1,4}(?:\.\d+)?)", q)
            if m: val = m.group(1)
        if not val:
            pred = qa_pipe(question=f"What is the {k}?", context=q)
            ans = pred.get("answer","").strip()
            if pred.get("score",0) >= 0.7 and num_pat.fullmatch(ans):
                val = ans
        out[k] = clean_number(val) if val else None
    return out

# ------------------------- DICE core  -------------------------
class DICEModel:
    def __init__(self):
        self.timesteps = range(1, 82)
        self.gama = 0.300; self.pop1 = 7752.9; self.popadj = 0.145; self.popasym = 10825.0
        self.dk = 0.1; self.q1 = 135.7; self.AL1 = 5.84; self.gA1 = 0.066; self.delA = 0.0015
        self.gsigma1 = -0.015; self.delgsig = 0.96; self.asymgsig = -0.005
        self.e1 = 37.56; self.miu1 = 0.05; self.fosslim = 6000; self.CumEmiss0 = 633.5
        self.a1 = 0; self.a2base = 0.003467; self.a3 = 2.00
        self.expcost2 = 2.6; self.pback = 515.0; self.gback = -0.012
        self.cprice1 = 6; self.gcprice = 0.025
        self.limmiu2070 = 1.0; self.limmiu2120 = 1.1; self.limmiu2200 = 1.05; self.limmiu2300 = 1.0; self.delmiumax = 0.12
        self.betaclim = 0.5; self.elasmu = 0.95; self.prstp = 0.001; self.pi = 0.05; self.k0 = 295
        self.siggc1 = 0.01; self.tstep = 5; self.SRF = 1000000
        self.scale1 = 0.00891061; self.scale2 = -6275.91
        self.eland0 = 5.9; self.deland = 0.1
        self.F_Misc2020 = -0.054; self.F_Misc2100 = 0.265
        self.F_GHGabate2020 = 0.518; self.F_GHGabate2100 = 0.957
        self.ECO2eGHGB2020 = 9.96; self.ECO2eGHGB2100 = 15.5
        self.emissrat2020 = 1.40; self.emissrat2100 = 1.21
        self.Fcoef1 = 0.00955; self.Fcoef2 = 0.861
        self.emshare0 = 0.2173; self.emshare1 = 0.224; self.emshare2 = 0.2824; self.emshare3 = 0.2763
        self.tau0 = 1000000; self.tau1 = 394.4; self.tau2 = 36.53; self.tau3 = 4.304
        self.teq1 = 0.324; self.teq2 = 0.44; self.d1 = 236; self.d2 = 4.07
        self.irf0 = 32.4; self.irC = 0.019; self.irT = 4.165; self.fco22x = 3.93
        self.mat0 = 886.5128014; self.res00 = 150.093; self.res10 = 102.698; self.res20 = 39.534; self.res30 = 6.1865
        self.mateq = 588; self.tbox10 = 0.1477; self.tbox20 = 1.099454; self.tatm0 = 1.24715
        # containers
        self.L = {1:self.pop1}; self.sig1 = self.e1/(self.q1*(1-self.miu1)); self.sigma = {1:self.sig1}
        self.aL = {1:self.AL1}; self.gA = {}; self.gsig = {}; self.pbacktime = {}; self.cpricebase = {}
        self.varpcc = {}; self.rprecaut = {}; self.RR1 = {}; self.RR = {}; self.miuup = {}
        self.eland = {}; self.CO2E_GHGabateB = {}; self.F_Misc = {}; self.emissrat = {}; self.sigmatot = {}; self.COST1TOT = {}
        self.MAT = {1:self.mat0}; self.TATM = {1:self.tatm0}
        self.RES0 = {1:self.res00}; self.RES1 = {1:self.res10}; self.RES2 = {1:self.res20}; self.RES3 = {1:self.res30}
        self.TBOX1 = {1:self.tbox10}; self.TBOX2 = {1:self.tbox20}; self.alpha = {}; self.IRFt = {}
        self.F_GHGabate = {1:self.F_GHGabate2020}; self.CCATOT = {1:self.CumEmiss0}
        self.ECO2 = {}; self.EIND = {}; self.ECO2E = {}; self.CACC = {}; self.FORC = {}
        self.DAMFRAC = {}; self.DAMAGES = {}; self.ABATECOST = {}; self.MCABATE = {}; self.CPRICE = {}
        self.YGROSS = {}; self.YNET = {}; self.Y = {}; self.C = {}; self.CPC = {}; self.I = {}; self.S = {}
        self.K = {1:self.k0}; self.MIU = {}; self.RFACTLONG = {1:1000000}; self.RLONG = {}; self.RSHORT = {}
        self.PERIODU = {}; self.TOTPERIODU = {}; self.UTILITY = 0

    def initialize_dynamic_parameters(self):
        self.rartp = exp(self.prstp + self.betaclim * self.pi) - 1
        for t in self.timesteps:
            self.gA[t] = self.gA1 * exp(-self.delA * 5 * (t - 1))
            self.gsig[t] = min(self.gsigma1 * (self.delgsig ** (t - 1)), self.asymgsig)
            self.L[t] = self.pop1 if t == 1 else self.L[t-1] * ((self.popasym / self.L[t-1]) ** self.popadj)
            self.aL[t] = self.AL1 if t == 1 else self.aL[t-1] / (1 - self.gA[t])
            self.sigma[t] = self.e1/(self.q1*(1-self.miu1)) if t==1 else self.sigma[t-1]*exp(5*self.gsig[t])
            self.pbacktime[t] = self.calculate_pbacktime(t)
            self.cpricebase[t] = self.cprice1 * ((1 + self.gcprice) ** (5 * (t - 1)))
            self.optlrsav = (self.dk + 0.004) / (self.dk + 0.004 * self.elasmu + self.rartp) * self.gama
            self.varpcc[t] = min(self.siggc1 ** 2 * 5 * (t - 1), self.siggc1 ** 2 * 5 * 47)
            self.rprecaut[t] = -0.5 * self.varpcc[t] * self.elasmu ** 2
            self.RR1[t] = 1 / ((1 + self.rartp) ** (self.tstep * (t - 1)))
            self.RR[t] = self.RR1[t] * (1 + self.rprecaut[t]) ** (-self.tstep * (t - 1))
            self.miuup[t] = self.calculate_miuup(t)

    def initialize_nonco2_parameters(self):
        for t in self.timesteps:
            self.eland[t] = self.eland0 * ((1 - self.deland) ** (t - 1))
            self.CO2E_GHGabateB[t] = self.calculate_CO2E_GHGabateB(t)
            self.F_Misc[t] = self.F_Misc2020 + ((self.F_Misc2100 - self.F_Misc2020) / 16) * (t - 1) if t <= 16 else self.F_Misc2100
            self.emissrat[t] = self.emissrat2020 + ((self.emissrat2100 - self.emissrat2020) / 16) * (t - 1) if t <= 16 else self.emissrat2100
            self.sigmatot[t] = self.sigma[t] * self.emissrat[t]
            self.COST1TOT[t] = self.pbacktime[t] * self.sigmatot[t] / self.expcost2 / 1000

    def calculate_pbacktime(self, t):
        if t <= 7:  return self.pback * exp(-5 * 0.01 * (t - 7))
        else:       return self.pback * exp(-5 * 0.001 * (t - 7))

    def calculate_CO2E_GHGabateB(self, t):
        if t <= 16: return self.ECO2eGHGB2020 + ((self.ECO2eGHGB2100 - self.ECO2eGHGB2020) / 16) * (t - 1)
        else:       return self.ECO2eGHGB2100

    def calculate_miuup(self, t):
        if t == 1: return 0.05
        elif t == 2: return 0.10
        elif t > 57: return self.limmiu2300
        elif t > 37: return self.limmiu2200
        elif t > 20: return self.limmiu2120
        elif t > 11: return self.limmiu2070
        elif t > 8:  return 0.85 + 0.05 * (t - 8)
        else:        return self.delmiumax * (t - 1)

    def update_model(self, t):
        # Diagnostic alpha[t] based on previous IRF

        if t == 1:
            self.alpha[t] = 1.0
        else:
            alpha_t = (self.irf0 + self.irC * self.CACC[t-1] + self.irT * self.TATM[t-1]) / self.irf0
            self.alpha[t] = max(alpha_t, 0.1)

        # Economics and emissions

        self.YGROSS[t] = max((self.aL[t] * ((self.L[t] / 1000) ** (1 - self.gama))) * (self.K[t] ** self.gama), 0)
        self.ECO2[t]  = (self.sigma[t] * self.YGROSS[t] + self.eland[t]) * (1 - self.MIU[t])
        self.EIND[t]  = (self.sigma[t] * self.YGROSS[t]) * (1 - self.MIU[t])
        self.ECO2E[t] = (self.sigma[t] * self.YGROSS[t] + self.eland[t] + self.CO2E_GHGabateB[t]) * (1 - self.MIU[t])

        # Carbon boxes

        if t == 1:
            self.RES0[t] = self.res00; self.RES1[t] = self.res10; self.RES2[t] = self.res20; self.RES3[t] = self.res30
        else:
            self.RES0[t] = (self.emshare0 * self.tau0 * self.alpha[t] * (self.ECO2[t] / 3.667)) * (1 - exp(-self.tstep / (self.tau0 * self.alpha[t]))) + self.RES0[t - 1] * exp(-self.tstep / (self.tau0 * self.alpha[t]))
            self.RES1[t] = (self.emshare1 * self.tau1 * self.alpha[t] * (self.ECO2[t] / 3.667)) * (1 - exp(-self.tstep / (self.tau1 * self.alpha[t]))) + self.RES1[t - 1] * exp(-self.tstep / (self.tau1 * self.alpha[t]))
            self.RES2[t] = (self.emshare2 * self.tau2 * self.alpha[t] * (self.ECO2[t] / 3.667)) * (1 - exp(-self.tstep / (self.tau2 * self.alpha[t]))) + self.RES2[t - 1] * exp(-self.tstep / (self.tau2 * self.alpha[t]))
            self.RES3[t] = (self.emshare3 * self.tau3 * self.alpha[t] * (self.ECO2[t] / 3.667)) * (1 - exp(-self.tstep / (self.tau3 * self.alpha[t]))) + self.RES3[t - 1] * exp(-self.tstep / (self.tau3 * self.alpha[t]))

        # Atmospheric concentration and cumulative emissions

        if t == 1: self.MAT[t] = self.mat0
        else:      self.MAT[t] = max(self.mateq + self.RES0[t] + self.RES1[t] + self.RES2[t] + self.RES3[t], 10)

        if t < max(self.timesteps): self.CCATOT[t + 1] = self.CCATOT[t] + self.ECO2[t] * (5 / 3.666)
        self.CACC[t] = self.CCATOT[t] - (self.MAT[t] - self.mateq)

        # Non-CO2 forcing and total forcing

        if t < max(self.timesteps):
            self.F_GHGabate[t + 1] = self.Fcoef2 * self.F_GHGabate[t] + self.Fcoef1 * self.CO2E_GHGabateB[t] * (1 - self.MIU[t])
        self.FORC[t] = self.fco22x * ((log(self.MAT[t] / self.mateq)) / log(2)) + self.F_Misc[t] + self.F_GHGabate[t]

        # Temperature boxes

        if t == 1:
            self.TBOX1[t] = self.tbox10; self.TBOX2[t] = self.tbox20; self.TATM[t] = self.tatm0
        else:
            self.TBOX1[t] = self.TBOX1[t-1] * exp(-self.tstep / self.d1) + self.teq1 * self.FORC[t] * (1 - exp(-self.tstep / self.d1))
            self.TBOX2[t] = self.TBOX2[t-1] * exp(-self.tstep / self.d2) + self.teq2 * self.FORC[t] * (1 - exp(-self.tstep / self.d2))
            self.TATM[t]  = min(max(self.TBOX1[t] + self.TBOX2[t], 0), 20)

        # Damages, net output, consumption, and price

        self.DAMFRAC[t] = (self.a1 * self.TATM[t]) + (self.a2base * (self.TATM[t] ** self.a3))
        self.DAMAGES[t] = self.YGROSS[t] * self.DAMFRAC[t]
        self.YNET[t]    = max(self.YGROSS[t] * (1 - self.DAMFRAC[t]), 0)
        self.ABATECOST[t]= self.YGROSS[t] * self.COST1TOT[t] * (self.MIU[t] ** self.expcost2)
        self.Y[t] = max(self.YNET[t] - self.ABATECOST[t], 0)
        self.I[t] = max(self.S[t] * self.Y[t], 0)
        self.C[t] = max(self.Y[t] - self.I[t], 2)
        self.CPC[t] = max(1000 * self.C[t] / self.L[t], 0.01)
        self.MCABATE[t] = self.pbacktime[t] * (self.MIU[t] ** (self.expcost2 - 1))
        self.CPRICE[t]  = self.MCABATE[t]

        # Discount factors and interest rates
        if t == 1:
            self.RFACTLONG[t] = 1000000; self.RLONG[t] = -log(self.RFACTLONG[t] / self.SRF) / (5 * 1); self.RSHORT[t] = 0.0
        else:
            self.RFACTLONG[t] = max(self.SRF * (self.CPC[t] / self.CPC[1]) ** (-self.elasmu) * self.RR[t], 0.0001)
            self.RLONG[t]     = -log(self.RFACTLONG[t] / self.SRF) / (5 * t)
            self.RSHORT[t]    = -log(self.RFACTLONG[t] / self.RFACTLONG[t-1]) / 5.0

        # Diagnostics
        self.IRFt[t]      = max(self.irf0 + self.irC * self.CACC[t] + self.irT * self.TATM[t], 0)
        self.PERIODU[t]   = ((self.C[t] * 1000 / self.L[t]) ** (1 - self.elasmu) - 1) / (1 - self.elasmu) - 1
        self.TOTPERIODU[t]= self.PERIODU[t] * self.L[t] * self.RR[t]

    def compute_utility(self):
        self.UTILITY = self.tstep * self.scale1 * sum(self.TOTPERIODU[t] for t in self.timesteps) + self.scale2
        return self.UTILITY


def _clamp(x, lo, hi):
    return max(lo, min(hi, x))

class DICEOptimizerIPM:
    def __init__(self, dice_model):
        self.model = dice_model
        self.T = len(dice_model.timesteps)

    def _simulate(self, MIU_values, S_values, add_smooth_penalty=True):
        m = self.model
        keys = ['ECO2','EIND','ECO2E','RES0','RES1','RES2','RES3','MAT','CCATOT','CACC',
                'F_GHGabate','FORC','TBOX1','TBOX2','TATM','DAMFRAC','DAMAGES','YGROSS',
                'YNET','Y','I','C','CPC','MCABATE','CPRICE','PERIODU','TOTPERIODU',
                'RFACTLONG','RLONG','RSHORT','MIU','S','alpha','K','IRFt']
        for k in keys:
            if k not in m.__dict__: setattr(m,k,{})
        m.K[1] = m.k0

        miu_path = []
        for ti, t in enumerate(m.timesteps, start=1):
            miu_t = _clamp(MIU_values[ti-1], 0.0, 1.0)
            miu_t = min(miu_t, m.miuup[t])
            if getattr(m, 'enforce_price', False) and t in getattr(m, 'required_miu', {}):
                miu_t = m.required_miu[t]
            m.MIU[t] = miu_t
            miu_path.append(miu_t)

            m.S[t] = _clamp(S_values[ti-1], 0.0, 1.0) if t <= 37 else 0.28

            if t > 1:
                m.K[t] = (1.0 - m.dk) ** m.tstep * m.K[t-1] + m.tstep * m.I[t-1]
                m.K[t] = max(m.K[t], 1.0)

            m.update_model(t)

        utility = m.compute_utility()

        if add_smooth_penalty and SMOOTH_LAMBDA > 0:
            diffs = np.diff(np.array(miu_path))
            penalty = float(np.sum(diffs * diffs))
            utility -= SMOOTH_LAMBDA * penalty

        return utility

    def objective_function(self, decision_variables):
        T = self.T
        MIU_values = decision_variables[:T]
        S_values   = decision_variables[T:2*T]
        return -self._simulate(MIU_values, S_values)

    def optimize(self):
        T = self.T
        m = self.model
        x0_miu = np.array([min(0.6 * m.miuup[t], 0.9) for t in m.timesteps], dtype=float)
        s0 = min(max(getattr(m, "optlrsav", 0.25), 0.05), 0.5)
        x0_S   = np.array([s0 if t <= 37 else 0.28 for t in m.timesteps], dtype=float)
        x0 = np.concatenate([x0_miu, x0_S])

        lb = np.zeros(2*T, dtype=float)
        ub = np.ones(2*T, dtype=float)
        bounds = Bounds(lb, ub)

        constraints = []
        if ENFORCE_MONO_MIU:
            # Linear constraint: MIU[t] - MIU[t-1] >= 0  (t=2..T)
            A = np.zeros((T-1, 2*T), dtype=float)
            for k in range(1, T):
                A[k-1, k]   =  1.0
                A[k-1, k-1] = -1.0
            mono_con = LinearConstraint(A, lb=np.zeros(T-1), ub=np.full(T-1, np.inf))
            constraints.append(mono_con)

        return minimize(self.objective_function, x0, method='trust-constr',
                        bounds=bounds, constraints=constraints,
                        options={'maxiter': 300 if FAST_MODE else 500, 'verbose': 0})

# ------------------------- Helpers: anchors + overrides -------------------------

def year_to_period(y: int) -> int:
    return (y - 2020)//5 + 1

def set_price_anchors_forward(model, anchors, growths=None, enforce=False):
    def g_for(i):
        if growths is None: return model.gcprice
        if isinstance(growths,(int,float)): return float(growths)
        return float(growths[min(i,len(growths)-1)])
    desired = {}
    t_anchors = [year_to_period(y) for y,_ in anchors]
    p_anchors = [float(p) for _,p in anchors]
    for i,(ta,pa) in enumerate(zip(t_anchors,p_anchors)):
        desired[ta] = pa
        g = g_for(i)
        t_end = (t_anchors[i+1]-1) if i+1<len(t_anchors) else max(model.timesteps)
        for t in range(ta+1, t_end+1):
            desired[t] = pa * ((1+g) ** (5*(t-ta)))
    for t, val in desired.items():
        model.cpricebase[t] = val
    if enforce:
        req = {}; hit_ceiling = False
        for t, val in desired.items():
            if model.pbacktime[t] > 0 and model.expcost2 > 1:
                miu = (val / model.pbacktime[t]) ** (1.0/(model.expcost2-1.0))
                miu = min(max(0.0, miu), model.miuup[t])
                if miu >= model.miuup[t]-1e-9: hit_ceiling = True
                req[t] = miu
        model.required_miu = req; model.enforce_price = True
        print(f"✓ Anchored (ENFORCED): {anchors}")
        if hit_ceiling: print("⚠ Some target prices saturate at the MIU upper bound; realized CPRICE may be lower.")
    else:
        model.required_miu = {}; model.enforce_price = False
        print(f"✓ Anchored (baseline): {anchors}")

def parse_enforce_flag_and_clean_query(raw: str):
    lower = raw.lower()
    enforce = ("enforce=true" in lower) or ("policy=true" in lower) or ("mode=policy" in lower) or ("true mode" in lower)
    cleaned = re.sub(r"(enforce|policy|mode)\s*=\s*(true|policy)", "", raw, flags=re.I)
    return enforce, cleaned

def run_dice_with_anchors_and_overrides(anchors, overrides=None, growths=None, enforce=False):
    m = DICEModel()

    # Scalar overrides

    if overrides:
        for k, v in overrides.items():
            if v is None: continue
            if k in ("tax","year"): continue
            try: setattr(m, k, float(v))
            except: pass
    m.initialize_dynamic_parameters()
    m.initialize_nonco2_parameters()
    if anchors:
        set_price_anchors_forward(m, anchors, growths=growths, enforce=enforce)
    res = DICEOptimizerIPM(m).optimize()
    print("✓ DICE solved." if getattr(res, "success", True) else "⚠ DICE may not converge.")

    variables = [
        "MAT","TATM","RES0","RES1","RES2","RES3","TBOX1","TBOX2","alpha","IRFt","F_GHGabate",
        "CCATOT","CACC","FORC","DAMFRAC","ABATECOST","YGROSS","YNET","Y","C","CPC","I","S","K",
        "MIU","PERIODU","TOTPERIODU","CPRICE","UTILITY","cpricebase","pbacktime"
      ]

    df_all = pd.DataFrame()
    for name in variables:
        d = getattr(m, name)
        if isinstance(d, dict):
            tmp = pd.DataFrame(list(d.items()), columns=["Time", name])
        else:
            tmp = pd.DataFrame({name:[d]}); tmp["Time"] = 1
        df_all = tmp if df_all.empty else pd.merge(df_all, tmp, on="Time", how="outer")
    df_all["Year"] = df_all["Time"].apply(lambda t: 2020 + 5*(int(t)-1))
    if anchors:
        df_all = df_all[df_all["Year"] >= anchors[0][0]].copy()
    tag = "_".join([f"{y}-{int(p)}" for y,p in anchors]) if anchors else "no_anchor"
    tag += ("_ENF" if enforce else "_BASE")
    filename = f"DICE_Model_Combined_Results_Anchors_{tag}.xlsx"
    df_all.to_excel(filename, index=False)
    print(f"✓ Exported: {filename}")
    for y,p in (anchors or []):
        t = year_to_period(y)
        actual = getattr(m,"CPRICE",{}).get(t, np.nan)
        print(f"Check {y}: target=${p:.1f}/t, actual=${actual:.1f}/t")
    return filename

# ------------------------- Friendly validation & popup helpers -------------------------
def show_popup(message: str):

    try:
        import tkinter as tk
        from tkinter import messagebox
        root = tk.Tk()
        root.withdraw()
        messagebox.showwarning("Input Needed", message)
        root.destroy()
    except Exception:
        print("\n[Notice] " + message + "\n")

def print_parameter_intro():
    print("\n================ DICE Input Guide ================")
    for k in PARAMS:
        lo, hi = RANGES[k]
        print(f"- {k}: {PARAM_DESCRIPTIONS[k]}")
        print(f"  Valid range: [{lo}, {hi}]")
    print("\nModes:")
    print("  • TRUE mode (ENF): add 'enforce=true' (recommended for policy-anchored runs).")
    print("  • BASE mode: omit 'enforce=true' (pure optimization with MIU upper bounds).")
    print("\nExamples:")
    print("  In 2030, set a carbon tax of $50/ton, prstp=0.002, a2base 0.0035, elasmu 0.95, gA1=0.06, gsigma1=-0.01, pback=400. enforce=true")
    print("  In 2050, set a carbon price of $100 per ton; prstp=0.001; a2base=0.0035; enforce=true")
    print("==================================================\n")

def validate_extracted_params(extracted: dict):
    """
    Validate required fields and ranges.
    Required: tax, year (and within ranges).
    Optional fields: if provided, must be within their ranges.
    Returns (ok, message). When ok=False, message explains what to fix.
    """
    required_missing = [k for k in ("tax","year") if not extracted.get(k)]
    if required_missing:
        return False, f"Please include both tax and year. Missing: {', '.join(required_missing)}."

    # Check ranges
    problems = []
    for k, (lo, hi) in RANGES.items():
        v = extracted.get(k)
        if v is None:
            continue
        try:
            vf = float(v)
            if not (lo <= vf <= hi):
                problems.append(f"{k}={vf} is out of range [{lo}, {hi}]")
        except Exception:
            problems.append(f"{k} has an invalid numeric value.")
    if problems:
        return False, "Input values out of expected ranges:\n  - " + "\n  - ".join(problems)

    return True, "OK"

# ------------------------- Main: prompt user safely -------------------------
if __name__ == "__main__":
    try:
        print("\n[READY] QA fine-tuning complete.")
        print_parameter_intro()

        while True:
            user_text = input("> Please enter your policy sentence (include enforce=true for TRUE mode):\n> ").strip()
            if not user_text:
                show_popup("No input detected. Please enter a policy sentence following the guide.")
                continue

            enforce_flag, q_clean = parse_enforce_flag_and_clean_query(user_text)
            ex = extract(q_clean)
            print("Parsed parameters:", ex)

            ok, msg = validate_extracted_params(ex)
            if not ok:
                show_popup(msg + " Please try again.")
                continue

            # Build anchors (tax/year must be present and valid at this point)

            anchors = []
            try:
                anchors.append((int(float(ex["year"])), float(ex["tax"])))
            except Exception:
                show_popup("Could not read 'tax' or 'year' as numbers. Please try again.")
                continue

            overrides = {k: (float(ex[k]) if ex.get(k) is not None else None)
                         for k in ["prstp","a2base","elasmu","gA1","gsigma1","pback"]}

            if not anchors:
                # Should not happen after validation, but keep a friendly guard
                show_popup("No valid tax/year anchor recognized. Please try again following the examples.")
                continue

            # Run DICE
            run_dice_with_anchors_and_overrides(anchors, overrides=overrides, enforce=enforce_flag)
            break

    except KeyboardInterrupt:
        print("\nSession cancelled by user.")
    except Exception:
        # Any unexpected issue: show a professional message, not a raw traceback
        show_popup("Something went wrong while processing the request. Please check your input and try again.")


Using device: cuda


Map (num_proc=2):   0%|          | 0/11195 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/4798 [00:00<?, ? examples/s]

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



[Training] samples=11195 epochs=3
{'loss': 2.1886, 'grad_norm': 2.615128517150879, 'learning_rate': 4.714285714285714e-05, 'epoch': 0.2857142857142857}
{'loss': 0.3151, 'grad_norm': 3.5073740482330322, 'learning_rate': 4.5026455026455024e-05, 'epoch': 0.5714285714285714}
{'loss': 0.3107, 'grad_norm': 4.202085018157959, 'learning_rate': 3.973544973544974e-05, 'epoch': 0.8571428571428571}
{'loss': 0.2716, 'grad_norm': 5.259110450744629, 'learning_rate': 3.444444444444445e-05, 'epoch': 1.1428571428571428}
{'loss': 0.1847, 'grad_norm': 0.7418156266212463, 'learning_rate': 2.9153439153439156e-05, 'epoch': 1.4285714285714286}
{'loss': 0.1744, 'grad_norm': 3.27234148979187, 'learning_rate': 2.3862433862433864e-05, 'epoch': 1.7142857142857144}
{'loss': 0.15, 'grad_norm': 0.8587740659713745, 'learning_rate': 1.8571428571428572e-05, 'epoch': 2.0}
{'loss': 0.1256, 'grad_norm': 2.8313965797424316, 'learning_rate': 1.3280423280423282e-05, 'epoch': 2.2857142857142856}
{'loss': 0.1114, 'grad_norm': 

Device set to use cuda:0


{'train_runtime': 30.1627, 'train_samples_per_second': 1113.461, 'train_steps_per_second': 34.811, 'train_loss': 0.38099321002051945, 'epoch': 3.0}

[READY] QA fine-tuning complete.

- tax: Carbon price level in $/tCO2 for the specified target year.
  Valid range: [5, 1000]
- year: Target year for the carbon price (calendar year).
  Valid range: [2020, 2100]
- prstp: Pure rate of social time preference (annual).
  Valid range: [0.0, 0.03]
- a2base: Quadratic damage coefficient in the damage function.
  Valid range: [0.0005, 0.008]
- elasmu: Elasticity of marginal utility of consumption.
  Valid range: [0.5, 2.5]
- gA1: Initial growth rate of TFP (per 5 years).
  Valid range: [0.04, 0.12]
- gsigma1: Initial growth of decarbonization (per year, negative).
  Valid range: [-0.03, -0.003]
- pback: Backstop technology cost (reference level, 2019$ per tCO2).
  Valid range: [150, 1500]

Modes:
  • TRUE mode (ENF): add 'enforce=true' (recommended for policy-anchored runs).
  • BASE mode: omit '

  self.H.update(self.x - self.x_prev, self.g - self.g_prev)


⚠ DICE may not converge.
✓ Exported: DICE_Model_Combined_Results_Anchors_2020-6_BASE.xlsx
Check 2020: target=$6.0/t, actual=$5.8/t
