In [0]:
username = 'scook'
from IPython.display import display, HTML, clear_output
try:
    %reload_ext autotime
except:
    %pip install -U ipython-autotime ipywidgets codetiming openpyxl numpy pandas geopandas pgeocode flaml[automl] git+https://github.com/AnotherSamWilson/miceforest.git
    dbutils.library.restartPython()
    clear_output()
    dbutils.notebook.exit('Rerun to use newly installed/updated packages')

import pathlib, shutil, pickle, warnings, requests, dataclasses, codetiming, numpy as np, pandas as pd, geopandas as gpd, pgeocode, flaml as fl, miceforest as mf
clear_output()
catalog = 'dev.bronze.'
root = pathlib.Path(f'/Workspace/Users/{username}@tarleton.edu/admitted_matriculation_predictor_2025/')
data = root/'data'
flags_raw = pathlib.Path('/Volumes/aiml/scook/scook_files/admitted_flags_raw')
flags_prc = pathlib.Path('/Volumes/aiml/flags/flags_volume/')

############ annoying warnings to suppress ############
for w in [
    "Could not infer format, so each element will be parsed individually, falling back to `dateutil`",
    "Engine has switched to 'python' because numexpr does not support extension array dtypes",
    "The default of observed=False is deprecated and will be changed to True in a future version of pandas",
    "errors='ignore' is deprecated"
    "The behavior of DataFrame concatenation with empty or all-NA entries is deprecated",
    "The behavior of array concatenation with empty entries is deprecated",
    "DataFrame is highly fragmented",
    "DataFrameGroupBy.apply operated on the grouping columns",
    ]:
    warnings.filterwarnings(action='ignore', message=f".*{w}.*", append=True)

##########################################
############ helper functions ############
##########################################
tab = '    '
divider = '\n##############################################################################################################\n'

def listify(*args):
    """ensure it is a list"""
    if len(args)==1:
        if args[0] is None or args[0] is np.nan or args[0] is pd.NA:
            return list()
        elif isinstance(args[0], str):
            return [args[0]]
    try:
        return list(*args)
    except Exception as e:
        return list(args)

def setify(*args):
    """ensure it is a set"""
    return set(listify(*args))

def unique(*args):
    """get unique items maintaining order"""
    return listify(dict.fromkeys(listify(*args)))

def difference(A, B):
    return unique([x for x in listify(A) if x not in listify(B)])

def rjust(x, width, fillchar=' '):
    return str(x).rjust(width,str(fillchar))

def ljust(x, width, fillchar=' '):
    return str(x).ljust(width,str(fillchar))

def join(lst, sep='\n,', pre='', post=''):
    """flexible way to join list of strings into a single string"""
    return f"{pre}{str(sep).join(map(str,listify(lst)))}{post}"

def alias(dct):
    """convert dict of original column name:new column name into list"""
    return [f'{k} as {v}' for k,v in dct.items()]

def indent(x, lev=1):
    return x.replace('\n','\n'+tab*lev) if lev>0 else x

def subqry(qry, lev=1):
    """make qry into subquery"""
    return '(' + indent('\n'+qry.strip()+'\n)', lev)

def run(qry, show=False, sample='10 rows', seed=42):
    """run qry and return dataframe"""
    L = qry.split(' ')
    if len(L) == 1:
        qry = f'select * from {catalog}{L[0]}'
        if sample is not None:
            qry += f' tablesample ({sample}) repeatable ({seed})'
    if show:
        print(qry)
    return spark.sql(qry).toPandas().prep().sort_index()

def rm(path, root=False):
    path = pathlib.Path(path)
    if path.is_file():
        path.unlink()
    elif path.is_dir():
        if root:
            shutil.rmtree(path)
        else:
            for p in path.iterdir():
                rm(p, True)
    return path

def get_desc(code):
    for nm in code.split('_'):
        if len(nm) == 4:
            break
    return f'{code} as {nm}_code, (select stv{nm}_desc from {catalog}saturnstv{nm} where {code} = stv{nm}_code limit 1) as {nm}_desc'

def coalesce(x, y=False):
    return f'coalesce({x}, {y}) as {x}'

############ pandas functions ############
pd.options.display.max_columns = None
def disp(df, rows=4, head=True, sort=False):
    """convenient display method"""
    with pd.option_context('display.min_rows', rows, 'display.max_rows', rows):
        print(df.shape)
        df = df.sort_index(axis=1) if sort else df
        missing = df.isnull().sum().to_frame().T
        X = pd.concat([df.dtypes.to_frame().T, missing, (missing/df.shape[0]*100).round(2), df.head(rows) if head else df.tails(rows)])
        display(HTML(X.to_html()))

def inser(df, column, value, loc=0):
    df.insert(loc, column, value)
    return df

def to_numeric(df, downcast='integer', errors='ignore', **kwargs):
    """convert to numeric dtypes if possible"""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        return (
            df
            .apply(lambda s: s.astype('string').str.lower().str.strip() if s.dtype in ['object','string'] else s)  # prep strings
            .apply(lambda s: s if pd.api.types.is_datetime64_any_dtype(s) else pd.to_numeric(s, downcast=downcast, errors=errors, **kwargs))  # convert to numeric if possible
            .convert_dtypes()  # convert to new nullable dtypes
            .apply(lambda s: s.astype('Int64') if pd.api.types.is_integer_dtype(s) else s)
        )

def prep(df, **kwargs):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        h = lambda x: x.to_numeric(**kwargs).rename(columns=lambda s: s.lower().strip().replace(' ','_').replace('-','_') if isinstance(s, str) else s)
        idx = h(df[[]].reset_index())  # drop columns, reset_index to move index to columns, then apply g
        return h(df).reset_index(drop=True).set_index(pd.MultiIndex.from_frame(idx))  # set idx back to df's index

def wrap(fcn):
    """Make new methods work for Series and DataFrames"""
    def wrapper(X, *args, **kwargs):
        df = fcn(pd.DataFrame(X), *args, **kwargs)
        return None if df is None else df.squeeze() if isinstance(X, pd.Series) else df  # squeeze to series if input was series
    return wrapper

for f in [disp, inser, to_numeric, prep]:
    """monkey-patch my helpers into Pandas Series & DataFrame classees so we can use df.method syntax"""
    setattr(pd.DataFrame, f.__name__, f)
    setattr(pd.Series, f.__name__, wrap(f))

def prediction(clf, X, y):
    Z = X.copy()
    Z['prediction'] = clf.predict_proba(X)[:,1]
    Z['actual'] = y
    Z['mae'] = np.abs(Z['prediction'] - Z['actual'])
    Z['log_loss'] = -1*(Z['actual']*np.log(Z['prediction']) + (1-Z['actual'])*np.log(1-Z['prediction']))
    return Z

for f in [prediction]:
    setattr(fl.automl.automl.AutoML, f.__name__, f)

#########################################
################## AMP ##################
#########################################
@dataclasses.dataclass
class Term():
    term_code: int = 202408
    cycle_day: int = None
    cycle_date: str = None
    overwrite: set = None
    seed: int = 42

    #Allows self['attr'] and self.attr syntax
    def __contains__(self, key):
        return hasattr(self, key)
    def __getitem__(self, key):
        return getattr(self, key)
    def __setitem__(self, key, val):
        setattr(self, key, val)
    def __delitem__(self, key):
        if key in self:
            delattr(self, key)


    def __post_init__(self):
        self.term_code = int(self.term_code)
        self.cycle_day, self.cycle_date, self.stable_date, self.stem, self.term_desc = self.get_cycle(self.term_code, self.cycle_day, self.cycle_date)
        self.overwrite = setify(self.overwrite)
        # Because these take about 1 hour each, we force user to manually delete parquet files to avoid accidental deletion requring lengthy re-creation
        self.overwrite.discard('drivetimes_s')
        self.overwrite.discard('drivetimes_m')
        self.overwrite.discard('drivetimes_w')
        self.overwrite.discard('drivetimes_r')
        self.overwrite.discard('drivetimes_l')


    def get(self, fcn, dst, prereq=[], *, divide=True, read=True, suffix='.parquet', **kwargs):
        nm = str(dst)
        if '/' in nm:
            dst = pathlib.Path(dst).with_suffix(suffix)
            nm = dst.stem
        else:
            dst = data/f'{dst}/{self.term_code}/{dst}_{self.stem}{suffix}'

        if nm in self.overwrite:
            del self[nm]
            dst.unlink(missing_ok=True)
            self.overwrite.remove(nm)

        new = False
        if not nm in self:
            if not dst.exists():
                new = True
                for f in listify(prereq):
                    f()
                print(f'creating {dst.name}: ', end='')
                with codetiming.Timer():
                    rslt = fcn(**kwargs)
                    dst.parent.mkdir(parents=True, exist_ok=True)
                    if suffix == '.parquet':
                        pd.DataFrame(rslt).prep().to_parquet(dst)  # forced to wrap with explicit pd.DataFrame to due strange error under pandas 2.2.3 "Object of type PlanMetrics is not JSON serializable" with to_parquet
                    elif suffix == 'csv':
                        pd.DataFrame(rslt).prep().to_csv(dst)
                    else:
                        with open(dst, 'wb') as f:
                            pickle.dump(rslt, f, pickle.HIGHEST_PROTOCOL)
                if divide:
                    print(divider)
            if read:
                if suffix == '.parquet':
                    self[nm] = pd.read_parquet(dst)
                elif suffix == 'csv':
                    self[nm] = pd.read_csv(dst)
                else:
                    with open(dst, 'rb') as f:
                        self[nm] = pickle.load(f)
            else:
                self[nm] = None
        return self[nm], new
##################################################
################# get drivetimes #################
##################################################
    def get_zips(self, show=False):
        def fcn():
            df = (
                pgeocode.Nominatim('us')._data  # get all zips
                .prep()
                .rename(columns={'postal_code':'zip'})
                .query("state_code.notnull() & state_code not in [None,'mh']")
            )
            return df
        df, new = self.get(fcn, root/'geo/zips')
        return df


    def get_states(self, show=False):
        return set(self.get_zips()['state_code'])


    def get_drivetimes(self, show=False):
        def fcn():
            from pgeocode import Nominatim
            from sklearn.metrics.pairwise import haversine_distances
            print()
            campus_coords = {
                's': '-98.215784,32.216217',
                'm': '-97.432975,32.582436',
                'w': '-97.172176,31.587908',
                'r': '-96.467920,30.642055',
                'l': '-96.983211,32.462267',
                }
            url = "https://www2.census.gov/geo/tiger/GENZ2020/shp/cb_2020_us_zcta520_500k.zip"
            gdf = gpd.read_file(url).prep().set_index('zcta5ce20')  # get all ZCTA https://www.census.gov/programs-surveys/geography/guidance/geo-areas/zctas.html
            pts = gdf.sample_points(size=10, method="uniform").explode().apply(lambda g: f"{g.x},{g.y}")  # sample 10 random points in each ZCTA
            M = []
            for k, v in campus_coords.items():
                def fcn1():
                    print()
                    L = []
                    i = 0
                    di = 200
                    I = pts.shape[0]
                    while i < I:
                        u = join([v, *pts.iloc[i:i+di]],';')
                        url = f"http://router.project-osrm.org/table/v1/driving/{u}?sources={0}&annotations=duration,distance&fallback_speed=1&fallback_coordinate=snapped"
                        response = requests.get(url).json()
                        L.append(np.squeeze(response['durations'])[1:]/60)
                        i += di
                        print(k,i,round(i/I*100))
                    df = pts.to_frame()[[]]
                    df[k] = np.concatenate(L)
                    return df
                df, new = self.get(fcn1, root/f'geo/drivetimes_{k}')
                M.append(df)
            D = pd.concat(M, axis=1).groupby(level=0).min().stack().reset_index().set_axis(['zip','camp_code','drivetime'], axis=1)

            # There are a few USPS zips without equivalent ZCTA, so we assign them drivetimes for the nearest
            Z = self.get_zips().merge(D.query("camp_code=='s'"), how='left').set_index('zip')
            mask = Z['drivetime'].isnull()  # zips without a ZTCA
            Z = Z[['latitude','longitude']]
            X = np.radians(Z[~mask])
            Y = np.radians(Z[mask])
            M = (
                pd.DataFrame(haversine_distances(X, Y), index=X.index, columns=Y.index) # haversine distance between pairs with and without ZCTA
                .idxmin()  # find nearest ZCTA
                .reset_index()
                .set_axis(['new_zip','zip'], axis=1)
                .prep()
                .merge(D)  # merge the drivetimes for that ZCTA
                .drop(columns='zip')
                .rename(columns={'new_zip':'zip'})
            )
            df = pd.concat([D,M], ignore_index=True)
            return df
        df, new = self.get(fcn, root/'geo/drivetimes', self.get_zips)
        return df

########################################################
################# get term information #################
########################################################
    def get_terms(self, show=False):
        def fcn():
            qry = f"""
select
    stvterm_code as term_code
    ,replace(stvterm_desc, ' ', '') as term_desc
    ,stvterm_start_date as start_date
    ,stvterm_end_date as end_date
    ,stvterm_fa_proc_yr as fa_proc_yr
    ,stvterm_housing_start_date as housing_start_date
    ,stvterm_housing_end_date as housing_end_date
    ,sobptrm_census_date as census_date
from
    {catalog}saturnstvterm as A
inner join
    {catalog}saturnsobptrm as B
on
    stvterm_code = sobptrm_term_code
where
    sobptrm_ptrm_code='1'
"""
            df = run(qry, show).set_index('term_code')
            df['stable_date'] = df['census_date'].apply(lambda x: x+pd.Timedelta(days=7+4-x.weekday())) # Friday of week following census
            return df
        df, new = self.get(fcn, data/'terms')
        return df


    def get_cycle(self, term_code, cycle_day=None, cycle_date=None, show=False):
        term_desc, stable_date = self.get_terms().loc[term_code,['term_desc','stable_date']]
        if cycle_day is None:
            if cycle_date is None:
                cycle_date = pd.Timestamp.now()
            cycle_date = min(pd.to_datetime(cycle_date), pd.Timestamp.now()).normalize()
            cycle_day = (stable_date - cycle_date).days
        cycle_date = (stable_date - pd.Timedelta(days=cycle_day)).date()
        stem = f'{term_code}_{cycle_date}_{"-" if cycle_day < 0 else "+"}{rjust(abs(cycle_day),3,0)}'
        return cycle_day, cycle_date, stable_date, stem, term_desc
#######################################################
############ process flags reports archive ############
#######################################################
    def get_spriden(self, show=False):
        # Get id-pidm crosswalk so we can replace id by pidm in flags below
        # GA's should not have permissions to run this because it can see pii
        if 'spriden' not in self:
            qry = f"""
            select distinct
                spriden_id as id,
                spriden_pidm as pidm
            from
                {catalog}saturnspriden as A
            where
                spriden_change_ind is null
                and spriden_activity_date between '2000-09-01' and '2025-09-01'
                and spriden_id REGEXP '^[0-9]+'
            """
            self.spriden = run(qry, show)
        return self.spriden


    def process_flags(self, show=False):
        # GA's should not have permissions to run this because it can see pii
        counter = 0
        divide = False
        for src in sorted(flags_raw.iterdir(), reverse=True):
            counter += 1
            if counter > 5:
                break
            a,b = src.name.lower().split('.')
            if b != 'xlsx' or 'melt' in a or 'admitted' not in a:
                print(a, 'SKIP')
                continue
            # Handles 2 naming conventions that were used at different times
            try:
                cycle_date = pd.to_datetime(a[:10].replace('_','-'))
                multi = True
            except:
                try:
                    cycle_date = pd.to_datetime(a[-6:])
                    multi = False
                except:
                    print(a, 'FAIL')
                    continue
            book = pd.ExcelFile(src, engine='openpyxl')
            # Again, handles the 2 different versions with different sheet names
            if multi:
                sheets = {sheet:sheet for sheet in book.sheet_names if sheet.isnumeric() and int(sheet) % 100 in [1,6,8]}
            else:
                sheets = {a[:6]: book.sheet_names[0]}
            for term_code, sheet in sheets.items():
                cycle_day, cycle_date, stable_date, stem, term_desc = self.get_cycle(term_code, cycle_date=cycle_date)
                def fcn():
                    df = (
                        self.get_spriden()
                        .assign(cycle_day=cycle_day, cycle_date=cycle_date)
                        .merge(book.parse(sheet).prep(), on='id', how='right')
                        .drop(columns=['id','last_name','first_name','mi','pref_fname','street1','street2','primary_phone','call_em_all','email'], errors='ignore')
                    )
                    return df
                if self.get(fcn, flags_prc/f'{term_code}/flags_{stem}', read=False, divide=False)[1]:
                    divide = True
                    counter = 0
                    dst = flags_prc/f'flags_{term_code}.parquet'
                    dst.unlink(missing_ok=True)
        if divide:
            print(divider)
            self.combine_flags()


    def combine_flags(self, show=False):
        def fcn(term_code):
            F = sorted((flags_prc/f'{term_code}').glob('*.parquet'))+sorted((flags_prc/f'{term_code-2}').glob('*.parquet'))
            L = [pd.read_parquet(src) for src in F]
            df = pd.concat(L, ignore_index=True).prep()
            del L
            for k in ['dob',*df.filter(like='date').columns]:  # convert date columns
                if k in df:
                    df[k] = pd.to_datetime(df[k], errors='coerce')
            return df
        divide = False
        for x in flags_prc.iterdir():
            if x.is_dir():
                term_code = int(x.stem)
                if term_code%10==8:
                    if self.get(fcn, flags_prc/f'flags_{term_code}', read=False, divide=False, term_code=term_code)[1]:
                        divide = True
        if divide:
            print(divider)


    def get_flags(self, show=False):
        def fcn():
            df = (
                pd.read_parquet(flags_prc/f'flags_{self.term_code}.parquet')
                .query(f"cycle_date<='{self.cycle_date}'")
                .sort_values(['pidm','cycle_date'])
                .drop_duplicates(subset=['pidm','term_code'], keep='last')
            )
            df.loc[~df['state'].isin(self.get_states()),'zip'] = pd.NA
            df['zip'] = df['zip'].str.split('-', expand=True)[0].str[:5].to_numeric(errors='coerce')
            return df
        df, new = self.get(fcn, 'flags', self.combine_flags)
        return df
##########################################
############ get student data ############
###############################
    def get_students(self, show=False):
        def fcn():
            df = (self.admissions
                  .merge(self.get_flags(), on=['pidm','term_code'], how='left', suffixes=['', '_flags'])
                  .merge(self.get_drivetimes(), on=['zip','camp_code'], how='left', suffixes=['', '_zips'])                
            )
            mask = df.eval("drivetime.isnull() & zip.notnull() & camp_code!='o'")
            if mask.any():
                df[mask].set_index(['state','city','zip','camp_code'])[[]].sort_index().reset_index().disp(50)

            for c in ['gap_score','t_gap_score','ftic_gap_score']:
                if c not in df:
                    df[c] = pd.NA
            
            df['gap_score'] = np.where(
                df['styp_code']=='n',
                df['ftic_gap_score'].combine_first(df['t_gap_score']).combine_first(df['gap_score']),
                df['t_gap_score'].combine_first(df['ftic_gap_score']).combine_first(df['gap_score']))
            
            
            # df['oriented'] = np.where(df['orien_sess'].notnull() | df['registered'].notnull(), 'y', np.where(df['orientation_hold_exists'].notnull(), 'n', 'w'))
            df['oriented'] = df['orientation_hold_exists'].isnull() | df['orien_sess'].notnull() | df['registered'].notnull()

            # df['verified'] = np.where(df['ver_complete'].notnull(), 'y', np.where(df['selected_for_ver'].notnull(), 'n', 'w'))
            df['verified'] = df['selected_for_ver'].isnull() | df['ver_complete'].notnull()
            
            df['sat10_total_score'] = (36-9) / (1600-590) * (df['sat10_total_score']-590) + 9
            df['act_equiv'] = df[['act_new_comp_score','sat10_total_score']].max(axis=1)

            df['eager'] = (self.stable_date - df['first_date']).dt.days
            df['age'] = (self.stable_date - df['birth_date']).dt.days

            for k in ['reading', 'writing', 'math']:
                df[f'tsi_{k}'] = ~df[k].isin(['not college ready', 'retest required', pd.NA, None, np.nan])
            
            repl = {'ae':0, 'n1':1, 'n2':2, 'n3':3, 'n4':4, 'r1':1, 'r2':2, 'r3':3, 'r4':4}
            df['hs_qrtl'] = pd.cut(df['hs_pctl'], bins=[-1,25,50,75,90,101], labels=[4,3,2,1,0], right=False).combine_first(df['apdc_code'].map(repl))

            df['lgcy'] = ~df['lgcy_code'].isin(['o',pd.NA,None,np.nan])
            df['resd'] = df['resd_code'] == 'r'

            for k in ['waiver_desc','fafsa_app','ssb_last_accessed','finaid_accepted','schlship_app']:
                df[k.split('_')[0]] = df[k].notnull()



            # df['majr_code'] = df['majr_code'].replace({'0000':pd.NA, 'und':pd.NA, 'eled':'eted', 'agri':'unda'})

            # df['coll_code'] = df['coll_code'].replace({'ae':'an', 'eh':'ed', 'hs':'hl', 'st':'sm', '00':pd.NA})

            # df['coll_desc'] = df['coll_code'].map({
            #     'an': 'ag & natural_resources',
            #     'ba': 'business',
            #     'ed': 'education',
            #     'en': 'engineering',
            #     'hl': 'health sciences',
            #     'la': 'liberal & fine arts',
            #     'sm': 'science & mathematics',
            #     pd.NA: 'no college designated',
            # })



            # checks = [
            #     'cycle_day >= 0',
            #     'eager >= cycle_day',
            #     'age >= 5000',
            #     'distance >= 0',
            #     'hs_pctl >=0',
            #     'hs_pctl <= 100',
            #     'hs_qrtl >= 0',
            #     'hs_qrtl <= 4',
            #     'act_equiv >= 1',
            #     'act_equiv <= 36',
            #     'gap_score >= 0',
            #     'gap_score <= 100',
            # ]
            # for check in checks:
            #     mask = df.eval(check)
            #     assert mask.all(), [check,df[~mask].disp(5)]
            mask = df['cycle_date_flags'].isnull()  # rows from admissions not on flags - should not be any
            if mask.any():
                display(df[mask]['styp_code'].value_counts().sort_index().to_frame().T)
            return df.set_index(['pidm'])
        df, new = self.get(fcn, 'students', [self.get_admissions,self.get_flags,self.get_drivetimes])
        return df


    def newest(self, qry, part, sel='0 as temp'):
        """The OPEIR daily snapshot experienced occasional glitched causing incomplete copies.
        Consequently, record can vanished then reappear later. This function fixes this issue."""
        A, B = [indent(s.strip()) for s in qry.rsplit('from',1)]
        qry = f"""
select
    *
from (
    {A}
        ,min(current_date) over (partition by {part}) as first_date
        ,max(current_date) over (partition by {part}) as last_date
        ,least(greatest(timestamp('{self.cycle_date}'), min(current_date) over ()), max(current_date) over ()) as cycle_date
    from
        {B}
    qualify
        cycle_date between first_date and last_date  -- keep records where cycle_date falls between its first & last appearance (+5 days for safety)
    )
where
    current_date <= '{self.cycle_date}'  -- discard records after cycle_date
qualify
    row_number() over (partition by {part} order by current_date desc) = 1  -- keep most recent remaining record
"""

        qry = f"""
select
    pidm
    --,{self.cycle_day} as cycle_day
    ,cycle_date
    ,current_date
    ,first_date
    ,last_date
    ,{get_desc('term_code')}
    ,{get_desc('levl_code')}
    ,{get_desc('styp_code')}
    ,{get_desc('camp_code')}
    ,{get_desc('coll_code_1')}
    ,{get_desc('dept_code')}
    ,{get_desc('majr_code_1')}
    --,gender
    ,spbpers_sex as gender
    ,birth_date
    ,{get_desc('spbpers_lgcy_code')}
    ,gorvisa_vtyp_code is not null as international
    ,gorvisa_natn_code_issue as natn_code, (select stvnatn_nation from {catalog}saturnstvnatn where gorvisa_natn_code_issue = stvnatn_nation limit 1) as natn_desc
    ,{coalesce('race_asian')}
    ,{coalesce('race_black')}
    ,coalesce(spbpers_ethn_cde=2, False) as race_hispanic
    ,{coalesce('race_native')}
    ,{coalesce('race_pacific')}
    ,{coalesce('race_white')}
    ,{indent(join(sel))}
from {subqry(qry)} as A

left join
    {catalog}spbpers_v
on
    pidm = spbpers_pidm

left join (
    select
        *
    from
        {catalog}generalgorvisa
    qualify
        row_number() over (partition by gorvisa_pidm order by gorvisa_seq_no desc) = 1
    )
on
    pidm = gorvisa_pidm

left join (
    select
        gorprac_pidm
        ,max(gorprac_race_cde='AS') as race_asian
        ,max(gorprac_race_cde='BL') as race_black
        ,max(gorprac_race_cde='IN') as race_native
        ,max(gorprac_race_cde='HA') as race_pacific
        ,max(gorprac_race_cde='WH') as race_white
    from
        {catalog}generalgorprac
    group by
        gorprac_pidm
    )
on
    pidm = gorprac_pidm
"""
        return qry


    def get_admissions(self, show=False):
        def fcn():
            qry = self.newest(
                part = 'pidm, appl_no',
                sel = [
                    'appl_no',
                    get_desc('apst_code'),
                    get_desc('apdc_code'),
                    get_desc('admt_code'),
                    get_desc('saradap_resd_code'),
                    'hs_percentile',
                    # 'sbgi_code',
                ],
                qry = f"""
select distinct
    A.*
from
    --dev.opeir.opeiradmissions_{self.term_desc} as A
    dev.opeir.admissions_{self.term_desc}_v as A
inner join
    {catalog}saturnstvapdc as B
on
    apdc_code = stvapdc_code
where
    stvapdc_inst_acc_ind is not null  --only accepted
""")
            # qry = join(L, '\nunion all\n')

#             qry = f"""
# select
#     *
#     ,min(levl_code='UG') over (partition by pidm) as lev
# from {subqry(qry)}
# """

#             qry = f"""
# select
#     *
# from {subqry(qry)}
# where
#     lev
# """

# qualify
#     min(levl_code='UG') over (partition by pidm) = True  -- remove pidm's with graduate admission even it if also has an undergradute admission, min acts like logical and
#     and row_number() over (partition by pidm order by appl_no desc) = 1  -- de-duplicate the few remaining pidms with multiple record by keeping highest appl_no
# """
            df = run(qry, show)
            # df = (
            #     run(qry, show)
            #     .sort_values('appl_no')
            #     .groupby('pidm')
            #     .filter(lambda x: (x['levl_code']=='ug').all())
            # )
            
            
            return df
        df, new = self.get(fcn, 'admissions')
        return df


    def get_registrations(self, show=False):
        def fcn():
            grp = 'pidm, term_code, levl_code, styp_code'
            qry = self.newest(
                part = 'pidm, subj_code, crse_numb',
                sel = [
                    'subj_code || crse_numb as crse_code',
                    'crn',
                    'credit_hr',
                    ],
                qry=f"""
select distinct
    A.*
from
    dev.opeir.opeirregistration_{self.term_desc} as A
""")
            df = run(qry, show).set_index(['crse_code','pidm'])
            return df
        df, new = self.get(fcn, 'registrations')
        return df

# df = pd.concat([Term(term_code=term_code, cycle_date='2024-09-09', overwrite=['admissions']).get_admissions(show=True).copy() for term_code in [202406, 202408]], ignore_index=True)
# df = df[df.groupby('pidm', group_keys=False).apply(lambda x: (x['levl_code']=='ug').all() & (x['appl_no']==x['appl_no'].max()))]#.

# L = {term_code: Term(term_code=term_code, cycle_date='2024-09-09', overwrite=['admissions']).get_admissions(show=True).copy() for term_code in [202406, 202408]}

# self = Term(
#     # cycle_day=0,
#     cycle_date='2024-09-09',
#     term_code=term_code,
#     overwrite=[
#         # 'terms',
#         # 'zips',
#         # 'drivetimes',
#         # 'flags',
#         'admissions',
#         # 'students',
#         # 'registrations',
#     ]
# )
# # self.process_flags()
# # self.get_zips()
# # self.get_drivetimes()
# # self.get_registrations(show=True)
# A = self.get_admissions(show=True)
# # A = self.get_admissions(show=True)
# # self.get_students()


In [0]:
@dataclasses.dataclass
class AMP(Term):
    date: str = ''
    crse_code: str = '_headcnt'
    year: int = 2024
    time_budget: int = 60
    overwrite: set = None

    def __post_init__(self):
        assert len(self.date)==5, "Please specify date using 'mm-dd' format (2 digit month & 2 digit day)"
        self.year = int(self.year)
        self.term_code = 100*self.year+8
        self.cycle_date = f'{self.year}-{self.date}'
        super().__post_init__()
        kwargs = {k: self[k] for k in ['cycle_date','term_code','overwrite']}
        self.current = Term(**kwargs)
        self.summer  = Term(**kwargs | {'term_code': 100*self.year+6})
        self.stable  = Term(**kwargs | {'cycle_day': 0})



        
        
        
        super().__post_init__()
        kwargs = {k: self[k] for k in ['cycle_date','term_code','overwrite']}
        # self.current = Term(**kwargs)
        # self.stable = Term(**kwargs, cycle_day=0)
        # self.idx = 'styp_desc'
        self.idx = 'styp_code'
        self.agg = ['styp_desc','camp_desc']
        self.features = {
            'act_equiv':pd.NA,
            'age':pd.NA,
            'camp_desc':'stephenville',
            'drivetime':pd.NA,
            'eager':pd.NA,
            'fafsa': False,
            'finaid': False,
            'gap_score':0,
            'gender':pd.NA,
            'hs_qrtl':pd.NA,
            'lgcy':False,
            'oriented':False,
            'race_asian':False,
            'race_black':False,
            'race_hispanic':False,
            'race_native':False,
            'race_pacific':False,
            'race_white':False,
            'schlship':False,
            'ssb':False,
            'tsi_math':False,
            'tsi_reading':False,
            'tsi_writing':False,
            'verified':False,
            'waiver':False,
            'credit_hr':0,
            'current':False,
        }
        self.get_prepared()
        # if self.year < max(years):
        #     self.get_multipliers()
        #     self.get_models()


    def get_prepared(self):
        def fcn():
            df = pd.concat([x.get_admissions() for x in [self.summer, self.current]], ignore_index=True)
            df = df[df.groupby('pidm', group_keys=False).apply(lambda x: (x['levl_code']=='ug').all() & (x['appl_no']==x['appl_no'].max()))]

            # df = self.current.get_students()
            try:
                df['credit_hr'] = self.current.get_registrations().loc['_tot_sch','credit_hr']
            except:
                df['credit_hr'] = 0
                # print('credit_hr')
                clear_output()

            try:
                df['current'] = self.current.get_registrations().loc[self.crse_code,'credit_hr']>0
            except:
                df['current'] = False
                # print('current')
                clear_output()

            try:
                df['stable'] = self.stable.get_registrations().loc[self.crse_code,'credit_hr']>0
            except:
                df['stable'] = False
                # print('stable')
                clear_output()

            df = (
                df
                # self.current.get_students()
                # .join(self.current.get_registrations().loc['_tot_sch'  ]['credit_hr'])
                # .join(self.current.get_registrations().loc[self.crse_code]['credit_hr'].rename('current')>0)
                # .join(self.stable .get_registrations().loc[self.crse_code]['credit_hr'].rename('stable' )>0)
                .fillna(self.features|{'stable':False})
                .query("levl_code=='ug' & styp_code in ['n','r','t']")
                .prep()
                .set_index(difference(listify(self.idx)+listify(self.agg), self.features), append=True)
                .sort_index()
            )
            for k,v in df.select_dtypes('string').items():
                df[k] = pd.Categorical(v)

            def fcn1(Z):
                y = Z.pop('stable')
                imp = mf.ImputationKernel(Z[Z.columns.intersection(self.features)].reset_index(drop=True), random_state=self.seed)
                imp.mice(10)
                return imp.complete_data().set_index(Z.index), y
            return {key: fcn1(Z) for key, Z in df.groupby(self.idx)}
        dct, new = self.get(fcn, f'prepared', self.current.get_students, suffix='.pkl')
        return dct


    def get_enrollments(self):
        def fcn():
            df = (
                self.stable.get_registrations()
                .query(f"levl_code=='ug' & styp_code in ['n','t','r']")
                .join(self.current.get_students().assign(admitted=lambda x: x['current_date'].notnull().prep())['admitted'])
                .fillna({'admitted':False})
            )
            return df
        df, new = self.get(fcn, f'enrollments', [self.stable.get_registrations,self.current.get_students])
        return df

    
    def get_multipliers(self):
        def fcn():
            return 1 / self.get_enrollments().groupby(['crse_code',*listify(self.idx)])['admitted'].mean().rename('mlt')
        df, new = self.get(fcn, f'multipliers', self.get_enrollments)
        return df


    def get_models(self):
        def fcn():
            def fcn1(Z):
                dct = {
                    'time_budget':self.time_budget,
                    'task':'classification',
                    'verbose':0,
                    'metric':'log_loss',
                    'eval_method':'cv',
                    'n_splits':3,
                    'seed':self.seed,
                    # 'early_stop':True,
                    'estimator_list': ['xgboost','lgbm','rf'],
                }
                clf = fl.AutoML(**dct)
                clf.fit(*Z, **dct)
                # y = X[[]].join(self.stable.get_registrations().loc[self.crse_code]['credit_hr'].rename('stable')>0).fillna(False)
                # y = X[[]].join(self.stable.get_registrations().loc[self.crse_code]['credit_hr'].rename('stable')).fillna(0)>0
                # y = X.join(self.stable.get_registrations().loc[self.crse_code])['credit_hr'].fillna(0)>0
                
                # ['credit_hr'].rename('stable')).fillna(0)>0
                # print(type(y))
                # y.disp(1)
                # clf.fit(X, y, **dct)
                return clf
            return {key: fcn1(Z) for key, Z in self.get_prepared().items()}
        clf, new = self.get(fcn, f'models', self.get_enrollments, suffix='.pickle')
        return clf


    def get_predictions(self, learners=dict()):
        def fcn():
            dct = {
                'crse_code': self.crse_code,
                'prediction_year': self.year,
                'learner_year': learner.year,
            }

            L = [
                clf.prediction(*self.get_prepared()[key])
                .assign(**dct, mlt=learner.get_multipliers().loc[self.crse_code,*key]['mlt'])
                .set_index(list(dct), append=True)

                # .set_index(list(dct.keys()))

                #  .reset_index()
                # .inser('mlt', learner.get_multipliers().loc[self.crse_code,*key]['mlt'])
                # .inser('learner_year', learner.year)
                # .inser('prediction_year', self.year)
                # .inser('crse_code', self.crse_code)
                for year, learner in learners.items() if 'models' in learner
                for key, clf in learner.get_models().items()]
            return pd.concat(L) if len(L)>0 else pd.DataFrame()
        df, new = self.get(fcn, f'predictions', self.get_prepared)
        return df
    

    def get_forecasts(self):
        def fcn():
            return {key:
                self.get_predictions()
                .assign(forecast=lambda Z: Z['prediction']*Z['mlt'])
                .groupby(unique('crse_code','prediction_year','learner_year',*self.idx,key))
                [['forecast']].sum()
                for key in self.agg}
            # def fcn1(Z):
            #     return (Z['prediction']*Z['mlt']).sum()
            # df = self.get_predictions().copy()
            # df['prediction'] *= df['mlt']
            # return df.groupby(unique('crse_code','prediction_year','learner_year',*self.idx,key))['prediction'].sum()
            # return self.get_predictions().groupby(unique('crse_code','prediction_year','learner_year',*self.idx,key)).apply(fcn1)
        dct, new = self.get(fcn, f'forecasts', suffix='pickle')
        return dct


# years = [2022,2023,2024,2025]
years = [2025,2024]
years = listify(years)
amps = {year: AMP(
        date='03-28',
        crse_code='_headcnt',
        time_budget=10,
        year=year,
        overwrite={
            # 'flags',
            # 'admissions',
            # 'students',
            # 'registrations',
            # 'prepared',
            'enrollments',
            'multipliers',
            'models',
            'predictions',
            'forecasts',
            },
    ) for year in years}

for self in amps.values():
    self.get_predictions(amps)


# F = pd.concat([self.forecasts for self in amp.values()]).sort_index(ascending=[True, False,False])
# F.to_csv(data/f'forecast

# self = amps[2024]
# self.get_predictions()
# self.get_forecasts()
# self.get_admissions()
# self.current.get_registrations(show=True)
self.predictions