In [0]:
userlastname = 'cook'
from IPython.display import clear_output, display
try:
    %reload_ext autotime
except:
    %pip install -U ipython-autotime ipywidgets codetiming Jinja2 openpyxl numpy pandas geopandas scikit-learn pgeocode flaml[automl] git+https://github.com/AnotherSamWilson/miceforest.git
    dbutils.library.restartPython()
    clear_output()
    dbutils.notebook.exit('Rerun to use newly installed/updated packages')

import os, sys, copy, pathlib, shutil, pickle, warnings, requests, dataclasses, time, codetiming, numpy as np, pandas as pd, sklearn as sk, geopandas as gpd, pgeocode, miceforest as mf, flaml as fl
from pgeocode import Nominatim
from sklearn.metrics import log_loss
from sklearn.metrics.pairwise import haversine_distances
from sklearn.model_selection import cross_val_predict
clear_output()
pd.options.display.max_columns = None
sk.set_config(transform_output="pandas")
now = pd.Timestamp.now()
eps = np.finfo(float).eps
tab = '    '
divider = '##############################################################################################################'
catalog = 'dev.bronze.'
root = pathlib.Path(f'/Volumes/aiml/amp')
geo = root/f'amp_cook_files/202508/geo'
shr = root/f'amp_cook_files/202508/data'
usr = root/f'amp_{userlastname}_files/202508/output'
flags_raw = pathlib.Path('/Volumes/aiml/scook/scook_files/admitted_flags_raw')
flags_prc = pathlib.Path('/Volumes/aiml/flags/flags_volume/')

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=sk.exceptions.InconsistentVersionWarning)
for w in [
    "Could not infer format, so each element will be parsed individually, falling back to `dateutil`",
    ]:
    warnings.filterwarnings(action='ignore', message=f".*{w}.*")


############ helper functions ############
def dt(*args):
    return pd.to_datetime(args).dropna().min().normalize()

def setmeth(cls, fcn):
    """monkey-patch new method into a mutable class (fails for immutable class)"""
    setattr(cls, fcn.__name__, fcn)

def listify(*args, sort=False, reverse=False):
    """ensure it is a list"""
    if len(args)==1:
        if args[0] is None or args[0] is np.nan or args[0] is pd.NA:
            return list()
        elif isinstance(args[0], str):
            return [args[0]]
    try:
        L = list(*args)
    except Exception as e:
        L = list(args)
    if sort:
        try:
            L = sorted(L, reverse=reverse) 
        except Exception as e:
            pass
    return L

def setify(*args):
    """ensure it is a set"""
    return set(listify(*args))

def unpack(*args, **kwargs):
    L = [y for x in args for y in (unpack(*x) if isinstance(x, (list,tuple,set)) else listify(x))]
    return listify(L, **kwargs)

def unique(*args, **kwargs):
    L = dict.fromkeys(unpack(*args))
    return listify(L, **kwargs)

def difference(A, B, **kwargs):
    return unique([x for x in listify(A) if x not in listify(B)], **kwargs)

def rjust(x, width, fillchar=' '):
    return str(x).rjust(width,str(fillchar))

def ljust(x, width, fillchar=' '):
    return str(x).ljust(width,str(fillchar))

def join(lst, sep='\n,', pre='', post=''):
    """flexible way to join list of strings into a single string"""
    return f"{pre}{str(sep).join(map(str,listify(lst)))}{post}"

def alias(dct):
    """convert dict of original column name:new column name into list"""
    return [f'{k} as {v}' for k,v in dct.items()]

def indent(x, lev=1):
    return x.replace('\n','\n'+tab*lev) if lev>0 else x

def subqry(qry, lev=1):
    """make qry into subquery"""
    qry = '\n' + qry.strip()
    qry = '(' + qry + '\n)' if 'select' in qry else qry
    return indent(qry, lev)

def run(qry, show=False, sample='10 rows', seed=42):
    """run qry and return dataframe"""
    L = qry.split(' ')
    if len(L) == 1:
        qry = f'select * from {catalog}{L[0]}'
        if sample is not None:
            qry += f' tablesample ({sample}) repeatable ({seed})'
    if show:
        print(qry)
    return spark.sql(qry).toPandas().prep().sort_index()


############ pandas functions ############
def disp(X, max_rows=3, sort=False):
    """convenient display method"""
    print(type(X), X.shape)
    X = (X.sort_index(axis=1) if sort else X).reset_index()
    Y = pd.DataFrame({'dtype':X.dtypes.astype('string'), 'missing_pct':X.isnull().mean()*100}).T.rename_axis('column').reset_index().prep(case='')
    print(X.shape)
    display(Y)
    display(X.head(max_rows))

# def disp(X, max_rows=3, precision=None, sort=False, **props):
# def disp(X, max_rows=3, sort=False):#, precision=3, **props):
#     """convenient display method"""
#     X = (X.sort_index(axis=1) if sort else X).reset_index()
#     Y = pd.DataFrame({'dtype':X.dtypes.astype('string'), 'missing_pct':X.isnull().mean()*100}).T.rename_axis('column').prep(case='')
#     print(X.shape)
#     display(Y)
#     display(X.head(max_rows))
    # X = X.head(max_rows)
    # display(X)
    # props = {
    #     'text-align': 'center',
    #     'vertical-align': 'top',
    #     'border': '1px dotted black',
    #     'width': 'auto',
    #     'font-size': '16px',
    #     } | props
    # fmt = {'precision': precision, 'hyperlinks': 'html'}
    # # display(X.head(max_rows).reset_index())
    # # display(X)
    # display(X.style
    #     .format(**fmt)
    #     # .format_index(**fmt, axis=0)
    #     # .format_index(**fmt, axis=1)
    #     # .set_table_styles([{'selector':k, 'props':[*props.items()]} for k in ['th','td']])
    #     # .set_table_attributes('style="border-collapse: collapse"')
    # )
    # assert 1==2

def to_numeric(df, case='lower', downcast='integer', errors='ignore', category=False, **kwargs):
    """convert to numeric dtypes if possible"""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        case = case if case in dir(pd.Series().str) else 'strip'
        return (
            df
            .apply(lambda s: getattr(s.astype('string').str.strip().str,case)() if s.dtype in ['object','string'] else s)  # prep strings
            .apply(lambda s: s if pd.api.types.is_datetime64_any_dtype(s) else pd.to_numeric(s, downcast=downcast, errors=errors, **kwargs))  # convert to numeric if possible
            .convert_dtypes()  # convert to new nullable dtypes
            .apply(lambda s: s.astype('Int64') if pd.api.types.is_integer_dtype(s) else s.astype('category') if s.dtype=='string' and category else s)
        )

def prep(df, **kwargs):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        h = lambda x: x.to_numeric(**kwargs).rename(columns=lambda s: s.lower().strip().replace(' ','_').replace('-','_') if isinstance(s, str) else s)
        idx = h(df[[]].reset_index())  # drop columns, reset_index to move index to columns, then apply g
        return h(df).reset_index(drop=True).set_index(pd.MultiIndex.from_frame(idx))  # set idx back to df's index

def groupb(df, by=None, **kwargs):
    """my preferred defaults for groupby"""
    kwargs = {'axis':0,'level':None,'as_index':True,'sort':False,'group_keys':False,'observed':False,'dropna':False}|kwargs
    return df.groupby(by, **kwargs)

def get_incoming(df):
    return df.query("levl_code=='ug' & styp_code in ['n','r','t']")

def get_duplicates(df, subset='pidm', quit=True, rows=10):
    mask = df.groupb(subset, sort=True).transform('size') > 1
    if mask.any():
        df[mask].disp(rows)
        if quit:
            raise Exception(f'{mask.sum()} duplicates detected')
    return df[mask]

def get_missing(df, rows=-1):
    miss = df.isnull().mean()*100
    if miss.any():
        miss[miss>0].sort_values(ascending=False).round(1).disp(rows)
    return miss

def wrap(fcn):
    """Make new methods work for Series and DataFrames"""
    def wrapper(X, *args, **kwargs):
        df = fcn(pd.DataFrame(X), *args, **kwargs)
        return None if df is None else df.squeeze() if isinstance(X, pd.Series) else df  # squeeze to series if input was series
    wrapper.__name__ = fcn.__name__
    return wrapper

for fcn in [
    disp,
    to_numeric,
    prep,
    get_incoming,
    get_duplicates,
    get_missing,
    groupb,
    ]:
    """monkey-patch my helpers into Pandas Series & DataFrame classees so we can use df.method syntax"""
    setmeth(pd.DataFrame, fcn)
    setmeth(pd.Series, wrap(fcn))

############ file i/o functions ############
def get_size(path):
    os.system(f'du -h {path}')

def rm(path, root=False):
    path = pathlib.Path(path)
    if path.is_file():
        path.unlink()
    elif path.is_dir():
        if root:
            shutil.rmtree(path)
        else:
            for p in path.iterdir():
                rm(p, True)
    return path

def mkdir(path):
    path = pathlib.Path(path)
    (path if path.suffix == '' else path.parent).mkdir(parents=True, exist_ok=True)
    return path

def reset(path):
    rm(path)
    mkdir(path)
    return path

def prepr(X):
    if isinstance(X, (pd.DataFrame,pd.Series)):
        return X.prep()
    elif isinstance(X, dict):
        return {k: prepr(v) for k, v in X.items()}
    elif isinstance(X, (list,tuple,set)):
        return type(X)(prepr(v) for v in X)
    else:
        return X

def dump(path, obj):
    path = reset(path)
    obj = prepr(obj)
    if path.suffix == '.parquet':
        pd.DataFrame(obj).to_parquet(path)  # forced to wrap with explicit pd.DataFrame to due strange error under pandas 2.2.3 "Object of type PlanMetrics is not JSON serializable" with to_parquet
    elif path.suffix == '.csv':
        pd.DataFrame(obj).to_csv(path)
    else:
        with open(path, 'wb') as f:
            pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
    return obj

def load(path):
    path = pathlib.Path(path)
    if path.suffix == '.parquet':
        return pd.read_parquet(path)
    elif path.suffix == '.csv':
        return pd.read_csv(path)
    else:
        with open(path, 'rb') as f:
            return pickle.load(f)

def get_desc(code):
    for nm in code.split('_'):
        if len(nm) == 4:
            break
    return f'{code} as {nm}_code, (select stv{nm}_desc from {catalog}saturnstv{nm} where {code} = stv{nm}_code limit 1) as {nm}_desc'

def coalesce(x, y=False):
    return f'coalesce({x}, {y}) as {x}'

races = [f'race_{r}' for r in ['asian','black','hispanic','native','pacific','white']]

def prediction(clf, X, y, cross=False):
    Z = X.copy()
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        p = (cross_val_predict(clf.model, X, y, cv=min(10,y.sum()), method='predict_proba') if cross and y.sum()>1 else clf.predict_proba(X)).T[1]
    return pd.DataFrame({'prediction': p, 'actual': y, 'error': p-y, 'cv_score':clf.best_loss})
setmeth(fl.automl.automl.AutoML, prediction)


def custom_log_loss(X_val, y_val, estimator, labels, X_train, y_train, weight_val=None, weight_train=None, config=None, groups_val=None, groups_train=None):
    """Some (crse,styp) are entirely False which causes an error with built-in log_loss. We create a custom_log_loss simply to set labels=[False, True] https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML/"""
    start = time.time()
    y_pred = estimator.predict_proba(X_val)
    pred_time = (time.time() - start) / len(X_val)
    val_loss = log_loss(y_val, y_pred, labels=[False,True], sample_weight=weight_val)
    y_pred = estimator.predict_proba(X_train)
    train_loss = log_loss(y_train, y_pred, labels=[False,True], sample_weight=weight_train)
    return val_loss, {"val_loss": val_loss, "train_loss": train_loss, "pred_time": pred_time}

###########################################################################################################################
###########################################################################################################################
@dataclasses.dataclass
class Data():
    term_code: int
    amp_date: str = ''
    overwrite: set = None
    seed: int = 42

    #Allows self['attr'] and self.attr syntax
    def __contains__(self, key):
        return hasattr(self, key)
    def __getitem__(self, key):
        return getattr(self, key)
    def __setitem__(self, key, val):
        setattr(self, key, val)
    def __delitem__(self, key):
        if key in self:
            delattr(self, key)


    def __post_init__(self):
        self.overwrite = setify(self.overwrite)
        # Because these take about 1 hour each, force user to manually delete drivetimes parquet files to avoid accidental deletion requring lengthy re-creation
        self.overwrite.discard('drivetimes')
        self.amp_date = dt(self.amp_date, now)
        self.amp_date += pd.Timedelta(days=2-self.amp_date.weekday())  # move to closest Wednesday (Flags release)
        self.year = self.term_code // 100
        self.term_code, self.amp_date, self.amp_day, self.stable_date, self.term_desc, self.stem = self.get_dates(self.term_code, self.amp_date)


    def get_dates(self, term_code, current_date):
        term_code = int(term_code)
        term_desc, stable_date = self.get_terms().loc[term_code,['term_desc','stable_date']]
        current_date = dt(current_date).date()
        stable_date = dt(stable_date).date()
        current_day = (stable_date - current_date).days
        stem = f'{current_date}_{term_code}_{"-" if current_day < 0 else "+"}{rjust(abs(current_day),3,0)}'
        return term_code, current_date, current_day, stable_date, term_desc, stem

    def get_dst(self, path, nm, suffix='.parquet'):
        # nm = file.split('_')[0]
        if path in [shr,usr]:
            dst = path/f"{self.amp_date}/{self.term_code}/{nm.split('_')[0]}/{self.stem}_{nm}"
        else:
            dst = path/nm
        return dst.with_suffix(suffix)


    def get(self, fcn, path, nm, suffix='.parquet', prereq=[], divide=True, read=True, **kwargs):
        dst = self.get_dst(path, nm, suffix)
        if nm in self.overwrite:
            del self[nm]
            reset(dst)
            self.overwrite.remove(nm)

        new = False
        if not nm in self:
            if not dst.exists():
                [f() for f in unique(prereq)]
                print(f'creating {dst}', end=': ')
                with codetiming.Timer():
                    self[nm] = dump(dst, fcn(**kwargs))
                if divide:
                    print(divider)
                new = True
            elif read:
                self[nm] = load(dst)
            else:
                self[nm] = None
        return self[nm], new
    # def get_dst(self, path, file, suffix='.parquet'):
    #     nm = file.split('_')[0]
    #     if path in [shr,usr]:
    #         dst = path/f"{self.amp_date}/{self.term_code}/{nm}/{self.stem}_{file}"
    #     else:
    #         dst = path/file
    #     return dst.with_suffix(suffix), nm


    # def get(self, fcn, path, file, suffix='.parquet', prereq=[], divide=True, read=True, **kwargs):
    #     dst, nm = self.get_dst(path, file, suffix)
    #     if nm in self.overwrite:
    #         del self[nm]
    #         reset(dst)
    #         self.overwrite.remove(nm)

    #     print(nm, dst)
    #     new = False
    #     if not nm in self:
    #         if not dst.exists():
    #             print('creating')
    #             [f() for f in unique(prereq)]
    #             print(f'creating {dst}', end=': ')
    #             with codetiming.Timer():
    #                 self[nm] = dump(dst, fcn(**kwargs))
    #             if divide:
    #                 print(divider)
    #             new = True
    #         elif read:
    #             print('file exists')
    #             self[nm] = load(dst)
    #         else:
    #             self[nm] = None
    #     return self[nm], new
##################################################
################# get drivetimes #################
##################################################
    def get_zips(self):
        def fcn():
            df = (
                pgeocode.Nominatim('us')._data  # get all zips
                .prep()
                .rename(columns={'postal_code':'zip'})
                .query("state_code.notnull() & state_code not in [None,'mh']")
            )
            return df
        df, new = self.get(fcn, geo, 'zips')
        return df


    def get_states(self):
        return set(self.get_zips()['state_code'])


    def get_drivetimes(self):
        def fcn():
            print()
            campus_coords = {
                's': '-98.215784,32.216217',
                'm': '-97.432975,32.582436',
                'w': '-97.172176,31.587908',
                'r': '-96.467920,30.642055',
                'l': '-96.983211,32.462267',
                }
            url = "https://www2.census.gov/geo/tiger/GENZ2020/shp/cb_2020_us_zcta520_500k.zip"
            gdf = gpd.read_file(url).prep().set_index('zcta5ce20')  # get all ZCTA https://www.census.gov/programs-surveys/geography/guidance/geo-areas/zctas.html
            pts = gdf.sample_points(size=10, method="uniform").explode().apply(lambda g: f"{g.x},{g.y}")  # sample 10 random points in each ZCTA
            M = []
            for k, v in campus_coords.items():
                def fcn1():
                    print()
                    L = []
                    i = 0
                    di = 200
                    I = pts.shape[0]
                    while i < I:
                        u = join([v, *pts.iloc[i:i+di]],';')
                        url = f"http://router.project-osrm.org/table/v1/driving/{u}?sources={0}&annotations=duration,distance&fallback_speed=1&fallback_coordinate=snapped"
                        response = requests.get(url).json()
                        L.append(np.squeeze(response['durations'])[1:]/60)
                        i += di
                        print(k,i,round(i/I*100))
                    df = pts.to_frame()[[]]
                    df[k] = np.concatenate(L)
                    return df
                df, new = self.get(fcn1, geo, f'drivetimes_{k}')
                M.append(df)
            D = pd.concat(M, axis=1).groupb(level=0).min().stack().reset_index().set_axis(['zip','camp_code','drivetime'], axis=1)

            # There are a few USPS zips without equivalent ZCTA, so we assign them drivetimes for the nearest
            Z = self.get_zips().merge(D.query("camp_code=='s'"), how='left').set_index('zip')
            mask = Z['drivetime'].isnull()  # zips without a ZTCA
            Z = Z[['latitude','longitude']]
            X = np.radians(Z[~mask])
            Y = np.radians(Z[mask])
            M = (
                pd.DataFrame(haversine_distances(X, Y), index=X.index, columns=Y.index) # haversine distance between pairs with and without ZCTA
                .idxmin()  # find nearest ZCTA
                .reset_index()
                .set_axis(['new_zip','zip'], axis=1)
                .prep()
                .merge(D)  # merge the drivetimes for that ZCTA
                .drop(columns='zip')
                .rename(columns={'new_zip':'zip'})
            )
            df = pd.concat([D,M], ignore_index=True)
            return df
        df, new = self.get(fcn, geo, 'drivetimes', prereq=self.get_zips)
        return df
########################################################
################# get term information #################
########################################################
    def get_terms(self, show=False):
        def fcn():
            qry = f"""
select
    stvterm_code as term_code
    ,replace(stvterm_desc, ' ', '') as term_desc
    ,stvterm_start_date as start_date
    ,stvterm_end_date as end_date
    ,stvterm_fa_proc_yr as fa_proc_yr
    ,stvterm_housing_start_date as housing_start_date
    ,stvterm_housing_end_date as housing_end_date
    ,sobptrm_census_date as census_date
from
    {catalog}saturnstvterm as A
inner join
    {catalog}saturnsobptrm as B
on
    stvterm_code = sobptrm_term_code
where
    sobptrm_ptrm_code='1'
"""
            df = run(qry, show).set_index('term_code')
            df['stable_date'] = df['census_date'].apply(lambda x: x+pd.Timedelta(days=7+4-x.weekday())) # Friday of week following census
            return df
        df, new = self.get(fcn, geo, 'terms')
        return df
#######################################################
############ process flags reports archive ############
#######################################################
    def get_spriden(self, show=False):
        # Get id-pidm crosswalk so we can replace id by pidm in flags below
        # GA's should not have permissions to run this because it can see pii
        if 'spriden' not in self:
            qry = f"""
            select distinct
                spriden_id as id
                ,spriden_pidm as pidm
                ,spriden_last_name as last_name
                ,spriden_first_name as first_name

            from
                {catalog}saturnspriden as A
            where
                spriden_change_ind is null
                and spriden_activity_date between '2000-09-01' and '2025-09-01'
                and spriden_id REGEXP '^[0-9]+'
            """
            self.spriden = run(qry, show)
        return self.spriden


    def process_flags(self, early_stop=3):
        # GA's should not have permissions to run this because it can see pii
        counter = 0
        divide = False
        for src in sorted(flags_raw.iterdir(), reverse=True):
            counter += 1
            if counter > early_stop:
                break
            a,b = src.name.lower().split('.')
            if b != 'xlsx' or 'melt' in a or 'admitted' not in a:
                print(a, 'SKIP')
                continue
            # Handles 2 naming conventions that were used at different times
            try:
                current_date = pd.to_datetime(a[:10].replace('_','-'))
                multi = True
            except:
                try:
                    current_date = pd.to_datetime(a[-6:])
                    multi = False
                except:
                    print(a, 'FAIL')
                    continue
            book = pd.ExcelFile(src, engine='openpyxl')
            # Again, handles the 2 different versions with different sheet names
            if multi:
                sheets = {sheet:sheet for sheet in book.sheet_names if sheet.isnumeric() and int(sheet) % 100 in [1,6,8]}
            else:
                sheets = {a[:6]: book.sheet_names[0]}
            for term_code, sheet in sheets.items():
                term_code, current_date, current_day, stable_date, term_desc, stem = self.get_dates(term_code, current_date)
                def fcn():
                    B = book.parse(sheet).prep()
                    # if not B.empty:
                    B['id'] = B['id'].to_numeric(errors='coerce')  # CRITICAL step - id is stored as string dtype to allow leading 0's, but this opens the door for serious data entry errors (ex: ID="D") which can have catastrophic effects downstream.  This step convert such issues to null, which get removed during the merge below.
                    mask = B['id'].isnull()
                    if mask.any():
                        print(f'WARNING: {mask.sum()} non-numeric ids')
                        B[mask].disp(5)
                    df = (
                        self.get_spriden()[['pidm','id']]
                        .assign(current_date=current_date)#, current_day=current_day)
                        .merge(B, on='id', how='inner')
                        .drop(columns=['id','last_name','first_name','mi','pref_fname','street1','street2','primary_phone','call_em_all','email'], errors='ignore')
                    )
                    return df
                if self.get(fcn, flags_prc, f'{term_code}/{stem}_flags', read=False, divide=False)[1]:
                    divide = True
                    counter = 0
                    dst = flags_prc/f'{term_code//100}_flags.parquet'
                    rm(dst)
        if divide:
            print(divider)
            self.combine_flags()


    def combine_flags(self):
        def fcn(year):
            L = [pd.read_parquet(src) for path in flags_prc.iterdir() if path.is_dir() and str(year) in path.stem for src in path.glob('*.parquet')]
            df = pd.concat(L, ignore_index=True).prep()
            del L
            for k in ['dob',*df.filter(like='date').columns]:  # convert date columns
                if k in df:
                    df[k] = pd.to_datetime(df[k], errors='coerce')
            return df
        for year in {int(x.stem)//100 for x in flags_prc.iterdir() if x.is_dir()}:
            self.get(fcn, flags_prc, f'{year}_flags', read=False, year=year)


    def get_flags(self):
        def fcn():
            df = (
                pd.read_parquet(flags_prc/f'{self.year}_flags.parquet')
                .query(f"current_date<='{self.amp_date}'")
                .sort_values(['pidm','current_date'])
                .drop_duplicates(subset=['pidm','term_code'], keep='last')
            )
            df.loc[~df['state'].isin(self.get_states()),'zip'] = pd.NA
            df['zip'] = df['zip'].str.split('-', expand=True)[0].str[:5].to_numeric(errors='coerce')
            return df
        df, new = self.get(fcn, shr, 'flags', prereq=[self.combine_flags, self.get_states])
        return df
##########################################
############ get student data ############
##########################################
    def newest(self, qry, prt, tbl='', sel=''):
        """The OPEIR daily snapshot experienced occasional glitched causing incomplete copies.
        Consequently, records can have a "gap" where they vanish then reappear later. This function fixes this issue."""
        prt = join(prt, ', ')
        if tbl == '':
            tbl = qry
        if sel != '':
            sel = ','+join(sel)

        qry = f"""
select
    {prt}
    ,current_date
    ,min(current_date) over (partition by {prt}) as first_date  --first date this record appeared
    ,max(current_date) over (partition by {prt}) as last_date  --last date this record appeared
    ,least(timestamp('{self.amp_date}'), max(current_date) over ()) as amp_date  --clip amp_date to last date of ANY record
    --,least(greatest(timestamp('{self.amp_date}'), min(current_date) over ()), max(current_date) over ()) as amp_date  --clip amp_date between first & last date of ANY record
from
    {qry.strip()}
qualify
    amp_date between first_date and dateadd(last_date, 5)  -- keep records where amp_date falls between that record's first & last appearance (+5 days in case we are in a gap right now - the record will reappear but has not yet done so
"""

        qry = f"""
select
    *
from {subqry(qry)}
where
    --current_date <= '{self.amp_date}'  -- discard records after amp_date
    current_date <= amp_date  -- discard records after amp_date
qualify
    row_number() over (partition by {prt} order by current_date desc) = 1  -- keep most recent remaining record
"""

        qry = f"""
select distinct
    pidm
    ,first_date
    ,current_date
    ,amp_date
    ,last_date
    ,{get_desc('term_code')}
    ,{get_desc('levl_code')}
    ,{get_desc('styp_code')}
    ,{get_desc('camp_code')}
    ,{get_desc('coll_code_1')}
    ,{get_desc('dept_code')}
    ,{get_desc('majr_code_1')}
    --,gender
    ,spbpers_sex as gender
    ,birth_date
    ,{get_desc('spbpers_lgcy_code')}
    ,gorvisa_vtyp_code is not null as international
    ,gorvisa_natn_code_issue as natn_code, (select stvnatn_nation from {catalog}saturnstvnatn where gorvisa_natn_code_issue = stvnatn_nation limit 1) as natn_desc
    ,{coalesce('race_asian')}
    ,{coalesce('race_black')}
    ,coalesce(spbpers_ethn_cde=2, False) as race_hispanic
    ,{coalesce('race_native')}
    ,{coalesce('race_pacific')}
    ,{coalesce('race_white')}
    {indent(sel)}
from {subqry(qry)} as A

left join
    {tbl}
using
    ({prt}, current_date)

left join
    {catalog}spbpers_v
on
    pidm = spbpers_pidm

left join (
    select
        *
    from
        {catalog}generalgorvisa
    qualify
        row_number() over (partition by gorvisa_pidm order by gorvisa_seq_no desc) = 1
    )
on
    pidm = gorvisa_pidm

left join (
    select
        gorprac_pidm
        ,max(gorprac_race_cde='AS') as race_asian
        ,max(gorprac_race_cde='BL') as race_black
        ,max(gorprac_race_cde='IN') as race_native
        ,max(gorprac_race_cde='HA') as race_pacific
        ,max(gorprac_race_cde='WH') as race_white
    from
        {catalog}generalgorprac
    group by
        gorprac_pidm
    )
on
    pidm = gorprac_pidm
"""
        return qry


    def get_registrations(self, show=False):
        def fcn():
            # tbl = f'dev.opeir.opeirregistration_{self.term_desc}'
            tbl = f'dev.opeir.registration_{self.term_desc}_v'
            if spark.catalog.tableExists(tbl):
                qry = self.newest(
                    tbl = tbl,
                    prt = ['pidm','crn'],
                    sel = ['credit_hr as count', 'subj_code || crse_numb as crse_code'],
                    qry = f"""
    {tbl} as A
where
    credit_hr > 0
    and subj_code <> 'INST'""")
                A = run(qry, show)
                B = A.groupb(['pidm','crse_code'])['count'].sum().reset_index('crse_code')
                C = B.groupb('pidm')[['count']].sum().assign(crse_code='_tot_sch')
                D = C.copy()
                B['count'] = 1
                D['count'] = 1
                D['crse_code'] = '_headcnt'
                E = D.copy()
                E['crse_code'] = '_proba'
                F = pd.concat([B,C,D,E])
                G = A.drop(columns=['count','crse_code']).sort_values('current_date').groupb('pidm', sort=True).last()
                df = G.join(F).sort_index()
            else:
                # placeholder if table DNE
                df = pd.DataFrame(columns=['pidm','levl_code','styp_code','count','crse_code']).set_index('pidm')
            df.get_duplicates(['pidm','crse_code'])
            return df
        df, new = self.get(fcn, shr, 'registrations')
        return df


    def get_admissions(self, show=False):
        def fcn():
            def fcn1(season):
                # tbl = f'dev.opeir.opeiradmissions_{season}{self.year}'
                tbl = f'dev.opeir.admissions_{season}{self.year}_v'
                return self.newest(
                    tbl = tbl,
                    prt = ['pidm', 'appl_no'],
                    sel = [
                        'appl_no',
                        get_desc('apst_code'),
                        get_desc('apdc_code'),
                        get_desc('admt_code'),
                        get_desc('saradap_resd_code'),
                        'hs_percentile',
                        # 'sbgi_code',
                    ],
                    qry = f"""
    {tbl} as A
inner join
    {catalog}saturnstvapdc as B
on
    apdc_code = stvapdc_code
where
    stvapdc_inst_acc_ind is not null  --only accepted""")
            L = [run(fcn1(season), show) for season in ['summer','fall']]
            df = pd.concat(L, ignore_index=True)
            mask = df.groupb('pidm').apply(lambda x: (x['levl_code']=='ug').all() & (x['appl_no']==x['appl_no'].max()))
            df = df.loc[mask]
            df.get_duplicates()
            return df
        df, new = self.get(fcn, shr, 'admissions')
        return df


    def get_students(self):
        def fcn():
            df = (
                self.get_admissions()
                .merge(self.get_flags(), on=['pidm','term_code'], how='left', suffixes=['', '_drop'])
                .merge(self.get_drivetimes(), on=['zip','camp_code'], how='left', suffixes=['', '_drop'])
                .prep()
            )
            for c in ['gap_score','t_gap_score','ftic_gap_score']:
                if c not in df:
                    df[c] = pd.NA
            df['gap_score'] = np.where(
                df['styp_code']=='n',
                df['ftic_gap_score'].combine_first(df['t_gap_score']).combine_first(df['gap_score']),
                df['t_gap_score'].combine_first(df['ftic_gap_score']).combine_first(df['gap_score']))
            df['oriented'] = df['orientation_hold_exists'].isnull() | df['orien_sess'].notnull() | df['registered'].notnull()
            df['verified'] = df['selected_for_ver'].isnull() | df['ver_complete'].notnull()
            df['sat10_total_score'] = (36-9) / (1600-590) * (df['sat10_total_score']-590) + 9
            df['act_equiv'] = df[['act_new_comp_score','sat10_total_score']].max(axis=1)
            df['eager'] = (dt(self.stable_date) - df['first_date']).dt.days
            df['age'] = (dt(self.stable_date) - df['birth_date']).dt.days / 365
            for k in ['reading', 'writing', 'math']:
                df[f'tsi_{k}'] = ~df[k].isin(['not college ready', 'retest required', pd.NA, None, np.nan])
            repl = {'ae':0, 'n1':1, 'n2':2, 'n3':3, 'n4':4, 'r1':1, 'r2':2, 'r3':3, 'r4':4}
            df['hs_qrtl'] = pd.cut(df['hs_pctl'], bins=[-1,25,50,75,90,101], labels=[4,3,2,1,0], right=False).combine_first(df['apdc_code'].map(repl))
            df['lgcy'] = ~df['lgcy_code'].isin(['o',pd.NA,None,np.nan])
            df['resd'] = df['resd_code'] == 'r'
            for k in ['waiver_desc','fafsa_app','ssb_last_accessed','finaid_accepted','schlship_app']:
                df[k.split('_')[0]] = df[k].notnull()
            df.get_duplicates()
            M = df.get_incoming().query('current_date_drop.isnull()')
            if not M.empty:
                print(M.shape)
                display(M.sort_values('first_date').head(50))
                if M.shape[0] > 10:
                    raise Exception('Too many unmatched students')
                # display(M[['pidm','amp_date','term_code','camp_code','first_date','last_date','current_date','current_date_drop']].sort_values('first_date').head(50))
                # display(M['first_date'].value_counts())
                # M[['pidm','amp_date','term_code','first_date','last_date','current_date','current_date_drop']].disp(3)
                # M['first_date'].value_counts().disp(-1)
            return df.loc[:, ~df.columns.str.contains('_drop')].prep().set_index('pidm').sort_index()
        df, new = self.get(fcn, shr, 'students', prereq=[self.get_drivetimes, self.get_flags, self.get_admissions])
        return df
###########################################################################################################################
###########################################################################################################################
@dataclasses.dataclass
class Model(Data):
    is_learner: bool = True
    submodels: any = 'styp_desc'
    aggregates: list = None
    features: dict = None
    flaml: dict = None


    def __post_init__(self):
        super().__post_init__()
        self.aggregates = unique('crse_code', difference(self.aggregates, self.submodels))
        self.flaml = self.flaml if self.flaml is not None else dict()
        self.stable = Data(amp_date=self.stable_date, **{k:self[k] for k in ['term_code','overwrite']})


    def get_imputed(self):
        def fcn():
            def fcn1(subpop, df):
                X = df.fillna(self.features)[self.features.keys()].prep(category=True)
                imp = mf.ImputationKernel(X.reset_index(drop=True), random_state=self.seed)
                imp.mice(10)
                XX = imp.complete_data().set_index(X.index)
                XX.get_missing()
                return XX
            return {subpop: fcn1(subpop, df) for subpop, df in self.get_students().get_incoming().groupb(self.submodels)}
        dct, new = self.get(fcn, usr, 'imputed', '.pkl', prereq=self.get_students)
        return dct


    def set_crse_code(self, crse_code='_headcnt'):
        """added to allow a single Model to run many crse_codes (fix out of memory issues)"""
        self.crse_code = crse_code
        J = {'prepared','learners','predictions','forecasts'}
        self.overwrite |= {f"{j}_{crse_code}" for j in J.intersection(self.overwrite)}
        for k in {k for j in J for k in self.__dict__.keys() if j in k}:
            del self[k]


    def get_prepared(self):
        def fcn():
            g = lambda obj=self, crse=self.crse_code, nm=None: obj.get_registrations().query(f"crse_code=='{crse}'")['count'].rename(nm if nm is not None else crse)
            Z = {subpop: X
                 .join(g(crse='_tot_sch'))
                 .join(g(nm='current').astype('boolean'))
                 .join(g(nm='stable', obj=self.stable).astype('boolean'))
                .fillna({'_tot_sch':0, 'current':False, 'stable':False})
                for subpop, X in self.get_imputed().items()}
            {subpop: z.get_duplicates() for subpop, z in Z.items()}
            return {subpop: [X, X.pop('stable')] for subpop, X in Z.items()}
        dct, new = self.get(fcn, usr, f'prepared_{self.crse_code}', '.pkl', prereq=[self.get_imputed,self.get_registrations,self.stable.get_registrations])
        return dct


    def get_learners(self):
        """train model - biggest bottleneck - can we run multiple (crse_code, year) in parallel?"""
        def fcn():
            def fcn1(subpop, Z):
                log_file = self.get_dst(usr, f'learners_{self.crse_code}_{subpop}', suffix='.log')
                reset(log_file)
                dct = {
                    'time_budget':30,
                    # 'max_iter': 100,
                    'task':'classification',
                    'log_file_name': log_file,
                    'log_type': 'all',
                    'log_training_metric':True,
                    'verbose':0,
                    'metric':custom_log_loss,
                    'eval_method':'cv',
                    'n_splits':3,
                    'seed':self.seed,
                    # 'early_stop':True,
                    'estimator_list': ['xgboost'],
                } | self.flaml
                learner = fl.AutoML(**dct)
                learner.fit(*Z, **dct)
                return learner
            return {subpop: fcn1(subpop, Z) for subpop, Z in self.get_prepared().items()}
        if self.is_learner:
            clf, new = self.get(fcn, usr, f'learners_{self.crse_code}', '.pkl', prereq=[self.get_prepared,self.get_enrollments])
        else:
            clf = None
        return clf


    def get_predictions(self, models=None):
        """Generates predictions when models dict is passed
        Otherwise, reads predictions from self or file, throwing error if neither exists (NoneType' object has no attribute 'items)
        run_prediction uses this to quickly load existing predictions and only trigger model training on error"""
        def fcn():
            L = [
                learner.prediction(*self.get_prepared()[subpop], cross=self.term_code==model.term_code)
                .assign(
                    crse_code=self.crse_code,
                    prediction_term_code=self.term_code,
                    model_term_code=model.term_code,
                    mlt=model.get_enrollments()['crse_code'].loc['_headcnt'].loc[subpop]['mlt']
                ).reset_index().set_index(['crse_code','prediction_term_code','model_term_code','pidm'])
                for term_code, model in models.items() if model.is_learner
                for subpop, learner in model.get_learners().items()]
            return pd.concat(L).sort_index() if len(L)>0 else pd.DataFrame()
        df, new = self.get(fcn, usr, f'predictions_{self.crse_code}', prereq=self.get_prepared)
        return df


    def get_enrollments(self):
        def fcn():
            def fcn1(agg):
                grp = unique('crse_code', self.submodels, agg)
                g = lambda X, Y: X.join(Y, rsuffix='_y').get_incoming().groupb(grp)['count'].sum()  # get stuff from Y that is not in X
                df = pd.DataFrame({
                    'current':g(self.get_students(), self.stable.get_registrations()),
                    'actual' :g(self.stable.get_registrations(), self.stable.get_students()),
                    }).fillna(0)
                df['mlt'] = df['actual'] / df['current']
                return df.sort_index()
            return {agg: fcn1(agg) for agg in self.aggregates}
        dct, new = self.get(fcn, usr, f'enrollments', '.pkl', prereq=[self.get_students,self.stable.get_students,self.stable.get_registrations])
        return dct


    def get_forecasts(self):
        def fcn():
            dct = {'prediction':'sum'}
            err = [
                'cv_score',
                # 'mse',
                # 'mae',
                # 'log_loss',
                ]
            if self.is_learner:
                dct |= {k:'mean' for k in err}
            Z = self.get_students().join(self.get_predictions(), how='inner').copy()
            Z['prediction'] = Z['prediction'] * Z['mlt']
            def fcn1(agg):
                df = Z.groupb(unique('crse_code',self.submodels,agg,'prediction_term_code','model_term_code')).agg(dct)
                df['prediction'] = df['prediction'].round()
                if self.is_learner:
                    df = df.join(self.get_enrollments()[agg]['actual']).fillna(0)
                    df['error'] = df['prediction'] - df['actual']
                    df['error_pct'] = df['error'] / df['actual'] * 100
                    df[err] *= 100
                    df = df[['prediction','actual','error','error_pct',*err]]
                return df.prep()
            return {agg: fcn1(agg) for agg in self.aggregates}
        dct, new = self.get(fcn, usr, f'forecasts_{self.crse_code}', '.pkl')
        return dct
###########################################################################################################################
###########################################################################################################################
@dataclasses.dataclass
class AMP(Model):
    model_term_codes: tuple = (202408,202508)
    crse_codes: tuple = '_headcnt'

    def __post_init__(self):
        print(self.amp_date)
        self.amp_date = dt(self.amp_date, now)
        self.amp_date -= pd.Timedelta(days=(self.amp_date.weekday()-2)%7)  # move to preceeding Wednesday (flags release)
        super().__post_init__()
        self.crse_codes = unique(self.crse_codes, sort=True)
        self.model_term_codes = unique(self.term_code, self.model_term_codes, sort=True, reverse=True)
        kwargs = {k: copy.deepcopy(self[k]) for k in ['features','submodels','aggregates','flaml','overwrite']}
        self.models = {term_code: Model(term_code=term_code, is_learner=term_code<self.term_code, amp_date=self.amp_date.replace(year=term_code//100), **kwargs) for term_code in self.model_term_codes}
    

    def get_students(self):
        nm = 'students'
        if nm not in self:
            self[nm] = pd.concat([model.get_students().get_incoming().assign(prediction_term_code=term_code) for term_code, model in self.models.items()]).prep().set_index('prediction_term_code', append=True).sort_index()
        return self[nm]


    def run_predictions(self, crse_code='_headcnt'):
        [model.set_crse_code(crse_code) for model in self.models.values()]
        try:
            L = [model.get_predictions() for model in self.models.values()]
            print(f'{crse_code} predictions read from file')
        except:
            for model in self.models.values():
                model.get_prepared()
                if crse_code == '_proba':
                    model[f'prepared_{crse_code}'] = {subpop: [X.drop([*races,'gender','international'], errors='ignore'), y] for subpop, [X,y] in model.get_prepared().items()}
                model.get_learners()
            L = [model.get_predictions(self.models) for model in self.models.values()]
        rm(usr/'prepared')
        return pd.concat(L).sort_index()


    def get_results(self):
        def fcn():
            print()
            results = dict()
            for crse_code in self.crse_codes:
                if crse_code != '_proba':
                    self.run_predictions(crse_code)
                    for model in self.models.values():
                        dct = {'predictions': model.get_predictions(), **model.get_forecasts()}
                        for key, val in dct.items():
                            results.setdefault(key, []).append(val)
            return {agg: pd.concat(L).sort_index(ascending=['term_code' not in k for k in L[0].index.names]) for agg, L in results.items()}
        dct, new = self.get(fcn, usr, 'results', '.pkl')
        return dct


    def get_AMP(self):
        instructions = pd.DataFrame({"":[
            f"Admitted Matriculation Projections (AMP) for {self.amp_date}",
            '',
            f'''Executive Summary''',
            f'''AMP is a predictive model designed to forecast the incoming (not continuing) Fall cohort to help leaders proactively allocate resources (instructors, sections, labs, etc) in advance.''',
            f'''It is intended to supplement, not replace, the knowledge, expertise, and insights developed by institutional leaders through years of experience.''',
            f'''Like all AI/ML models (and humans), AMP is fallible and should be used with caution.''',
            f'''AMP learns exclusively from historical data captured in EM’s Flags reports and IDA’s daily admissions and registration snapshots.''',
            f'''It cannot account for factors not present in these datasets, including curriculum changes, policy shifts, structural changes, demographic variation, changes in oversight, etc.''',
            f'''''',
            f'''AMP provides both “Summary” and “Details” files. For most users, rows in the “Summary” file with model_term_code = 202408 will suffice.''',
            f'''Because AMP’s accuracy varies across courses, the “Details” file includes historical error analyses to help users assess the reliability of each forecast (details below).''',
            '',
            f'''As widely requested, AMP includes predictions for the Fall 2025 cohort in Ft. Worth, despite having no prior Ft. Worth FTIC example to learn from.''',
            f'''These are a good-faith effort to offer my best data-driven insights, but due to the lack of training data,''',
            f'''they are inherently more speculative and should be treated with lower confidence (details below).''',
            f'''''',
            f'''Definitions''',
            f'''crse_code = course code (_headcnt = total headcount)''',
            f'''styp_desc = student type; returning = re-enrolling after a previous attempt (not continuing)''',
            f'''prediction_term_code = cohort being forecast''',
            f'''model_term_code = cohort used to train AMP''',
            f'''prediction = forecast headcount''',
            f'''*actual = true headcount''',
            f'''*error = prediction - actual''',
            f'''*error_pct = error / actual * 100''',
            f'''*cv_score = average validation log-loss from 3-fold cross-validation''',
            f'''*=appears only in “Details” & not available for 2025 (since actuals are not yet known)''',
            '',
            f'''Methodology''',
            f'''AMP uses XGBoost, a machine learning algorithm, to forecast the number, characteristics, and likely course enrollments of incoming Fall students.''',
            f'''Predictions are based on application and pre-semester engagement (orientation, course registration, financial aid, etc.) from EM’s Flags and IDA’s daily snapshots.''',
            f'''For each student admitted for Fall 2025, AMP identifies similar students from past Fall cohorts, analyzes their course enrollments (if any),''',
            f'''learns relevant patterns, then forecasts Fall 2025 course enrollment for the admitted student in question.''',
            f'''More precisely, for each (incoming student, course)-pair, AMP assigns a probability whether that student will be enrolled in that course on the Friday after Census.''',
            f'''These (student, course)-level probabilities are then aggregated in many different ways to forecast headcounts for courses, campuss, majors, colleges, TSI statuses, etc.''',
            f'''These appear on different sheets in this workbook.''',
            f'''''',
            f'''Since admissions and registration data evolve through the spring and summer, AMP is trained only on data available as of the same date in previous years.''',
            f'''AMP's forecast for Ft. Worth's Fall 2025 cohort are necessarily based on previous Stephenville cohorts since no Ft. Worth FTIC's existed on this date.''',
            f'''Suppose AMP predicts, "Based on similar FTIC's in Stephenville in 2024, I predict Alice has a 75% probability to matriculate in Fall 2025".''',
            f'''If Alice is applying to Ft. Worth, then 0.75 is added to Ft. Worth's forecast.''',
            f'''However, AMP can not yet understand how to adjust its 75% projection to reflect how Ft. Worth FTIC's behave differently than Stephenville FTIC since there are no Ft. Worth FTIC's to learn from.''',
            f'''Though not ideal, this is the best idea we've found to forecast Ft. Worth FTIC in the absence of valid training examples.''',
            f'',
            f'''AMP is trained separately using the Fall 2024, 2023, and 2022 cohorts.''',
            f'''Most users should focus on prediction_term_code = 202508 and model_term_code = 202408, as Fall 2025 is likely to resemble Fall 2024 more closely than Fall 2023 & Fall 2022.''',
            f'''Users with domain expertise may choose to incorporate older cohorts (e.g., weighted average of model_term_codes 2024, 2023, & 2022) if they believe those terms are similarly relevant.''',
            f'',
            f'''Rows for prediction_term_code < 202508 appear only in the “Details” file and include retrospective "predictions" and actual outcomes.''',
            f'''This allows users to assess AMP's ability to forecast each individual course and calibrate their confidence accordingly.''',
            f'',
            f'''Predictions for small values are less reliable than for large numbers (Central Limit Theorem).''',
            f'',
            f'''AMP only models students who have already applied and been admitted (eager).''',
            f'''However, more students will apply between now and start of term, especially transfer & returning (lagging).''',
            f'''AMP generates forecasts based on eager students then inflates using the eager-lagging ratio from that model_term_code.''',
            f'''This assumes the eager-lagging behavior will be approximately the same this year.''',
            f'''While this assumption cannot be verified in advance, we must make SOME assumption. This one has proven sufficiently accurate in past cycles.''',
            f'',
            f'''Dr. Scott Cook is eager to provide as much additional detail as the user desires: scook@tarleton.edu.''',
            f'''source code: https://github.com/drscook/admitted_matriculation_predictor'''
        ]}).set_index("")

        def format_xlsx(sheet):
            from openpyxl.styles import Alignment
            sheet.auto_filter.ref = sheet.dimensions
            for cell in sheet[1]:
                cell.alignment = Alignment(horizontal="left")
            for column in sheet.columns:
                width = 1+max(len(str(cell.value))+3*(i==0) for i, cell in enumerate(column))
                sheet.column_dimensions[column[0].column_letter].width = width
            sheet.freeze_panes = "A2"

        def write_xlsx(rpt, sheets):
            file = f'AMP_{rpt}'
            dst = self.get_dst(usr, file, '.xlsx')
            if file in self.overwrite or not dst.exists():
                print(f'creating {dst.name}')
                src = "report.xlsx"
                reset(src)
                with pd.ExcelWriter(src, mode="w", engine="openpyxl") as writer:
                    for sheet_name, df in sheets.items():
                        df.reset_index().to_excel(writer, sheet_name=sheet_name, index=False)
                        format_xlsx(writer.sheets[sheet_name])
                reset(dst)
                shutil.copy(src, dst)
                rm(src)
        
        self.get_results()

        get_key = lambda df: df.reset_index().assign(key=lambda X: X['pidm']*1000000+X['prediction_term_code']).prep().set_index('key').sort_index()
        R = self.get(lambda: get_key(self.get_results()['predictions']), usr, 'AMP_predictions')[0]
        S = self.get(lambda: get_key(self.get_students()), usr, 'AMP_students')[0]

        for rpt, fcn in {
            'summary':lambda df: df.query(f"prediction_term_code=={self.term_code}").iloc[:,:1],
            'details':lambda df: df,
            }.items():
            sheets = {'instructions':instructions} | {agg: fcn(df).round().prep() for agg, df in self.get_results().items() if agg!='predictions'}
            write_xlsx(rpt, sheets)

        A = self.get_spriden().set_index('pidm')
        B = self.get_students().loc[:,'term_code':]
        C = self.run_predictions('_proba').loc[('_proba',self.term_code,self.term_code-100),'prediction']
        Z = A.join(B, how='inner').join(C, how='inner').prep().sort_values('prediction', ascending=False)
        sheets = {'instructions':instructions} | {subpop: df for subpop, df in Z.groupb(self.submodels)}
        write_xlsx('em', sheets)

    # def get_dashboard(self):
    #     # self.get_results()
    #     get_key = lambda df: df.reset_index().assign(key=lambda X: X['pidm']*1000000+X['prediction_term_code']).prep().set_index('key').sort_index()
    #     self.get(lambda: get_key(self.get_results()['predictions']), usr, 'AMP_predictions')[0]
    #     self.get(lambda: get_key(self.get_students()), usr, 'AMP_attributes')[0]


self = AMP(
    term_code = 202508,
    model_term_codes = [202208, 202308, 202408],
    flaml = {
        'time_budget': 60,
    },
    crse_codes = {
    '_headcnt',
    '_proba',    
    'agec2317',
    'ansc1119',
    'ansc1319',
    'anth2302',
    'anth2351',
    'arts1301',
    'arts1303',
    'arts1304',
    'arts3331',
    'biol1305',
    'biol1406',
    'biol1407',
    'biol2401',
    'biol2402',
    'busi1301',
    'busi1307',
    'chem1111',
    'chem1112',
    'chem1302',
    'chem1311',
    'chem1312',
    'chem1407',
    'chem1409',
    'cnst1301',
    'comm1311',
    'comm1315',
    'comm2302',
    'crij1301',
    'dram1310',
    'dram2361',
    'easc2310',
    'econ1301',
    'econ2301',
    'educ1301',
    'engl1301',
    'engl1302',
    'engl2307',
    'engl2320',
    'engl2321',
    'engl2326',
    'engl2340',
    'engl2350',
    'engl2360',
    'engl2362',
    'engl2364',
    'engl2366',
    'engl2368',
    'engr1211',
    'engr2303',
    'envs1302',
    'fina1360',
    'geog1303',
    'geog1320',
    'geog1451',
    'geog2301',
    'geol1403',
    'geol1404',
    'geol1407',
    'geol1408',
    'govt2305',
    'govt2306',
    'hist1301',
    'hist1302',
    'hist2321',
    'hist2322',
    'huma1315',
    'kine1301',
    'kine1338',
    'kine2315',
    'math1314',
    'math1316',
    'math1324',
    'math1325',
    'math1332',
    'math1342',
    'math1352',
    'math2412',
    'math2413',
    'musi1303',
    'musi1310',
    'musi1311',
    'musi2350',
    'musi3325',
    'phil1301',
    'phil1304',
    'phil2303',
    'phil3301',
    'phys1302',
    'phys1401',
    'phys1402',
    'phys1403',
    'phys1411',
    'phys2425',
    'phys2426',
    'psyc2301',
    'psyc3303',
    'psyc3307',
    'soci1301',
    'soci1306',
    'soci2303',
    'univ0010',
    'univ0200',
    'univ0204',
    'univ0301',
    'univ0314',
    'univ0324',
    'univ0332',
    'univ0342',
    },
    aggregates = [
        'styp_desc',
        'camp_desc',
        'coll_desc',
        'dept_desc',
        'majr_desc',
        'hs_qrtl',
        'tsi_math',
        'tsi_reading',
        'tsi_writing',
        # 'gender',
        # *races,
        # 'international',
        'resd_desc',
        'oriented',
        'waiver',
        'lgcy',
    ],
    features = {
        'act_equiv':pd.NA,
        'age':pd.NA,
        'camp_desc':'stephenville',
        'drivetime':pd.NA,
        'eager':pd.NA,
        'fafsa': False,
        'finaid': False,
        'gap_score':0,
        'gender':pd.NA,
        'hs_qrtl':pd.NA,
        'international':False,
        'lgcy':False,
        'oriented':False,
        **{r: False for r in races},
        'schlship':False,
        'ssb':False,
        'tsi_math':False,
        'tsi_reading':False,
        'tsi_writing':False,
        'verified':False,
        'waiver':False,
    },
    overwrite = set({
        # 'registrations',
        # 'admissions',
        # 'flags',
        # 'students',
        # 'imputed',
        # 'prepared',
        # 'enrollments',
        # 'learners',
        # 'predictions',
        # 'forecasts',
        # 'results',
        # 'AMP_details',
        # 'AMP_summary',
        # 'AMP_em',
        # 'AMP_predictions',
        # 'AMP_attributes',
    }),
    # amp_date = '2025-05-14',
)

# self.process_flags()
self.get_AMP()

In [0]:
### old useful code - do not delete


# import pandas as pd
# A = run(f'saturnstvcamp').T
# repl = dict(zip(A.iloc[0], A.iloc[1]))
# df = pd.read_parquet('/Volumes/aiml/amp/amp_files/2025/geo/drivetimes.parquet').pivot_table(index='zip', columns='camp_code', values='drivetime').rename(columns=repl).rename_axis(columns=None).reset_index()
# df.to_csv('/Volumes/aiml/amp/amp_files/2025/geo/drivetimes.csv', index=False)



    # def get_reports(self):
    #     self.get_em_report()

    #     instructions = pd.DataFrame({"":[
    #         f"Admitted Matriculation Projections (AMP) for {self.amp_date}",
    #         '',
    #         f'''Executive Summary''',
    #         f'''AMP is a predictive model designed to forecasr the incoming (not continuing) Fall cohort to help leaders proactively allocate resources (instructors, sections, labs, etc) in advance.''',
    #         f'''It is intended to supplement, not replace, the knowledge, expertise, and insights developed by institutional leaders over years of experience.''',
    #         f'''Like all AI/ML models (and humans), AMP is fallible and should be used with caution.''',
    #         f'''AMP learns exclusively from historical data captured in EM’s Flags reports and IDA’s daily admissions and registration snapshots.''',
    #         f'''It cannot account for factors not present in these datasets, including curriculum changes, policy shifts, structural changes, demographic variation, changes in oversight, etc.''',
    #         f'''''',
    #         f'''AMP provides both “Summary” and “Details” files. For most users, rows in the “Summary” file with learner_year = 2024 will suffice.''',
    #         f'''Because AMP’s accuracy varies across courses, the “Details” file includes historical error analyses to help users assess the reliability of each forecast.''',
    #         '',
    #         f'''As widely requested, AMP includes predictions for the 2025 cohort in Ft. Worth, despite having no prior Ft. Worth FTIC example to learn from.''',
    #         f'''These are a good-faith effort to offer my best data-driven insights, but due to the lack of training data,''',
    #         f'''they are inherently more speculative and should be treated with lower confidence (details below).''',
    #         f'''''',
    #         f'''Definitions''',
    #         f'''crse_code = course code (_headcnt = total headcount)''',
    #         f'''styp_desc = student type; returning = re-enrolling after a previous attempt (not continuing)''',
    #         f'''prediction_year = cohort being forecast''',
    #         f'''model_year = cohort used to train AMP''',
    #         f'''prediction = forecast headcount''',
    #         f'''*actual = true headcount''',
    #         f'''*error = prediction - actual''',
    #         f'''*error_pct = error / actual * 100''',
    #         f'''*cv_score = average validation log-loss from 3-fold cross-validation''',
    #         f'''*=appears only in “Details” & not available for 2025 (since actuals are not yet known)''',
    #         '',
    #         f'''Methodology''',
    #         f'''AMP uses XGBoost, a machine learning algorithm, to forecast the number, characteristics, and likely course enrollments of incoming Fall students.''',
    #         f'''Predictions are based on application and pre-semester engagement (orientation, course registration, financial aid, etc.) from EM’s Flags and IDA’s daily snapshots.''',
    #         f'''For each student admitted for Fall 2025, AMP identifies similar students from past Fall admits, analyzes their course enrollments (if any),''',
    #         f'''learns relevant patterns, then forecasts Fall 2025 course enrollment for the admitted student in question.''',
    #         f'''More precisely, for each (incoming student, course)-pair, AMP assigns a probability whether that student will be enrolled in that course on the Friday after Census.''',
    #         f'''These (student, course)-level predictions are then aggregated in many different ways to forecast headcounts for courses, campuss, majors, colleges, TSI statuses, etc.''',
    #         f'''These appear on different sheets in this workbook.''',
    #         f'''''',
    #         f'''Since admissions and registration data evolve through the spring and summer, AMP is trained only on data available as of the same date in previous years.''',
    #         f'''AMP's forecast for Ft. Worth's 2025 cohort are necessarily based on previous Stephenville cohorts since no Ft. Worth FTIC's existed on this date.''',
    #         f'''Suppose AMP predicts, "Based on similar FTIC's in Stephenville in 2024, I predict Alice has a 75% probability to matriculate in Fall 2025".''',
    #         f'''If Alice is applying to Ft. Worth, then 0.75 is added to Ft. Worth's forecast.''',
    #         f'''However, AMP can not yet understand how to adjust its 75% projection to reflect how Ft. Worth FTIC's behave differently than Stephenville FTIC since there are no Ft. Worth FTIC's to learn from.''',
    #         f'''Though not ideal, this appears to be the most reasonable mechanism to forecast Ft. Worth FTIC in the absence of training examples.''',
    #         '',
    #         f'''AMP is trained separately using the 2024, 2023, and 2022 cohorts.''',
    #         f'''Most users should focus on prediction_year = 2025 and learner_year = 2024, as 2025 is likely to resemble 2024 more closely than 2023 & 2022.''',
    #         f'''Users with domain expertise may choose to incorporate older cohorts (e.g., weighted average of learner_years 2024, 2023, & 2022) if they believe those years are similarly relevant.''',
    #         '',
    #         f'''Rows for prediction_year < 2025 appear only in the “Details” file and include retrospective "predictions" and actual outcomes.''',
    #         f'''This allows users to assess AMP's ability to forecast each individual course and calibrate their confidence accordingly.''',
    #         '',
    #         f'''Predictions for small values are less reliable than for large numbers (central limit theorem).''',
    #         '',
    #         f'''AMP only models students who have already applied and been admitted (eager).''',
    #         f'''However, more students will apply between now and start of term, especially transfer & returning (lagging).''',
    #         f'''AMP generates forecasts based on eager students then inflates using the eager-lagging ratio from the learner_year.''',
    #         f'''This assumes the eager-lagging behavior will be approximately the same this year. approach has proven sufficiently accurate in prior years.''',
    #         f'''While this assumption cannot be verified in advance, some assumption is needed. This one has proven sufficiently accurate in past cycles.''',
    #         '',
    #         f'''Dr. Scott Cook is eager to provide as much additional detail on AMP's workings as the user desires - email scook@tarleton.edu.''',
    #         f'''source code: https://github.com/drscook/admitted_matriculation_predictor'''
    #     ]})

    #     def format_xlsx(sheet):
    #         from openpyxl.styles import Alignment
    #         sheet.auto_filter.ref = sheet.dimensions
    #         for cell in sheet[1]:
    #             cell.alignment = Alignment(horizontal="left")
    #         for j, column in enumerate(sheet.columns):
    #             width = max(len(str(cell.value))+3*(i==0) for i, cell in enumerate(column))
    #             sheet.column_dimensions[chr(65+j)].width = width
    #         sheet.freeze_panes = "A2"

    #     def fcn_details(df):
    #         return df

    #     def fcn_summary(df):
    #         return df.query(f"prediction_year==prediction_year.max()").iloc[:,:1]

    #     for nm, fcn in {'details':fcn_details, 'summary':fcn_summary}.items():
    #         src = "report.xlsx"
    #         reset(src)
    #         with pd.ExcelWriter(src, mode="w", engine="openpyxl") as writer:
    #             instructions.to_excel(writer, sheet_name='instructions', index=False)
    #             for key, df in self.get_results().items():
    #                 fcn(df).reset_index().round().prep().to_excel(writer, sheet_name=key, index=False)
    #                 format_xlsx(writer.sheets[key])
    #         dst = data/f"reports/{self.term_code}/{self.amp_date}/AMP_{self.amp_date}_{nm}.xlsx"
    #         reset(dst)
    #         shutil.copy(src, dst)
    #         rm(src)


    # def get_reports_em(self):
    #     def fcn():
    #         df = self.get_spriden().set_index('pidm').join(self.models[self.term_code].get_predictions(), how='inner').loc[:,:'prediction']
    #         return df
    #     df, new = self.get(fcn, 'reports_amp_em', self.run, suffix='.csv')
    #     # df, new = self.get(fcn, data/f"reports/{self.term_code}/{self.amp_date}/AMP_{self.amp_date}_EM", self.run, suffix='.csv')
    #     return df


## useful old code
#     def get_registrations(self, overwrite=False, show=False):
#         def fcn():
#             dct = {
#                 'sfrstcr_pidm':'pidm',
#                 'ssbsect_term_code':'term_code',
#                 'sgbstdn_levl_code':'levl_code',
#                 'sgbstdn_styp_code':'styp_code',
#                 'ssbsect_crn':'crn',
#             }
#             qry = f"""
# select
#     {indent(join(alias(dct)))}
#     ,lower(ssbsect_subj_code) || ssbsect_crse_numb as crse_code
#     ,max(ssbsect_credit_hrs) as credit_hr
# from
#     {catalog}saturnsfrstcr as A
# inner join
#     {catalog}saturnssbsect as B
# on
#     sfrstcr_term_code = ssbsect_term_code
#     and sfrstcr_crn = ssbsect_crn
# inner join (
#     select
#         *
#     from
#         {catalog}sgbstdn_amp_v
#     where
#         sgbstdn_term_code_eff <= {self.term_code}
#     qualify
#         row_number() over (partition by sgbstdn_pidm order by sgbstdn_term_code_eff desc) = 1
#     ) as C
# on
#     sfrstcr_pidm = sgbstdn_pidm
# where
#     sfrstcr_term_code = {self.term_code}
#     and sfrstcr_ptrm_code not in ('28','R3') -- drop weird term part
#     and sfrstcr_add_date <= '{self.amp_date}' -- added before amp_day
#     and (sfrstcr_rsts_date > '{self.amp_date}' or sfrstcr_rsts_code in ('DC','DL','RD','RE','RW','WD','WF')) -- dropped after amp_day or still enrolled
#     and ssbsect_subj_code <> 'INST' -- exceptional sections
# group by
#     {indent(join(dct.keys()))}
#     ,ssbsect_subj_code
#     ,ssbsect_crse_numb
# """

#             qry = f"""
# with A as {subqry(qry)}
# select * from A

# union all

# select
#     {indent(join(dct.values()))}
#     ,'_allcrse' as crse_code
#     ,sum(credit_hr) as credit_hr
# from A
# group by
#     {indent(join(dct.values()))}

# union all

# select
#     {indent(join(dct.values()))}
#     ,'_anycrse' as crse_code
#     ,case when sum(credit_hr) > 0 then 1 else 0 end as credit_hr
# from A
# group by
#     {indent(join(dct.values()))}
# """
#             df = run(qry, show)
#             return df
#         df, new = self.get(fcn, 'registrations', overwrite)
#         return df


#     def get_registrations(self, overwrite=False, show=False):
#         def fcn():
#             dct = {
#                 'sfrstcr_pidm':'pidm',
#                 'ssbsect_term_code':'term_code',
#             }
#             qry = f"""
# select
#     {indent(join(alias(dct)))}
#     ,lower(ssbsect_subj_code) || ssbsect_crse_numb as crse_code
#     ,max(ssbsect_credit_hrs) as credit_hr
# from
#     {catalog}saturnsfrstcr as A
# inner join
#     {catalog}saturnssbsect as B
# on
#     sfrstcr_term_code = ssbsect_term_code
#     and sfrstcr_crn = ssbsect_crn
# where
#     sfrstcr_term_code = {self.term_code}
#     and sfrstcr_error_flag is null
#     and sfrstcr_ptrm_code not in ('28','R3') -- drop weird term part
#     and sfrstcr_add_date <= '{self.amp_date}' -- added before amp_day
#     and (sfrstcr_rsts_date > '{self.amp_date}' or sfrstcr_rsts_code in ('DC','DL','RD','RE','RW','WD','WF')) -- dropped after amp_day or still enrolled
#     and ssbsect_subj_code <> 'INST' -- exceptional sections
# group by
#     {indent(join(dct.keys()))}
#     ,ssbsect_subj_code
#     ,ssbsect_crse_numb
# """

#             qry = f"""
# with A as {subqry(qry)}
# select * from A

# union all

# select
#     {indent(join(dct.values()))}
#     ,'_allcrse' as crse_code
#     ,sum(credit_hr) as credit_hr
# from A
# group by
#     {indent(join(dct.values()))}

# union all

# select
#     {indent(join(dct.values()))}
#     ,'_anycrse' as crse_code
#     ,case when sum(credit_hr) > 0 then 1 else 0 end as credit_hr
# from A
# group by
#     {indent(join(dct.values()))}
# """
#             df = run(qry, show)
#             return df
#         df, new = self.get(fcn, 'registrations', overwrite)

#         S = self.students[['pidm','term_code','styp_code']]
#         return df



    #     def fcn():
#             qry = f"""
# select
#     sfrstcr_pidm as pidm
#     ,sfrstcr_term_code as term_code
#     ,lower(ssbsect_subj_code) || ssbsect_crse_numb as crse_code
#     ,max(ssbsect_credit_hrs) as credit_hr
# from
#     {catalog}saturnsfrstcr as A
# inner join
#     {catalog}saturnssbsect as B
# on
#     sfrstcr_term_code = ssbsect_term_code
#     and sfrstcr_crn = ssbsect_crn
# where
#     sfrstcr_term_code = {self.term_code}
#     --and sfrstcr_error_flag is null
#     and sfrstcr_ptrm_code not in ('28','R3') -- drop weird term part
#     and sfrstcr_add_date <= '{self.amp_date}' -- added before amp_day
#     and (sfrstcr_rsts_date > '{self.amp_date}' or sfrstcr_rsts_code in ('DC','DL','RD','RE','RW','WD','WF')) -- dropped after amp_day or still enrolled
#     and ssbsect_subj_code <> 'INST' -- exceptional sections
# group by
#     sfrstcr_pidm
#     ,sfrstcr_term_code
#     ,ssbsect_subj_code
#     ,ssbsect_crse_numb
# """

#             qry = f"""
# with A as {subqry(qry)}
# select * from A

# union all

# select
#     pidm
#     ,term_code
#     ,'_allcrse' as crse_code
#     ,sum(credit_hr) as credit_hr
# from A
# group by
#     pidm
#     ,term_code

# union all

# select
#     pidm
#     ,term_code
#     ,'_anycrse' as crse_code
#     ,case when sum(credit_hr) > 0 then 1 else 0 end as credit_hr
# from A
# group by
#     pidm
#     ,term_code
# """


# don't delete - could be useful & was hard to create
            # stat_codes = ['AL','AR','AZ','CA','CO','CT','DC','DE','FL','GA','IA','ID','IL','IN','KS','KY','LA','MA','MD','ME','MI','MN','MO','MS','MT','NC','ND','NE','NH','NJ','NM','NV','NY','OH','OK','OR','PA','RI','SC','SD','TN','TX','UT','VA','VT','WA','WI','WV','WY'] # not AK & HI b/c can't get driving distance
#     ,{get_desc('spraddr_cnty_code')[0]}
#     ,{get_desc('spraddr_stat_code')[0]}
#     ,zip_code

# left join (
#     select
#         *
#         ,try_to_number(left(spraddr_zip, 5), '00000') as zip_code
#         ,case
#             when spraddr_atyp_code = 'PA' then 6
#             when spraddr_atyp_code = 'PR' then 5
#             when spraddr_atyp_code = 'MA' then 4
#             when spraddr_atyp_code = 'BU' then 3
#             when spraddr_atyp_code = 'BI' then 2
#             when spraddr_atyp_code = 'P1' then 1
#             when spraddr_atyp_code = 'P2' then 0
#             end as spraddr_atyp_rank
#     from
#         {catalog}spraddr_amp_v
#     where
#         spraddr_stat_code in ('{join(stat_codes, "','")}')
#         and spraddr_zip is not null
#     qualify
#         row_number() over (partition by spraddr_pidm order by spraddr_atyp_rank desc, spraddr_seqno desc) = 1
# )
# on
#     pidm = spraddr_pidm

# {get_desc('spraddr_cnty_code')[1]}
# {get_desc('spraddr_stat_code')[1]}



    # def get_zips(self, show=False):
    #     """takes ~3 hours toget zip codes and find nearest point on road network to the provided representative point"""
    #     def fcn():
    #         from pgeocode import Nominatim
    #         nomi = Nominatim('us')
    #         df = nomi.query_postal_code(pd.Series(nomi._data['postal_code'])).query("state_code.notnull() & state_code not in ['AK', 'HI', 'MH']").prep().set_index('postal_code').rename_axis('zip')
    #         nearest = lambda x: join(requests.get(f"http://router.project-osrm.org/nearest/v1/driving/{x['longitude']},{x['latitude']}").json()['waypoints'][0]['location'],',')
    #         df['point'] = df.apply(nearest, axis=1)
    #         return df
    #     df, new = self.get(fcn, root/'zips')
    #     self.states = set(df['state_code'])
    #     return df


    # def get_drivetimes(self, show=False):
    #     def fcn():
    #         campus_coords = {
    #             's': [-98.215784,32.216217],
    #             # 'm': '-97.432975,32.582436',
    #             # 'w': 76708,
    #             # 'r': 77807,
    #             }

    #         url = "https://www2.census.gov/geo/tiger/GENZ2020/shp/cb_2020_us_zcta520_500k.zip"
    #         gdf = gpd.read_file(url).prep().set_index('zcta5ce20').iloc[:5]
    #         pts = gdf.sample_points(size=5,method="uniform").explode()#.apply(lambda geom: f"{geom.x},{geom.y}")
    #         df = pts.to_frame()[[]]
    #         url = "http://router.project-osrm.org/table/v1/driving"
    #         headers = {"Content-Type": "application/json"}
    #         for k, v in campus_coords.items():
    #             u = [v, *pts]
    #             print(u)
    #             data = {
    #                 "coordinates": u,
    #                 "annotations": ["duration", "distance"],
    #                 "sources": 0,
    #             }
    #             response = requests.post(url, json=data, headers=headers)
    #             print(response.json())
    #             assert 1==2

            # for k, v in campus_coords.items():
            #     u = join([v, *pts], ';')
            #     url = f"http://router.project-osrm.org/table/v1/driving/{u}?destinations={0}&annotations=duration,distance"
            #     print(url)
            #     print(requests.get(url))
            #     df[k] = np.squeeze(requests.get(url).json()['durations'])[1:]/60
    #         # df.disp(10)
    #         # df = df.groupby('zip').min()
    #         # df.disp(10)
    #         df = df.groupby(level=0).min().stack().reset_index().set_axis(['zip','camp_code','drivetime'], axis=1)
    #         return df

            # df = self.zips.iloc[34339:34349].copy()
            # u = join(df.apply(lambda x: f"{x['longitude']},{x['latitude']}", axis=1),';')
            # for k, z in campus_zips.items():
            #     df[k] = np.squeeze(
            #     requests.get(f"http://router.project-osrm.org/table/v1/driving/{u}?destinations={df.index.get_loc(z)}&annotations=duration,distance&fallback_speed=26.8&fallback_coordinate=snapped"
            #     ).json()['distances'])/1609

            # for k, z in campus_zips.items():
            #     df[k] = np.squeeze(
            #     requests.get(f"http://router.project-osrm.org/table/v1/driving/{u}?destinations={df.index.get_loc(z)}&annotations=duration,distance&fallback_speed=26.8&fallback_coordinate=snapped"
            #     ).json()['durations'])/60

            # self.zips = self.zips.iloc[34339:34349]
            # self.zips.disp(20)
            # u = join(self.zips['point'],';')
            
            # dct = dict()
            # for k, z in campus_zips.items():
            #     i = self.zips.index.get_loc(z)
            #     print(self.zips.iloc[i])
            #     url = f"http://router.project-osrm.org/table/v1/driving/{u}?destinations={0}&annotations=duration,distance&fallback_speed=26.8&fallback_coordinate=snapped"
            #     print(url)
            #     response = requests.get(url).json()
            #     for k,v in response.items():
            #         print(k)
            #         display(v)
            #         print()
            #     print(response['distance'])
            #     dct[k] = np.squeeze(response['durations'])

            # # dct = {k: np.squeeze(
            # #     requests.get(f"http://router.project-osrm.org/table/v1/driving/{u}?destinations={self.zips.index.get_loc(z)}&fallback_speed=600&fallback_coordinate=snapped"
            # #     ).json()['durations'])/60 for k, z in campus_zips.items()}
            # dct = {k: np.squeeze(
            #     requests.get(f"http://router.project-osrm.org/table/v1/driving/{u}?destinations={self.zips.index.get_loc(z)}&annotations=duration,distance&fallback_speed=26.8&fallback_coordinate=snapped"
            #     ).json()['distances']) for k, z in campus_zips.items()}

            # df = pd.DataFrame(dct, index=self.zips.index).stack().rename_axis(['zip','camp_code']).rename('drivetime') / 1609
            # return df
            # print()
            # dct = {k: self.zips.loc[y] for k, y in {
            #     's': 76402,
            #     'm': 76036,
            #     'w': 76708,
            #     'r': 77807,
            #     }.items()}
            # L = [
            #     self.get(
            #         lambda: X.apply(get_driving_distance, y=y, axis=1).rename('distance').reset_index().assign(camp_code=k),
            #         root/f'distances/distances_{s}_{k}',
            #         divide=False,
            #     )[0] for s, X in self.zips.groupby('state_code') for k, y in dct.items()]
            # return pd.concat(L, ignore_index=True)
        # df, new = self.get(fcn, root/'drivetimes')#, self.get_zips)
        # return df




    # def get_flags_history(self, cutoff=202206):
    #     def fcn():
    #         import pyarrow.parquet as pq
    #         print()
    #         L = []
    #         for path in sorted(flags_prc.iterdir(), reverse=True):
    #             print(path)
    #             for src in path.iterdir():
    #                 _, term_code, amp_date, amp_day = src.stem.split('_')
    #                 col = pq.ParquetFile(src).schema.names
    #                 df = pd.DataFrame(columns=col).assign(term_code=[int(term_code)], amp_date=[amp_date]).fillna(True)
    #                 L.append(df)
    #         df = pd.concat(L).fillna(False).set_index(['amp_date','term_code']).sort_index()
    #         return df[sorted(df.columns)]
    #     df, new = self.get(fcn, path=data/'flags_history')
    #     A = df.query(f'term_code>={cutoff}').groupby('term_code').sum().sort_index(ascending=False).T.rename_axis('variable')
    #     B = A == A.max()
    #     B.insert(0, 'n', B.sum(axis=1))
    #     return B.reset_index().sort_values(['n', 'variable'], ascending=[False, True])



############ annoying warnings to suppress ############
# [warnings.filterwarnings(action='ignore', message=f".*{w}.*") for w in [
#     "Could not infer format, so each element will be parsed individually, falling back to `dateutil`",
#     "Engine has switched to 'python' because numexpr does not support extension array dtypes",
#     "The default of observed=False is deprecated and will be changed to True in a future version of pandas",
#     "errors='ignore' is deprecated"
#     "The behavior of DataFrame concatenation with empty or all-NA entries is deprecated",
#     "The behavior of array concatenation with empty entries is deprecated",
#     "DataFrame is highly fragmented",
# ]]



# for fore in amp.values():
#     for base in amp.values():
#         if base.year < max(years):
#             for styp, clf in base.get_models().items():
#                 for k in ['predictions','headcounts']:
#                     fore.__dict__.setdefault(k, dict()).setdefault(styp, dict())
#                 y = clf.prediction(fore.get_prepared()[styp])
#                 fore.predictions[styp][base.year] = y
#                 s = (
#                     (y[['pred']].sum() * base.get_enrollments().loc[base.crse_code,styp]['mlt']).round()
#                     .rename(base.year).to_frame().T.rename_axis('base_year')
#                     .assign(styp_code=styp, forecast_year=fore.year)
#                     .reset_index().set_index(['styp_code','forecast_year','base_year'])
#                 )
#                 if fore.year < max(years):
#                     s['true'] = fore.get_enrollments().loc[base.crse_code,styp]['stable']
#                     s['error'] = s['pred'] - s['true']
#                     s['error_pct'] = round(s['error'] / s['true'] * 100, 2)
#                 fore.headcounts[styp][base.year] = s.prep()
#     fore.forecasts = {styp: pd.concat(v.values()) for styp, v in fore.headcounts.items()}
# amp[2023].forecasts['n']



# def get_desc(code):
#     for nm in code.split('_'):
#         if len(nm) == 4:
#             break
#     return [f'{code} as {nm}_code, stv{nm}_desc as {nm}_desc', f'left join {catalog}saturnstv{nm} on {code} = stv{nm}_code']


#             qry = f"""
# select
#     pidm
#     ,{self.amp_day} as amp_day
#     ,timestamp('{self.amp_date}') as amp_date
#     ,current_date
#     ,first_date
#     ,final_date
#     ,{get_desc('term_code')[0]}
#     ,appl_no
#     ,{get_desc('apst_code')[0]}
#     ,{get_desc('apdc_code')[0]}
#     ,{get_desc('admt_code')[0]}
#     ,{get_desc('wrsn_code')[0]}
#     ,{get_desc('levl_code')[0]}
#     ,{get_desc('styp_code')[0]}
#     ,{get_desc('camp_code')[0]}
#     ,{get_desc('coll_code_1')[0]}
#     ,{get_desc('dept_code')[0]}
#     ,{get_desc('majr_code_1')[0]}
#     ,{get_desc('saradap_resd_code')[0]}
#     ,gender
#     ,birth_date
#     ,{get_desc('spbpers_lgcy_code')[0]}
#     ,gorvisa_vtyp_code is not null as international
#     ,gorvisa_natn_code_issue as natn_code, stvnatn_nation as natn_desc
#     ,{coalesce('race_asian')}
#     ,{coalesce('race_black')}
#     ,coalesce(spbpers_ethn_cde=2, False) as race_hispanic
#     ,{coalesce('race_native')}
#     ,{coalesce('race_pacific')}
#     ,{coalesce('race_white')}
#     ,hs_percentile
#     ,sbgi_code
#     ,enrolled_ind='Y' as enrolled_ind

# from {subqry(qry)} as A

# left join
#     {catalog}spbpers_amp_v
# on
#     pidm = spbpers_pidm

# left join (
#     select
#         *
#     from
#         {catalog}generalgorvisa
#         --{catalog}gorvisa_amp_v
#     qualify
#         row_number() over (partition by gorvisa_pidm order by gorvisa_seq_no desc) = 1
#     )
# on
#     pidm = gorvisa_pidm

# left join (
#     select
#         gorprac_pidm
#         ,max(gorprac_race_cde='AS') as race_asian
#         ,max(gorprac_race_cde='BL') as race_black
#         ,max(gorprac_race_cde='IN') as race_native
#         ,max(gorprac_race_cde='HA') as race_pacific
#         ,max(gorprac_race_cde='WH') as race_white
#     from
#         {catalog}generalgorprac
#     group by
#         gorprac_pidm
#     )
# on
#     pidm = gorprac_pidm

# {get_desc('term_code')[1]}
# {get_desc('levl_code')[1]}
# {get_desc('styp_code')[1]}
# {get_desc('admt_code')[1]}
# {get_desc('wrsn_code')[1]}
# {get_desc('apst_code')[1]}
# {get_desc('apdc_code')[1]}
# {get_desc('camp_code')[1]}
# {get_desc('coll_code_1')[1]}
# {get_desc('dept_code')[1]}
# {get_desc('majr_code_1')[1]}
# {get_desc('saradap_resd_code')[1]}
# {get_desc('gorvisa_natn_code_issue')[1]}
# {get_desc('spbpers_lgcy_code')[1]}

# qualify
#     min(levl_code='UG') over (partition by pidm) = True  -- remove pidm's with graduate admission even it if also has an undergradute admission, min acts like logical and
#     and row_number() over (partition by pidm order by appl_no desc) = 1  -- de-duplicate the few remaining pidms with multiple record by keeping highest appl_no
# """



#     def get_registrations(self, show=False):
#         def fcn():
#             dct = {
#                 'sfrstcr_pidm':'pidm',
#                 'ssbsect_term_code':'term_code',
#                 'sgbstdn_levl_code':'levl_code',
#                 'sgbstdn_styp_code':'styp_code',
#             }
#             qry = f"""
# select
#     {indent(join(alias(dct)))}
#     ,lower(ssbsect_subj_code) || ssbsect_crse_numb as crse_code
#     ,max(ssbsect_credit_hrs) as credit_hr
# from
#     {catalog}saturnsfrstcr as A
# inner join
#     {catalog}saturnssbsect as B
# on
#     sfrstcr_term_code = ssbsect_term_code
#     and sfrstcr_crn = ssbsect_crn
# inner join (
#     select
#         *
#     from
#         {catalog}sgbstdn_amp_v
#     where
#         sgbstdn_term_code_eff <= {self.term_code}
#     qualify
#         row_number() over (partition by sgbstdn_pidm order by sgbstdn_term_code_eff desc) = 1
#     ) as C
# on
#     sfrstcr_pidm = sgbstdn_pidm
# where
#     sfrstcr_term_code = {self.term_code}
#     and sfrstcr_error_flag is null
#     and sfrstcr_ptrm_code not in ('28','R3') -- drop weird term part
#     and sfrstcr_add_date <= '{self.amp_date}' -- added before amp_day
#     and (sfrstcr_rsts_date > '{self.amp_date}' or sfrstcr_rsts_code in ('DC','DL','RD','RE','RW','WD','WF')) -- dropped after amp_day or still enrolled
#     and ssbsect_subj_code <> 'INST' -- exceptional sections
#     and ssbsect_credit_hrs > 0
# group by
#     {indent(join(dct.keys()))}
#     ,ssbsect_subj_code
#     ,ssbsect_crse_numb
# """

#             qry = f"""
# with A as {subqry(qry)}
# select * from A

# union all

# select
#     {indent(join(dct.values()))}
#     ,'_allcrse' as crse_code
#     ,sum(credit_hr) as credit_hr
# from A
# group by
#     {indent(join(dct.values()))}

# union all

# select
#     {indent(join(dct.values()))}
#     ,'_anycrse' as crse_code
#     ,1 as credit_hr
# from A
# group by
#     {indent(join(dct.values()))}
# """
#             df = run(qry, show).set_index(['crse_code','pidm'])
#             return df
#         df, new = self.get(fcn, 'registrations')
#         return df


#     def get_registrations(self, show=False):
#         def fcn():
#             grp = join([
#                 'pidm',
#                 'term_code','term_desc',
#                 'levl_code','levl_desc',
#                 'styp_code','styp_desc',
#                 ], ', ')

#             qry = f"""
# select
#     sfrstcr_pidm as pidm
#     ,{get_desc('ssbsect_term_code')}
#     ,{get_desc('sgbstdn_levl_code')}
#     ,{get_desc('sgbstdn_styp_code')}
#     ,lower(ssbsect_subj_code) || ssbsect_crse_numb as crse_code
#     ,max(ssbsect_credit_hrs) as credit_hr
# from
#     {catalog}saturnsfrstcr as A
# inner join
#     {catalog}saturnssbsect as B
# on
#     sfrstcr_term_code = ssbsect_term_code
#     and sfrstcr_crn = ssbsect_crn
# inner join (
#     select
#         *
#     from
#         {catalog}sgbstdn_amp_v
#     where
#         sgbstdn_term_code_eff <= {self.term_code}
#     qualify
#         row_number() over (partition by sgbstdn_pidm order by sgbstdn_term_code_eff desc) = 1
#     ) as C
# on
#     sfrstcr_pidm = sgbstdn_pidm
# where
#     sfrstcr_term_code = {self.term_code}
#     and sfrstcr_error_flag is null
#     and sfrstcr_ptrm_code not in ('28','R3') -- drop weird term parts
#     and sfrstcr_add_date <= '{self.amp_date}' -- added before amp_day
#     and (sfrstcr_rsts_date > '{self.amp_date}' or sfrstcr_rsts_code in ('DC','DL','RD','RE','RW','WD','WF')) -- dropped after amp_day or still enrolled
#     and ssbsect_subj_code <> 'INST' -- exceptional sections
#     and ssbsect_credit_hrs > 0
# group by
#     {grp}, crse_code
# """

#             qry = f"""
# with CTE as {subqry(qry)}

# --individual courses
# select
#     *
# from
#     CTE

# union all

# --total credit hours
# select
#     {grp}
#     ,'_total_sch' as crse_code
#     ,sum(credit_hr) as credit_hr
# from
#     CTE
# group by
#     {grp}

# union all

# --headcount 
# select
#     {grp}
#     ,'_headcount' as crse_code
#     ,1 as credit_hr
# from
#     CTE
# group by
#     {grp}
# """
#             df = run(qry, show).set_index(['crse_code','pidm'])
#             return df
#         df, new = self.get(fcn, 'registrations')
#         return df



    def get_students(self, show=False):
        def fcn():
            df = (self.admissions
                  .merge(self.get_flags(), on=['pidm','term_code'], how='left', suffixes=['', '_flags'])
                  .merge(self.get_drivetimes(), on=['zip','camp_code'], how='left', suffixes=['', '_zips'])                
            )
            mask = df.eval("drivetime.isnull() & zip.notnull() & camp_code!='o'")
            if mask.any():
                df[mask].set_index(['state','city','zip','camp_code'])[[]].sort_index().reset_index().disp(50)

            for c in ['gap_score','t_gap_score','ftic_gap_score']:
                if c not in df:
                    df[c] = pd.NA
            
            df['gap_score'] = np.where(
                df['styp_code']=='n',
                df['ftic_gap_score'].combine_first(df['t_gap_score']).combine_first(df['gap_score']),
                df['t_gap_score'].combine_first(df['ftic_gap_score']).combine_first(df['gap_score']))
            
            
            # df['oriented'] = np.where(df['orien_sess'].notnull() | df['registered'].notnull(), 'y', np.where(df['orientation_hold_exists'].notnull(), 'n', 'w'))
            df['oriented'] = df['orientation_hold_exists'].isnull() | df['orien_sess'].notnull() | df['registered'].notnull()

            # df['verified'] = np.where(df['ver_complete'].notnull(), 'y', np.where(df['selected_for_ver'].notnull(), 'n', 'w'))
            df['verified'] = df['selected_for_ver'].isnull() | df['ver_complete'].notnull()
            
            df['sat10_total_score'] = (36-9) / (1600-590) * (df['sat10_total_score']-590) + 9
            df['act_equiv'] = df[['act_new_comp_score','sat10_total_score']].max(axis=1)

            df['eager'] = (self.stable_date - df['first_date']).dt.days
            df['age'] = (self.stable_date - df['birth_date']).dt.days

            for k in ['reading', 'writing', 'math']:
                df[f'tsi_{k}'] = ~df[k].isin(['not college ready', 'retest required', pd.NA, None, np.nan])
            
            repl = {'ae':0, 'n1':1, 'n2':2, 'n3':3, 'n4':4, 'r1':1, 'r2':2, 'r3':3, 'r4':4}
            df['hs_qrtl'] = pd.cut(df['hs_pctl'], bins=[-1,25,50,75,90,101], labels=[4,3,2,1,0], right=False).combine_first(df['apdc_code'].map(repl))

            df['lgcy'] = ~df['lgcy_code'].isin(['o',pd.NA,None,np.nan])
            df['resd'] = df['resd_code'] == 'r'

            for k in ['waiver_desc','fafsa_app','ssb_last_accessed','finaid_accepted','schlship_app']:
                df[k.split('_')[0]] = df[k].notnull()



            # df['majr_code'] = df['majr_code'].replace({'0000':pd.NA, 'und':pd.NA, 'eled':'eted', 'agri':'unda'})

            # df['coll_code'] = df['coll_code'].replace({'ae':'an', 'eh':'ed', 'hs':'hl', 'st':'sm', '00':pd.NA})

            # df['coll_desc'] = df['coll_code'].map({
            #     'an': 'ag & natural_resources',
            #     'ba': 'business',
            #     'ed': 'education',
            #     'en': 'engineering',
            #     'hl': 'health sciences',
            #     'la': 'liberal & fine arts',
            #     'sm': 'science & mathematics',
            #     pd.NA: 'no college designated',
            # })



            # checks = [
            #     'amp_day >= 0',
            #     'eager >= amp_day',
            #     'age >= 5000',
            #     'distance >= 0',
            #     'hs_pctl >=0',
            #     'hs_pctl <= 100',
            #     'hs_qrtl >= 0',
            #     'hs_qrtl <= 4',
            #     'act_equiv >= 1',
            #     'act_equiv <= 36',
            #     'gap_score >= 0',
            #     'gap_score <= 100',
            # ]
            # for check in checks:
            #     mask = df.eval(check)
            #     assert mask.all(), [check,df[~mask].disp(5)]
            mask = df['amp_date_flags'].isnull()  # rows from admissions not on flags - should not be any
            if mask.any():
                display(df[mask]['styp_code'].value_counts().sort_index().to_frame().T)
            return df.set_index(['pidm'])
        df, new = self.get(fcn, 'students', [self.get_admissions,self.get_flags,self.get_drivetimes])
        return df



for A in list(flags_prc.iterdir()):
#     if A.is_dir():
#         for src in list(A.iterdir()):
#             if src.is_file():
#                 date = [x for x in src.name.split('_') if '-' in x][0]
#                 dst = A / date / src.name
#                 reset(dst)
#                 print(src, dst)
#                 # assert False
#                 src.rename(dst)

#             # for src in list(B.iterdir()):
#             #     if src.is_file():
#             #         date = [x for x in src.name.split('_') if '-' in x][0]
#             #         dst = B / date / src.name
#             #         reset(dst)
#             #         print(src, dst)
#             #         src.rename(dst)