In [1]:
!pip install QuantLib

Collecting QuantLib
  Downloading quantlib-1.38-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Downloading quantlib-1.38-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (20.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.0/20.0 MB[0m [31m33.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: QuantLib
Successfully installed QuantLib-1.38


In [2]:
# product_definitions.py
"""
Contains classes for defining the static properties of various financial products.
Handles date string parsing in from_dict methods and __init__.
ConvertibleBondStaticBase no longer holds dynamic market parameters like vol, div_yield, etc.
These will be passed to the pricer.
Ensures index_stub has a default if not provided.
"""
import QuantLib as ql
from datetime import date, datetime
import abc
import numpy as np

def _parse_date_input(date_input):
    """Helper to parse date input which could be date object or ISO string."""
    if isinstance(date_input, datetime):
        return date_input.date()
    if isinstance(date_input, date):
        return date_input
    if isinstance(date_input, str):
        try:
            return date.fromisoformat(date_input)
        except ValueError:
            try:
                # Handle potential datetime strings from JSON, including those with 'Z'
                dt_obj = datetime.fromisoformat(date_input.replace('Z', '+00:00').replace('z', '+00:00'))
                return dt_obj.date()
            except ValueError:
                raise ValueError(f"Invalid date string format: '{date_input}'. Expected YYYY-MM-DD or ISO datetime.")
    if date_input is None: # Allow None to pass through if a date is optional
        return None
    raise TypeError(f"Unsupported date input type: {type(date_input)}. Value: {date_input}")

def _serialize_date_list(date_list):
    if date_list is None: return None
    # Ensure elements are parsed before attempting isoformat, and handle None elements
    return [_parse_date_input(d).isoformat() if _parse_date_input(d) else None for d in date_list]


def _parse_date_list(date_input_list):
    if date_input_list is None: return [] # Default to empty list if None
    return [_parse_date_input(d_str) if d_str else None for d_str in date_input_list]


class ProductStaticBase(abc.ABC):
    def __init__(self, valuation_date):
        self.valuation_date_py: date = _parse_date_input(valuation_date)

    @classmethod
    @abc.abstractmethod
    def from_dict(cls, params: dict) -> 'ProductStaticBase':
        pass

    @abc.abstractmethod
    def to_dict(self) -> dict:
        pass


class QuantLibBondStaticBase(ProductStaticBase):
    def __init__(
        self,
        valuation_date, maturity_date, coupon_rate: float,
        face_value: float = 100.0, freq: int = 2,
        calendar = None, day_count = None,
        business_convention: int = None, settlement_days: int = 0,
        currency: str = "USD", index_stub: str = "GENERIC_IR"
    ):
        super().__init__(valuation_date)
        self.maturity_date_py: date = _parse_date_input(maturity_date)
        self.coupon_rate: float = float(coupon_rate)
        self.face_value: float = float(face_value)
        self.freq: int = int(freq)
        self.settlement_days: int = int(settlement_days)
        self.currency: str = currency
        self.index_stub: str = index_stub if index_stub and index_stub.strip() else "GENERIC_IR"

        # Calendar Handling
        if isinstance(calendar, str):
            if calendar.lower() == "target": self.calendar_ql = ql.TARGET()
            elif calendar.lower() == "us_federalreserve": self.calendar_ql = ql.UnitedStates(ql.UnitedStates.FederalReserve)
            # Add more string-to-QL calendar mappings as needed
            else: self.calendar_ql = ql.NullCalendar() # Fallback or raise error
        elif isinstance(calendar, ql.Calendar): self.calendar_ql = calendar
        else: # Default based on currency if not provided or not recognized string
            if self.currency.upper() == "USD": self.calendar_ql = ql.UnitedStates(ql.UnitedStates.FederalReserve)
            elif self.currency.upper() == "EUR": self.calendar_ql = ql.TARGET()
            else: self.calendar_ql = ql.TARGET() # General fallback

        # Day Count Handling
        if isinstance(day_count, str):
            if day_count.lower() == "actualactualisda": self.day_count_ql = ql.ActualActual(ql.ActualActual.ISDA)
            elif day_count.lower() == "actual360": self.day_count_ql = ql.Actual360()
            elif day_count.lower() == "thirty360": self.day_count_ql = ql.Thirty360(ql.Thirty360.USA)
            # Add more string-to-QL day counter mappings
            else: self.day_count_ql = ql.Actual365Fixed() # Fallback
        elif isinstance(day_count, ql.DayCounter): self.day_count_ql = day_count
        else: self.day_count_ql = ql.ActualActual(ql.ActualActual.ISDA) # Default

        self.business_convention_ql: int = business_convention if business_convention is not None else ql.Following

        self.ql_valuation_date: ql.Date = ql.Date(self.valuation_date_py.day, self.valuation_date_py.month, self.valuation_date_py.year)
        # Setting global eval date should be done carefully, ideally once per run or context by the calling script.
        # ql.Settings.instance().evaluationDate = self.ql_valuation_date
        self.ql_maturity_date: ql.Date = ql.Date(self.maturity_date_py.day, self.maturity_date_py.month, self.maturity_date_py.year)
        self.issue_date_ql: ql.Date = self.ql_valuation_date # Default issue date, can be overridden by subclasses

        months_in_period = int(12 / self.freq)
        self.schedule: ql.Schedule = ql.Schedule(
            self.issue_date_ql, self.ql_maturity_date,
            ql.Period(months_in_period, ql.Months), self.calendar_ql,
            self.business_convention_ql, self.business_convention_ql, # terminationDateBusinessConvention
            ql.DateGeneration.Forward, False) # endOfMonth

        # This basic FixedRateBond is a placeholder; subclasses like Callable/Convertible will define their specific QL bond object.
        # For a pure vanilla bond, this is appropriate.
        self.bond: ql.Bond = ql.FixedRateBond(
            self.settlement_days, self.face_value, self.schedule,
            [self.coupon_rate], self.day_count_ql,
            self.business_convention_ql, self.face_value)

    @classmethod
    def from_dict(cls, params: dict) -> 'QuantLibBondStaticBase':
        # __init__ will handle parsing of valuation_date and maturity_date
        return cls(
            valuation_date=params['valuation_date'],
            maturity_date=params['maturity_date'],
            coupon_rate=float(params['coupon_rate']),
            face_value=float(params.get('face_value', 100.0)),
            freq=int(params.get('freq', 2)),
            calendar=params.get('calendar'),
            day_count=params.get('day_count'),
            business_convention=params.get('business_convention'),
            settlement_days=int(params.get('settlement_days', 0)),
            currency=params.get('currency', "USD"),
            index_stub=params.get('index_stub', "GENERIC_IR")
        )

    def to_dict(self) -> dict:
        # Determine product_type more accurately if possible
        product_type = 'VanillaBond'
        if isinstance(self, CallableBondStaticBase): product_type = 'CallableBond'
        elif isinstance(self, ConvertibleBondStaticBase): product_type = 'ConvertibleBond'

        data = {
            'product_type': product_type,
            'valuation_date': self.valuation_date_py.isoformat(),
            'maturity_date': self.maturity_date_py.isoformat(),
            'coupon_rate': self.coupon_rate, 'face_value': self.face_value,
            'freq': self.freq, 'settlement_days': self.settlement_days,
            'currency': self.currency, 'index_stub': self.index_stub,
            'calendar': self.calendar_ql.name() if self.calendar_ql else None,
            'day_count': self.day_count_ql.name() if self.day_count_ql else None,
            # 'business_convention': self.business_convention_ql # int, might need mapping
        }
        return data


class CallableBondStaticBase(QuantLibBondStaticBase):
    def __init__(
        self, valuation_date, maturity_date, coupon_rate: float,
        call_dates: list, call_prices: list[float], face_value: float = 100.0,
        freq: int = 2, calendar = None, day_count = None,
        business_convention: int = None, settlement_days: int = 0,
        currency: str = "USD", index_stub: str = "GENERIC_IR" ):
        super().__init__(valuation_date, maturity_date, coupon_rate, face_value, freq,
                         calendar, day_count, business_convention, settlement_days, currency, index_stub)
        self.call_dates_py: list[date] = _parse_date_list(call_dates)
        self.call_prices_py: list[float] = [float(p) for p in call_prices]
        self.call_schedule: ql.CallabilitySchedule = ql.CallabilitySchedule()
        if self.call_dates_py:
            for cd_py, cp in zip(self.call_dates_py, self.call_prices_py):
                if cd_py is None: continue
                ql_cd = ql.Date(cd_py.day, cd_py.month, cd_py.year)
                call = ql.Callability(ql.BondPrice(cp, ql.BondPrice.Clean), ql.Callability.Call, ql_cd)
                self.call_schedule.push_back(call)
        self.bond: ql.CallableFixedRateBond = ql.CallableFixedRateBond(
            self.settlement_days, self.face_value, self.schedule, [self.coupon_rate],
            self.day_count_ql, self.business_convention_ql, self.face_value,
            self.issue_date_ql, self.call_schedule)

    @classmethod
    def from_dict(cls, params: dict) -> 'CallableBondStaticBase':
        return cls(
            valuation_date=params['valuation_date'], maturity_date=params['maturity_date'],
            coupon_rate=float(params['coupon_rate']),
            call_dates=params.get('call_dates', []), call_prices=params.get('call_prices',[]),
            face_value=float(params.get('face_value', 100.0)), freq=int(params.get('freq', 2)),
            calendar=params.get('calendar'), day_count=params.get('day_count'),
            business_convention=params.get('business_convention'),
            settlement_days=int(params.get('settlement_days', 0)),
            currency=params.get('currency', "USD"),
            index_stub=params.get('index_stub', "GENERIC_IR"))

    def to_dict(self) -> dict:
      base = super().to_dict(); base.update({
          'product_type': 'CallableBond', # Override product_type
          'call_dates': _serialize_date_list(self.call_dates_py), 'call_prices': self.call_prices_py
      }); return base


class ConvertibleBondStaticBase(QuantLibBondStaticBase):
    def __init__(
        self, valuation_date, issue_date, maturity_date, coupon_rate: float,
        conversion_ratio: float, face_value: float = 100.0, freq: int = 2,
        settlement_days: int = 0, calendar = None, day_count = None,
        business_convention: int = None, exercise_type: str = 'EuropeanAtMaturity',
        currency: str = "USD", index_stub: str = "GENERIC_IR", underlying_symbol: str = None ):
        super().__init__(valuation_date, maturity_date, coupon_rate, face_value, freq, calendar,
                         day_count, business_convention, settlement_days, currency, index_stub)
        self.issue_date_py: date = _parse_date_input(issue_date)
        self.issue_date_ql = ql.Date(self.issue_date_py.day, self.issue_date_py.month, self.issue_date_py.year)

        months_in_period = int(12 / self.freq)
        self.schedule = ql.Schedule(
            self.issue_date_ql, self.ql_maturity_date, ql.Period(months_in_period, ql.Months),
            self.calendar_ql, self.business_convention_ql, self.business_convention_ql,
            ql.DateGeneration.Forward, False)

        self.conversion_ratio: float = float(conversion_ratio)
        self.exercise_type_str: str = exercise_type
        self.underlying_symbol: str = underlying_symbol

        if self.exercise_type_str == 'EuropeanAtMaturity':
            self.exercise: ql.Exercise = ql.EuropeanExercise(self.ql_maturity_date)
        # Add AmericanExercise or BermudanExercise handling if needed
        # elif self.exercise_type_str == 'American':
        #     self.exercise = ql.AmericanExercise(self.issue_date_ql, self.ql_maturity_date) # Example
        else: raise ValueError(f"Unsupported exercise type: {self.exercise_type_str}")

        self.convertible_call_schedule: ql.CallabilitySchedule = ql.CallabilitySchedule() # Empty for non-callable CB

        self.bond: ql.ConvertibleFixedCouponBond = ql.ConvertibleFixedCouponBond(
            self.exercise, self.conversion_ratio, self.convertible_call_schedule,
            self.issue_date_ql, self.settlement_days, [self.coupon_rate],
            self.day_count_ql, self.schedule, self.face_value)

    @classmethod
    def from_dict(cls, params: dict) -> 'ConvertibleBondStaticBase':
        return cls(
            valuation_date=params['valuation_date'], issue_date=params['issue_date'],
            maturity_date=params['maturity_date'], coupon_rate=float(params['coupon_rate']),
            conversion_ratio=float(params['conversion_ratio']),
            face_value=float(params.get('face_value', 100.0)), freq=int(params.get('freq', 2)),
            settlement_days=int(params.get('settlement_days', 0)),
            calendar=params.get('calendar'), day_count=params.get('day_count'),
            business_convention=params.get('business_convention'),
            exercise_type=params.get('exercise_type', 'EuropeanAtMaturity'),
            currency=params.get('currency', "USD"),
            index_stub=params.get('index_stub', "GENERIC_IR"),
            underlying_symbol=params.get('underlying_symbol'))

    def to_dict(self) -> dict:
        base = super().to_dict(); base.update({
            'product_type': 'ConvertibleBond', 'issue_date': self.issue_date_py.isoformat(),
            'conversion_ratio': self.conversion_ratio, 'exercise_type': self.exercise_type_str,
            'underlying_symbol': self.underlying_symbol
        }); return base


class EuropeanOptionStatic(ProductStaticBase):
    def __init__(self, valuation_date, expiry_date, strike_price: float, option_type: str,
                 day_count_convention = None, currency: str = "USD", underlying_symbol: str = None):
        super().__init__(valuation_date)
        self.expiry_date_py: date = _parse_date_input(expiry_date)
        self.strike_price: float = float(strike_price)
        self.currency: str = currency
        self.underlying_symbol: str = underlying_symbol

        if option_type.lower() not in ['call', 'put']: raise ValueError("Option type must be 'call' or 'put'")
        self.option_type: str = option_type.lower()

        ql_valuation_date = ql.Date(self.valuation_date_py.day, self.valuation_date_py.month, self.valuation_date_py.year)
        # ql.Settings.instance().evaluationDate = ql_valuation_date # Set globally once
        ql_expiry_date = ql.Date(self.expiry_date_py.day, self.expiry_date_py.month, self.expiry_date_py.year)

        if isinstance(day_count_convention, str):
            if day_count_convention.lower() == "actual365fixed": self.day_count_convention_ql = ql.Actual365Fixed()
            elif day_count_convention.lower() == "actual360": self.day_count_convention_ql = ql.Actual360()
            else: self.day_count_convention_ql = ql.Actual365Fixed()
        elif isinstance(day_count_convention, ql.DayCounter): self.day_count_convention_ql = day_count_convention
        else: self.day_count_convention_ql = ql.Actual365Fixed()

        self.time_to_expiry: float = self.day_count_convention_ql.yearFraction(ql_valuation_date, ql_expiry_date)
        if self.time_to_expiry < 0: self.time_to_expiry = 0.0

    @classmethod
    def from_dict(cls, params: dict) -> 'EuropeanOptionStatic':
        return cls(
            valuation_date=params['valuation_date'], expiry_date=params['expiry_date'],
            strike_price=float(params['strike_price']), option_type=params['option_type'],
            day_count_convention=params.get('day_count_convention'),
            currency=params.get('currency', "USD"),
            underlying_symbol=params.get('underlying_symbol'))

    def to_dict(self) -> dict:
        return {
            'product_type': 'EuropeanOption',
            'valuation_date': self.valuation_date_py.isoformat(),
            'expiry_date': self.expiry_date_py.isoformat(),
            'strike_price': self.strike_price, 'option_type': self.option_type,
            'day_count_convention': self.day_count_convention_ql.name() if self.day_count_convention_ql else None,
            'currency': self.currency, 'underlying_symbol': self.underlying_symbol
        }


In [3]:
# quantlib_custom_serializer.py
import json

def custom_json_serializer(obj):
    """
    Custom serializer for json.dumps that handles ProductStaticBase instances.
    """
    # Check if the object has a to_dict method
    if hasattr(obj, 'to_dict') and callable(obj.to_dict):
        return obj.to_dict()
    # Convert datetime.date objects to ISO format string
    elif isinstance(obj, date):
        return obj.isoformat()
    # Add checks for other non-serializable types if necessary
    # For now, rely on default JSONEncoder for built-in types
    # If other non-serializable objects (like QuantLib objects) are *not*
    # handled by the to_dict methods, they will still cause errors.
    # Ensure your to_dict methods convert all custom/QL objects to serializable types.
    # For QuantLib objects, you might need to extract simple values (like strings, floats, ints).
    # For example, a ql.Date can be converted to a string or a Python date.
    # ql.Calendar, ql.DayCounter etc. are usually not serializable directly;
    # represent them by name (string) or relevant properties in to_dict.
    elif isinstance(obj, ql.Date):
         return date(obj.year(), obj.month(), obj.day()).isoformat()
    elif isinstance(obj, ql.Calendar):
        return obj.name() # Represent calendar by its name string
    elif isinstance(obj, ql.DayCounter):
        return obj.name() # Represent day counter by its name string
    elif isinstance(obj, ql.Exercise):
        # This depends on the Exercise type. For EuropeanExercise, you might serialize the expiry date.
        if isinstance(obj, ql.EuropeanExercise):
            return {'type': 'EuropeanExercise', 'expiry_date': date(obj.date().year(), obj.date().month(), obj.date().day()).isoformat()}
        else:
             # Handle other QuantLib types or raise a TypeError if not handled
            raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable without custom handling.")
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        # Let the default encoder handle other types
        return json.JSONEncoder.default(None, obj)

In [4]:
# pricers.py
"""
Contains pricer classes for different financial products.
Pricers take a static product definition and market data to calculate a price.
Convertible bond pricer now correctly handles dynamic vs. fixed market parameters.
Ensures is_convertible and is_callable attributes are correctly set.
"""
import QuantLib as ql
import numpy as np
from scipy.stats import norm
import abc

class PricerBase(abc.ABC):
    def __init__(self, product_static: ProductStaticBase):
        self.product_static: ProductStaticBase = product_static
    @abc.abstractmethod
    def price(self, **kwargs) -> np.ndarray: pass

class FastBondPricer(PricerBase):
    def __init__(self, bond_static: QuantLibBondStaticBase):
        super().__init__(bond_static)
        self._gen_cashflows()
    def _gen_cashflows(self):
        bond_def: QuantLibBondStaticBase = self.product_static
        ql_sched, dc, ql_val = bond_def.schedule, bond_def.day_count_ql, bond_def.ql_valuation_date
        coupon_amt = bond_def.face_value * bond_def.coupon_rate / bond_def.freq
        cf_d, cf_t, cf_a = [],[],[]
        for i in range(len(ql_sched)):
            d = ql_sched[i]
            if d <= ql_val: continue
            cf_d.append(d.to_date()); cf_t.append(dc.yearFraction(ql_val, d))
            curr_cf = coupon_amt
            if d == bond_def.ql_maturity_date or d == ql_sched[-1]: curr_cf += bond_def.face_value
            cf_a.append(curr_cf)
        self.cf_dates, self.cf_times, self.cf_amounts = cf_d, np.array(cf_t,float), np.array(cf_a,float)
        valid = self.cf_times > 1e-9
        self.cf_times, self.cf_amounts = self.cf_times[valid], self.cf_amounts[valid]
        self.cf_dates = [dt for i, dt in enumerate(self.cf_dates) if valid[i]]
    def price(self, pillar_times: np.ndarray, market_scenario_data: np.ndarray, **kwargs) -> np.ndarray:
        rates = market_scenario_data
        num_p = len(pillar_times)
        if rates.ndim == 2 and rates.shape[1] > num_p : rates = rates[:, :num_p]
        elif rates.ndim == 1 and len(rates) > num_p: rates = rates[:num_p]

        if not self.cf_times.size: return np.zeros(rates.shape[0]) if rates.ndim == 2 else np.array([0.0])
        if rates.ndim == 1:
            r = np.interp(self.cf_times, pillar_times, rates); dfs = np.exp(-r*self.cf_times)
            return np.array([float(self.cf_amounts.dot(dfs))])
        r_mat = np.array([np.interp(self.cf_times, pillar_times, sr) for sr in rates])
        dfs_mat = np.exp(-r_mat * self.cf_times[None, :]); return dfs_mat.dot(self.cf_amounts)

class QuantLibBondPricer(PricerBase):
    def __init__(self, bond_static: QuantLibBondStaticBase, method: str = 'discount',
                 grid_steps: int = 100, convertible_engine_steps: int = 100):
        if not isinstance(bond_static, QuantLibBondStaticBase):
            raise TypeError("Requires QuantLibBondStaticBase derivative.")
        super().__init__(bond_static)
        self.method, self.grid_steps, self.convertible_engine_steps = method.lower(), grid_steps, convertible_engine_steps
        self.is_callable = isinstance(bond_static, CallableBondStaticBase)
        self.is_convertible = isinstance(bond_static, ConvertibleBondStaticBase)


    def _make_term_structure(self, pillar_times: np.ndarray, rates_vec: np.ndarray) -> ql.YieldTermStructureHandle:
        p_times_np = np.asarray(pillar_times, float)
        base_d: ql.Date = self.product_static.ql_valuation_date
        dates, eff_rates = ql.DateVector(), list(rates_vec)
        if not p_times_np.size or (p_times_np.size > 0 and p_times_np[0] > 1e-6):
             dates.push_back(base_d); eff_rates.insert(0, rates_vec[0] if rates_vec.size > 0 else 0.0)
        for t_val in p_times_np: dates.push_back(base_d + ql.Period(int(round(t_val*365.0)), ql.Days))
        if len(eff_rates) != len(dates):
             if len(dates)==1 and not eff_rates: eff_rates.append(0.0)

        cal = self.product_static.calendar_ql if hasattr(self.product_static, 'calendar_ql') else ql.TARGET()
        dc = self.product_static.day_count_ql if hasattr(self.product_static, 'day_count_ql') else ql.ActualActual(ql.ActualActual.ISDA)

        zc = ql.ZeroCurve(dates, eff_rates, dc, cal, ql.Linear(), ql.Continuous, ql.Annual)
        zc.enableExtrapolation(); return ql.YieldTermStructureHandle(zc)

    @staticmethod
    def _price_vanilla_static(bond: ql.Bond, ts_handle: ql.YieldTermStructureHandle) -> float:
        eng = ql.DiscountingBondEngine(ts_handle); bond.setPricingEngine(eng); return bond.NPV()
    @staticmethod
    def _price_callable_static(bond: ql.CallableFixedRateBond, ts_handle: ql.YieldTermStructureHandle,
                               params: tuple, steps: int) -> float:
        if params is None: raise ValueError("G2 model parameters must be provided for callable bond pricing.")
        a,sig,b,eta,rho=params; model=ql.G2(ts_handle,a,sig,b,eta,rho)
        eng=ql.TreeCallableFixedRateBondEngine(model,steps); bond.setPricingEngine(eng); return bond.cleanPrice()
    @staticmethod
    def _price_convertible_static(
        bond: ql.ConvertibleFixedCouponBond, ts_handle: ql.YieldTermStructureHandle,
        static_def: ConvertibleBondStaticBase, eng_steps: int,
        s0: float, div_yield: float, eq_vol: float, credit_spread: float ) -> float:
        eval_d = ql.Settings.instance().evaluationDate
        s0_h = ql.QuoteHandle(ql.SimpleQuote(s0))
        day_count = static_def.day_count_ql
        calendar = static_def.calendar_ql
        div_h = ql.YieldTermStructureHandle(ql.FlatForward(eval_d, div_yield, day_count))
        vol_h = ql.BlackVolTermStructureHandle(ql.BlackConstantVol(eval_d, calendar, eq_vol, day_count))
        cs_h = ql.QuoteHandle(ql.SimpleQuote(credit_spread))
        proc = ql.BlackScholesMertonProcess(s0_h, div_h, ts_handle, vol_h)
        eng = ql.BinomialCRRConvertibleEngine(proc, eng_steps, cs_h); bond.setPricingEngine(eng)
        return bond.NPV()

    def price(self, pillar_times: np.ndarray, market_scenario_data: np.ndarray, **kwargs) -> np.ndarray:
        if market_scenario_data.ndim == 1: market_scenario_data = market_scenario_data.reshape(1, -1)
        prices = []
        is_g2_single = self.method=='g2' and (isinstance(kwargs.get('g2_params'), tuple) or (isinstance(kwargs.get('g2_params'), list) and kwargs.get('g2_params') and isinstance(kwargs.get('g2_params')[0], (float,int))))
        if self.method=='g2' and kwargs.get('g2_params') is not None and not is_g2_single and len(kwargs.get('g2_params')) != market_scenario_data.shape[0]:
            raise ValueError("List of g2_params must match number of scenarios for G2 method.")

        for i, scen_data_row in enumerate(market_scenario_data):
            current_g2_p = None
            if self.method == 'g2':
                 current_g2_p = kwargs.get('g2_params') if is_g2_single else (kwargs.get('g2_params')[i] if kwargs.get('g2_params') else None)
            prices.append(self._price_single_curve_logic(pillar_times, scen_data_row, current_g2_p, **kwargs))
        return np.array(prices)

    def _price_single_curve_logic(self,
                                  pillar_times_np: np.ndarray,
                                  market_data_for_scenario_row: np.ndarray,
                                  g2_p_for_this_scen=None, **other_fixed_params) -> float:
        if self.product_static and hasattr(self.product_static, 'ql_valuation_date'):
             ql.Settings.instance().evaluationDate = self.product_static.ql_valuation_date
        else:
             today = ql.Date().todaysDate()
             print(f"Warning: product_static.ql_valuation_date not found, using QL today's date: {today}")
             ql.Settings.instance().evaluationDate = today


        num_rate_pillars = len(pillar_times_np)
        rates_scen = market_data_for_scenario_row[:num_rate_pillars]
        ts_handle = self._make_term_structure(pillar_times_np, rates_scen)

        if self.is_convertible and self.method == 'convertible_binomial':
            if not isinstance(self.product_static, ConvertibleBondStaticBase):
                raise TypeError("Product is not ConvertibleBondStaticBase for CB pricing.")

            num_dynamic_non_rate_factors = len(market_data_for_scenario_row) - num_rate_pillars

            s0_val = other_fixed_params.get('s0_val')
            div_val = other_fixed_params.get('dividend_yield')
            eq_vol_val = other_fixed_params.get('equity_volatility')
            cs_val = other_fixed_params.get('credit_spread')

            current_idx_in_scen_data = num_rate_pillars

            if num_dynamic_non_rate_factors >= 1:
                s0_val = market_data_for_scenario_row[current_idx_in_scen_data]
                current_idx_in_scen_data +=1
            if num_dynamic_non_rate_factors >= 2:
                div_val = market_data_for_scenario_row[current_idx_in_scen_data]
                current_idx_in_scen_data +=1
            if num_dynamic_non_rate_factors >= 3:
                eq_vol_val = market_data_for_scenario_row[current_idx_in_scen_data]
                current_idx_in_scen_data +=1
            if num_dynamic_non_rate_factors >= 4:
                cs_val = market_data_for_scenario_row[current_idx_in_scen_data]

            if s0_val is None: raise ValueError("S0 value missing for convertible bond pricing.")
            if div_val is None: raise ValueError("Dividend yield missing for convertible bond pricing.")
            if eq_vol_val is None: raise ValueError("Equity volatility missing for convertible bond pricing.")
            if cs_val is None: raise ValueError("Credit spread missing for convertible bond pricing.")

            return self._price_convertible_static(
                self.product_static.bond, ts_handle, self.product_static,
                self.convertible_engine_steps, s0_val, div_val, eq_vol_val, cs_val)

        elif self.is_callable and self.method == 'g2':
            if g2_p_for_this_scen is None: raise ValueError("g2_params needed for G2 pricing.")
            if not isinstance(self.product_static.bond, ql.CallableFixedRateBond):
                 raise TypeError(f"Expected ql.CallableFixedRateBond, got {type(self.product_static.bond)}")
            return self._price_callable_static(self.product_static.bond, ts_handle, g2_p_for_this_scen, self.grid_steps)
        elif self.method == 'discount':
            if not isinstance(self.product_static.bond, ql.Bond):
                 raise TypeError(f"Expected ql.Bond, got {type(self.product_static.bond)}")
            return self._price_vanilla_static(self.product_static.bond, ts_handle)
        raise ValueError(f"Unsupported method '{self.method}' or product for QL BondPricer.")


class BlackScholesPricer(PricerBase):
    def __init__(self, option_static: EuropeanOptionStatic,
                 risk_free_rate: float, dividend_yield: float = 0.0):
        if not isinstance(option_static, EuropeanOptionStatic):
            raise TypeError("Requires EuropeanOptionStatic.")
        super().__init__(option_static)
        self.risk_free_rate, self.dividend_yield = risk_free_rate, dividend_yield
    def price(self, stock_price: np.ndarray, volatility: np.ndarray, **kwargs) -> np.ndarray:
        S, sig = np.asarray(stock_price), np.asarray(volatility)
        opt: EuropeanOptionStatic = self.product_static
        K, T, r, q,otype = opt.strike_price,opt.time_to_expiry,self.risk_free_rate,self.dividend_yield,opt.option_type
        is_scalar = (S.ndim==0 and sig.ndim==0)
        if S.ndim==0 and sig.ndim>0: S=np.full_like(sig,S)
        if sig.ndim==0 and S.ndim>0: sig=np.full_like(S,sig)
        price = np.zeros_like(S,float) if not is_scalar else 0.0
        if T<=1e-9: price = np.maximum(S-K if otype=='call' else K-S,0.0); return price
        sig_calc = np.maximum(sig,1e-16)
        with np.errstate(divide='ignore',invalid='ignore'):
            d1=(np.log(S/K)+(r-q+0.5*sig_calc**2)*T)/(sig_calc*np.sqrt(T))
            d2=d1-sig_calc*np.sqrt(T)
            d1=np.where(S<=0,-np.inf if otype=='call' else np.inf,d1)
            d2=np.where(S<=0,-np.inf if otype=='call' else np.inf,d2)
            d1=np.nan_to_num(d1,nan=-np.inf if otype=='call' else np.inf,posinf=np.inf,neginf=-np.inf)
            d2=np.nan_to_num(d2,nan=-np.inf if otype=='call' else np.inf,posinf=np.inf,neginf=-np.inf)
        price_bs = (S*np.exp(-q*T)*norm.cdf(d1)-K*np.exp(-r*T)*norm.cdf(d2)) if otype=='call' \
              else (K*np.exp(-r*T)*norm.cdf(-d2)-S*np.exp(-q*T)*norm.cdf(-d1))
        price_bs = np.nan_to_num(price_bs,nan=0.0)
        if np.any(sig<=1e-9):
            intr = np.maximum(S*np.exp(-q*T)-K*np.exp(-r*T) if otype=='call' else K*np.exp(-r*T)-S*np.exp(-q*T),0.0)
            if is_scalar and (sig<=1e-9): price=intr
            elif not is_scalar: price=np.where(sig<=1e-9,intr,price_bs)
            else: price=np.where(sig<=1e-9,intr,price_bs)
        else: price=price_bs
        return np.maximum(price,0.0)


In [5]:

# g2_model.py
"""
Contains the G2Calibrator class for calibrating the G2++ interest rate model.
"""
import QuantLib as ql

class G2Calibrator:
    """
    Calibrates a G2++ (two-factor Hull-White) interest rate model
    to market cap/floor quotes.

    The G2++ model has 5 parameters: a, sigma, b, eta, rho.
    - a: mean reversion speed of the first factor (x)
    - sigma: volatility of the first factor (x)
    - b: mean reversion speed of the second factor (y)
    - eta: volatility of the second factor (y)
    - rho: correlation between the Wiener processes driving x and y

    Args:
        ts_handle (ql.YieldTermStructureHandle): Handle to the initial yield term structure
                                                 to which the model will be fitted. This term
                                                 structure defines the initial discount curve.
        index (ql.IborIndex or ql.OvernightIndex): The interest rate index underlying the
                                                   calibration instruments (e.g., caps/floors).
                                                   This index is used by the CapHelper to
                                                   determine caplet cashflows and fixings.
    """
    def __init__(self, ts_handle: ql.YieldTermStructureHandle, index: ql.InterestRateIndex):
        self.ts_handle: ql.YieldTermStructureHandle = ts_handle
        self.index: ql.InterestRateIndex = index
        # Initialize a G2 model. The parameters (a, sigma, b, eta, rho) will be calibrated.
        # The model is initialized with the provided term structure.
        self.model: ql.G2 = ql.G2(self.ts_handle)

    def calibrate(
        self,
        periods: list[ql.Period],
        quotes: list[ql.QuoteHandle],
        optimization_method: ql.OptimizationMethod = None,
        end_criteria: ql.EndCriteria = None,
        engine_steps: int = 50
    ) -> tuple[float, float, float, float, float]:
        """
        Calibrates the G2 model parameters to market cap volatilities.

        Args:
            periods (list[ql.Period]): List of ql.Period objects for cap tenors
                                       (e.g., [ql.Period('1Y'), ql.Period('2Y'), ...]).
                                       There should be at least 5 instruments for 5 parameters.
            quotes (list[ql.QuoteHandle]): List of ql.QuoteHandle objects containing the market
                                           volatilities (or prices) for the corresponding periods.
                                           Assumes Normal volatility for caps.
            optimization_method (ql.OptimizationMethod, optional):
                                QuantLib optimization method (e.g., ql.LevenbergMarquardt).
                                Defaults to a standard LevenbergMarquardt.
            end_criteria (ql.EndCriteria, optional):
                                QuantLib end criteria for the optimization.
                                Defaults to a standard EndCriteria.
            engine_steps (int, optional): Number of time steps for the TreeCapFloorEngine
                                          used to price caps during calibration. Affects accuracy
                                          and speed of calibration. Defaults to 50.

        Returns:
            tuple[float, float, float, float, float]:
                A tuple containing the calibrated G2 model parameters: (a, sigma, b, eta, rho).

        Raises:
            ValueError: If the number of periods and quotes do not match.
            RuntimeError: If calibration fails in QuantLib (e.g., due to insufficient instruments
                          or optimization issues).
        """
        if len(periods) != len(quotes):
            raise ValueError("Length of periods and quotes must match for calibration.")
        if len(periods) < 5: # G2++ has 5 parameters
            # This is a common cause of calibration failure.
            print(f"Warning: Number of calibration instruments ({len(periods)}) is less than 5. "
                  "G2++ calibration might be unstable or fail due to under-specification.")

        helpers = []
        for period_obj, quote_handle in zip(periods, quotes):
            # Create a CapHelper for each market quote.
            # The CapHelper links the market quote (volatility) to the model price of the cap.
            # Parameters like cap frequency, day count should ideally match the market convention
            # of the provided quotes.
            helper = ql.CapHelper(
                period_obj,
                quote_handle,
                self.index,                     # IborIndex underlying the cap
                ql.Semiannual,                 # Cap/floor coupon frequency (e.g., Semiannual)
                self.index.dayCounter() if self.index.dayCounter() else ql.Actual360(), # Day count for caplets
                False,                         # Not used for caps (related to first fixing)
                self.ts_handle,                # Initial term structure for discounting
                ql.BlackCalibrationHelper.RelativePriceError, # Error type for calibration objective function
                ql.Normal,                     # Volatility type (e.g., Normal, ShiftedLognormal)
                0.0                            # Shift (if using ShiftedLognormal)
            )
            # The pricing engine for the helper uses the G2 model instance that is being calibrated.
            # engine_steps for the tree engine affects accuracy/speed of pricing each cap during calibration.
            engine = ql.TreeCapFloorEngine(self.model, engine_steps)
            helper.setPricingEngine(engine)
            helpers.append(helper)

        # Default optimization method and end criteria if not provided
        opt_method = optimization_method or ql.LevenbergMarquardt(1e-8, 1e-8, 1e-8)
        crit = end_criteria or ql.EndCriteria(
            maxIterations=10000,
            maxStationaryStateIterations=1000,
            rootEpsilon=1e-8,
            functionEpsilon=1e-8,
            gradientNormEpsilon=1e-8
        )

        # Perform the calibration. This modifies self.model in-place.
        self.model.calibrate(helpers, opt_method, crit)

        # Return the calibrated parameters from the model
        # params() returns [a, sigma, b, eta, rho]
        return self.model.params()


In [6]:
# scenario_generator.py
"""
Contains classes for generating market scenarios for different risk factors.
SimpleRandomScenarioGenerator can now generate scenarios for a targeted list of factors.
"""
import numpy as np
import abc

class ScenarioGeneratorBase(abc.ABC):
    """
    Abstract base class for scenario generators.
    """
    @abc.abstractmethod
    def generate_scenarios(self, num_scenarios: int, target_factor_names: list[str] = None) -> tuple[np.ndarray, list[str]]:
        """
        Generates market scenarios.

        Args:
            num_scenarios (int): The number of scenarios to generate.
            target_factor_names (list[str], optional): A specific list of factor names
                to generate scenarios for. If None, generates for all configured factors.

        Returns:
            tuple[np.ndarray, list[str]]:
                - A 2D NumPy array where rows are scenarios and columns are risk factors.
                - A list of strings representing the names of the risk factors (columns),
                  matching the order of target_factor_names if provided, otherwise self.factor_names_ordered.
        """
        pass

class SimpleRandomScenarioGenerator(ScenarioGeneratorBase):
    """
    Generates scenarios using simple random shocks around base values.
    Can generate for all configured factors or a targeted subset.
    """
    def __init__(self,
                 base_rates_map: dict[str, float] = None,
                 rate_factor_shock_std_dev_map: dict[str, float] = None,
                 base_s0_map: dict[str, float] = None,
                 s0_shock_config_map: dict[str, tuple[str, float]] = None,
                 base_vol_map: dict[str, float] = None,
                 vol_shock_config_map: dict[str, tuple[str, float]] = None,
                 default_rate_shock_std_dev: float = 0.0010,
                 default_s0_shock_config: tuple[str, float] = ('normal', 0.10),
                 default_vol_shock_config: tuple[str, float] = ('normal', 0.05),
                 random_seed: int = None):

        self.base_rates_map = base_rates_map if base_rates_map is not None else {}
        self.rate_factor_shock_std_dev_map = rate_factor_shock_std_dev_map if rate_factor_shock_std_dev_map is not None else {}
        self.base_s0_map = base_s0_map if base_s0_map is not None else {}
        self.s0_shock_config_map = s0_shock_config_map if s0_shock_config_map is not None else {}
        self.base_vol_map = base_vol_map if base_vol_map is not None else {}
        self.vol_shock_config_map = vol_shock_config_map if vol_shock_config_map is not None else {}

        self.default_rate_shock_std_dev = default_rate_shock_std_dev
        self.default_s0_shock_config = default_s0_shock_config
        self.default_vol_shock_config = default_vol_shock_config

        self.rng = np.random.default_rng(random_seed)

        # Stores all unique factor names this generator is configured for, in a defined order.
        self._configured_factor_names_ordered = sorted(list(
            set(self.base_rates_map.keys()) |
            set(self.base_s0_map.keys()) |
            set(self.base_vol_map.keys())
        ))
        if not self._configured_factor_names_ordered and (self.base_rates_map or self.base_s0_map or self.base_vol_map):
            raise ValueError("Could not determine ordered factor names. Ensure keys in maps are strings.")

    def generate_scenarios(self, num_scenarios: int, target_factor_names: list[str] = None) -> tuple[np.ndarray, list[str]]:
        """
        Generates scenarios. If target_factor_names is provided, only generates for those factors.
        Otherwise, generates for all configured factors.
        """
        if target_factor_names is None:
            factors_to_generate = self._configured_factor_names_ordered
        else:
            # Validate that all target_factor_names are known to this generator
            for name in target_factor_names:
                if not (name in self.base_rates_map or \
                        name in self.base_s0_map or \
                        name in self.base_vol_map):
                    raise ValueError(f"Target factor name '{name}' is not configured in this scenario generator. "
                                     f"Known factors: {self._configured_factor_names_ordered}")
            factors_to_generate = target_factor_names # Generate in the requested order

        if not factors_to_generate:
            return np.array([]).reshape(num_scenarios, 0), []

        all_scenario_columns = []

        for factor_name in factors_to_generate:
            factor_column = np.zeros(num_scenarios)

            if factor_name in self.base_rates_map:
                base_value = self.base_rates_map[factor_name]
                shock_std_dev = self.rate_factor_shock_std_dev_map.get(factor_name, self.default_rate_shock_std_dev)
                shocks = self.rng.normal(loc=0.0, scale=shock_std_dev, size=num_scenarios)
                factor_column = base_value + shocks

            elif factor_name in self.base_s0_map:
                base_value = self.base_s0_map[factor_name]
                shock_type, shock_param = self.s0_shock_config_map.get(factor_name, self.default_s0_shock_config)
                if shock_type.lower() == 'normal':
                    actual_std_dev = shock_param * base_value if shock_param <= 1.0 else shock_param
                    shocks = self.rng.normal(loc=0.0, scale=actual_std_dev, size=num_scenarios)
                    factor_column = base_value + shocks
                elif shock_type.lower() == 'uniform':
                    half_width = shock_param * base_value if shock_param <= 1.0 else shock_param
                    factor_column = self.rng.uniform(base_value - half_width, base_value + half_width, size=num_scenarios)
                else:
                    raise ValueError(f"Unsupported shock_type: {shock_type} for S0 factor {factor_name}")
                factor_column = np.maximum(factor_column, 1e-6)

            elif factor_name in self.base_vol_map:
                base_value = self.base_vol_map[factor_name]
                shock_type, shock_param = self.vol_shock_config_map.get(factor_name, self.default_vol_shock_config)
                if shock_type.lower() == 'normal':
                    actual_std_dev = shock_param * base_value if shock_param <= 1.0 else shock_param
                    shocks = self.rng.normal(loc=0.0, scale=actual_std_dev, size=num_scenarios)
                    factor_column = base_value + shocks
                elif shock_type.lower() == 'uniform':
                    half_width = shock_param * base_value if shock_param <= 1.0 else shock_param
                    factor_column = self.rng.uniform(base_value - half_width, base_value + half_width, size=num_scenarios)
                else:
                    raise ValueError(f"Unsupported shock_type: {shock_type} for Vol factor {factor_name}")
                factor_column = np.maximum(factor_column, 1e-6)

            else:
                # This case should ideally not be reached if target_factor_names are validated
                # or if using _configured_factor_names_ordered.
                raise ValueError(f"Factor name {factor_name} not found in any configuration map during generation.")

            all_scenario_columns.append(factor_column[:, np.newaxis])

        return np.hstack(all_scenario_columns), factors_to_generate # Return the names of factors generated


In [15]:
# tff_approximator.py
"""
Contains classes and functions for Tensor Functional Form (TFF) approximation.
TensorFunctionalFormCalibrate.__init__ is simplified to accept explicit TFF inputs.
Date handling in worker function uses isoformat.
"""
import numpy as np
from scipy.stats.qmc import LatinHypercube, Sobol, scale
from concurrent.futures import ProcessPoolExecutor
import QuantLib as ql
from datetime import date, datetime
import re
import os

def engineer_option_features(
    s0_values: np.ndarray, vol_values: np.ndarray, order: int = 2
) -> tuple[np.ndarray, list[str]]:
    s0, vol = np.asarray(s0_values), np.asarray(vol_values)
    if s0.shape != vol.shape or s0.ndim != 1: raise ValueError("Invalid s0/vol shapes.")
    features = [s0, vol]; feature_names = ['S0_eng', 'Vol_eng']
    if order >= 2:
        features.extend([s0**2, vol**2, s0 * vol])
        feature_names.extend(['S0^2_eng', 'Vol^2_eng', 'S0*Vol_eng'])
    if order >= 3:
        features.extend([s0**3, vol**3, (s0**2)*vol, s0*(vol**2)])
        feature_names.extend(['S0^3_eng', 'Vol^3_eng', 'S0^2*Vol_eng', 'S0*Vol^2_eng'])
    if order >= 4:
        features.extend([s0**4, vol**4, (s0**3)*vol, s0*(vol**3), (s0**2)*(vol**2)])
        feature_names.extend(['S0^4_eng', 'Vol^4_eng', 'S0^3*Vol_eng', 'S0*Vol^3_eng', 'S0^2*Vol^2_eng'])
    return np.vstack(features).T, feature_names

def normalize_features(
    features: np.ndarray, means: np.ndarray = None, stds: np.ndarray = None
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    if features.ndim != 2: raise ValueError("Features must be 2D array.")
    if means is None or stds is None:
        means, stds = np.mean(features, axis=0), np.std(features, axis=0)
        stds[stds < 1e-8] = 1.0
    return (features - means) / stds, means, stds

def _parse_numeric_pillars_from_factor_names(factor_names: list[str]) -> np.ndarray:
    parsed_pillars = []
    for name_str in factor_names:
        match = re.search(r'(\d+(\.\d+)?)(?=Y)', name_str)
        if not match: match = re.search(r'(\d+(\.\d+)?)', name_str)
        if match:
            try:
                if any(sub.upper() in name_str.upper() for sub in ["RATE", "IR", "CURVE", "YIELD"]) and "Y" in name_str.upper():
                    parsed_pillars.append(float(match.group(1)))
            except ValueError: pass
        else:
            try:
                if not any(equity_tag in name_str.upper() for equity_tag in ["S0", "VOL", "EQUITY", "STOCK", "DIVYIELD", "CS"]):
                    parsed_pillars.append(float(name_str))
            except ValueError: pass
    return np.array(sorted(list(set(parsed_pillars))), dtype=float) if parsed_pillars else np.array([], dtype=float)


class TensorFunctionalForm:
    def __init__(self, A: np.ndarray, b: np.ndarray, c: float):
        self.A, self.b, self.c = np.asarray(A,float), np.asarray(b,float), float(c)
        if self.A.ndim!=2 or self.A.shape[0]!=self.A.shape[1]: raise ValueError("A must be square.")
        if self.b.ndim!=1 or self.b.shape[0]!=self.A.shape[0]: raise ValueError("b dim must match A.")
        self.D: int = self.A.shape[0]
    def __call__(self, x: np.ndarray) -> np.ndarray:
        x_arr = np.asarray(x)
        if x_arr.ndim == 1:
            if x_arr.shape[0]!=self.D: raise ValueError(f"Input dim {x_arr.shape[0]} != model dim {self.D}.")
            return float(x_arr @ self.A @ x_arr + self.b @ x_arr + self.c)
        elif x_arr.ndim == 2:
            if x_arr.shape[1]!=self.D: raise ValueError(f"Input shape {x_arr.shape}, expected (N, {self.D}).")
            return np.sum((x_arr @ self.A) * x_arr, axis=1) + x_arr @ self.b + self.c
        raise ValueError(f"Input must be 1D or 2D, got ndim={x_arr.ndim}")
    def to_dict(self) -> dict: return {'A':self.A.tolist(),'b':self.b.tolist(),'c':self.c,'D':self.D}
    @classmethod
    def from_dict(cls, data:dict) -> 'TensorFunctionalForm':
        if not all(k in data for k in ['A','b','c']): raise ValueError("Missing keys in TFF data dict.")
        return cls(np.array(data['A'],float), np.array(data['b'],float), data['c'])


def _price_one_scenario_for_tff(worker_args: tuple) -> float:
    (product_static_params_dict, pricer_config_for_worker,
     factor_names_for_tff, single_market_scenario_data,
     valuation_date_for_worker_iso, price_kwargs_dict) = worker_args

    valuation_date_for_worker = _parse_date_input(valuation_date_for_worker_iso)
    ql_val_date = ql.Date(valuation_date_for_worker.day, valuation_date_for_worker.month, valuation_date_for_worker.year)
    ql.Settings.instance().evaluationDate = ql_val_date

    product_type = product_static_params_dict['product_type']
    actual_rate_pillars_for_worker = np.asarray(product_static_params_dict.get('actual_rate_pillars', []), dtype=float)

    current_static_params = product_static_params_dict.copy()
    current_static_params['valuation_date'] = valuation_date_for_worker

    if product_type == 'VanillaBond': product_static_obj = QuantLibBondStaticBase.from_dict(current_static_params)
    elif product_type == 'CallableBond': product_static_obj = CallableBondStaticBase.from_dict(current_static_params)
    elif product_type == 'ConvertibleBond': product_static_obj = ConvertibleBondStaticBase.from_dict(current_static_params)
    elif product_type == 'EuropeanOption': product_static_obj = EuropeanOptionStatic.from_dict(current_static_params)
    else: raise ValueError(f"Unknown product type for TFF worker: {product_type}")

    if product_type in ['VanillaBond', 'CallableBond', 'ConvertibleBond']:
        pricer_instance = QuantLibBondPricer(product_static_obj, **pricer_config_for_worker.get('bond_pricer_config',{}))
        market_data_for_ql_pricer = np.array([single_market_scenario_data])

        price_result_array = pricer_instance.price(
            pillar_times=actual_rate_pillars_for_worker,
            market_scenario_data=market_data_for_ql_pricer,
            **price_kwargs_dict
        )
        return price_result_array[0]
    elif product_type == 'EuropeanOption':
        bs_config = pricer_config_for_worker.get('bs_pricer_config', {})
        if 'risk_free_rate' not in bs_config:
            raise ValueError("'risk_free_rate' missing in bs_pricer_config for EuropeanOption TFF worker.")
        pricer_instance = BlackScholesPricer(product_static_obj, bs_config['risk_free_rate'], bs_config.get('dividend_yield',0.0))
        if len(single_market_scenario_data) == 2: # S0 and Vol
            return pricer_instance.price(stock_price=single_market_scenario_data[0], volatility=single_market_scenario_data[1])
        else: raise ValueError(f"Option TFF expects 2 inputs (S0, Vol), got {len(single_market_scenario_data)}")
    raise ValueError(f"Pricer path failed for product type: {product_type}")


class TensorFunctionalFormCalibrate:
    def __init__(
        self,
        pricer_template: PricerBase,
        tff_input_raw_factor_names: list[str],
        tff_input_raw_base_values: np.ndarray,
        product_static_params_for_worker: dict, # Renamed from _override
        pricer_config_for_worker: dict,         # Renamed from _override
        actual_rate_pillars: np.ndarray = None  # Renamed from _override
    ):
        self.pricer_template = pricer_template
        self.product_static: ProductStaticBase = pricer_template.product_static

        self.tff_input_raw_factor_names = tff_input_raw_factor_names
        self.tff_input_raw_base_values = tff_input_raw_base_values

        self.product_static_params_for_worker = product_static_params_for_worker
        self.pricer_config_for_worker = pricer_config_for_worker
        self.actual_rate_pillars = actual_rate_pillars if actual_rate_pillars is not None else np.array([])

        val_date_from_params = self.product_static_params_for_worker.get('valuation_date')
        if isinstance(val_date_from_params, str):
            self.valuation_date_for_ql_settings_in_worker = date.fromisoformat(val_date_from_params)
        elif isinstance(val_date_from_params, date):
            self.valuation_date_for_ql_settings_in_worker = val_date_from_params
        elif self.product_static and hasattr(self.product_static, 'valuation_date_py'):
             self.valuation_date_for_ql_settings_in_worker = self.product_static.valuation_date_py
        else:
            raise TypeError("valuation_date in product_static_params_for_worker must be an ISO string or date object, or available on product_static.")


        self.product_type_str = self.product_static_params_for_worker.get('product_type')
        if not self.product_type_str:
            if isinstance(self.product_static, EuropeanOptionStatic): self.product_type_str = 'EuropeanOption'
            elif isinstance(self.product_static, CallableBondStaticBase): self.product_type_str = 'CallableBond'
            elif isinstance(self.product_static, ConvertibleBondStaticBase): self.product_type_str = 'ConvertibleBond'
            elif isinstance(self.product_static, QuantLibBondStaticBase): self.product_type_str = 'VanillaBond'
            else: raise TypeError(f"Cannot determine product_type_str for TFFCalibrate from {type(self.product_static)}")

        if self.product_type_str in ['VanillaBond', 'CallableBond', 'ConvertibleBond']:
            # Ensure actual_rate_pillars is correctly set up from override
            if 'actual_rate_pillars' not in self.product_static_params_for_worker or \
               not isinstance(self.product_static_params_for_worker['actual_rate_pillars'], list):
                if self.actual_rate_pillars.size > 0 : # Use the passed override
                     self.product_static_params_for_worker['actual_rate_pillars'] = self.actual_rate_pillars.tolist()
                else:
                    print(f"Warning: 'actual_rate_pillars' not in worker_params or override for {self.product_type_str}. Parsing from TFF input names.")
                    self.actual_rate_pillars = _parse_numeric_pillars_from_factor_names(self.tff_input_raw_factor_names)
                    self.product_static_params_for_worker['actual_rate_pillars'] = self.actual_rate_pillars.tolist()
            else: # Already a list in worker_params
                self.actual_rate_pillars = np.array(self.product_static_params_for_worker['actual_rate_pillars'])

        if not self.tff_input_raw_factor_names or self.tff_input_raw_base_values.size == 0:
            raise RuntimeError(f"TFF input factors/base values not set for {self.product_type_str}")
        if len(self.tff_input_raw_factor_names) != len(self.tff_input_raw_base_values):
            raise RuntimeError(f"Mismatch TFF factor names ({len(self.tff_input_raw_factor_names)}) and base values ({len(self.tff_input_raw_base_values)}).")


    def sample_and_fit(
        self, full_market_scenarios_for_tff_factors: np.ndarray,
        n_train: int = 50, n_test: int = 20,
        random_seed: int = 0, sampling_method: str = 'sobol', parallel_workers: int = None,
        option_feature_order: int = 0, **price_kwargs
    ) -> tuple[TensorFunctionalForm, np.ndarray, np.ndarray, float, dict]:

        rng_np = np.random.default_rng(random_seed)
        num_tff_factors = len(self.tff_input_raw_factor_names)

        if full_market_scenarios_for_tff_factors.ndim != 2 or \
           full_market_scenarios_for_tff_factors.shape[1] != num_tff_factors:
            raise ValueError(f"Shape error for scenarios. Expected (N, {num_tff_factors}), got {full_market_scenarios_for_tff_factors.shape}. Factors: {self.tff_input_raw_factor_names}")

        domain_min, domain_max = np.min(full_market_scenarios_for_tff_factors, axis=0), np.max(full_market_scenarios_for_tff_factors, axis=0)

        train_tff_inputs_raw = None
        if sampling_method == 'sobol':
            sampler = Sobol(d=num_tff_factors, scramble=True, seed=random_seed)
            train_tff_inputs_raw = scale(sampler.random(n=n_train), domain_min, domain_max)
        elif sampling_method == 'lhs':
            sampler_lhs = LatinHypercube(d=num_tff_factors, centered=True, seed=random_seed)
            train_tff_inputs_raw = scale(sampler_lhs.random(n=n_train), domain_min, domain_max)
        elif sampling_method == 'uniform':
            train_tff_inputs_raw = rng_np.uniform(low=domain_min, high=domain_max, size=(n_train, num_tff_factors))
        else: raise ValueError(f"Unknown sampling: {sampling_method}.")

        worker_args_list = [(self.product_static_params_for_worker, self.pricer_config_for_worker,
             self.tff_input_raw_factor_names, train_tff_inputs_raw[i],
             self.valuation_date_for_ql_settings_in_worker.isoformat(),
             price_kwargs) for i in range(n_train)]

        if parallel_workers is not False and parallel_workers != 0 and n_train > 1:
            actual_workers = parallel_workers if parallel_workers else (os.cpu_count() if os.cpu_count() else 1)
            #print(f"   Generating {n_train} training prices in parallel (workers={actual_workers})...")
            with ProcessPoolExecutor(max_workers=actual_workers) as executor:
                train_prices = np.array(list(executor.map(_price_one_scenario_for_tff, worker_args_list)))
        else:
            #print(f"   Generating {n_train} training prices sequentially...")
            train_prices = np.array([_price_one_scenario_for_tff(args) for args in worker_args_list])

        if train_prices.ndim == 0 and n_train == 1: train_prices = np.array([train_prices])
        if train_prices.shape[0] != n_train: raise ValueError(f"Shape of train_prices {train_prices.shape} != n_train {n_train}")

        tff_inputs_for_fitting = train_tff_inputs_raw
        normalization_params = {'means': None, 'stds': None, 'engineered_feature_names': self.tff_input_raw_factor_names, 'is_engineered': False}

        if self.product_type_str == 'EuropeanOption' and option_feature_order > 0:
            #print(f"   Engineering features for option TFF (order={option_feature_order})...")
            if train_tff_inputs_raw.shape[1] != 2: raise ValueError(f"Option FE expects 2 raw inputs, got {train_tff_inputs_raw.shape[1]}")
            engineered_features_train, eng_names = engineer_option_features(train_tff_inputs_raw[:,0], train_tff_inputs_raw[:,1], order=option_feature_order)
            tff_inputs_for_fitting, means, stds = normalize_features(engineered_features_train)
            normalization_params = {'means':means.tolist(), 'stds':stds.tolist(), 'engineered_feature_names':eng_names, 'is_engineered':True}

        D_eff = tff_inputs_for_fitting.shape[1]
        X_train = np.hstack([np.array([np.outer(s,s).flatten() for s in tff_inputs_for_fitting]), tff_inputs_for_fitting, np.ones((n_train,1))])
        if np.any(np.isnan(X_train)) or np.any(np.isinf(X_train)): raise ValueError("NaN/Inf in X_train.")
        if np.any(np.isnan(train_prices)) or np.any(np.isinf(train_prices)): raise ValueError("NaN/Inf in train_prices.")
        try: coeffs,_,_,_ = np.linalg.lstsq(X_train, train_prices, rcond=None)
        except np.linalg.LinAlgError as e: raise np.linalg.LinAlgError(f"Lstsq failed: {e}.")

        A_flat = coeffs[:D_eff*D_eff]
        A_mat = A_flat.reshape(D_eff,D_eff); A_sym = 0.5*(A_mat+A_mat.T)
        b_vec, c_s = coeffs[D_eff*D_eff : D_eff*D_eff+D_eff], coeffs[-1]
        fitted_tff = TensorFunctionalForm(A_sym, b_vec, c_s)

        test_idx = rng_np.choice(full_market_scenarios_for_tff_factors.shape[0], size=n_test, replace=False)
        test_tff_inputs_raw = full_market_scenarios_for_tff_factors[test_idx]
        test_worker_args = [(self.product_static_params_for_worker, self.pricer_config_for_worker,
             self.tff_input_raw_factor_names, test_tff_inputs_raw[i],
             self.valuation_date_for_ql_settings_in_worker.isoformat(), price_kwargs) for i in range(n_test)]
        #print(f"   Generating {n_test} test prices sequentially...")
        test_true_prices = np.array([_price_one_scenario_for_tff(args) for args in test_worker_args])

        test_inputs_eval = test_tff_inputs_raw
        if self.product_type_str == 'EuropeanOption' and normalization_params.get('is_engineered', False):
            if test_tff_inputs_raw.shape[1]!=2: raise ValueError(f"Test option TFF inputs expect 2 cols, got {test_tff_inputs_raw.shape[1]}")
            eng_feat_test,_ = engineer_option_features(test_tff_inputs_raw[:,0], test_tff_inputs_raw[:,1], order=option_feature_order)
            np_means = np.array(normalization_params['means']) if normalization_params.get('means') is not None else None
            np_stds = np.array(normalization_params['stds']) if normalization_params.get('stds') is not None else None
            test_inputs_eval,_,_ = normalize_features(eng_feat_test, np_means, np_stds)

        test_pred_prices = fitted_tff(test_inputs_eval)
        if test_true_prices.ndim==0 and n_test==1: test_true_prices = np.array([test_true_prices])
        if test_pred_prices.ndim==0 and n_test==1: test_pred_prices = np.array([test_pred_prices])
        if test_true_prices.shape != test_pred_prices.shape: raise ValueError(f"Shape mismatch test prices: true {test_true_prices.shape}, pred {test_pred_prices.shape}")

        rmse = np.sqrt(np.mean((test_true_prices - test_pred_prices)**2))
        return fitted_tff, test_tff_inputs_raw, test_true_prices, rmse, normalization_params


In [16]:
# utils.py
"""
Contains utility functions for the FastRiskDemo, such as generating
collections of bond definitions for portfolio testing.
"""
import numpy as np
from datetime import date
from dateutil.relativedelta import relativedelta

def generate_bond_collections(
    num_bonds: int,
    valuation_date: date = date(2025, 1, 1),
    face_value: float = 100.0,
    seed: int = 0,
    conv_params: dict = None # Parameters for generating convertibles
) -> tuple[list[QuantLibBondStaticBase], list[CallableBondStaticBase], list[ConvertibleBondStaticBase]]:
    """
    Generates collections of random vanilla, callable, and convertible bond static definitions.

    Args:
        num_bonds (int): Number of bond instances of each type to generate.
        valuation_date (date, optional): Common valuation date for all bonds.
                                         Defaults to date(2025, 1, 1).
        face_value (float, optional): Default face value for bonds. Defaults to 100.0.
        seed (int, optional): Random seed for reproducibility. Defaults to 0.
        conv_params (dict, optional): Dictionary of parameters for generating convertible bonds.
                                      Keys can include 'conversion_ratio', 'dividend_yield',
                                      'equity_volatility', 'initial_stock_price',
                                      'credit_spread_value', 'exercise_type'.
                                      Defaults are used if not provided.

    Returns:
        tuple[list, list, list]:
            A tuple containing:
            - A list of QuantLibBondStaticBase instances (vanilla bonds).
            - A list of CallableBondStaticBase instances (callable bonds).
            - A list of ConvertibleBondStaticBase instances (convertible bonds).
    """
    rng = np.random.default_rng(seed)
    vanilla_bonds = []
    callable_bonds = []
    convertible_bonds = []

    # Default parameters for convertible bond generation
    default_conv_params = {
        'conversion_ratio': 20.0,        # Example: 20 shares per bond of 100 face value
        'dividend_yield': 0.01,         # 1% continuous dividend yield
        'equity_volatility': 0.25,      # 25% annual volatility
        'initial_stock_price': 100.0,   # Base stock price
        'credit_spread_value': 0.015,   # 150 bps credit spread
        'exercise_type': 'EuropeanAtMaturity'
    }
    # Update defaults with any user-provided conv_params
    current_conv_params = default_conv_params.copy()
    if conv_params is not None:
        current_conv_params.update(conv_params)

    for i in range(num_bonds):
        # Common random parameters for each bond in the iteration
        years_to_maturity = int(rng.integers(3, 11)) # Ensure bonds have some life, min 3 years
        maturity_d = valuation_date + relativedelta(years=years_to_maturity)

        # Use different coupon ranges for different bond types for more realism
        van_call_coupon = float(rng.uniform(0.02, 0.06)) # 2% to 6% for vanilla/callable
        conv_coupon = float(rng.uniform(0.01, 0.04)) # 1% to 4% for convertibles (often lower)

        coupon_freq = int(rng.choice([1, 2]))    # Annual or Semi-annual

        # --- Vanilla Bond ---
        # Parameters for QuantLibBondStaticBase.from_dict
        vanilla_params_dict = {
            'valuation_date': valuation_date,
            'maturity_date': maturity_d,
            'coupon_rate': van_call_coupon,
            'face_value': face_value,
            'freq': coupon_freq,
            'settlement_days': 0
        }
        vanilla_bonds.append(QuantLibBondStaticBase.from_dict(vanilla_params_dict))

        # --- Callable Bond ---
        call_dates_list_py = []
        call_prices_list_py = []
        if years_to_maturity > 2: # Ensure some non-call period for callable bonds
            # Generate 1 to 3 call dates for variety
            num_calls = int(rng.integers(1, min(4, years_to_maturity - 1)))
            possible_call_years_offsets = list(range(1, years_to_maturity)) # Call can be from year 1 to year T-1

            if num_calls > 0 and len(possible_call_years_offsets) >= num_calls:
                chosen_call_years_offsets = sorted(rng.choice(possible_call_years_offsets, size=num_calls, replace=False))
                for year_offset in chosen_call_years_offsets:
                    call_d = valuation_date + relativedelta(years=int(year_offset))
                    if call_d < maturity_d: # Ensure call date is before maturity
                        call_dates_list_py.append(call_d)
                        # Call price typically at par or a slight premium
                        call_prices_list_py.append(float(face_value + rng.uniform(0.0, 3.0)))

        if call_dates_list_py and call_prices_list_py: # Only if valid call dates were generated
            callable_params_dict = {
                'valuation_date': valuation_date, 'maturity_date': maturity_d,
                'coupon_rate': van_call_coupon, 'face_value': face_value,
                'freq': coupon_freq, 'settlement_days': 0,
                'call_dates': call_dates_list_py, 'call_prices': call_prices_list_py
            }
            callable_bonds.append(CallableBondStaticBase.from_dict(callable_params_dict))

        # --- Convertible Bond ---
        # Issue date for convertible, can be different from valuation date.
        issue_offset_years = int(rng.integers(0, min(3, years_to_maturity - 1)))
        issue_date_convertible = valuation_date - relativedelta(years=issue_offset_years)
        # Ensure issue date is not after valuation date for this setup
        if issue_date_convertible > valuation_date:
            issue_date_convertible = valuation_date

        # Randomize some convertible parameters slightly for variety around the provided base
        s0_rand = current_conv_params['initial_stock_price'] * rng.uniform(0.8, 1.2)
        vol_rand = current_conv_params['equity_volatility'] * rng.uniform(0.7, 1.3)
        div_rand = max(0, current_conv_params['dividend_yield'] * rng.uniform(0.5, 1.5))
        cs_rand = max(0.001, current_conv_params['credit_spread_value'] * rng.uniform(0.5, 2.0))
        cr_rand = current_conv_params['conversion_ratio'] * rng.uniform(0.9, 1.1)

        convertible_params_dict = {
            'valuation_date': valuation_date,
            'issue_date': issue_date_convertible,
            'maturity_date': maturity_d,
            'coupon_rate': conv_coupon,
            'conversion_ratio': cr_rand,
            'dividend_yield': div_rand,
            'equity_volatility': vol_rand,
            'initial_stock_price': s0_rand,
            'credit_spread_value': cs_rand,
            'face_value': face_value,
            'freq': coupon_freq,
            'settlement_days': 0,
            'exercise_type': current_conv_params['exercise_type']
        }
        convertible_bonds.append(ConvertibleBondStaticBase.from_dict(convertible_params_dict))

    return vanilla_bonds, callable_bonds, convertible_bonds


In [17]:
# portfolio.py
"""
Contains classes for defining and analyzing portfolios of financial instruments.
The Portfolio class allows pricing using either TFF models (retrieved from a cache)
or full pricers.
"""
import numpy as np
import abc

class PortfolioBase(abc.ABC):
    """
    Abstract base class for a portfolio of financial instruments.
    """
    def __init__(self):
        # Stores details for each *position* in the portfolio.
        # Multiple positions can refer to the same underlying instrument_id if TFFs are cached.
        self.positions: list[dict] = []

    @abc.abstractmethod
    def add_position(self, *args, **kwargs):
        """Adds a position (instrument holding) to the portfolio."""
        pass

    @abc.abstractmethod
    def price_portfolio(self,
                        raw_market_scenarios: np.ndarray,
                        scenario_factor_names: list[str],
                        portfolio_rate_pillar_times: np.ndarray = None # For bond pricers
                        ) -> np.ndarray:
        """
        Prices all instruments in the portfolio for given market scenarios.

        Args:
            raw_market_scenarios (np.ndarray): 2D array of market scenarios
                                               (N_scenarios, N_total_market_factors).
            scenario_factor_names (list[str]): Names of the columns in raw_market_scenarios.
            portfolio_rate_pillar_times (np.ndarray, optional): 1D array of rate pillar times.
                                                               Required if portfolio contains bonds
                                                               priced with full QuantLibBondPricer.
        Returns:
            np.ndarray: 1D array of aggregated portfolio prices for each scenario (N_scenarios,).
        """
        pass

class Portfolio(PortfolioBase):
    """
    A portfolio where each instrument can be priced using either a pre-fitted
    Tensor Functional Form (TFF) model (retrieved from a cache) or its original full pricer.
    """
    def __init__(self):
        super().__init__()
        self.tff_model_cache: dict = {}
        # Cache structure:
        # { instrument_id: {
        #     'tff_model': TensorFunctionalForm_object,
        #     'raw_tff_input_names': list_of_names, # Raw factors TFF was trained on
        #     'normalization_params': dict_of_norm_params, # Includes engineered_feature_names
        #     'option_feature_order': int
        #   }, ...
        # }

    def to_dict(self) -> dict:
        """
        Returns a dictionary representation of the portfolio.

        Returns:
            dict: Dictionary representation of the portfolio.
        """
        return {
            'positions': [p.to_dict() if hasattr(p, 'to_dict') else p
                for p in self.positions],
            'tff_model_cache': self.tff_model_cache
        }

    def cache_tff_model(self,
                        instrument_id: str,
                        tff_model: TensorFunctionalForm,
                        raw_tff_input_names: list[str],
                        normalization_params: dict,
                        option_feature_order: int = 0):
        """
        Explicitly caches a fitted TFF model and its associated parameters.

        Args:
            instrument_id (str): Unique ID for the instrument type this TFF represents.
            tff_model (TensorFunctionalForm): The pre-fitted TFF model.
            raw_tff_input_names (list[str]): Names of raw market factors TFF is based on.
            normalization_params (dict): Normalization params from TFF calibration.
                                         Expected keys: 'means', 'stds', 'engineered_feature_names', 'is_engineered'.
            option_feature_order (int, optional): Order of feature engineering if option TFF.
        """
        if not instrument_id:
            raise ValueError("instrument_id must be provided for caching TFF model.")
        if not isinstance(tff_model, TensorFunctionalForm):
            raise TypeError("tff_model must be an instance of TensorFunctionalForm.")
        if raw_tff_input_names is None or not isinstance(raw_tff_input_names, list):
            raise ValueError("raw_tff_input_names (list of strings) must be provided for caching TFF model.")
        if normalization_params is None or not isinstance(normalization_params, dict):
            raise ValueError("normalization_params (dict) must be provided for caching TFF model.")

        self.tff_model_cache[instrument_id] = {
            'tff_model': tff_model,
            'raw_tff_input_names': raw_tff_input_names,
            'normalization_params': normalization_params,
            'option_feature_order': option_feature_order
        }

    def from_dict(self, portfolio_dict: dict):
        """
        Loads a portfolio from a dictionary representation. Similar to cache_tff_model

        Args:
            portfolio_dict (dict): Dictionary representation of the portfolio.
        """



    def add_position(self,
                       instrument_id: str,
                       product_static: ProductStaticBase,
                       num_holdings: int = 1,
                       pricing_engine_type: str = 'tff',
                       direct_tff_config: dict = None,
                       full_pricer_instance: PricerBase = None,
                       full_pricer_kwargs: dict = None):
        """
        Adds a position (an instrument holding) to the portfolio.
        If pricing_engine_type is 'tff', it retrieves the TFF model from the cache
        using the instrument_id. The TFF model must have been cached previously using `cache_tff_model`.

        Args:
            instrument_id (str): Unique ID for this instrument type. Used for TFF caching/lookup.
            product_static (ProductStaticBase): Static definition of the product for this position.
            num_holdings (int, optional): Number of units. Defaults to 1.
            pricing_engine_type (str, optional): 'tff' or 'full'. Defaults to 'tff'.
            full_pricer_instance (PricerBase, optional): A full pricer instance (required if type is 'full').
            full_pricer_kwargs (dict, optional): Keyword arguments for the full pricer's price method.
        """
        if not isinstance(num_holdings, int) or num_holdings <= 0:
                raise ValueError("num_holdings must be a positive integer.")
        if not instrument_id:
            raise ValueError("instrument_id must be provided for the position.")

        position_detail = {
            'instrument_id': instrument_id,
            'product_static': product_static,
            'num_holdings': num_holdings,
            'engine_type': pricing_engine_type.lower()
        }

        if position_detail['engine_type'] == 'tff':
            if direct_tff_config is not None:
                # A TFF model and its configuration are being provided directly as a dictionary
                if not isinstance(direct_tff_config, dict):
                    raise TypeError("direct_tff_config must be a dictionary.")

                model_dict = direct_tff_config.get('model_dict')
                raw_names = direct_tff_config.get('raw_input_names')
                norm_params = direct_tff_config.get('normalization_params')

                # Default option_feature_order to 0 if not explicitly in config
                opt_order = direct_tff_config.get('option_feature_order', 0)

                if model_dict is None or not isinstance(model_dict, dict):
                    raise ValueError("direct_tff_config is missing 'model_dict' or it's not a dictionary.")
                if raw_names is None or not isinstance(raw_names, list):
                    raise ValueError("direct_tff_config is missing 'raw_input_names' or it's not a list.")
                if norm_params is None or not isinstance(norm_params, dict):
                    raise ValueError("direct_tff_config is missing 'normalization_params' or it's not a dictionary.")

                # Deserialize the TFF model itself from model_dict
                try:
                    tff_model_instance = TensorFunctionalForm.from_dict(model_dict)
                except Exception as e:
                    raise ValueError(f"Failed to deserialize TFF model from direct_tff_config['model_dict']: {e}")

                position_detail['pricer_engine'] = tff_model_instance
                position_detail['raw_tff_input_names'] = raw_names
                position_detail['normalization_params'] = norm_params
                position_detail['option_feature_order'] = opt_order

            else:
                # No direct TFF config provided, so use the cache via instrument_id
                if instrument_id not in self.tff_model_cache:
                    raise ValueError(
                        f"TFF model for instrument_id '{instrument_id}' not found in cache. "
                        "Either fit and cache it first, or provide its configuration via 'direct_tff_config'."
                    )

                cached_data = self.tff_model_cache[instrument_id]
                position_detail['pricer_engine'] = cached_data['tff_model']
                position_detail['raw_tff_input_names'] = cached_data['raw_tff_input_names']
                position_detail['normalization_params'] = cached_data['normalization_params']
                position_detail['option_feature_order'] = cached_data['option_feature_order']

        elif position_detail['engine_type'] == 'full':
            if not isinstance(full_pricer_instance, PricerBase):
                raise TypeError("full_pricer_instance must be an instance of PricerBase if engine_type is 'full'.")
            position_detail['pricer_engine'] = full_pricer_instance
            position_detail['full_pricer_kwargs'] = full_pricer_kwargs or {}
        else:
            raise ValueError(f"Unsupported pricing_engine_type: {pricing_engine_type}. Choose 'tff' or 'full'.")

        self.positions.append(position_detail)

    def load_portfolio_from_specs(self, portfolio_specs: list[dict]):
        """
        Loads multiple positions into the portfolio from a list of specifications.

        Each specification in the list is a dictionary that should conform to
        the parameters expected by the `add_position` method.

        Args:
            portfolio_specs (list[dict]): A list of dictionaries, where each dictionary
                                          defines a position to be added.
                                          Expected keys in each dict:
                                            'instrument_id' (str)
                                            'product_static_object' (ProductStaticBase)
                                            'num_holdings' (int)
                                            'pricing_engine_type' (str: 'tff' or 'full')
                                            Optional for 'tff' type:
                                              'direct_tff_config' (dict, containing 'model_dict',
                                                                   'raw_input_names', 'normalization_params',
                                                                   'option_feature_order')
                                            Optional for 'full' type:
                                              'full_pricer_instance' (PricerBase)
                                              'full_pricer_kwargs' (dict)
        """
        if not isinstance(portfolio_specs, list):
            raise TypeError("portfolio_specs must be a list of dictionaries.")

        for i, item_spec in enumerate(portfolio_specs):
            if not isinstance(item_spec, dict):
                raise TypeError(f"Each item in portfolio_specs must be a dictionary. Found type {type(item_spec)} at index {i}.")

            try:
                self.add_position(
                    instrument_id=item_spec['instrument_id'],
                    product_static=item_spec['product_static_object'],
                    num_holdings=item_spec.get('num_holdings', 1), # Default if not specified
                    pricing_engine_type=item_spec.get('pricing_engine_type', 'tff'), # Default

                    direct_tff_config=item_spec.get('direct_tff_config'), # Will be None if key doesn't exist

                    full_pricer_instance=item_spec.get('full_pricer_instance'),
                    full_pricer_kwargs=item_spec.get('full_pricer_kwargs')
                )
            except KeyError as e:
                raise ValueError(f"Missing required key {e} in portfolio_spec at index {i}: {item_spec}")
            except Exception as e:
                raise RuntimeError(f"Error adding position from spec at index {i} ('{item_spec.get('instrument_id', 'Unknown ID')}'): {e}")

        print(f"Successfully loaded {len(self.positions)} positions into the portfolio from specifications.")

    def price_portfolio(self,
                        raw_market_scenarios: np.ndarray,
                        scenario_factor_names: list[str],
                        portfolio_rate_pillar_times: np.ndarray = None
                        ) -> np.ndarray:
        """
        Prices all positions in the portfolio for given market scenarios.

        Args:
            raw_market_scenarios (np.ndarray): 2D array of raw market scenarios
                                               (N_scenarios, N_total_market_factors).
            scenario_factor_names (list[str]): Names of the columns in raw_market_scenarios.
            portfolio_rate_pillar_times (np.ndarray, optional): 1D array of rate pillar times (numeric tenors).
                                                               Required if portfolio contains bonds
                                                               priced with full QuantLibBondPricer or FastBondPricer.

        Returns:
            np.ndarray: 1D array of aggregated portfolio prices for each scenario (N_scenarios,).
        """
        if not self.positions:
            return np.array([])

        num_scenarios = raw_market_scenarios.shape[0]
        portfolio_prices_per_scenario = np.zeros(num_scenarios, dtype=float)

        for position_detail in self.positions:
            pricer_engine = position_detail['pricer_engine']
            engine_type = position_detail['engine_type']
            num_holdings = position_detail['num_holdings']

            instrument_prices_this_instrument = np.zeros(num_scenarios, dtype=float)

            if engine_type == 'tff':
                tff_model: TensorFunctionalForm = pricer_engine
                # These are the names of the RAW factors the TFF was originally trained on.
                raw_tff_input_factor_names_for_this_tff = position_detail['raw_tff_input_names']
                norm_params = position_detail['normalization_params']
                opt_feat_order = position_detail['option_feature_order']

                # Select the relevant columns from the global raw_market_scenarios
                try:
                    indices_of_raw_factors_in_global_scenarios = [scenario_factor_names.index(name) for name in raw_tff_input_factor_names_for_this_tff]
                except ValueError as e:
                    raise ValueError(f"A TFF input name in {raw_tff_input_factor_names_for_this_tff} not found in scenario_factor_names {scenario_factor_names} for instrument_id '{position_detail['instrument_id']}'. Error: {e}")

                current_raw_inputs_for_tff = raw_market_scenarios[:, indices_of_raw_factors_in_global_scenarios]

                inputs_for_tff_evaluation = current_raw_inputs_for_tff # Default

                # opt_feat_order should be available from position_detail if it's an option TFF
                opt_feat_order = position_detail.get('option_feature_order', 0)

                if isinstance(position_detail['product_static'], EuropeanOptionStatic) and \
                   norm_params.get('is_engineered', False): # Check if TFF was trained with engineered features

                    s0_factor_actual_name_port = None
                    vol_factor_actual_name_port = None
                    # These indices are relative to raw_tff_input_factor_names_for_this_tff
                    # and thus to the columns of current_raw_inputs_for_tff
                    s0_idx_in_tff_inputs = -1
                    vol_idx_in_tff_inputs = -1

                    # For an option TFF, raw_tff_input_factor_names_for_this_tff is expected to be [full_s0_name, full_vol_name]
                    if len(raw_tff_input_factor_names_for_this_tff) == 2:
                        for i, name in enumerate(raw_tff_input_factor_names_for_this_tff):
                            if name.upper().endswith("_S0"): # Using suffix matching
                                s0_factor_actual_name_port = name
                                s0_idx_in_tff_inputs = i
                            elif name.upper().endswith("_VOLATILITY") or name.upper().endswith("_VOL"): # Using suffix matching
                                vol_factor_actual_name_port = name
                                vol_idx_in_tff_inputs = i

                        if s0_idx_in_tff_inputs == -1 or vol_idx_in_tff_inputs == -1 or s0_idx_in_tff_inputs == vol_idx_in_tff_inputs:
                            raise ValueError(
                                f"Portfolio pricing: Could not identify distinct S0 and Volatility factors for option "
                                f"'{position_detail['instrument_id']}' from its TFF input names: "
                                f"{raw_tff_input_factor_names_for_this_tff}."
                            )
                    else:
                        raise ValueError(
                            f"Portfolio pricing: Option TFF for '{position_detail['instrument_id']}' with engineered features "
                            f"expects 2 raw input factors (S0, Vol), but "
                            f"TFF was trained on {len(raw_tff_input_factor_names_for_this_tff)} names: "
                            f"{raw_tff_input_factor_names_for_this_tff}."
                        )

                    # current_raw_inputs_for_tff columns are ordered as per raw_tff_input_factor_names_for_this_tff
                    s0_scenarios_raw = current_raw_inputs_for_tff[:, s0_idx_in_tff_inputs]
                    vol_scenarios_raw = current_raw_inputs_for_tff[:, vol_idx_in_tff_inputs]

                    engineered_features_test, _ = engineer_option_features( # Changed name for clarity
                        s0_scenarios_raw, vol_scenarios_raw, order=opt_feat_order
                    )
                    # Use normalization_params['means'] and ['stds'] from the cached TFF model data
                    inputs_for_tff_evaluation, _, _ = normalize_features(
                        engineered_features_test, # Changed name
                        norm_params['means'],
                        norm_params['stds']
                    )

                instrument_prices_this_instrument = tff_model(inputs_for_tff_evaluation)

            elif engine_type == 'full':
                full_pricer: PricerBase = pricer_engine
                pricer_kwargs = position_detail.get('full_pricer_kwargs', {})

                if isinstance(full_pricer, (QuantLibBondPricer, FastBondPricer)):
                    if portfolio_rate_pillar_times is None:
                        raise ValueError("portfolio_rate_pillar_times must be provided for full QL/Fast bond pricing in portfolio.")

                    num_rate_pillars = len(portfolio_rate_pillar_times)
                    # Select rate columns from raw_market_scenarios
                    rate_indices = []
                    try:
                        # Attempt to find by name if portfolio_rate_pillar_times contains names (should be numeric tenors)
                        # For simplicity, assume scenario_factor_names aligns with rate_pillar_names for the rate part
                        rate_indices = [scenario_factor_names.index(name) for name in portfolio_rate_pillar_times if isinstance(name, str) and name in scenario_factor_names]
                        if len(rate_indices) != num_rate_pillars: # Fallback if names don't match
                             rate_indices = list(range(num_rate_pillars))
                    except (ValueError, IndexError, TypeError):
                        # Fallback if portfolio_rate_pillar_times are not strings or not found
                        rate_indices = list(range(num_rate_pillars))

                    market_data_for_bond_pricer = raw_market_scenarios[:, rate_indices]

                    if isinstance(full_pricer, QuantLibBondPricer) and full_pricer.is_convertible and 'S0' in scenario_factor_names:
                        s0_col_idx = scenario_factor_names.index('S0')
                        market_data_for_bond_pricer = np.hstack((
                            market_data_for_bond_pricer,
                            raw_market_scenarios[:, s0_col_idx, np.newaxis]
                        ))

                    instrument_prices_this_instrument = full_pricer.price(
                        pillar_times=portfolio_rate_pillar_times, # These are the numeric tenors
                        market_scenario_data=market_data_for_bond_pricer,
                        **pricer_kwargs
                    )
                elif isinstance(full_pricer, BlackScholesPricer):
                    s0_idx = scenario_factor_names.index('S0')
                    vol_idx = scenario_factor_names.index('Volatility')
                    s0_scens = raw_market_scenarios[:, s0_idx]
                    vol_scens = raw_market_scenarios[:, vol_idx]
                    instrument_prices_this_instrument = full_pricer.price(
                        stock_price=s0_scens,
                        volatility=vol_scens,
                        **pricer_kwargs
                    )
                else:
                    raise TypeError(f"Unsupported full pricer type in portfolio: {type(full_pricer)} for instrument_id '{position_detail['instrument_id']}'")
            else:
                raise ValueError(f"Unknown engine_type: {engine_type} for instrument_id '{position_detail['instrument_id']}'")

            # Ensure consistent shapes for aggregation
            if instrument_prices_this_instrument.ndim == 0:
                instrument_prices_this_instrument = np.full(num_scenarios, float(instrument_prices_this_instrument))
            elif len(instrument_prices_this_instrument) != num_scenarios:
                 instrument_prices_this_instrument = instrument_prices_this_instrument.flatten()
                 if len(instrument_prices_this_instrument) != num_scenarios:
                    raise ValueError(f"Price array shape mismatch for instrument '{position_detail['instrument_id']}'. Expected ({num_scenarios},), got {instrument_prices_this_instrument.shape}")

            portfolio_prices_per_scenario += instrument_prices_this_instrument * num_holdings

        return portfolio_prices_per_scenario




In [29]:
# workflow_manager.py
"""
Contains classes to manage the workflow of instrument processing (TFF calibration),
portfolio construction, and portfolio analytics (e.g., VaR).
Uses targeted scenario generation for TFF fitting.
Correctly handles convertible bond TFF factor configurations and pricer parameters.
Ensures TFFCalibrate is called with its simplified constructor.
"""
import json
from datetime import date, datetime
import numpy as np
import QuantLib as ql
import time
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm

# --- JSON Serialization Helpers ---
def portfolio_json_serializer(obj):
    if isinstance(obj, (datetime, date)): return obj.isoformat()
    if isinstance(obj, np.ndarray): return obj.tolist()
    if hasattr(obj, 'to_dict') and callable(obj.to_dict): return obj.to_dict()
    if isinstance(obj, ProductStaticBase): return obj.to_dict()
    if isinstance(obj, ql.Date): return date(obj.year(), obj.month(), obj.day()).isoformat()
    if isinstance(obj, ql.Calendar): return obj.name()
    if isinstance(obj, ql.DayCounter): return obj.name()
    if isinstance(obj, (np.float32, np.float64)): return float(obj)
    if isinstance(obj, (np.int32, np.int64)): return int(obj)
    if isinstance(obj, dict):
        return {k: portfolio_json_serializer(v) for k, v in obj.items()}
    raise TypeError(f"Object of type {obj.__class__.__name__} ({obj}) is not JSON serializable by custom serializer")

def reconstruct_product_static(product_dict: dict) -> ProductStaticBase:
    product_type = product_dict.get('product_type')
    if not product_type:
        raise ValueError("Product dictionary must contain a 'product_type' field.")
    if product_type == 'VanillaBond':
        return QuantLibBondStaticBase.from_dict(product_dict)
    elif product_type == 'CallableBond':
        return CallableBondStaticBase.from_dict(product_dict)
    elif product_type == 'ConvertibleBond':
        return ConvertibleBondStaticBase.from_dict(product_dict)
    elif product_type == 'EuropeanOption':
        return EuropeanOptionStatic.from_dict(product_dict)
    else:
        raise ValueError(f"Unknown product_type for reconstruction: {product_type}")

def generate_portfolio_specs_for_serialization(
    holdings_data: list[dict],
    model_registry: dict,
    instrument_definitions_data_for_pricer_params: list[dict] = None
    ) -> list[dict]:
    portfolio_specs_for_json = []
    if instrument_definitions_data_for_pricer_params is None:
        instrument_definitions_data_for_pricer_params = []

    for holding in holdings_data:
        instrument_id = holding.get("instrument_id")
        if not instrument_id:
            print(f"   Skipping holding due to missing instrument_id: {holding}")
            continue

        if instrument_id in model_registry and not model_registry[instrument_id].get('error'):
            entry = model_registry[instrument_id]
            spec_item = {
                "client_id": holding.get("client_id"),
                "instrument_id": instrument_id,
                "num_holdings": holding.get("num_holdings"),
                "pricing_engine_type": entry["pricing_method"].lower(),
                "product_static_object": entry["product_static_dict"]
            }
            if entry["pricing_method"] == 'TFF' and 'tff_model_dict' in entry:
                spec_item["direct_tff_config"] = {
                    "model_dict": entry["tff_model_dict"],
                    "raw_input_names": entry["tff_raw_input_names"],
                    "normalization_params": entry["tff_normalization_params"],
                    "option_feature_order": entry.get("tff_option_feature_order", 0)
                }
                if 'tff_fixed_pricer_params' in entry:
                    spec_item['pricer_params'] = entry['tff_fixed_pricer_params']

            if entry["pricing_method"] == 'FULL':
                if 'pricer_params' in entry:
                     spec_item['pricer_params'] = entry['pricer_params']
                else:
                    original_instrument_spec = next(
                        (item for item in instrument_definitions_data_for_pricer_params
                         if item.get("instrument_id") == instrument_id),
                        None
                    )
                    if original_instrument_spec and 'pricer_params' in original_instrument_spec:
                         spec_item['pricer_params'] = original_instrument_spec['pricer_params']

            portfolio_specs_for_json.append(spec_item)
        else:
            print(f"   Skipping instrument '{instrument_id}' for JSON spec generation: not in valid model_registry or had an error.")
    return portfolio_specs_for_json


# --- TFF Configuration Factory (Defined BEFORE InstrumentProcessor) ---
class TFFConfigurationFactory:
    def __init__(self, scenario_generator: SimpleRandomScenarioGenerator,
                 default_numeric_rate_tenors: np.ndarray):
        self.scenario_generator = scenario_generator
        self.default_numeric_rate_tenors = default_numeric_rate_tenors

    def _get_base_value(self, factor_name: str) -> float:
        """Helper to get base value from scenario generator's maps."""
        for m_map_name in ['base_rates_map', 'base_s0_map', 'base_vol_map']:
            m_map = getattr(self.scenario_generator, m_map_name, {})
            if factor_name in m_map:
                return m_map[factor_name]
        if hasattr(self.scenario_generator, 'base_s0_map') and factor_name in self.scenario_generator.base_s0_map:
             return self.scenario_generator.base_s0_map[factor_name]
        raise ValueError(f"Base value for TFF factor '{factor_name}' not found in scenario_generator's configured base maps.")

    def create_config(self, product_static: ProductStaticBase,
                      tff_behavior_params: dict = None,
                      instrument_pricer_params: dict = None) -> dict: # Added instrument_pricer_params
        if tff_behavior_params is None: tff_behavior_params = {}
        if instrument_pricer_params is None: instrument_pricer_params = {} # For BS rfr/div

        raw_names = []
        raw_base_values = []
        fixed_params_for_training = {}
        opt_feature_order = 0
        pricer_cfg_for_worker = {} # For worker pricer reconstruction

        if isinstance(product_static, EuropeanOptionStatic):
            if not product_static.underlying_symbol or not product_static.currency:
                raise ValueError("EuropeanOptionStatic needs 'underlying_symbol' and 'currency'.")
            s0_fn = f"{product_static.currency}_{product_static.underlying_symbol}_S0"
            vol_fn = f"{product_static.currency}_{product_static.underlying_symbol}_VOL"
            raw_names = [s0_fn, vol_fn]
            raw_base_values = [self._get_base_value(s0_fn), self._get_base_value(vol_fn)]
            opt_feature_order = tff_behavior_params.get('option_feature_order', 0)
            # BS pricer constructor takes r and q. These are from instrument_pricer_params
            pricer_cfg_for_worker['bs_pricer_config'] = {
                'risk_free_rate': instrument_pricer_params.get('bs_risk_free_rate'),
                'dividend_yield': instrument_pricer_params.get('bs_dividend_yield', 0.0)
            }
            if pricer_cfg_for_worker['bs_pricer_config']['risk_free_rate'] is None:
                raise ValueError("Missing 'bs_risk_free_rate' in pricer_params for EuropeanOption.")


        elif isinstance(product_static, QuantLibBondStaticBase):
            if not product_static.currency or not product_static.index_stub:
                raise ValueError("Bond product needs 'currency' and a non-empty 'index_stub'.")
            if self.default_numeric_rate_tenors is None or self.default_numeric_rate_tenors.size == 0:
                raise ValueError("default_numeric_rate_tenors needed for Bond TFF setup.")

            rate_factor_names = [f"{product_static.currency}_{product_static.index_stub}_{t:.2f}Y" for t in self.default_numeric_rate_tenors]
            base_rate_vals = [self._get_base_value(name) for name in rate_factor_names]
            raw_names.extend(rate_factor_names)
            raw_base_values.extend(base_rate_vals)

            pricer_cfg_for_worker['bond_pricer_config'] = { # Default for vanilla
                'method': 'discount', 'grid_steps': 100, 'convertible_engine_steps': 100
            }

            if isinstance(product_static, CallableBondStaticBase):
                pricer_cfg_for_worker['bond_pricer_config']['method'] = 'g2'
                pricer_cfg_for_worker['bond_pricer_config']['grid_steps'] = instrument_pricer_params.get('g2_grid_steps', 32)
                if instrument_pricer_params.get('g2_params'): # G2 params are fixed for TFF training
                    fixed_params_for_training['g2_params'] = instrument_pricer_params['g2_params']
                # If not in instrument_pricer_params, InstrumentProcessor will use its default_g2_params

            elif isinstance(product_static, ConvertibleBondStaticBase):
                pricer_cfg_for_worker['bond_pricer_config']['method'] = 'convertible_binomial'
                pricer_cfg_for_worker['bond_pricer_config']['convertible_engine_steps'] = instrument_pricer_params.get('conv_engine_steps', 50)

                if not product_static.underlying_symbol: raise ValueError("Convertible needs 'underlying_symbol'.")

                s0_fn_cb = f"{product_static.currency}_{product_static.underlying_symbol}_S0"
                if s0_fn_cb not in raw_names: # Ensure S0 is always a factor for CB TFFs
                    raw_names.append(s0_fn_cb)
                    raw_base_values.append(self._get_base_value(s0_fn_cb))

                conv_all_dynamic = tff_behavior_params.get('convertible_tff_market_inputs_as_factors', False)

                if conv_all_dynamic:
                    div_fn = f"{product_static.currency}_{product_static.underlying_symbol}_DIVYIELD"
                    vol_fn = f"{product_static.currency}_{product_static.underlying_symbol}_EQVOL"
                    cs_fn = f"{product_static.currency}_{product_static.underlying_symbol}_CS"
                    raw_names.extend([div_fn, vol_fn, cs_fn])
                    raw_base_values.extend([self._get_base_value(div_fn), self._get_base_value(vol_fn), self._get_base_value(cs_fn)])
                else: # S0 and Rates are dynamic, others are fixed
                    # fixed_cb_params should come from instrument_spec['pricer_params']
                    fixed_cb_p = tff_behavior_params.get('fixed_cb_params', {})
                    fixed_params_for_training['dividend_yield'] = fixed_cb_p.get('dividend_yield')
                    fixed_params_for_training['equity_volatility'] = fixed_cb_p.get('equity_volatility')
                    fixed_params_for_training['credit_spread'] = fixed_cb_p.get('credit_spread')
                    # s0_val is NOT part of fixed_params_for_training here as S0 is a dynamic TFF factor
                    if any(v is None for k,v in fixed_params_for_training.items() if k in ['dividend_yield', 'equity_volatility', 'credit_spread']):
                        raise ValueError(f"Missing fixed CB params (div,eq_vol,cs) when S0 is dynamic but others fixed. Got: {fixed_cb_p}")
        else:
            raise TypeError(f"Unsupported product type for TFF Configuration: {type(product_static)}")

        return {
            "tff_input_raw_factor_names": raw_names,
            "tff_input_raw_base_values": np.array(raw_base_values),
            "fixed_pricer_params_for_tff_training": fixed_params_for_training,
            "option_feature_order": opt_feature_order,
            "pricer_config_for_worker": pricer_cfg_for_worker, # For reconstructing pricer in worker
            "actual_rate_pillars": _parse_numeric_pillars_from_factor_names(raw_names) if isinstance(product_static, QuantLibBondStaticBase) else np.array([])
        }


# --- Workflow Classes ---
class InstrumentProcessor:
    def __init__(self, scenario_generator: SimpleRandomScenarioGenerator,
                 global_valuation_date: date,
                 default_numeric_rate_tenors: np.ndarray = None,
                 default_g2_params = None,
                 default_bs_risk_free_rate: float = 0.025,
                 default_bs_dividend_yield: float = 0.0,
                 parallel_workers_tff: int = None,
                 n_scenarios_for_tff_domain: int = 1000
                 ):
        self.scenario_generator = scenario_generator
        self.global_valuation_date = global_valuation_date
        self.default_numeric_rate_tenors = default_numeric_rate_tenors
        self.default_g2_params = default_g2_params
        self.default_bs_risk_free_rate = default_bs_risk_free_rate
        self.default_bs_dividend_yield = default_bs_dividend_yield
        self.num_instrument_processing_workers = parallel_workers_tff if parallel_workers_tff else 0
        self.n_scenarios_for_tff_domain = n_scenarios_for_tff_domain

        self.model_registry = {}
        self.tff_config_factory = TFFConfigurationFactory(
            scenario_generator=self.scenario_generator,
            default_numeric_rate_tenors=self.default_numeric_rate_tenors
        )
        self.tff_config_factory.default_g2_params = default_g2_params # Pass G2 default to factory if needed

    def _create_pricer_template(self, product_static: ProductStaticBase, instrument_spec: dict):
        pricer_params = instrument_spec.get('pricer_params', {})
        if isinstance(product_static, EuropeanOptionStatic):
            rfr = pricer_params.get('bs_risk_free_rate', self.default_bs_risk_free_rate)
            div = pricer_params.get('bs_dividend_yield', self.default_bs_dividend_yield)
            return BlackScholesPricer(product_static, risk_free_rate=rfr, dividend_yield=div)
        elif isinstance(product_static, CallableBondStaticBase):
            grid_steps = pricer_params.get('g2_grid_steps', 32)
            return QuantLibBondPricer(product_static, method='g2', grid_steps=grid_steps)
        elif isinstance(product_static, ConvertibleBondStaticBase):
            engine_steps = pricer_params.get('conv_engine_steps', 50)
            return QuantLibBondPricer(product_static, method='convertible_binomial', convertible_engine_steps=engine_steps)
        elif isinstance(product_static, QuantLibBondStaticBase):
            if instrument_spec.get('pricer_type_preference', 'QuantLib').upper() == 'FAST':
                 return FastBondPricer(product_static)
            return QuantLibBondPricer(product_static, method='discount')
        else:
            raise ValueError(f"Unsupported product type for pricer template: {type(product_static)}")

    def _get_scenario_slice(self, all_scenarios, all_factor_names, target_factor_names_for_tff):
        if not target_factor_names_for_tff:
            return np.array([]).reshape(all_scenarios.shape[0],0)
        try:
            global_indices_map = {name: i for i, name in enumerate(all_factor_names)}
            ordered_indices = [global_indices_map[name] for name in target_factor_names_for_tff]
            return all_scenarios[:, ordered_indices]
        except KeyError as e:
            missing_factor = str(e).strip("'")
            raise ValueError(
                f"Error slicing scenarios: Factor name '{missing_factor}' required by TFF "
                f"not found in generated scenario factor names. "
                f"Required by TFF: {target_factor_names_for_tff}, "
                f"Available from generator: {all_factor_names}."
            )
        except Exception as e:
            raise RuntimeError(f"General error during scenario slicing for TFF factors {target_factor_names_for_tff}: {e}")

    def _process_single_instrument_spec(self, args_tuple):
        instrument_spec, global_market_scenarios, global_factor_names, ql_val_date_iso = args_tuple

        val_d_worker = date.fromisoformat(ql_val_date_iso)
        ql.Settings.instance().evaluationDate = ql.Date(val_d_worker.day, val_d_worker.month, val_d_worker.year)

        instrument_id = instrument_spec.get('instrument_id')
        product_type_str = instrument_spec.get('product_type')
        params = instrument_spec.get('params', {})
        pricing_preference = instrument_spec.get('pricing_preference', 'FULL').upper()

        registry_entry = {'instrument_id': instrument_id, 'pricing_method': pricing_preference}
        if 'valuation_date' not in params: params['valuation_date'] = self.global_valuation_date
        if 'product_type' not in params: params['product_type'] = product_type_str

        try:
            product_static_object = reconstruct_product_static(params)
            registry_entry['product_static_dict'] = product_static_object.to_dict()
        except Exception as e:
            print(f"    ERROR creating product static for {instrument_id} in worker: {e}")
            registry_entry.update({'error': str(e), 'pricing_method': 'ERROR'})
            return instrument_id, registry_entry

        pricer_template = self._create_pricer_template(product_static_object, instrument_spec)
        if 'pricer_params' in instrument_spec: registry_entry['pricer_params'] = instrument_spec['pricer_params']

        if pricing_preference == 'TFF':
            tff_config_from_spec = instrument_spec.get('tff_config', {})
            factory_behavior_params = tff_config_from_spec.copy()
            # Pass pricer_params from instrument_spec to factory for fixed_cb_params or fixed_bs_params
            factory_behavior_params['fixed_cb_params'] = instrument_spec.get('pricer_params', {})
            factory_behavior_params['fixed_bs_params'] = instrument_spec.get('pricer_params', {})


            try:
                #print(f"    Calibrating TFF for {instrument_id} (in worker)...")
                tff_inputs = self.tff_config_factory.create_config(
                    product_static=product_static_object,
                    tff_behavior_params=factory_behavior_params,
                    instrument_pricer_params=instrument_spec.get('pricer_params', {}) # Pass for BS rfr/div
                )

                tff_sample_fit_parallel_workers = False

                tff_calibrator = TensorFunctionalFormCalibrate(
                    pricer_template=pricer_template,
                    tff_input_raw_factor_names=tff_inputs["tff_input_raw_factor_names"],
                    tff_input_raw_base_values=tff_inputs["tff_input_raw_base_values"],
                    product_static_params_for_worker=product_static_object.to_dict(), # Use 'product_static_params_for_worker'
                    pricer_config_for_worker=tff_inputs["pricer_config_for_worker"], # Use 'pricer_config_for_worker'
                    actual_rate_pillars=tff_inputs["actual_rate_pillars"] # Use 'actual_rate_pillars'
                )

                scenarios_for_this_tff = self._get_scenario_slice(global_market_scenarios, global_factor_names, tff_inputs["tff_input_raw_factor_names"])
                if scenarios_for_this_tff.size == 0 and tff_inputs["tff_input_raw_factor_names"]:
                     raise ValueError("Empty scenario slice for TFF fitting.")

                s_t = time.time()
                model_tff, _, _, rmse_tff, norm_params_tff = tff_calibrator.sample_and_fit(
                    full_market_scenarios_for_tff_factors=scenarios_for_this_tff,
                    n_train=tff_config_from_spec.get('n_train', 64),
                    n_test=tff_config_from_spec.get('n_test', 10),
                    random_seed=instrument_spec.get('seed', 42),
                    parallel_workers=tff_sample_fit_parallel_workers,
                    option_feature_order=tff_inputs["option_feature_order"],
                    **tff_inputs["fixed_pricer_params_for_tff_training"]
                )
                fit_time = time.time() - s_t
                if model_tff and norm_params_tff:
                    registry_entry.update({
                        'tff_model_dict': model_tff.to_dict(),
                        'tff_raw_input_names': tff_inputs["tff_input_raw_factor_names"],
                        'tff_normalization_params': norm_params_tff,
                        'tff_option_feature_order': tff_inputs["option_feature_order"],
                        'tff_rmse': rmse_tff, 'tff_fit_time_seconds': fit_time
                    })
                    if tff_inputs["fixed_pricer_params_for_tff_training"]:
                        registry_entry['tff_fixed_pricer_params'] = tff_inputs["fixed_pricer_params_for_tff_training"]
                    #print(f"      TFF for {instrument_id} fitted. RMSE: {rmse_tff:.6f}, Time: {fit_time:.2f}s")
                else: raise RuntimeError("TFF fitting returned None for model or norm_params.")
            except Exception as e:
                print(f"    ERROR during TFF calibration for {instrument_id} in worker: {e}")
                registry_entry.update({'pricing_method': 'FULL', 'error_tff_calibration': str(e)})

        return instrument_id, registry_entry


    def process_instruments(self, instrument_definitions: list[dict],
                            global_market_scenarios: np.ndarray,
                            global_factor_names: list[str]):
        print(f"Processing {len(instrument_definitions)} instrument definitions...")
        worker_args_list = [(spec, global_market_scenarios, global_factor_names, self.global_valuation_date.isoformat()) for spec in instrument_definitions]

        if self.num_instrument_processing_workers > 0 and len(instrument_definitions) > 1:
            print(f"  Processing instruments in parallel (workers={self.num_instrument_processing_workers})...")
            with ProcessPoolExecutor(max_workers=self.num_instrument_processing_workers) as executor:
                futures = [executor.submit(self._process_single_instrument_spec, args) for args in worker_args_list]
                for future in tqdm(as_completed(futures), total=len(instrument_definitions), desc="Processing Instruments"):
                    try:
                        instrument_id, registry_entry = future.result()
                        if instrument_id: self.model_registry[instrument_id] = registry_entry
                    except Exception as e: print(f"    ERROR processing an instrument in parallel (future result): {e}")
        else:
            print("  Processing instruments sequentially...")
            for args_tuple in tqdm(worker_args_list, total=len(instrument_definitions), desc="Processing Instruments"):
                try:
                    instrument_id, registry_entry = self._process_single_instrument_spec(args_tuple)
                    if instrument_id: self.model_registry[instrument_id] = registry_entry
                except Exception as e: print(f"    CRITICAL ERROR processing instrument spec {args_tuple[0].get('instrument_id', 'Unknown')}: {e}")
        print("Finished processing instrument definitions.")
        return self.model_registry

    def save_model_registry(self, filepath: str):
        print(f"Saving model registry to {filepath}...")
        try:
            with open(filepath, 'w') as f: json.dump(self.model_registry, f, indent=4, default=portfolio_json_serializer)
            print("  Model registry saved successfully.")
        except Exception as e: print(f"  ERROR saving model registry: {e}")

    @classmethod
    def load_model_registry(cls, filepath: str) -> dict:
        print(f"Loading model registry from {filepath}...")
        try:
            with open(filepath, 'r') as f: registry = json.load(f)
            print("  Model registry loaded successfully."); return registry
        except Exception as e: print(f"  ERROR loading model registry: {e}"); return {}


class PortfolioBuilder:
    def __init__(self, model_registry: dict = None):
        self.model_registry = model_registry if model_registry is not None else {}
        self.uncalculated_instruments = []

    def build_portfolios_from_specs(self, portfolio_specs_list: list[dict],
                                       global_valuation_date: date,
                                       default_g2_params=None,
                                       default_bs_rfr=0.025, default_bs_div=0.0
                                       ) -> dict[str, Portfolio]:
        print(f"Building portfolios from {len(portfolio_specs_list)} detailed specifications...")
        portfolios = {}
        self.uncalculated_instruments = []

        for spec_idx, spec in enumerate(portfolio_specs_list):
            client_id = spec.get('client_id')
            instrument_id = spec.get('instrument_id')
            num_holdings = spec.get('num_holdings')
            product_static_dict_from_spec = spec.get('product_static_object')
            pricing_method_from_spec = spec.get('pricing_engine_type', 'full').lower()
            direct_tff_config_from_spec = spec.get('direct_tff_config')
            pricer_params_from_spec = spec.get('pricer_params', {})

            if not client_id or not instrument_id or num_holdings is None or product_static_dict_from_spec is None:
                print(f"  Skipping spec at index {spec_idx}: missing essential fields.")
                continue

            try:
                if 'valuation_date' not in product_static_dict_from_spec:
                    product_static_dict_from_spec['valuation_date'] = global_valuation_date
                product_static_object = reconstruct_product_static(product_static_dict_from_spec)
            except Exception as e:
                print(f"  ERROR reconstructing product static for '{instrument_id}' from spec: {e}. Skipping.")
                if instrument_id not in self.uncalculated_instruments: self.uncalculated_instruments.append(instrument_id)
                continue

            if client_id not in portfolios: portfolios[client_id] = Portfolio()
            portfolio_instance = portfolios[client_id]

            final_pricing_method = pricing_method_from_spec
            final_direct_tff_config = direct_tff_config_from_spec
            final_full_pricer_instance = None
            final_pricer_kwargs = {}
            if pricer_params_from_spec: final_pricer_kwargs.update(pricer_params_from_spec)

            if final_pricing_method == 'tff':
                if not final_direct_tff_config:
                    if instrument_id in self.model_registry and self.model_registry[instrument_id].get('pricing_method', '').upper() == 'TFF':
                        entry = self.model_registry[instrument_id]
                        if all(k in entry for k in ['tff_model_dict', 'tff_raw_input_names', 'tff_normalization_params']):
                            final_direct_tff_config = {
                                'model_dict': entry['tff_model_dict'],
                                'raw_input_names': entry['tff_raw_input_names'],
                                'normalization_params': entry['tff_normalization_params'],
                                'option_feature_order': entry.get('tff_option_feature_order', 0)
                            }
                            if 'tff_fixed_pricer_params' in entry:
                                final_pricer_kwargs.update(entry['tff_fixed_pricer_params'])
                        else:
                            print(f"  WARNING: TFF data incomplete for '{instrument_id}' in registry. Fallback to FULL.")
                            final_pricing_method = 'full'
                    else:
                        print(f"  WARNING: TFF spec for '{instrument_id}' missing direct_tff_config and not found as TFF in registry. Fallback to FULL.")
                        final_pricing_method = 'full'
                elif isinstance(product_static_object, ConvertibleBondStaticBase) and final_direct_tff_config:
                     final_pricer_kwargs.update(pricer_params_from_spec)


            if final_pricing_method == 'full':
                current_pricer_params = final_pricer_kwargs
                try:
                    if isinstance(product_static_object, EuropeanOptionStatic):
                        rfr = current_pricer_params.get('bs_risk_free_rate', default_bs_rfr)
                        div = current_pricer_params.get('bs_dividend_yield', default_bs_div)
                        final_full_pricer_instance = BlackScholesPricer(product_static_object, rfr, div)
                    elif isinstance(product_static_object, CallableBondStaticBase):
                        grid_steps = current_pricer_params.get('g2_grid_steps', 32)
                        final_full_pricer_instance = QuantLibBondPricer(product_static_object, method='g2', grid_steps=grid_steps)
                        if current_pricer_params.get('g2_params', default_g2_params):
                             final_pricer_kwargs['g2_params'] = current_pricer_params.get('g2_params', default_g2_params)
                    elif isinstance(product_static_object, ConvertibleBondStaticBase):
                        engine_steps = current_pricer_params.get('conv_engine_steps', 50)
                        final_full_pricer_instance = QuantLibBondPricer(product_static_object, method='convertible_binomial', convertible_engine_steps=engine_steps)
                        cb_full_kwargs_needed = {
                            's0_val': current_pricer_params.get('s0_val', current_pricer_params.get('initial_stock_price')),
                            'dividend_yield': current_pricer_params.get('dividend_yield'),
                            'equity_volatility': current_pricer_params.get('equity_volatility'),
                            'credit_spread': current_pricer_params.get('credit_spread')
                        }
                        if any(val is None for val in cb_full_kwargs_needed.values()):
                            raise ValueError(f"Missing required pricer_params for FULL CB pricing of {instrument_id}.")
                        final_pricer_kwargs.update(cb_full_kwargs_needed)
                    elif isinstance(product_static_object, QuantLibBondStaticBase):
                        final_full_pricer_instance = QuantLibBondPricer(product_static_object, method='discount')
                    else: raise ValueError("Unknown product type for full pricer reconstruction.")
                except Exception as e_pricer:
                    print(f"  WARNING: Cannot create full pricer for '{instrument_id}': {e_pricer}. Skipping.")
                    if instrument_id not in self.uncalculated_instruments: self.uncalculated_instruments.append(instrument_id)
                    continue

            try:
                portfolio_instance.add_position(
                    instrument_id=instrument_id, product_static=product_static_object,
                    num_holdings=num_holdings, pricing_engine_type=final_pricing_method,
                    direct_tff_config=final_direct_tff_config if final_pricing_method == 'tff' else None,
                    full_pricer_instance=final_full_pricer_instance if final_pricing_method == 'full' else None,
                    full_pricer_kwargs=final_pricer_kwargs
                )
            except Exception as e:
                print(f"  ERROR adding position '{instrument_id}' to portfolio for '{client_id}': {e}")
                if instrument_id not in self.uncalculated_instruments: self.uncalculated_instruments.append(instrument_id)

        if self.uncalculated_instruments:
            print(f"  Summary: Uncalculated instruments during build_portfolios_from_specs: {self.uncalculated_instruments}")
        print(f"Finished building {len(portfolios)} portfolios from detailed specs.")
        return portfolios


class PortfolioAnalytics:
    def __init__(self,
                 client_portfolios: dict[str, Portfolio],
                 global_market_scenarios: np.ndarray,
                 global_factor_names: list[str],
                 numeric_rate_tenors: np.ndarray,
                 scenario_generator_for_base_values: SimpleRandomScenarioGenerator
                 ):
        self.client_portfolios = client_portfolios
        self.global_market_scenarios = global_market_scenarios
        self.global_factor_names = global_factor_names
        self.numeric_rate_tenors = numeric_rate_tenors
        self.scenario_generator_for_base_values = scenario_generator_for_base_values
        self.results = {}

    def calculate_base_portfolio_values(self) -> dict[str, float]:
        base_values = {}
        base_value_scenario_list = []
        sg_for_base = self.scenario_generator_for_base_values
        for factor_name in self.global_factor_names:
            val_found = False
            for current_map_name in ['base_rates_map', 'base_s0_map', 'base_vol_map']:
                current_map = getattr(sg_for_base, current_map_name, {})
                if factor_name in current_map:
                    base_value_scenario_list.append(current_map[factor_name])
                    val_found = True
                    break
            if not val_found:
                if hasattr(sg_for_base, 'base_s0_map') and factor_name in sg_for_base.base_s0_map:
                     base_value_scenario_list.append(sg_for_base.base_s0_map[factor_name]); val_found = True
                elif hasattr(sg_for_base, 'base_vol_map') and factor_name in sg_for_base.base_vol_map:
                     base_value_scenario_list.append(sg_for_base.base_vol_map[factor_name]); val_found = True
            if not val_found:
                print(f"Warning: Factor '{factor_name}' for base value not found in generator base maps. Using 0.0.")
                base_value_scenario_list.append(0.0)
        base_value_scenario_np = np.array([base_value_scenario_list])

        for client_id, portfolio_obj in self.client_portfolios.items():
            if portfolio_obj.positions:
                try:
                    base_val = portfolio_obj.price_portfolio(
                        raw_market_scenarios=base_value_scenario_np,
                        scenario_factor_names=self.global_factor_names,
                        portfolio_rate_pillar_times=self.numeric_rate_tenors
                    )[0]
                    base_values[client_id] = base_val
                except Exception as e:
                    print(f"  ERROR calculating base value for portfolio {client_id}: {e}")
                    base_values[client_id] = np.nan
            else:
                base_values[client_id] = 0.0
        return base_values

    def run_var_analysis(self, var_percentiles: list[float] = None):
        if var_percentiles is None:
            var_percentiles = [1.0, 5.0]

        print(f"Running VaR Analysis for percentiles: {[f'{(100-p):.0f}%' for p in var_percentiles]}")

        base_portfolio_values = self.calculate_base_portfolio_values()
        self.results = {}

        for client_id, portfolio_obj in self.client_portfolios.items():
            client_results = {'base_value': base_portfolio_values.get(client_id, np.nan)}
            if portfolio_obj.positions and not np.isnan(client_results['base_value']):
                print(f"  Analyzing portfolio for {client_id}...")
                try:
                    portfolio_values_scenarios = portfolio_obj.price_portfolio(
                        raw_market_scenarios=self.global_market_scenarios,
                        scenario_factor_names=self.global_factor_names,
                        portfolio_rate_pillar_times=self.numeric_rate_tenors
                    )

                    client_results['mean_scenario_value'] = np.mean(portfolio_values_scenarios)
                    client_results['std_dev_scenario_value'] = np.std(portfolio_values_scenarios)
                    pnl_distribution = portfolio_values_scenarios - client_results['base_value']
                    client_results['pnl_distribution_mean'] = np.mean(pnl_distribution)
                    client_results['pnl_distribution_std_dev'] = np.std(pnl_distribution)

                    vars_calculated = {}
                    for p in var_percentiles:
                        var_value = np.percentile(pnl_distribution, p)
                        vars_calculated[f"var_{(100-p):.0f}pct"] = -var_value
                    client_results['var_values'] = vars_calculated

                    print(f"    Client {client_id}: Base Value={client_results['base_value']:.2f}, "
                          f"Mean Scen. Value={client_results['mean_scenario_value']:.2f}, "
                          f"VaRs: {vars_calculated}")

                except Exception as e:
                    print(f"    ERROR during VaR analysis for portfolio {client_id}: {e}")
                    client_results['error_var_analysis'] = str(e)
            else:
                if not portfolio_obj.positions: print(f"  Portfolio for {client_id} is empty, skipping VaR.")
                else: print(f"  Base value for portfolio {client_id} could not be calculated (was NaN), skipping VaR.")
                client_results['var_values'] = {f"var_{(100-p):.0f}pct": np.nan for p in var_percentiles}
            self.results[client_id] = client_results
        return self.results


In [30]:
# main_demonstration.py
"""
Main script demonstrating the new workflow using InstrumentProcessor, PortfolioBuilder,
and PortfolioAnalytics.
"""
import QuantLib as ql
import numpy as np
from datetime import date, datetime
from dateutil.relativedelta import relativedelta
import time
import os
import json

def run_demonstration(
    enable_parallel_tff_fitting: bool = True,
    use_hardcoded_g2_params: bool = True
    ):
    print(f"--- FastRiskDemo with Workflow Manager (Parallel TFF: {enable_parallel_tff_fitting}) ---")

    # --- Global Setup ---
    val_d = date(2025, 5, 18)
    numeric_rate_tenors = np.array([0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 7.0, 10.0])
    DEMO_CURRENCY = "USD"
    DEMO_RATE_INDEX_STUB = "IR"
    OPT_UNDERLYING_SYMBOL = "DEMO_OPT_STOCK"
    CONV_UNDERLYING_SYMBOL = "DEMO_CONV_STOCK"

    default_g2_p = (0.01, 0.003, 0.015, 0.006, -0.75) if use_hardcoded_g2_params else None

    base_demo_currency_rates_values = np.array([0.020, 0.021, 0.022, 0.025, 0.027, 0.030, 0.032, 0.033])
    base_rates_map = {f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_{t:.2f}Y": base_demo_currency_rates_values[i] for i, t in enumerate(numeric_rate_tenors)}

    opt_s0_factor = f"{DEMO_CURRENCY}_{OPT_UNDERLYING_SYMBOL}_S0"
    opt_vol_factor = f"{DEMO_CURRENCY}_{OPT_UNDERLYING_SYMBOL}_VOL"
    conv_s0_factor = f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_S0"
    conv_vol_factor = f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_EQVOL"
    conv_div_factor = f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_DIVYIELD"
    conv_cs_factor = f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_CS"

    base_s0_map = {opt_s0_factor: 100.0, conv_s0_factor: 100.0}
    base_vol_map = {opt_vol_factor: 0.25, conv_vol_factor: 0.25}

    base_other_factors_map_for_conv = {
        conv_div_factor: 0.01,
        conv_cs_factor: 0.015
    }
    temp_s0_map_for_gen = {**base_s0_map, **base_other_factors_map_for_conv}
    temp_vol_map_for_gen = {**base_vol_map}

    scenario_gen_global = SimpleRandomScenarioGenerator(
        base_rates_map=base_rates_map,
        base_s0_map=temp_s0_map_for_gen,
        base_vol_map=temp_vol_map_for_gen,
        random_seed=42
    )
    N_GLOBAL_SCENARIOS = 100
    global_market_scenarios, global_factor_names = scenario_gen_global.generate_scenarios(N_GLOBAL_SCENARIOS)
    print(f"Generated {N_GLOBAL_SCENARIOS} global market scenarios with factors: {global_factor_names}")


    # --- Step 1: Instrument Processing and TFF Calibration ---
    print("\\n--- Step 1: Instrument Processing & TFF Calibration ---")

    default_conv_fixed_pricer_params = {
        's0_val': 100.0,
        'dividend_yield': 0.01,
        'equity_volatility': 0.25,
        'credit_spread': 0.015
    }

    instrument_definitions_data = [
        {
            "instrument_id": f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_VANILLA_5Y",
            "product_type": "VanillaBond", "pricing_preference": "TFF",
            "params": { "valuation_date": val_d, "maturity_date": (val_d + relativedelta(years=5)),
                        "coupon_rate": 0.03, "face_value": 100.0, "currency": DEMO_CURRENCY,
                        "index_stub": DEMO_RATE_INDEX_STUB, "freq": 2, "settlement_days": 0 },
            "tff_config": {"n_train": 64, "n_test": 10}
        },
        {
            "instrument_id": f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_CALLABLE_5Y_G2",
            "product_type": "CallableBond", "pricing_preference": "TFF",
            "params": { "valuation_date": val_d, "maturity_date": (val_d + relativedelta(years=5)),
                        "coupon_rate": 0.032, "face_value": 100.0, "currency": DEMO_CURRENCY,
                        "index_stub": DEMO_RATE_INDEX_STUB, "freq": 2,
                        "call_dates": [(val_d + relativedelta(years=y)) for y in [2,3,4]],
                        "call_prices": [102.0, 101.0, 100.0]},
            "pricer_params": {"g2_params": default_g2_p},
            "tff_config": {"n_train": 128, "n_test": 10}
        },
        {
            "instrument_id": f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_CONV_BOND_5Y_S0_DYNAMIC",
            "product_type": "ConvertibleBond", "pricing_preference": "TFF",
            "params": {
                'valuation_date': val_d, 'issue_date': (val_d - relativedelta(months=6)),
                'maturity_date': (val_d + relativedelta(years=5, months=-6)), 'coupon_rate': 0.02,
                'conversion_ratio': 20.0, 'face_value': 100.0, 'currency': DEMO_CURRENCY,
                'index_stub': DEMO_RATE_INDEX_STUB, 'underlying_symbol': CONV_UNDERLYING_SYMBOL, 'freq': 2
            },
            "pricer_params": default_conv_fixed_pricer_params,
            "tff_config": {"n_train": 128, "n_test": 10,
                           "convertible_tff_market_inputs_as_factors": False
                          }
        },
        {
            "instrument_id": f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_CONV_BOND_5Y_ALL_DYNAMIC",
            "product_type": "ConvertibleBond", "pricing_preference": "TFF",
            "params": {
                'valuation_date': val_d, 'issue_date': (val_d - relativedelta(months=6)),
                'maturity_date': (val_d + relativedelta(years=5, months=-6)), 'coupon_rate': 0.02,
                'conversion_ratio': 20.0, 'face_value': 100.0, 'currency': DEMO_CURRENCY,
                'index_stub': DEMO_RATE_INDEX_STUB, 'underlying_symbol': CONV_UNDERLYING_SYMBOL, 'freq': 2
            },
            "tff_config": {"n_train": 128, "n_test": 10,
                           "convertible_tff_market_inputs_as_factors": True
                          }
        },
        {
            "instrument_id": f"{DEMO_CURRENCY}_{OPT_UNDERLYING_SYMBOL}_EURO_CALL_1Y_STRIKE105_ORD2",
            "product_type": "EuropeanOption", "pricing_preference": "TFF",
            "params": { 'valuation_date': val_d, 'expiry_date': (val_d + relativedelta(years=1)),
                        'strike_price': 105.0, 'option_type': 'call',
                        'currency': DEMO_CURRENCY, 'underlying_symbol': OPT_UNDERLYING_SYMBOL,
            }, "pricer_params": { 'bs_risk_free_rate': 0.025, 'bs_dividend_yield': 0.01 },
            "tff_config": {"n_train": 128, "n_test": 10, "option_feature_order": 2}
        },
        { "instrument_id": f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_VANILLA_10Y_FULL",
          "product_type": "VanillaBond", "pricing_preference": "FULL",
          "params": { "valuation_date": val_d, "maturity_date": (val_d + relativedelta(years=10)),
                      "coupon_rate": 0.035, "face_value": 100.0, "currency": DEMO_CURRENCY,
                      "index_stub": DEMO_RATE_INDEX_STUB, "freq": 2 }}
    ]

    instrument_processor = InstrumentProcessor(
        scenario_generator=scenario_gen_global, global_valuation_date=val_d,
        default_numeric_rate_tenors=numeric_rate_tenors, default_g2_params=default_g2_p,
        default_bs_risk_free_rate=0.025, default_bs_dividend_yield=0.01,
        parallel_workers_tff=os.cpu_count() if enable_parallel_tff_fitting else False,
        n_scenarios_for_tff_domain=500 )

    model_registry = instrument_processor.process_instruments(
        instrument_definitions_data, global_market_scenarios, global_factor_names )

    instrument_processor.save_model_registry("model_registry.json")

    # --- Step 2: Portfolio Construction ---
    print("\\n--- Step 2: Portfolio Construction ---")
    holdings_data = [
        {"client_id": "ClientA", "instrument_id": f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_VANILLA_5Y", "num_holdings": 100},
        {"client_id": "ClientA", "instrument_id": f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_CALLABLE_5Y_G2", "num_holdings": 50},
        {"client_id": "ClientB", "instrument_id": f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_CONV_BOND_5Y_S0_DYNAMIC", "num_holdings": 35},
        {"client_id": "ClientB", "instrument_id": f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_CONV_BOND_5Y_ALL_DYNAMIC", "num_holdings": 40},
        {"client_id": "ClientB", "instrument_id": f"{DEMO_CURRENCY}_{OPT_UNDERLYING_SYMBOL}_EURO_CALL_1Y_STRIKE105_ORD2", "num_holdings": 200},
        {"client_id": "ClientA", "instrument_id": f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_VANILLA_10Y_FULL", "num_holdings": 80},
        {"client_id": "ClientA", "instrument_id": "MISSING_INSTRUMENT_ID_EXAMPLE", "num_holdings": 10}
    ]

    initial_portfolio_specs = generate_portfolio_specs_for_serialization(
        holdings_data=holdings_data, model_registry=model_registry,
        instrument_definitions_data_for_pricer_params=instrument_definitions_data )

    portfolio_builder_initial = PortfolioBuilder(model_registry)
    client_portfolios = portfolio_builder_initial.build_portfolios_from_specs(
        portfolio_specs_list=initial_portfolio_specs, global_valuation_date=val_d,
        default_g2_params=default_g2_p, default_bs_rfr=0.025, default_bs_div=0.01 )

    if portfolio_builder_initial.uncalculated_instruments:
        print(f"  WARNING (Initial Build): Uncalculated instruments: {portfolio_builder_initial.uncalculated_instruments}")

    # --- Step 3: Portfolio Pricing / VaR Calculation (using PortfolioAnalytics) ---
    print("\\n--- Step 3: Portfolio Pricing / VaR (using PortfolioAnalytics) ---")
    if client_portfolios:
        portfolio_analyzer = PortfolioAnalytics(
            client_portfolios=client_portfolios, global_market_scenarios=global_market_scenarios,
            global_factor_names=global_factor_names, numeric_rate_tenors=numeric_rate_tenors,
            scenario_generator_for_base_values=scenario_gen_global )
        var_results = portfolio_analyzer.run_var_analysis(var_percentiles=[1.0, 5.0])
    else: print("  No client portfolios were built, skipping VaR analysis.")

    # --- Step 4: JSON Serialization/Deserialization Demo for Portfolio Specs ---
    print("\\n--- Step 4: Portfolio JSON Serialization/Deserialization Demo ---")
    portfolio_json_string = None
    if initial_portfolio_specs:
        try:
            portfolio_json_string = json.dumps(initial_portfolio_specs, indent=4, default=portfolio_json_serializer)
            print("\\n--- Portfolio Specifications (JSON Serialized) ---")
            if initial_portfolio_specs:
                 print("Sample of first item in JSON specs:"); print(json.dumps(initial_portfolio_specs[0], indent=4, default=portfolio_json_serializer))
        except Exception as e: print(f"   ERROR serializing portfolio specs to JSON: {e}"); portfolio_json_string = None
    else: print("\\n--- No valid portfolio specifications to serialize to JSON ---")

    if portfolio_json_string:
        print("\\n--- Loading and Pricing Portfolio from JSON String ---")
        try:
            loaded_portfolio_specs_from_str = json.loads(portfolio_json_string)
            portfolio_builder_from_json = PortfolioBuilder(model_registry)
            client_portfolios_from_json = portfolio_builder_from_json.build_portfolios_from_specs(
                portfolio_specs_list=loaded_portfolio_specs_from_str, global_valuation_date=val_d,
                default_g2_params=default_g2_p, default_bs_rfr=0.025, default_bs_div=0.01 )
            if portfolio_builder_from_json.uncalculated_instruments:
                 print(f"  WARNING (JSON Load): Uncalculated instruments: {portfolio_builder_from_json.uncalculated_instruments}")
            if client_portfolios_from_json:
                reloaded_portfolio_analyzer = PortfolioAnalytics(
                    client_portfolios=client_portfolios_from_json, global_market_scenarios=global_market_scenarios,
                    global_factor_names=global_factor_names, numeric_rate_tenors=numeric_rate_tenors,
                    scenario_generator_for_base_values=scenario_gen_global )
                print("  Results for reloaded portfolio from JSON:")
                reloaded_var_results = reloaded_portfolio_analyzer.run_var_analysis(var_percentiles=[1.0, 5.0])
            else: print("  No client portfolios were built from JSON specs.")
        except Exception as e:
            print(f"   ERROR loading or pricing portfolio from JSON string: {e}")
            import traceback; traceback.print_exc()

    print("\\n--- End of Demonstration ---")

if __name__ == "__main__":
    try:
        print(f"QuantLib version: {ql.__version__}")
        run_demonstration( enable_parallel_tff_fitting=False)
    except NameError as e:
        if any(cn in str(e) for cn in ['ProductStaticBase','InstrumentProcessor','PortfolioBuilder','PortfolioAnalytics']):
            print(f"ERROR: Class not defined. Ensure all notebook cells for class definitions are executed. Details: {e}")
        elif 'QuantLib' in str(e) or 'ql' in str(e): print("ERROR: QuantLib not found/imported.")
        else: print(f"A NameError: {e}"); import traceback; traceback.print_exc()
    except Exception as e:
        print(f"An unexpected error: {e}"); import traceback; traceback.print_exc()


QuantLib version: 1.38
--- FastRiskDemo with Workflow Manager (Parallel TFF: False) ---
Generated 100 global market scenarios with factors: ['USD_DEMO_CONV_STOCK_CS', 'USD_DEMO_CONV_STOCK_DIVYIELD', 'USD_DEMO_CONV_STOCK_EQVOL', 'USD_DEMO_CONV_STOCK_S0', 'USD_DEMO_OPT_STOCK_S0', 'USD_DEMO_OPT_STOCK_VOL', 'USD_IR_0.25Y', 'USD_IR_0.50Y', 'USD_IR_1.00Y', 'USD_IR_10.00Y', 'USD_IR_2.00Y', 'USD_IR_3.00Y', 'USD_IR_5.00Y', 'USD_IR_7.00Y']
\n--- Step 1: Instrument Processing & TFF Calibration ---
Processing 6 instrument definitions...
  Processing instruments sequentially...


Processing Instruments: 100%|██████████| 6/6 [00:02<00:00,  2.63it/s]

Finished processing instrument definitions.
Saving model registry to model_registry.json...
  Model registry saved successfully.
\n--- Step 2: Portfolio Construction ---
   Skipping instrument 'MISSING_INSTRUMENT_ID_EXAMPLE' for JSON spec generation: not in valid model_registry or had an error.
Building portfolios from 6 detailed specifications...
Finished building 2 portfolios from detailed specs.
\n--- Step 3: Portfolio Pricing / VaR (using PortfolioAnalytics) ---
Running VaR Analysis for percentiles: ['99%', '95%']
  Analyzing portfolio for ClientA...
    Client ClientA: Base Value=22835.59, Mean Scen. Value=22837.98, VaRs: {'var_99pct': np.float64(176.43897244048952), 'var_95pct': np.float64(142.09450213629026)}
  Analyzing portfolio for ClientB...
    Client ClientB: Base Value=145649.74, Mean Scen. Value=147226.85, VaRs: {'var_99pct': np.float64(32472.24956112121), 'var_95pct': np.float64(22810.359783920892)}
\n--- Step 4: Portfolio JSON Serialization/Deserialization Demo ---
\n-




In [31]:
# tff_performance_demo.py
"""
Demonstrates and compares the performance of Tensor Functional Form (TFF) models
against full pricers for various financial instruments.
Includes TFF calibration time in the performance summary.
Aligns with simplified TensorFunctionalFormCalibrate constructor.
"""
import QuantLib as ql
import numpy as np
from datetime import date
from dateutil.relativedelta import relativedelta
import time
import os

def run_performance_demonstration(
    n_benchmark_scenarios: int = 128, # Defaulting to your last used value
    n_tff_train_samples: int = 32,    # Defaulting to your last used value
    n_tff_test_samples: int = 8,      # Defaulting to your last used value
    enable_parallel_tff_fitting: bool = False, # Defaulting to your last used value
    use_hardcoded_g2_params: bool = True,
    option_feature_eng_order: int = 3 # Defaulting to your last used value
    ):
    """
    Runs the performance comparison between TFF and full pricers.
    """
    print(f"--- TFF Performance Demonstration ---")
    print(f"Benchmarking with {n_benchmark_scenarios} scenarios.")
    print(f"TFF training with {n_tff_train_samples} samples, testing with {n_tff_test_samples} samples.")

    # --- Helper function to get scenario slices ---
    def get_scenario_slice(all_scenarios, all_factor_names, target_factor_names_for_tff):
        if not target_factor_names_for_tff:
            return np.array([]).reshape(all_scenarios.shape[0],0)
        try:
            global_indices_map = {name: i for i, name in enumerate(all_factor_names)}
            ordered_indices = [global_indices_map[name] for name in target_factor_names_for_tff]
            return all_scenarios[:, ordered_indices]
        except KeyError as e:
            missing_factor = str(e).strip("'")
            raise ValueError(
                f"Error slicing scenarios: Factor name '{missing_factor}' required by TFF "
                f"not found in generated scenario factor names. "
                f"Required by TFF: {target_factor_names_for_tff}, "
                f"Available from generator: {all_factor_names}."
            )
        except Exception as e:
            raise RuntimeError(f"General error during scenario slicing for TFF factors {target_factor_names_for_tff}: {e}")

    # --- Global Setup ---
    val_d = date(2025, 5, 18)
    numeric_rate_tenors = np.array([0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 7.0, 10.0])
    DEMO_CURRENCY = "USD"
    DEMO_RATE_INDEX_STUB = "IR"
    OPT_UNDERLYING_SYMBOL = "DEMO_OPT_STOCK"
    CONV_UNDERLYING_SYMBOL = "DEMO_CONV_STOCK"

    default_g2_p = (0.01, 0.003, 0.015, 0.006, -0.75) if use_hardcoded_g2_params else None
    if not use_hardcoded_g2_params and default_g2_p is None:
        print("Warning: G2 calibration is not performed, and no hardcoded G2 params provided. Callable bond TFF might fail.")

    base_demo_currency_rates_values = np.array([0.020, 0.021, 0.022, 0.025, 0.027, 0.030, 0.032, 0.033])
    base_rates_map = {f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_{t:.2f}Y": base_demo_currency_rates_values[i] for i, t in enumerate(numeric_rate_tenors)}

    opt_s0_factor = f"{DEMO_CURRENCY}_{OPT_UNDERLYING_SYMBOL}_S0"
    opt_vol_factor = f"{DEMO_CURRENCY}_{OPT_UNDERLYING_SYMBOL}_VOL"
    conv_s0_factor = f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_S0"
    conv_vol_factor = f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_EQVOL"
    conv_div_factor = f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_DIVYIELD"
    conv_cs_factor = f"{DEMO_CURRENCY}_{CONV_UNDERLYING_SYMBOL}_CS"

    base_s0_map = {opt_s0_factor: 100.0, conv_s0_factor: 100.0}
    base_vol_map = {opt_vol_factor: 0.25, conv_vol_factor: 0.25}
    base_other_factors_map_for_conv = {conv_div_factor: 0.01, conv_cs_factor: 0.015}
    temp_s0_map_for_gen = {**base_s0_map, **base_other_factors_map_for_conv}
    temp_vol_map_for_gen = {**base_vol_map}

    scenario_gen_global = SimpleRandomScenarioGenerator(
        base_rates_map=base_rates_map, base_s0_map=temp_s0_map_for_gen,
        base_vol_map=temp_vol_map_for_gen, random_seed=42)

    benchmark_scenarios, benchmark_factor_names = scenario_gen_global.generate_scenarios(n_benchmark_scenarios)
    print(f"Generated {n_benchmark_scenarios} benchmark scenarios with {len(benchmark_factor_names)} factors.")

    ql.Settings.instance().evaluationDate = ql.Date(val_d.day, val_d.month, val_d.year)
    parallel_workers = os.cpu_count() if enable_parallel_tff_fitting else False
    if enable_parallel_tff_fitting and parallel_workers is None: parallel_workers = 1


    # --- Instrument Definitions and Pricer Templates ---
    vanilla_static = QuantLibBondStaticBase.from_dict({
        'valuation_date': val_d, 'maturity_date': (val_d + relativedelta(years=5)),
        'coupon_rate': 0.03, 'face_value': 100.0, 'currency': DEMO_CURRENCY,
        'index_stub': DEMO_RATE_INDEX_STUB, 'freq': 2})
    ql_pricer_vanilla = QuantLibBondPricer(vanilla_static, method='discount')

    callable_static = CallableBondStaticBase.from_dict({
        'valuation_date': val_d, 'maturity_date': (val_d + relativedelta(years=5)),
        'coupon_rate': 0.032, 'face_value': 100.0, 'currency': DEMO_CURRENCY,
        'index_stub': DEMO_RATE_INDEX_STUB, 'freq': 2,
        'call_dates': [(val_d + relativedelta(years=y)) for y in [2,3,4]],
        'call_prices': [102.0, 101.0, 100.0]})
    ql_pricer_callable = QuantLibBondPricer(callable_static, method='g2', grid_steps=32)

    conv_static_s0_dynamic = ConvertibleBondStaticBase.from_dict({
        'valuation_date': val_d, 'issue_date': (val_d - relativedelta(months=6)),
        'maturity_date': (val_d + relativedelta(years=5, months=-6)), 'coupon_rate': 0.02,
        'conversion_ratio': 20.0, 'face_value': 100.0, 'currency': DEMO_CURRENCY,
        'index_stub': DEMO_RATE_INDEX_STUB, 'underlying_symbol': CONV_UNDERLYING_SYMBOL, 'freq': 2})
    ql_pricer_conv_s0_dynamic = QuantLibBondPricer(conv_static_s0_dynamic, method='convertible_binomial', convertible_engine_steps=50)
    conv_s0_dynamic_fixed_params = {'dividend_yield': 0.01, 'equity_volatility': 0.25, 'credit_spread': 0.015, 's0_val': 100.0}


    option_static = EuropeanOptionStatic.from_dict({
        'valuation_date': val_d, 'expiry_date': (val_d + relativedelta(years=1)),
        'strike_price': 105.0, 'option_type': 'call',
        'currency': DEMO_CURRENCY, 'underlying_symbol': OPT_UNDERLYING_SYMBOL})
    bs_pricer_option = BlackScholesPricer(option_static, risk_free_rate=0.025, dividend_yield=0.01)

    # Simplified TFF Configuration Logic (inspired by TFFConfigurationFactory)
    def get_tff_calibration_inputs(product_static, sg, default_tenors, behavior_params=None, pricer_params=None):
        if behavior_params is None: behavior_params = {}
        if pricer_params is None: pricer_params = {}

        raw_names = []
        raw_base_vals = []
        fixed_training_params = {}
        opt_order = 0
        actual_pillars = np.array([])
        pricer_cfg_worker = {}

        if isinstance(product_static, EuropeanOptionStatic):
            s0fn = f"{product_static.currency}_{product_static.underlying_symbol}_S0"
            volfn = f"{product_static.currency}_{product_static.underlying_symbol}_VOL"
            raw_names = [s0fn, volfn]
            raw_base_vals = [sg.base_s0_map[s0fn], sg.base_vol_map[volfn]]
            opt_order = behavior_params.get('option_feature_order', 0)
            pricer_cfg_worker['bs_pricer_config'] = {
                'risk_free_rate': pricer_params.get('bs_risk_free_rate', 0.025), # Get from pricer_params
                'dividend_yield': pricer_params.get('bs_dividend_yield', 0.01)
            }
        elif isinstance(product_static, QuantLibBondStaticBase):
            rate_names = [f"{product_static.currency}_{product_static.index_stub}_{t:.2f}Y" for t in default_tenors]
            base_rates = [sg.base_rates_map[name] for name in rate_names]
            raw_names.extend(rate_names)
            raw_base_vals.extend(base_rates)
            actual_pillars = _parse_numeric_pillars_from_factor_names(rate_names)
            pricer_cfg_worker['bond_pricer_config'] = {'method': 'discount'} # Default

            if isinstance(product_static, CallableBondStaticBase):
                pricer_cfg_worker['bond_pricer_config']['method'] = 'g2'
                pricer_cfg_worker['bond_pricer_config']['grid_steps'] = pricer_params.get('g2_grid_steps', 32)
                if pricer_params.get('g2_params'): fixed_training_params['g2_params'] = pricer_params['g2_params']

            elif isinstance(product_static, ConvertibleBondStaticBase):
                pricer_cfg_worker['bond_pricer_config']['method'] = 'convertible_binomial'
                pricer_cfg_worker['bond_pricer_config']['convertible_engine_steps'] = pricer_params.get('conv_engine_steps', 50)

                s0fn_cb = f"{product_static.currency}_{product_static.underlying_symbol}_S0"
                raw_names.append(s0fn_cb); raw_base_vals.append(sg.base_s0_map[s0fn_cb])

                conv_all_dynamic = behavior_params.get('convertible_tff_market_inputs_as_factors', False)
                if conv_all_dynamic:
                    divfn = f"{product_static.currency}_{product_static.underlying_symbol}_DIVYIELD"
                    volfn = f"{product_static.currency}_{product_static.underlying_symbol}_EQVOL"
                    csfn  = f"{product_static.currency}_{product_static.underlying_symbol}_CS"
                    raw_names.extend([divfn, volfn, csfn])
                    raw_base_vals.extend([sg.base_s0_map[divfn], sg.base_vol_map[volfn], sg.base_s0_map[csfn]]) # Assuming DIV/CS in s0_map for demo
                else: # S0 and Rates dynamic, others fixed
                    fixed_params = pricer_params # These are the fixed ones for training
                    fixed_training_params['dividend_yield'] = fixed_params.get('dividend_yield')
                    fixed_training_params['equity_volatility'] = fixed_params.get('equity_volatility')
                    fixed_training_params['credit_spread'] = fixed_params.get('credit_spread')
                    # s0_val is not needed in fixed_training_params as S0 is dynamic
                    if any(v is None for k,v in fixed_training_params.items() if k in ['dividend_yield', 'equity_volatility', 'credit_spread']):
                        raise ValueError(f"Missing fixed CB params for TFF training. Got: {fixed_params}")
        return raw_names, np.array(raw_base_vals), product_static.to_dict(), pricer_cfg_worker, actual_pillars, opt_order, fixed_training_params


    instrument_setups = [
        {"id": "Vanilla_Bond_QL", "static": vanilla_static, "full_pricer": ql_pricer_vanilla, "tff_behavior_params": {}},
        {"id": "Callable_Bond_G2", "static": callable_static, "full_pricer": ql_pricer_callable,
         "pricer_params_for_tff": {"g2_params": default_g2_p}, # These are fixed for TFF training
         "full_pricer_kwargs": {"g2_params": default_g2_p} # For full pricer benchmark call
        },
        {"id": "Convertible_S0_Dynamic", "static": conv_static_s0_dynamic, "full_pricer": ql_pricer_conv_s0_dynamic,
         "tff_behavior_params": {"convertible_tff_market_inputs_as_factors": False},
         "pricer_params_for_tff": conv_s0_dynamic_fixed_params, # Fixed params for TFF
         "full_pricer_kwargs": conv_s0_dynamic_fixed_params # For full pricer benchmark
        },
        {"id": "European_Option_BS", "static": option_static, "full_pricer": bs_pricer_option,
         "pricer_params_for_tff": {'bs_risk_free_rate': 0.025, 'bs_dividend_yield': 0.01}, # For pricer_config_for_worker
         "tff_behavior_params": {"option_feature_order": option_feature_eng_order}
        },
    ]

    print("\n--- Calibrating TFF Models ---")
    fitted_tffs = {}

    for setup in instrument_setups:
        instrument_id = setup['id']
        #print(f"  Calibrating TFF for: {instrument_id}")
        fit_time = -1.0
        try:
            if instrument_id == "Callable_Bond_G2" and setup.get("pricer_params_for_tff", {}).get("g2_params") is None:
                print(f"    WARNING: g2_params are None for {instrument_id}. Skipping TFF calibration.")
                continue

            # Get TFF inputs using our local helper
            raw_names, raw_bases, ps_worker_dict, pricer_cfg_worker, act_pillars, opt_ord, fixed_train_params = \
                get_tff_calibration_inputs(
                    product_static=setup['static'],
                    sg=scenario_gen_global,
                    default_tenors=numeric_rate_tenors,
                    behavior_params=setup.get("tff_behavior_params",{}),
                    pricer_params=setup.get("pricer_params_for_tff", {}) # Pass pricer_params for BS config
                )

            tff_cal = TensorFunctionalFormCalibrate(
                pricer_template=setup['full_pricer'],
                tff_input_raw_factor_names=raw_names,
                tff_input_raw_base_values=raw_bases,
                product_static_params_for_worker=ps_worker_dict,
                pricer_config_for_worker=pricer_cfg_worker,
                actual_rate_pillars=act_pillars
            )

            scens_for_this_tff_domain, _ = scenario_gen_global.generate_scenarios(
                num_scenarios=100,
                target_factor_names=tff_cal.tff_input_raw_factor_names )
            if scens_for_this_tff_domain.size == 0 and tff_cal.tff_input_raw_factor_names:
                print(f"    WARNING: Could not generate domain scenarios for {instrument_id}. Skipping TFF.")
                continue

            start_calibration_time = time.time()
            model, _, _, rmse, norm_params = tff_cal.sample_and_fit(
                full_market_scenarios_for_tff_factors=scens_for_this_tff_domain,
                n_train=n_tff_train_samples, n_test=n_tff_test_samples,
                random_seed=42, parallel_workers=parallel_workers,
                option_feature_order=opt_ord, # Use order from get_tff_calibration_inputs
                **fixed_train_params # Pass fixed params for training
            )
            fit_time = time.time() - start_calibration_time

            if model and norm_params:
                fitted_tffs[instrument_id] = {
                    "model": model, "norm_params": norm_params,
                    "raw_input_names": tff_cal.tff_input_raw_factor_names,
                    "option_feature_order": opt_ord, "rmse": rmse,
                    "calibration_time": fit_time
                }
                #print(f"    TFF for {instrument_id} calibrated. RMSE: {rmse:.6f}, Calib Time: {fit_time:.2f}s")
            else: print(f"    WARNING: TFF fitting failed for {instrument_id}.")
        except Exception as e:
            print(f"    ERROR calibrating TFF for {instrument_id}: {e}")
            import traceback; traceback.print_exc()


    print("\n--- Performance Benchmarking ---")
    results_summary = []

    for setup in instrument_setups:
        instrument_id = setup['id']
        print(f"  Benchmarking: {instrument_id}")
        full_pricer = setup['full_pricer']

        tff_data = fitted_tffs.get(instrument_id)
        if not tff_data or tff_data.get("model") is None:
            print(f"    Skipping {instrument_id} as TFF model was not successfully calibrated.")
            continue

        tff_model = tff_data['model']
        tff_norm_params = tff_data['norm_params']
        tff_raw_input_names = tff_data['raw_input_names']
        tff_opt_order = tff_data['option_feature_order']
        tff_calib_time = tff_data['calibration_time']

        scenarios_for_this_instrument_raw = get_scenario_slice(benchmark_scenarios, benchmark_factor_names, tff_raw_input_names)
        if scenarios_for_this_instrument_raw.size == 0 and tff_raw_input_names:
            print(f"    Could not slice benchmark scenarios for {instrument_id}. Skipping.")
            continue

        start_time_full = time.time()
        full_pricer_call_kwargs = setup.get("full_pricer_kwargs", {})

        if isinstance(setup['static'], EuropeanOptionStatic):
            if scenarios_for_this_instrument_raw.shape[1] == 2:
                s0_scens = scenarios_for_this_instrument_raw[:, 0]
                vol_scens = scenarios_for_this_instrument_raw[:, 1]
                full_prices = full_pricer.price(stock_price=s0_scens, volatility=vol_scens, **full_pricer_call_kwargs)
            else:
                print(f"    WARNING: Option scenario data for full pricer has incorrect shape {scenarios_for_this_instrument_raw.shape}. Skipping benchmark.")
                continue
        else:
            full_prices = full_pricer.price(
                pillar_times=numeric_rate_tenors,
                market_scenario_data=scenarios_for_this_instrument_raw,
                **full_pricer_call_kwargs)
        time_full = time.time() - start_time_full

        tff_inputs_for_eval = scenarios_for_this_instrument_raw
        if isinstance(setup['static'], EuropeanOptionStatic) and tff_norm_params.get('is_engineered', False):
            if scenarios_for_this_instrument_raw.shape[1] != 2:
                 print(f"    Skipping TFF eval for {instrument_id}: Option requires 2 raw inputs for engineering.")
                 continue
            eng_feat, _ = engineer_option_features(scenarios_for_this_instrument_raw[:,0], scenarios_for_this_instrument_raw[:,1], order=tff_opt_order)
            np_means = np.array(tff_norm_params['means']) if tff_norm_params.get('means') is not None else None
            np_stds = np.array(tff_norm_params['stds']) if tff_norm_params.get('stds') is not None else None
            tff_inputs_for_eval, _, _ = normalize_features(eng_feat, np_means, np_stds)

        start_time_tff_eval = time.time()
        tff_prices = tff_model(tff_inputs_for_eval)
        time_tff_eval = time.time() - start_time_tff_eval

        speedup_eval_only = time_full / time_tff_eval if time_tff_eval > 1e-9 else float('inf')
        total_tff_time_for_run = tff_calib_time + (time_tff_eval * (n_benchmark_scenarios / scenarios_for_this_instrument_raw.shape[0] if scenarios_for_this_instrument_raw.shape[0] > 0 else 1) ) # Extrapolate eval time
        speedup_total = time_full / total_tff_time_for_run if total_tff_time_for_run > 1e-9 else float('inf')

        results_summary.append({
            "Instrument": instrument_id,
            "Full Pricer Time (s)": f"{time_full:.4f}",
            "TFF Calib Time (s)": f"{tff_calib_time:.2f}",
            "TFF Eval Time (s)": f"{time_tff_eval:.4f}",
            "Speedup (Eval Only)": f"{speedup_eval_only:.1f}x",
            "Speedup (Incl. Calib for this run)": f"{speedup_total:.1f}x",
            "TFF RMSE": f"{tff_data['rmse']:.6f}"
        })
        print(f"    {instrument_id}: Full Time={time_full:.4f}s, TFF Calib={tff_calib_time:.2f}s, TFF Eval={time_tff_eval:.4f}s, Speedup(Eval)={speedup_eval_only:.1f}x, Speedup(Total)={speedup_total:.1f}x, RMSE={tff_data['rmse']:.6f}")


    print("\n--- Performance Summary ---")
    if results_summary:
        header = results_summary[0].keys()
        max_lengths = [len(h) for h in header]
        for row in results_summary:
            for i, val in enumerate(row.values()):
                max_lengths[i] = max(max_lengths[i], len(str(val)))

        header_fmt = " | ".join([f"{{:<{max_lengths[i]}}}" for i in range(len(header))])
        separator_fmt = "-|-".join(['-' * max_lengths[i] for i in range(len(header))])

        print(header_fmt.format(*header))
        print(separator_fmt)
        for row in results_summary:
            print(header_fmt.format(*row.values()))
    else:
        print("No TFF models were successfully benchmarked.")

    print("\n--- End of Performance Demonstration ---")


if __name__ == "__main__":
    try:
        print(f"QuantLib version: {ql.__version__}")
        run_performance_demonstration(
            n_benchmark_scenarios=512 ,
            n_tff_train_samples=32,
            n_tff_test_samples=8,
            enable_parallel_tff_fitting=False,
            option_feature_eng_order=2 # Using the value from the original call
        )
    except NameError as e:
        if any(cn in str(e) for cn in ['ProductStaticBase','TensorFunctionalFormCalibrate']):
            print(f"ERROR: Class not defined. Ensure all notebook cells for class definitions are executed. Details: {e}")
        else: print(f"A NameError: {e}"); import traceback; traceback.print_exc()
    except Exception as e:
        print(f"An unexpected error: {e}"); import traceback; traceback.print_exc()


QuantLib version: 1.38
--- TFF Performance Demonstration ---
Benchmarking with 512 scenarios.
TFF training with 32 samples, testing with 8 samples.
Generated 512 benchmark scenarios with 14 factors.

--- Calibrating TFF Models ---

--- Performance Benchmarking ---
  Benchmarking: Vanilla_Bond_QL
    Vanilla_Bond_QL: Full Time=0.0358s, TFF Calib=0.01s, TFF Eval=0.0001s, Speedup(Eval)=476.1x, Speedup(Total)=4.9x, RMSE=0.001221
  Benchmarking: Callable_Bond_G2
    Callable_Bond_G2: Full Time=7.4660s, TFF Calib=0.60s, TFF Eval=0.0003s, Speedup(Eval)=23633.8x, Speedup(Total)=12.3x, RMSE=0.015620
  Benchmarking: Convertible_S0_Dynamic
    Convertible_S0_Dynamic: Full Time=0.0812s, TFF Calib=0.01s, TFF Eval=0.0001s, Speedup(Eval)=779.8x, Speedup(Total)=5.8x, RMSE=0.000031
  Benchmarking: European_Option_BS
    European_Option_BS: Full Time=0.0006s, TFF Calib=0.01s, TFF Eval=0.0000s, Speedup(Eval)=14.3x, Speedup(Total)=0.1x, RMSE=0.010826

--- Performance Summary ---
Instrument             | F

In [77]:
# parallel_tff_calibration_demo.py
"""
Demonstrates parallel TFF calibration for a large number of instruments
using the InstrumentProcessor.
"""
import QuantLib as ql
import numpy as np
from datetime import date, datetime
from dateutil.relativedelta import relativedelta
import time
import os
import json # For potentially saving/loading portfolio specs

def generate_instrument_definitions(num_instruments: int, val_date_param: date) -> list[dict]: # Changed param name for clarity
    """
    Generates a list of diverse instrument definitions.
    Uses val_date_param consistently.
    """
    definitions = []
    DEMO_CURRENCY = "USD"
    DEMO_RATE_INDEX_STUB = "IR"

    for i in range(num_instruments):
        instrument_type_choice = i % 4 # Cycle through 4 types
        instrument_id_suffix = f"INSTRUMENT_{i+1}"

        # Common params
        maturity_years = np.random.randint(2, 11)
        maturity_dt = val_date_param + relativedelta(years=maturity_years) # Use val_date_param
        coupon = 0.02 + np.random.rand() * 0.03 # Random coupon between 2% and 5%

        if instrument_type_choice == 0: # Vanilla Bond
            instrument_id = f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_VANILLA_{instrument_id_suffix}"
            definitions.append({
                "instrument_id": instrument_id,
                "product_type": "VanillaBond", "pricing_preference": "TFF",
                "params": {
                    "valuation_date": val_date_param, "maturity_date": maturity_dt, # Use val_date_param
                    "coupon_rate": coupon, "face_value": 100.0, "currency": DEMO_CURRENCY,
                    "index_stub": DEMO_RATE_INDEX_STUB, "freq": 2, "settlement_days": 0 },
                "tff_config": {"n_train": 32, "n_test": 4, "seed": i}
            })
        elif instrument_type_choice == 1: # Callable Bond
            instrument_id = f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_CALLABLE_{instrument_id_suffix}"
            call_offset = maturity_years // 2
            call_dates_list = []
            if call_offset >=1 :
                 call_dates_list = [(val_date_param + relativedelta(years=y)).isoformat() for y in range(call_offset, maturity_years)] # Use val_date_param

            definitions.append({
                "instrument_id": instrument_id,
                "product_type": "CallableBond", "pricing_preference": "TFF",
                "params": {
                    "valuation_date": val_date_param, "maturity_date": maturity_dt, # Use val_date_param
                    "coupon_rate": coupon, "face_value": 100.0, "currency": DEMO_CURRENCY,
                    "index_stub": DEMO_RATE_INDEX_STUB, "freq": 2,
                    "call_dates": call_dates_list,
                    "call_prices": [100.0 + len(call_dates_list) - j for j in range(len(call_dates_list))] if call_dates_list else []
                },
                "pricer_params": {"g2_params": (0.01, 0.003, 0.015, 0.006, -0.75)},
                "tff_config": {"n_train": 32, "n_test": 4, "seed": i}
            })
        elif instrument_type_choice == 2: # Convertible Bond (S0 dynamic)
            conv_underlying_sym = f"STOCK_{i%10}"
            instrument_id = f"{DEMO_CURRENCY}_{conv_underlying_sym}_CONV_S0_DYN_{instrument_id_suffix}"
            definitions.append({
                "instrument_id": instrument_id,
                "product_type": "ConvertibleBond", "pricing_preference": "TFF",
                "params": {
                    'valuation_date': val_date_param, 'issue_date': (val_date_param - relativedelta(months=6)),  # Use val_date_param
                    'maturity_date': maturity_dt, 'coupon_rate': coupon,
                    'conversion_ratio': 15.0 + np.random.rand() * 10, 'face_value': 100.0,
                    'currency': DEMO_CURRENCY, 'index_stub': DEMO_RATE_INDEX_STUB,
                    'underlying_symbol': conv_underlying_sym, 'freq': 2
                },
                "pricer_params": {
                    'dividend_yield': 0.01 + np.random.rand()*0.01,
                    'equity_volatility': 0.20 + np.random.rand()*0.1,
                    'credit_spread': 0.01 + np.random.rand()*0.01,
                    's0_val': 90 + np.random.rand()*20
                },
                "tff_config": {"n_train": 64, "n_test": 4, "seed": i,
                               "convertible_tff_market_inputs_as_factors": False
                              }
            })
        elif instrument_type_choice == 3: # European Option
            opt_underlying_sym = f"STOCK_{i%10}"
            strike = 90 + np.random.rand() * 20
            instrument_id = f"{DEMO_CURRENCY}_{opt_underlying_sym}_EURO_CALL_1Y_K{int(strike)}_{instrument_id_suffix}"
            definitions.append({
                "instrument_id": instrument_id,
                "product_type": "EuropeanOption", "pricing_preference": "TFF",
                "params": {
                    'valuation_date': val_date_param, 'expiry_date': (val_date_param + relativedelta(years=1)), # Use val_date_param
                    'strike_price': strike, 'option_type': 'call',
                    'currency': DEMO_CURRENCY, 'underlying_symbol': opt_underlying_sym,
                },
                "pricer_params": { 'bs_risk_free_rate': 0.025, 'bs_dividend_yield': 0.01 },
                "tff_config": {"n_train": 64, "n_test": 10, "option_feature_order": 2, "seed": i}
            })
    return definitions


def run_parallel_calibration_demo(
    num_instruments_to_generate: int = 100,
    num_parallel_workers: int = None
    ):
    print(f"--- Parallel TFF Calibration Demo ---")
    print(f"Generating {num_instruments_to_generate} instruments...")
    print(f"Parallel workers for InstrumentProcessor: {num_parallel_workers if num_parallel_workers is not None else 'os.cpu_count()'}")

    # --- Global Setup ---
    val_d_main = date(2025, 5, 18) # Renamed to avoid conflict with function param if any
    numeric_rate_tenors = np.array([0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 7.0, 10.0])
    DEMO_CURRENCY = "USD"
    DEMO_RATE_INDEX_STUB = "IR"

    base_rates_map = {f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_{t:.2f}Y": 0.02 + t*0.001 for t in numeric_rate_tenors}

    all_underlying_symbols = list(set([f"STOCK_{i%10}" for i in range(num_instruments_to_generate)]))
    base_s0_map = {}
    base_vol_map = {}
    base_other_map = {}

    for sym in all_underlying_symbols:
        base_s0_map[f"{DEMO_CURRENCY}_{sym}_S0"] = 90 + np.random.rand() * 20
        base_vol_map[f"{DEMO_CURRENCY}_{sym}_VOL"] = 0.20 + np.random.rand() * 0.1
        base_vol_map[f"{DEMO_CURRENCY}_{sym}_EQVOL"] = 0.20 + np.random.rand() * 0.1
        base_other_map[f"{DEMO_CURRENCY}_{sym}_DIVYIELD"] = 0.01 + np.random.rand() * 0.01
        base_other_map[f"{DEMO_CURRENCY}_{sym}_CS"] = 0.01 + np.random.rand() * 0.01

    merged_s0_map = {**base_s0_map, **base_other_map}

    scenario_gen_global = SimpleRandomScenarioGenerator(
        base_rates_map=base_rates_map,
        base_s0_map=merged_s0_map,
        base_vol_map=base_vol_map,
        random_seed=42
    )

    N_DOMAIN_SCENARIOS = 500
    global_market_scenarios, global_factor_names = scenario_gen_global.generate_scenarios(N_DOMAIN_SCENARIOS)
    print(f"Generated {N_DOMAIN_SCENARIOS} global scenarios for TFF domain with {len(global_factor_names)} factors.")

    # Generate instrument definitions, passing val_d_main
    instrument_definitions = generate_instrument_definitions(num_instruments_to_generate, val_d_main)
    print(f"Created {len(instrument_definitions)} instrument definitions.")

    # --- Instrument Processing with Parallel TFF Calibration ---
    print("\\n--- Instrument Processing & TFF Calibration ---")

    instrument_processor = InstrumentProcessor(
        scenario_generator=scenario_gen_global,
        global_valuation_date=val_d_main, # Use val_d_main
        default_numeric_rate_tenors=numeric_rate_tenors,
        default_g2_params=(0.01, 0.003, 0.015, 0.006, -0.75),
        default_bs_risk_free_rate=0.025,
        default_bs_dividend_yield=0.01,
        parallel_workers_tff=num_parallel_workers,
        n_scenarios_for_tff_domain=100
    )

    start_processing_time = time.time()
    model_registry = instrument_processor.process_instruments(
        instrument_definitions,
        global_market_scenarios,
        global_factor_names
    )
    total_processing_time = time.time() - start_processing_time

    print(f"\nTotal instrument processing time: {total_processing_time:.2f} seconds.")

    successful_tff_calibrations = 0
    failed_tff_calibrations = 0
    full_pricers_used = 0

    for instrument_id, entry in model_registry.items():
        if entry.get('pricing_method') == 'TFF' and 'tff_model_dict' in entry and not entry.get('error_tff_calibration'):
            successful_tff_calibrations += 1
        elif entry.get('pricing_method') == 'TFF' and entry.get('error_tff_calibration'):
            failed_tff_calibrations +=1
            print(f"  TFF failed for {instrument_id}, fell back to FULL. Error: {entry['error_tff_calibration']}")
        elif entry.get('pricing_method') == 'FULL':
            full_pricers_used +=1
        elif entry.get('error'):
            print(f"  Error processing {instrument_id}: {entry['error']}")

    print("\n--- Processing Summary ---")
    print(f"Total instruments processed: {len(instrument_definitions)}")
    print(f"Total instrument processing time: {total_processing_time:.2f} seconds.")
    print(f"Total instrument processing time per instrument: {total_processing_time/len(instrument_definitions):.5f} seconds.")
    print(f"Instrument processing time per instrument per core: {(total_processing_time/len(instrument_definitions)*num_parallel_workers):.5f} seconds.")
    print(f"Successfully calibrated TFF models: {successful_tff_calibrations}")
    print(f"Failed TFF calibrations (fell back to FULL or error): {failed_tff_calibrations}")
    print(f"Instruments set to use FULL pricer: {full_pricers_used}")

    print("\\n--- End of Parallel Calibration Demonstration ---")


if __name__ == "__main__":
    try:
        print(f"QuantLib version: {ql.__version__}")

        num_workers = os.cpu_count()
        # num_workers = False # To run InstrumentProcessor sequentially for comparison

        run_parallel_calibration_demo(
            num_instruments_to_generate=1024,
            num_parallel_workers=16
        )
    except NameError as e:
        if any(cn in str(e) for cn in ['ProductStaticBase','InstrumentProcessor','PortfolioBuilder','PortfolioAnalytics']): # Added PortfolioAnalytics
            print(f"ERROR: Class not defined. Ensure all notebook cells for class definitions are executed. Details: {e}")
        elif 'QuantLib' in str(e) or 'ql' in str(e): print("ERROR: QuantLib not found/imported.")
        else: print(f"A NameError: {e}"); import traceback; traceback.print_exc()
    except Exception as e:
        print(f"An unexpected error: {e}"); import traceback; traceback.print_exc()


QuantLib version: 1.38
--- Parallel TFF Calibration Demo ---
Generating 1024 instruments...
Parallel workers for InstrumentProcessor: 16
Generated 500 global scenarios for TFF domain with 58 factors.
Created 1024 instrument definitions.
\n--- Instrument Processing & TFF Calibration ---
Processing 1024 instrument definitions...
  Processing instruments in parallel (workers=16)...


Processing Instruments: 100%|██████████| 1024/1024 [00:17<00:00, 59.58it/s]

Finished processing instrument definitions.

Total instrument processing time: 17.40 seconds.

--- Processing Summary ---
Total instruments processed: 1024
Total instrument processing time: 17.40 seconds.
Total instrument processing time per instrument: 0.01700 seconds.
Instrument processing time per instrument per core: 0.27193 seconds.
Successfully calibrated TFF models: 1024
Failed TFF calibrations (fell back to FULL or error): 0
Instruments set to use FULL pricer: 0
\n--- End of Parallel Calibration Demonstration ---





In [96]:
# parallel_tff_calibration_demo.py
"""
Demonstrates parallel TFF calibration for a large number of instruments
using the InstrumentProcessor, followed by portfolio construction and analytics.
"""
import QuantLib as ql
import numpy as np
from datetime import date, datetime
from dateutil.relativedelta import relativedelta
import time
import os
import json

# Note: Pricer and TFFCalibrate classes are used internally by InstrumentProcessor

def generate_instrument_definitions(num_instruments: int, val_date_param: date) -> list[dict]:
    """
    Generates a list of diverse instrument definitions.
    Uses val_date_param consistently.
    """
    definitions = []
    DEMO_CURRENCY = "USD"
    DEMO_RATE_INDEX_STUB = "IR"

    # Define fixed parameters for "S0_DYNAMIC" convertibles for consistency in TFF training
    conv_s0_dynamic_fixed_pricer_params = {
        's0_val': 100.0, # This is a base for the dynamic S0, but pricer will get scenario S0
        'dividend_yield': 0.01,
        'equity_volatility': 0.25,
        'credit_spread': 0.015
    }
    # Default pricer params for options
    default_option_pricer_params = { 'bs_risk_free_rate': 0.025, 'bs_dividend_yield': 0.01 }
    conv_s0_dynamic_fixed_params = {'dividend_yield': 0.01, 'equity_volatility': 0.25, 'credit_spread': 0.015, 's0_val': 100.0}
    default_g2_p = (0.01, 0.003, 0.015, 0.006, -0.75)

    for i in range(num_instruments):
        instrument_type_choice = i % 4
        instrument_id_suffix = f"INSTRUMENT_{i+1}"

        maturity_years = np.random.randint(2, 11)
        maturity_dt = val_date_param + relativedelta(years=maturity_years)
        coupon = 0.02 + np.random.rand() * 0.03

        if instrument_type_choice == 0: # Vanilla Bond
            instrument_id = f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_VANILLA_{instrument_id_suffix}"
            definitions.append({
                "instrument_id": instrument_id,
                "product_type": "VanillaBond", "pricing_preference": "TFF",
                "params": {
                    "valuation_date": val_date_param, "maturity_date": maturity_dt,
                    "coupon_rate": coupon, "face_value": 100.0, "currency": DEMO_CURRENCY,
                    "index_stub": DEMO_RATE_INDEX_STUB, "freq": 2, "settlement_days": 0 },
                "tff_config": {"n_train": 32, "n_test": 4, "seed": i}
            })
        elif instrument_type_choice == 1: # Callable Bond
            instrument_id = f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_CALLABLE_{instrument_id_suffix}"
            call_offset = maturity_years // 2
            call_dates_list = []
            if call_offset >=1 :
                 call_dates_list = [(val_date_param + relativedelta(years=y)).isoformat() for y in range(call_offset, maturity_years)]

            definitions.append({
                "instrument_id": instrument_id,
                "product_type": "CallableBond", "pricing_preference": "TFF",
                "params": {
                    "valuation_date": val_date_param, "maturity_date": maturity_dt,
                    "coupon_rate": coupon, "face_value": 100.0, "currency": DEMO_CURRENCY,
                    "index_stub": DEMO_RATE_INDEX_STUB, "freq": 2,
                    "call_dates": call_dates_list,
                    "call_prices": [100.0 + len(call_dates_list) - j for j in range(len(call_dates_list))] if call_dates_list else []
                },
                "pricer_params": {"g2_params": (0.01, 0.003, 0.015, 0.006, -0.75)},
                "tff_config": {"n_train": 64, "n_test": 4, "seed": i}
            })
        elif instrument_type_choice == 2: # Convertible Bond (S0 dynamic, others fixed for TFF)
            conv_underlying_sym = f"STOCK_{i%10}"
            instrument_id = f"{DEMO_CURRENCY}_{conv_underlying_sym}_CONV_S0_DYN_{instrument_id_suffix}"
            definitions.append({
                "instrument_id": instrument_id,
                "product_type": "ConvertibleBond", "pricing_preference": "TFF",
                "params": {
                    'valuation_date': val_date_param, 'issue_date': (val_date_param - relativedelta(months=6)),
                    'maturity_date': maturity_dt, 'coupon_rate': coupon,
                    'conversion_ratio': 15.0 + np.random.rand() * 10, 'face_value': 100.0,
                    'currency': DEMO_CURRENCY, 'index_stub': DEMO_RATE_INDEX_STUB,
                    'underlying_symbol': conv_underlying_sym, 'freq': 2
                },
                "pricer_params": conv_s0_dynamic_fixed_params, # Used for fixed params in TFF training
                "tff_config": {"n_train": 64, "n_test": 4, "seed": i,
                               "convertible_tff_market_inputs_as_factors": False
                              }
            })
        elif instrument_type_choice == 3: # European Option
            opt_underlying_sym = f"STOCK_{i%10}"
            strike = 90 + np.random.rand() * 20
            instrument_id = f"{DEMO_CURRENCY}_{opt_underlying_sym}_EURO_CALL_1Y_K{int(strike)}_{instrument_id_suffix}"
            definitions.append({
                "instrument_id": instrument_id,
                "product_type": "EuropeanOption", "pricing_preference": "TFF",
                "params": {
                    'valuation_date': val_date_param, 'expiry_date': (val_date_param + relativedelta(years=1)),
                    'strike_price': strike, 'option_type': 'call',
                    'currency': DEMO_CURRENCY, 'underlying_symbol': opt_underlying_sym,
                },
                "pricer_params": default_option_pricer_params,
                "tff_config": {"n_train": 64, "n_test": 10, "option_feature_order": 2, "seed": i}
            })
    return definitions


def run_parallel_calibration_demo(
    num_instruments_to_generate: int = 1000, # Changed to 1000
    num_parallel_workers: int = None
    ):
    print(f"--- Parallel TFF Calibration Demo ---")
    print(f"Generating {num_instruments_to_generate} instruments...")
    actual_workers = num_parallel_workers if num_parallel_workers is not None else os.cpu_count()
    print(f"Parallel workers for InstrumentProcessor: {actual_workers}")

    # --- Global Setup ---
    val_d_main = date(2025, 5, 18)
    numeric_rate_tenors = np.array([0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 7.0, 10.0])
    DEMO_CURRENCY = "USD"
    DEMO_RATE_INDEX_STUB = "IR"
    default_g2_p = (0.01, 0.003, 0.015, 0.006, -0.75)

    base_rates_map = {f"{DEMO_CURRENCY}_{DEMO_RATE_INDEX_STUB}_{t:.2f}Y": 0.02 + t*0.001 for t in numeric_rate_tenors}

    all_underlying_symbols = list(set([f"STOCK_{i%10}" for i in range(num_instruments_to_generate)]))
    base_s0_map = {}
    base_vol_map = {}
    base_other_map = {}

    for sym in all_underlying_symbols:
        base_s0_map[f"{DEMO_CURRENCY}_{sym}_S0"] = 90 + np.random.rand() * 20
        base_vol_map[f"{DEMO_CURRENCY}_{sym}_VOL"] = 0.20 + np.random.rand() * 0.1
        base_vol_map[f"{DEMO_CURRENCY}_{sym}_EQVOL"] = 0.20 + np.random.rand() * 0.1
        base_other_map[f"{DEMO_CURRENCY}_{sym}_DIVYIELD"] = 0.01 + np.random.rand() * 0.01
        base_other_map[f"{DEMO_CURRENCY}_{sym}_CS"] = 0.01 + np.random.rand() * 0.01

    merged_s0_map = {**base_s0_map, **base_other_map}

    scenario_gen_global = SimpleRandomScenarioGenerator(
        base_rates_map=base_rates_map,
        base_s0_map=merged_s0_map,
        base_vol_map=base_vol_map,
        random_seed=42
    )

    N_DOMAIN_SCENARIOS = 100 # Reduced for faster domain generation in demo
    global_market_scenarios, global_factor_names = scenario_gen_global.generate_scenarios(N_DOMAIN_SCENARIOS)
    print(f"Generated {N_DOMAIN_SCENARIOS} global scenarios for TFF domain with {len(global_factor_names)} factors.")

    instrument_definitions = generate_instrument_definitions(num_instruments_to_generate, val_d_main)
    print(f"Created {len(instrument_definitions)} instrument definitions.")

    # --- Instrument Processing with Parallel TFF Calibration ---
    print("\\n--- Instrument Processing & TFF Calibration ---")

    instrument_processor = InstrumentProcessor(
        scenario_generator=scenario_gen_global,
        global_valuation_date=val_d_main,
        default_numeric_rate_tenors=numeric_rate_tenors,
        default_g2_params=(0.01, 0.003, 0.015, 0.006, -0.75),
        default_bs_risk_free_rate=0.025,
        default_bs_dividend_yield=0.01,
        parallel_workers_tff=actual_workers, # Use calculated actual_workers
        n_scenarios_for_tff_domain=50 # Smaller set for faster sample_and_fit domain slicing
    )

    start_processing_time = time.time()
    model_registry = instrument_processor.process_instruments(
        instrument_definitions,
        global_market_scenarios,
        global_factor_names
    )
    total_processing_time = time.time() - start_processing_time

    print(f"\nTotal instrument processing time: {total_processing_time:.2f} seconds.")

    successful_tff_calibrations = 0
    failed_calibrations = 0
    for instrument_id, entry in model_registry.items():
        if entry.get('pricing_method') == 'TFF' and 'tff_model_dict' in entry and not entry.get('error_tff_calibration'):
            successful_tff_calibrations += 1
        elif entry.get('error') or entry.get('error_tff_calibration'):
            failed_calibrations +=1
            # print(f"  TFF failed for {instrument_id}. Error: {entry.get('error_tff_calibration', entry.get('error'))}")


    print("\n--- Processing Summary ---")
    print(f"Total instruments to process: {len(instrument_definitions)}")
    print(f"Successfully calibrated TFF models: {successful_tff_calibrations}")
    print(f"Failed TFF calibrations/processing errors: {failed_calibrations}")
    if len(instrument_definitions) > 0 :
        print(f"Average processing time per instrument: {total_processing_time/len(instrument_definitions):.3f} seconds.")


    # --- Portfolio Construction & Analytics ---
    port_time = time.time()
    print("\n--- Portfolio Construction & Analytics ---")
    if successful_tff_calibrations == 0:
        print("No TFF models were successfully calibrated. Skipping portfolio analytics.")
    else:
        holdings_data = []
        for instrument_id, entry in model_registry.items():
            if entry.get('pricing_method') == 'TFF' and 'tff_model_dict' in entry and not entry.get('error_tff_calibration'):
                holdings_data.append({
                    "client_id": "ClientParallelDemo",
                    "instrument_id": instrument_id,
                    "num_holdings": 1000 # As requested
                })

        if not holdings_data:
            print("No successfully calibrated TFF instruments to add to portfolio.")
        else:
            print(f"Creating portfolio with {len(holdings_data)} successfully TFF-calibrated instruments.")
            initial_portfolio_specs = generate_portfolio_specs_for_serialization(
                holdings_data=holdings_data,
                model_registry=model_registry,
                instrument_definitions_data_for_pricer_params=instrument_definitions # Pass original defs
            )

            portfolio_builder = PortfolioBuilder(model_registry)
            client_portfolios = portfolio_builder.build_portfolios_from_specs(
                portfolio_specs_list=initial_portfolio_specs,
                global_valuation_date=val_d_main,
                default_g2_params=default_g2_p, # Pass defaults for full pricer reconstruction if needed
                default_bs_rfr=0.025,
                default_bs_div=0.01
            )

            if portfolio_builder.uncalculated_instruments:
                print(f"  WARNING: Uncalculated instruments during portfolio build: {portfolio_builder.uncalculated_instruments}")

            if client_portfolios.get("ClientParallelDemo"):
                portfolio_analyzer = PortfolioAnalytics(
                    client_portfolios=client_portfolios,
                    global_market_scenarios=global_market_scenarios, # Use the smaller domain scenarios for VaR demo
                    global_factor_names=global_factor_names,
                    numeric_rate_tenors=numeric_rate_tenors,
                    scenario_generator_for_base_values=scenario_gen_global )

                print("\n--- VaR Calculation for 'ClientParallelDemo' Portfolio ---")
                var_results = portfolio_analyzer.run_var_analysis(var_percentiles=[1.0, 5.0])
            else:
                print("Portfolio 'ClientParallelDemo' not built.")
    print(f"Portfolio construction and analytics time: {time.time() - port_time:.5f} seconds.")
    print("\\n--- End of Parallel Calibration Demonstration ---")


if __name__ == "__main__":
    try:
        print(f"QuantLib version: {ql.__version__}")

        num_workers = os.cpu_count()
        # num_workers = 2 # For testing with fewer workers
        # num_workers = False # To run InstrumentProcessor sequentially for comparison

        run_parallel_calibration_demo(
            num_instruments_to_generate=128, # As per user request
            num_parallel_workers=16
        )
    except NameError as e:
        # Add TFFConfigurationFactory if it's directly used and might cause NameError
        if any(cn in str(e) for cn in ['ProductStaticBase','InstrumentProcessor','PortfolioBuilder','PortfolioAnalytics', 'TFFConfigurationFactory']):
            print(f"ERROR: Class not defined. Ensure all notebook cells for class definitions are executed. Details: {e}")
        elif 'QuantLib' in str(e) or 'ql' in str(e): print("ERROR: QuantLib not found/imported.")
        else: print(f"A NameError: {e}"); import traceback; traceback.print_exc()
    except Exception as e:
        print(f"An unexpected error: {e}"); import traceback; traceback.print_exc()


QuantLib version: 1.38
--- Parallel TFF Calibration Demo ---
Generating 128 instruments...
Parallel workers for InstrumentProcessor: 16
Generated 100 global scenarios for TFF domain with 58 factors.
Created 128 instrument definitions.
\n--- Instrument Processing & TFF Calibration ---
Processing 128 instrument definitions...
  Processing instruments in parallel (workers=16)...


Processing Instruments: 100%|██████████| 128/128 [00:04<00:00, 30.87it/s]

Finished processing instrument definitions.

Total instrument processing time: 4.35 seconds.

--- Processing Summary ---
Total instruments to process: 128
Successfully calibrated TFF models: 128
Failed TFF calibrations/processing errors: 0
Average processing time per instrument: 0.034 seconds.

--- Portfolio Construction & Analytics ---
Creating portfolio with 128 successfully TFF-calibrated instruments.
Building portfolios from 128 detailed specifications...
Finished building 1 portfolios from detailed specs.

--- VaR Calculation for 'ClientParallelDemo' Portfolio ---
Running VaR Analysis for percentiles: ['99%', '95%']
  Analyzing portfolio for ClientParallelDemo...
    Client ClientParallelDemo: Base Value=67934190.41, Mean Scen. Value=68066436.69, VaRs: {'var_99pct': np.float64(5319164.739779026), 'var_95pct': np.float64(4492584.654667854)}
Portfolio construction and analytics time: 0.02433 seconds.
\n--- End of Parallel Calibration Demonstration ---



