In [0]:
username = 'scook'
from IPython.display import display, HTML, clear_output
try:
    %reload_ext autotime
except:
    %pip install -U ipython-autotime ipywidgets codetiming openpyxl
    %reload_ext autotime
clear_output()

import pathlib, shutil, warnings, dataclasses, numpy as np, pandas as pd
from codetiming import Timer
seed = 42
catalog = 'dev.bronze.'
root = pathlib.Path(f'/Workspace/Users/{username}@tarleton.edu/admitted_matriculation_predictor_2025/')
flags_raw = pathlib.Path('/Volumes/aiml/scook/scook_files/admitted_flags_raw')
flags_prc = pathlib.Path('/Volumes/aiml/flags/flags_volume/')

##########################################
############ helper functions ############
##########################################
tab = '    '
divider = '\n##############################################################################################################\n'

def listify(*args):
    """ensure it is a list"""
    if len(args)==1:
        if args[0] is None or args[0] is np.nan or args[0] is pd.NA:
            return list()
        elif isinstance(args[0], str):
            return [args[0]]
    try:
        return list(*args)
    except Exception as e:
        return list(args)

def rjust(x, width, fillchar=' '):
    return str(x).rjust(width,str(fillchar))

def ljust(x, width, fillchar=' '):
    return str(x).ljust(width,str(fillchar))

def join(lst, sep='\n,', pre='', post=''):
    """flexible way to join list of strings into a single string"""
    return f"{pre}{str(sep).join(map(str,listify(lst)))}{post}"

def alias(dct):
    """convert dict of original column name:new column name into list"""
    return [f'{k} as {v}' for k,v in dct.items()]

def indent(x, lev=1):
    return x.replace('\n','\n'+tab*lev) if lev>0 else x

def subqry(qry, lev=1):
    """make qry into subquery"""
    return "(" + indent('\n'+qry.strip()+'\n)', lev)

def run(qry, show=False, sample='10 rows', seed=seed):
    """run qry and return dataframe"""
    L = qry.split(' ')
    if len(L) == 1:
        qry = f'select * from {catalog}{L[0]}'
        if sample is not None:
            qry += f' tablesample ({sample}) repeatable ({seed})'
    if show:
        print(qry)
    return spark.sql(qry).toPandas().prep()

def rm(path, root=False):
    path = pathlib.Path(path)
    if path.is_file():
        path.unlink()
    elif path.is_dir():
        if root:
            shutil.rmtree(path)
        else:
            for p in path.iterdir():
                rm(p, True)
    return path

############ annoying warnings to suppress ############
# [warnings.filterwarnings(action='ignore', message=f".*{w}.*") for w in [
#     "Could not infer format, so each element will be parsed individually, falling back to `dateutil`",
#     "Engine has switched to 'python' because numexpr does not support extension array dtypes",
#     "The default of observed=False is deprecated and will be changed to True in a future version of pandas",
#     "errors='ignore' is deprecated"
#     "The behavior of DataFrame concatenation with empty or all-NA entries is deprecated",
#     "The behavior of array concatenation with empty entries is deprecated",
#     "DataFrame is highly fragmented",
# ]]

############ pandas functions ############
pd.options.display.max_columns = None
def disp(df, rows=4, head=True):
    """convenient display method"""
    from IPython.display import display, HTML
    with pd.option_context('display.min_rows', rows, 'display.max_rows', rows):
        X = df.head(rows) if head else df.tails(rows)
        display(HTML(X.to_html()))
        print(df.shape)

def to_numeric(df, downcast='integer', errors='ignore', **kwargs):
    """convert to numeric dtypes if possible"""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        return (
            df
            .apply(lambda s: s.astype('string').str.lower().str.strip() if s.dtype in ['object','string'] else s)  # prep strings
            .apply(lambda s: s if pd.api.types.is_datetime64_any_dtype(s) else pd.to_numeric(s, downcast=downcast, errors=errors, **kwargs))  # convert to numeric if possible
            .convert_dtypes()  # convert to new nullable dtypes
        )

def prep(df, **kwargs):
    h = lambda x: x.to_numeric(**kwargs).rename(columns=lambda s: s.lower().strip().replace(' ','_').replace('-','_') if isinstance(s, str) else s)
    idx = h(df[[]].reset_index())  # drop columns, reset_index to move index to columns, then apply g
    return h(df).reset_index(drop=True).set_index(pd.MultiIndex.from_frame(idx))  # set idx back to df's index

def wrap(fcn):
    """Make new methods work for Series and DataFrames"""
    def wrapper(X, *args, **kwargs):
        df = fcn(pd.DataFrame(X), *args, **kwargs)
        return None if df is None else df.squeeze() if isinstance(X, pd.Series) else df  # squeeze to series if input was series
    return wrapper

for f in [disp, to_numeric, prep]:
    """monkey-patch my helpers into Pandas Series & DataFrame classees so we can use df.method syntax"""
    setattr(pd.DataFrame, f.__name__, f)
    setattr(pd.Series, f.__name__, wrap(f))


#########################################
################## AMP ##################
#########################################
@dataclasses.dataclass
class Term():
    term_code: int = 202408
    cycle_day: int = None
    cycle_date: str = None

    #Allows self['attr'] and self.attr syntax
    def __contains__(self, key):
        return hasattr(self, key)
    def __getitem__(self, key):
        return getattr(self, key)
    def __setitem__(self, key, val):
        setattr(self, key, val)
    def __delitem__(self, key):
        if key in self:
            delattr(self, key)

    def __post_init__(self):
        self.get_terms()

    def get(self, fcn, nm, overwrite=False, path=None, divide=True):
        dst = root/f'{nm}/{self.term_code}/{nm}_{self.stem}.parquet' if path is None else pathlib.Path(path).with_suffix('.parquet')
        new = False
        if overwrite:
            del self[nm]
            dst.unlink(missing_ok=True)
        if not nm in self:
            if dst.exists():
                self[nm] = pd.read_parquet(dst)
            else:
                print(f'creating {dst}: ', end='')
                new = True
                with Timer():
                    dst.parent.mkdir(parents=True, exist_ok=True)
                    self[nm] = fcn().prep()
                    self[nm].to_parquet(dst)
                if divide:
                    print(divider)
        return self[nm], new

########################################################
################# get term information #################
########################################################
    def get_terms(self, overwrite=False, show=False):
        self.term_code = int(self.term_code)
        def fcn():
            qry = f"""
select
    stvterm_code as term_code
    ,replace(stvterm_desc, ' ', '') as term_desc
    ,stvterm_start_date as start_date
    ,stvterm_end_date as end_date
    ,stvterm_fa_proc_yr as fa_proc_yr
    ,stvterm_housing_start_date as housing_start_date
    ,stvterm_housing_end_date as housing_end_date
    ,sobptrm_census_date as census_date
from
    {catalog}saturnstvterm as A
inner join
    {catalog}saturnsobptrm as B
on
    stvterm_code = sobptrm_term_code
where
    sobptrm_ptrm_code='1'
"""
            df = run(qry, show).set_index('term_code')
            df['stable_date'] = df['census_date'].apply(lambda x: x+pd.Timedelta(days=7+4-x.weekday())) # Friday of week following census
            return df
        self.get(fcn, 'terms', overwrite, path=root/'terms')
        self.cycle_day, self.cycle_date, self.stem = self.get_cycle(self.term_code, self.cycle_day, self.cycle_date)


    def get_cycle(self, term_code, cycle_day=None, cycle_date=None):
        stable_date = self.terms.loc[int(term_code),'stable_date']
        if cycle_day is None:
            if cycle_date is None:
                cycle_date = pd.Timestamp.now()
            cycle_date = pd.to_datetime(cycle_date).normalize()
            cycle_day = (stable_date - cycle_date).days
        cycle_date = str((stable_date - pd.Timedelta(days=cycle_day)).date())
        stem = f'{term_code}_{cycle_date}_{"-" if cycle_day < 0 else "+"}{rjust(abs(cycle_day),3,0)}'
        return cycle_day, cycle_date, stem

#######################################################
############ process flags reports archive ############
#######################################################
    def get_spriden(self, overwrite=False, show=False):
        # Get id-pidm crosswalk so we can replace id by pidm in flags below
        if 'spriden' not in self:
            qry = f"""
            select distinct
                spriden_id as id,
                spriden_pidm as pidm
            from
                {catalog}saturnspriden as A
            where
                spriden_change_ind is null
                and spriden_activity_date between '2000-09-01' and '2025-09-01'
                and spriden_id REGEXP '^[0-9]+'
            """
            self.spriden = run(qry, show)
        return self.spriden


    def process_flags(self, overwrite=False, show=False):
        counter = 0
        for src in sorted(flags_raw.iterdir(), reverse=True):
            counter += 1
            if counter > 5:
                break
            a,b = src.name.lower().split('.')
            if b != 'xlsx' or 'melt' in a or 'admitted' not in a:
                print(a, 'SKIP')
                continue
            # Handles 2 naming conventions that were used at different times
            try:
                cycle_date = pd.to_datetime(a[:10].replace('_','-'))
                multi = True
            except:
                try:
                    cycle_date = pd.to_datetime(a[-6:])
                    multi = False
                except:
                    print(a, 'FAIL')
                    continue
            # print(a, dt.date())
            book = pd.ExcelFile(src, engine='openpyxl')
            # Again, handles the 2 different versions with different sheet names
            if multi:
                sheets = {sheet:sheet for sheet in book.sheet_names if sheet.isnumeric() and int(sheet) % 100 in [1,6,8]}
            else:
                sheets = {a[:6]: book.sheet_names[0]}
            for term_code, sheet in sheets.items():
                cycle_day, cycle_date, stem = self.get_cycle(term_code, cycle_date=cycle_date)
                def fcn():
                    df = (
                        self.get_spriden()
                        .assign(cycle_day=cycle_day, cycle_date=cycle_date)
                        .merge(book.parse(sheet).prep(), on='id', how='right')
                        .drop(columns=['id','last_name','first_name','mi','pref_fname','street1','street2','primary_phone','call_em_all','email'], errors='ignore')
                    )
                    return df
                if self.get(fcn, 'flags', overwrite, path=flags_prc/f'{term_code}/flags_{stem}', divide=False)[1]:
                    counter = 0
                del self['flags']
        print(divider)


    def get_flags_history(self, overwrite=False, show=False, cutoff=202206):
        def fcn():
            import pyarrow.parquet as pq
            print()
            L = []
            for path in sorted(flags_prc.iterdir(), reverse=True):
                print(path)
                # if int(path.stem) < 202506:
                #     break
                for src in path.iterdir():
                    _, term_code, cycle_date, cycle_day = src.stem.split('_')
                    col = pq.ParquetFile(src).schema.names
                    df = pd.DataFrame(columns=col).assign(term_code=[int(term_code)], cycle_date=[cycle_date]).fillna(True)
                    L.append(df)
            df = pd.concat(L).fillna(False).set_index(['cycle_date','term_code']).sort_index()
            return df[sorted(df.columns)]
        df, new = self.get(fcn, 'flags_history', overwrite, path=root/'flags_history')
        A = df.query(f'term_code>={cutoff}').groupby('term_code').sum().sort_index(ascending=False).T.rename_axis('variable')
        B = A == A.max()
        B.insert(0, 'n', B.sum(axis=1))
        return B.reset_index().sort_values(['n', 'variable'], ascending=[False, True])


    def get_flags(self, overwrite=False, show=False):
        def fcn():
            L = []
            for term_code in [self.term_code-2,self.term_code]:  # summer & fall
                for src in sorted((flags_prc/f'{term_code}').iterdir()):
                    if src.stem.split('_')[2] < self.cycle_date:  # find first flags before cycle_date
                        L.append(pd.read_parquet(src))
                        break
            df = pd.concat(L, ignore_index=True)
            for k in ['dob',*df.filter(like='date').columns]:  # convert date columns
                df[k] = pd.to_datetime(df[k], errors='coerce')
            return df
        self.get(fcn, 'flags', overwrite)

########################################################
############### get course registrations ###############
########################################################
    def get_registrations(self, overwrite=False, show=False):
        def fcn():
            dct = {
                'sfrstcr_pidm':'pidm',
                'ssbsect_term_code':'term_code',
                'sgbstdn_levl_code':'levl_code',
                'sgbstdn_styp_code':'styp_code',
                'ssbsect_crn':'crn',
            }
            qry = f"""
select
    {indent(join(alias(dct)))}
    ,lower(ssbsect_subj_code) || ssbsect_crse_numb as crse_code
    ,max(ssbsect_credit_hrs) as credit_hr
from
    {catalog}saturnsfrstcr as A
inner join
    {catalog}saturnssbsect as B
on
    sfrstcr_term_code = ssbsect_term_code
    and sfrstcr_crn = ssbsect_crn
inner join (
    select
        *
    from
        {catalog}sgbstdn_amp_v
    where
        sgbstdn_term_code_eff <= {self.term_code}
    qualify
        row_number() over (partition by sgbstdn_pidm order by sgbstdn_term_code_eff desc) = 1
    ) as C
on
    sfrstcr_pidm = sgbstdn_pidm
where
    sfrstcr_term_code = {self.term_code}
    and sfrstcr_ptrm_code not in ('28','R3') -- drop weird parts of term
    and sfrstcr_add_date <= '{self.cycle_date}' -- added before cycle_day
    and (sfrstcr_rsts_date > '{self.cycle_date}' or sfrstcr_rsts_code in ('DC','DL','RD','RE','RW','WD','WF')) -- dropped after cycle_day or still enrolled
    and ssbsect_subj_code <> 'INST' -- exceptional sections
group by
    {indent(join(dct.keys()))}
    ,ssbsect_subj_code
    ,ssbsect_crse_numb
"""

            qry = f"""
with A as {subqry(qry)}
select * from A

union all

select
    {indent(join(dct.values()))}
    ,'_allcrse' as crse_code
    ,sum(credit_hr) as credit_hr
from A
group by
    {indent(join(dct.values()))}

union all

select
    {indent(join(dct.values()))}
    ,'_anycrse' as crse_code
    ,case when sum(credit_hr) > 0 then 1 else 0 end as credit_hr
from A
group by
    {indent(join(dct.values()))}
"""
            df = run(qry, show)
            return df
        self.get(fcn, 'registrations', overwrite)

##############################################
############### get admissions ###############
##############################################
    def get_admissions(self, overwrite=False, show=False):
        def fcn():
            L = [f"""
select
    *
from (
    select distinct
        A.*
        ,min(current_date) over (partition by pidm, appl_no) as first_date
        ,max(current_date) over (partition by pidm, appl_no) as final_date
    from
        dev.opeir.opeiradmissions_{self.terms.loc[term_code,'term_desc']} as A
        --dev.opeir.admissions_{self.terms.loc[term_code,'term_desc']}_v as A
    inner join
        {catalog}saturnstvapdc as B
    on
        apdc_code = stvapdc_code
    where
        stvapdc_inst_acc_ind is not null  --only accepted
    qualify
        '{self.cycle_date}' between first_date and final_date --keep only pidm, appl_no where cycle_date falls between its first and last records
    )
where
    current_date <= '{self.cycle_date}'  -- only records before cycle_date
qualify
    row_number() over (partition by pidm, appl_no order by current_date desc) = 1  -- most current record remaiming for this pidm, appl_no
"""
                for term_code in [self.term_code-2, self.term_code]]
            qry = join(L, '\nunion all\n')


            stat_codes = ['AL','AR','AZ','CA','CO','CT','DC','DE','FL','GA','IA','ID','IL','IN','KS','KY','LA','MA','MD','ME','MI','MN','MO','MS','MT','NC','ND','NE','NH','NJ','NM','NV','NY','OH','OK','OR','PA','RI','SC','SD','TN','TX','UT','VA','VT','WA','WI','WV','WY'] # not AK & HI b/c can't get driving distance

            def get_desc(code):
                for nm in code.split('_'):
                    if len(nm) == 4:
                        break
                return [f'{code} as {nm}_code, stv{nm}_desc as {nm}_desc', f'left join {catalog}saturnstv{nm} on {code} = stv{nm}_code']

            qry = f"""
select
    pidm
    ,{self.cycle_day} as cycle_day
    ,{self.cycle_date} as cycle_date
    ,current_date
    ,first_date
    ,final_date
    ,{get_desc('term_code')[0]}
    ,appl_no
    ,{get_desc('apst_code')[0]}
    ,{get_desc('apdc_code')[0]}
    ,{get_desc('admt_code')[0]}
    ,{get_desc('wrsn_code')[0]}
    ,{get_desc('levl_code')[0]}
    ,{get_desc('styp_code')[0]}    
    ,{get_desc('camp_code')[0]}
    ,{get_desc('coll_code_1')[0]}
    ,{get_desc('dept_code')[0]}
    ,{get_desc('majr_code_1')[0]}
    ,{get_desc('saradap_resd_code')[0]}
    ,gender
    ,birth_date
    ,{get_desc('spbpers_lgcy_code')[0]}
    ,gorvisa_vtyp_code is not null as international
    ,gorvisa_natn_code_issue as natn_code, stvnatn_nation as natn_desc
    ,spbpers_ethn_cde=2 as race_hispanic
    ,race_asian
    ,race_black
    ,race_native
    ,race_pacific
    ,race_white
    ,hs_percentile
    ,sbgi_code
    ,enrolled_ind

from {subqry(qry)} as A

left join
    {catalog}spbpers_amp_v
on
    pidm = spbpers_pidm

left join (
    select
        *
    from
        {catalog}generalgorvisa
    qualify
        row_number() over (partition by gorvisa_pidm order by gorvisa_seq_no desc) = 1
    )
on
    pidm = gorvisa_pidm

left join (
    select
        gorprac_pidm
        ,max(gorprac_race_cde='AS') as race_asian
        ,max(gorprac_race_cde='BL') as race_black
        ,max(gorprac_race_cde='IN') as race_native
        ,max(gorprac_race_cde='HA') as race_pacific
        ,max(gorprac_race_cde='WH') as race_white
    from
        {catalog}generalgorprac
    group by
        gorprac_pidm
    )
on
    pidm = gorprac_pidm

{get_desc('term_code')[1]}
{get_desc('levl_code')[1]}
{get_desc('styp_code')[1]}
{get_desc('admt_code')[1]}
{get_desc('wrsn_code')[1]}
{get_desc('apst_code')[1]}
{get_desc('apdc_code')[1]}
{get_desc('camp_code')[1]}
{get_desc('coll_code_1')[1]}
{get_desc('dept_code')[1]}
{get_desc('majr_code_1')[1]}
{get_desc('saradap_resd_code')[1]}
{get_desc('gorvisa_natn_code_issue')[1]}
{get_desc('spbpers_lgcy_code')[1]}

qualify
    min(levl_code='UG') over (partition by pidm) = True  -- remove pidm's with graduate admission even it if also has an undergradute admission, min acts like logical and
    and row_number() over (partition by pidm order by appl_no desc) = 1  -- de-duplicate the few remaining pidms with multiple record by keeping highest appl_no
"""
# don't delete - could be useful & was hard to create
#     ,{get_desc('spraddr_cnty_code')[0]}
#     ,{get_desc('spraddr_stat_code')[0]}
#     ,zip_code

# left join (
#     select
#         *
#         ,try_to_number(left(spraddr_zip, 5), '00000') as zip_code
#         ,case
#             when spraddr_atyp_code = 'PA' then 6
#             when spraddr_atyp_code = 'PR' then 5
#             when spraddr_atyp_code = 'MA' then 4
#             when spraddr_atyp_code = 'BU' then 3
#             when spraddr_atyp_code = 'BI' then 2
#             when spraddr_atyp_code = 'P1' then 1
#             when spraddr_atyp_code = 'P2' then 0
#             end as spraddr_atyp_rank
#     from
#         {catalog}spraddr_amp_v
#     where
#         spraddr_stat_code in ('{join(stat_codes, "','")}')
#         and spraddr_zip is not null
#     qualify
#         row_number() over (partition by spraddr_pidm order by spraddr_atyp_rank desc, spraddr_seqno desc) = 1
# )
# on
#     pidm = spraddr_pidm

# {get_desc('spraddr_cnty_code')[1]}
# {get_desc('spraddr_stat_code')[1]}

            df = run(qry, show)
            return df 
        self.get(fcn, 'admissions', overwrite)

####################################
############### main ###############
####################################
self = Term(
    term_code=202308,
    cycle_day=50,
)

# GA's should not have permissions to run this because it can see pii
# self.process_flags(
#     # overwrite=True,
#     # show=True,
# )

H = self.get_flags_history(
    # overwrite=True,
    # show=True,
    cutoff=202206,
)
H.disp(100)

# self.get_terms(
#     overwrite=True,
#     show=True,
# )

# self.get_flags(
#     overwrite=True,
#     show=True,
# )

# self.get_registrations(
#     overwrite=True,
#     show=True,
# )

self.get_admissions(
    overwrite=True,
    show=True,
)