In [0]:
from IPython.display import display, HTML, clear_output
try:
    %reload_ext autotime
except:
    %pip install -U ipython-autotime ipywidgets codetiming openpyxl
    %reload_ext autotime
clear_output()
import pathlib, shutil, warnings, dataclasses, numpy as np, pandas as pd
from codetiming import Timer
seed = 42
catalog = 'dev.bronze.saturn'
root = pathlib.Path('/Workspace/Users/scook@tarleton.edu/admitted_matriculation_predictor_2025/')
flags_raw = pathlib.Path('/Volumes/aiml/scook/scook_files/admitted_flags_raw')
flags_prc = pathlib.Path('/Volumes/aiml/flags/flags_volume/')
##########################################
############ helper functions ############
##########################################
tab = '    '
divider = '\n##############################################################################################################\n'
def listify(*args):
    """ensure it is a list"""
    if len(args)==1:
        if args[0] is None or args[0] is np.nan or args[0] is pd.NA:
            return list()
        elif isinstance(args[0], str):
            return [args[0]]
    try:
        return list(*args)
    except Exception as e:
        return list(args)

def rjust(x, width, fillchar=' '):
    return str(x).rjust(width,str(fillchar))

def ljust(x, width, fillchar=' '):
    return str(x).ljust(width,str(fillchar))

def join(lst, sep='\n,', pre='', post=''):
    """flexible way to join list of strings into a single string"""
    return f"{pre}{str(sep).join(map(str,listify(lst)))}{post}"

def alias(dct):
    """convert dict of original column name:new column name into list"""
    return [f'{k} as {v}' for k,v in dct.items()]

def indent(x, lev=1):
    return x.replace('\n','\n'+tab*lev) if lev>0 else x

def subqry(qry, lev=1):
    """make qry into subquery"""
    return "(" + indent('\n'+qry.strip(), lev) + '\n' + indent(")", lev-1)

def run(qry, show=False):
    """run qry and return dataframe"""
    L = qry.split(' ')
    qry = "select * from "+catalog+L[0] if len(L) == 1 else qry
    if show:
        print(qry)
    return spark.sql(qry).toPandas().prep()

############ annoying warnings to suppress ############
[warnings.filterwarnings(action='ignore', message=f".*{w}.*") for w in [
    "Could not infer format, so each element will be parsed individually, falling back to `dateutil`",
    "Engine has switched to 'python' because numexpr does not support extension array dtypes",
    "The default of observed=False is deprecated and will be changed to True in a future version of pandas",
    "errors='ignore' is deprecated and will raise in a future version",
    "The behavior of DataFrame concatenation with empty or all-NA entries is deprecated",
    "The behavior of array concatenation with empty entries is deprecated",
    "DataFrame is highly fragmented",
]]

############ pandas functions ############
pd.options.display.max_columns = None
def disp(df, rows=4, head=True):
    """convenient display method"""
    with pd.option_context('display.min_rows', rows, 'display.max_rows', rows):
        X = df.head(rows) if head else df.tails(rows)
        display(HTML(X.to_html()))
    return df

def to_numeric(df, downcast='integer', errors='ignore', **kwargs):
    """convert to numeric dtypes if possible"""
    return (
        df
        .apply(lambda s: s.astype('string').str.lower().str.strip() if s.dtype in ['object','string'] else s)  # prep strings
        .apply(lambda s: s if pd.api.types.is_datetime64_any_dtype(s) else pd.to_numeric(s, downcast=downcast, errors=errors, **kwargs))  # convert to numeric if possible
        .convert_dtypes()  # convert to new nullable dtypes
    )

def prep(df, **kwargs):
    h = lambda x: x.to_numeric(**kwargs).rename(columns=lambda s: s.lower().strip().replace(' ','_').replace('-','_') if isinstance(s, str) else s)
    idx = h(df[[]].reset_index())  # drop columns, reset_index to move index to columns, then apply g
    return h(df).reset_index(drop=True).set_index(pd.MultiIndex.from_frame(idx))  # set idx back to df's index

def wrap(fcn):
    """Make new methods work for Series and DataFrames"""
    def wrapper(X, *args, **kwargs):
        df = fcn(pd.DataFrame(X), *args, **kwargs)
        return None if df is None else df.squeeze() if isinstance(X, pd.Series) else df  # squeeze to series if input was series
    return wrapper

for f in [disp, to_numeric, prep]:
    """monkey-patch my helpers into Pandas Series & DataFrame classees so we can use df.method syntax"""
    setattr(pd.DataFrame, f.__name__, f)
    setattr(pd.Series, f.__name__, wrap(f))


#########################################
################## AMP ##################
#########################################
@dataclasses.dataclass
class Term():
    term_code: int = 202408
    cycle_day: int = None
    cycle_date: str = None

    #Allows self['attr'] and self.attr syntax
    def __contains__(self, key):
        return hasattr(self, key)
    def __getitem__(self, key):
        return getattr(self, key)
    def __setitem__(self, key, val):
        setattr(self, key, val)
    def __delitem__(self, key):
        if key in self:
            delattr(self, key)


    def __post_init__(self):
        self.term_code = int(self.term_code)
        self.cycle_day, self.cycle_date, self.stem = self.get_cycle(self.term_code, self.cycle_day, self.cycle_date)


    def get(self, fcn, nm, overwrite=False, path=None, divide=True):
        targ = root / f'{nm}/{self.term_code}/{nm}_{self.stem}.parquet' if path is None else pathlib.Path(path).with_suffix('.parquet')
        new = False
        if overwrite:
            del self[nm]
            targ.unlink(missing_ok=True)
        if not nm in self:
            if targ.exists():
                self[nm] = pd.read_parquet(targ)
            else:
                print(f'creating {targ.name}: ', end='')
                new = True
                with Timer():
                    targ.parent.mkdir(parents=True, exist_ok=True)
                    self[nm] = fcn().prep()
                    self[nm].to_parquet(targ)
                if divide:
                    print(divider)
        return self[nm], new

########################################################
################# get term information #################
########################################################
    def get_terms(self, overwrite=False, show=False):
        def fcn():
            qry = f"""
select
    stvterm_code as term_code
    ,replace(stvterm_desc, ' ', '') as term_desc
    ,stvterm_start_date as start_date
    ,stvterm_end_date as end_date
    ,stvterm_fa_proc_yr as fa_proc_yr
    ,stvterm_housing_start_date as housing_start_date
    ,stvterm_housing_end_date as housing_end_date
    ,sobptrm_census_date as census_date
from
    {catalog}stvterm as A
inner join
    {catalog}sobptrm as B
on
    stvterm_code = sobptrm_term_code
where
    sobptrm_ptrm_code='1'
"""
            df = run(qry, show).set_index('term_code')
            df['stable_date'] = df['census_date'].apply(lambda x: x+pd.Timedelta(days=7+4-x.weekday())) # Friday of week following census
            return df
        return self.get(fcn, 'terms', overwrite, path=root / 'terms')


    def get_cycle(self, term_code, cycle_day=None, cycle_date=None):
        stable_date = self.get_terms()[0].loc[int(term_code),'stable_date']
        if cycle_day is None:
            if cycle_date is None:
                cycle_date = pd.Timestamp.now()
            cycle_date = pd.to_datetime(cycle_date).normalize()
            cycle_day = (stable_date - cycle_date).days
        cycle_date = (stable_date - pd.Timedelta(days=cycle_day)).date()
        stem = f'{term_code}_{rjust(cycle_day,3,0)}_{cycle_date}'
        return cycle_day, cycle_date, stem

#######################################################
############ process flags reports archive ############
#######################################################
    def get_spriden(self, show=False):
        # Get id-pidm crosswalk so we can replace id by pidm in flags below
        if 'spriden' not in self:
            qry = f"""
            select distinct
                spriden_id as id,
                spriden_pidm as pidm
            from
                {catalog}spriden as A
            where
                spriden_change_ind is null
                and spriden_activity_date between '2000-09-01' and '2025-09-01'
                and spriden_id REGEXP '^[0-9]+'
            """
            self.spriden = run(qry, show)
        return self.spriden


    def process_flags(self, overwrite=False):
        counter = 0
        for src in sorted(flags_raw.iterdir(), reverse=True):
            counter += 1
            if counter > 5:
                break
            a,b = src.name.lower().split('.')
            if b != 'xlsx' or 'melt' in a or 'admitted' not in a:
                print(a, 'SKIP')
                continue
            # Handles 2 naming conventions that were used at different times
            try:
                cycle_date = pd.to_datetime(a[:10].replace('_','-'))
                multi = True
            except:
                try:
                    cycle_date = pd.to_datetime(a[-6:])
                    multi = False
                except:
                    print(a, 'FAIL')
                    continue
            # print(a, dt.date())
            book = pd.ExcelFile(src, engine='openpyxl')
            # Again, handles the 2 different versions with different sheet names
            if multi:
                sheets = {sheet:sheet for sheet in book.sheet_names if sheet.isnumeric() and int(sheet) % 100 in [1,6,8]}
            else:
                sheets = {a[:6]: book.sheet_names[0]}
            for term_code, sheet in sheets.items():
                cycle_day, cycle_date, stem = self.get_cycle(term_code, cycle_date=cycle_date)
                if cycle_day >= 0:
                    def fcn():
                        df = (
                            self.get_spriden()
                            .assign(cycle_day=cycle_day, cycle_date=cycle_date)
                            .merge(book.parse(sheet).prep(), on='id', how='right')
                            .drop(columns=['id','last_name','first_name','mi','pref_fname','street1','street2','primary_phone','call_em_all','email'], errors='ignore')
                        )
                        return df
                    if self.get(fcn, 'flags', overwrite, path=flags_prc / f'{term_code}/flags_{stem}', divide=False)[1]:
                        counter = 0
                    del self['flags']
        print(divider)

########################################################
############### get course registrations ###############
########################################################
    def get_registrations(self, overwrite=False, show=False):
        def fcn():
            dct = {
                'sfrstcr_pidm':'pidm',
                'ssbsect_term_code':'term_code',
                'sgbstdn_levl_code':'levl_code',
                'sgbstdn_styp_code':'styp_code',
                'ssbsect_crn':'crn',
            }
            qry = f"""
select
    {indent(join(alias(dct)))}
    ,lower(ssbsect_subj_code) || ssbsect_crse_numb as crse_code
    ,max(ssbsect_credit_hrs) as credit_hr
from
    {catalog}sfrstcr as A
inner join
    {catalog}ssbsect as B
on
    sfrstcr_term_code = ssbsect_term_code
    and sfrstcr_crn = ssbsect_crn
inner join (
    select
        *
    from
        {catalog}sgbstdn
    where
        sgbstdn_term_code_eff <= 202408
    qualify
        row_number() over (partition by sgbstdn_pidm order by sgbstdn_term_code_eff desc) = 1
    ) as C
on
    sfrstcr_pidm = sgbstdn_pidm
where
    sfrstcr_term_code = {self.term_code}
    and sfrstcr_ptrm_code not in ('28','R3') -- drop weird parts of term
    and sfrstcr_add_date <= '{self.cycle_date}' -- added before cycle_day
    and (sfrstcr_rsts_date > '{self.cycle_date}' or sfrstcr_rsts_code in ('DC','DL','RD','RE','RW','WD','WF')) -- dropped after cycle_day or still enrolled
    and ssbsect_subj_code <> 'INST' -- exceptional sections
group by
    {indent(join(dct.keys()))}
    ,ssbsect_subj_code
    ,ssbsect_crse_numb
"""

            qry = f"""
with A as {subqry(qry)}
select * from A
union all
select
    {indent(join(dct.values()))}
    ,'_allcrse' as crse_code
    ,sum(credit_hr) as credit_hr
from A
group by
    {indent(join(dct.values()))}
union all
select
    {indent(join(dct.values()))}
    ,'_anycrse' as crse_code
    ,case when sum(credit_hr) > 0 then 1 else 0 end as credit_hr
from A
group by
    {indent(join(dct.values()))}
"""
            df = run(qry, show)
            return df
        return self.get(fcn, 'registrations', overwrite)

##############################################
############### get admissions ###############
##############################################
    def get_admissions(self, overwrite=False, show=False):
        def fcn():
            qry = f"""
select
    saradap_pidm as pidm
    ,{self.term_code} as term_code
    ,saradap_levl_code as levl_code
    ,saradap_styp_code as styp_code
from 
    {catalog}saradap as A
inner join (
    select
        *
    from
        {catalog}sarappd as A
    where
        sarappd_term_code_entry in ({self.term_code-2}, {self.term_code})  -- consider both summer and fall admits since summer admits will be in fall cohort
        and sarappd_apdc_date <= '{self.cycle_date}' -- decision made before cycle_date
    qualify
        row_number() over (partition by sarappd_pidm, sarappd_term_code_entry, sarappd_appl_no order by sarappd_seq_no desc) = 1  -- most current decision
) as B
on
    sarappd_pidm = saradap_pidm
    and sarappd_term_code_entry = saradap_term_code_entry
    and sarappd_appl_no = saradap_appl_no
inner join 
    {catalog}stvapdc as C
on
    sarappd_apdc_code = stvapdc_code
where
    stvapdc_inst_acc_ind is not null
qualify
    max(saradap_levl_code<>'UG') over (partition by saradap_pidm) = False
    and row_number() over (partition by saradap_pidm order by saradap_appl_no desc) = 1  -- most current application
"""
            # incomplete
            # this extracts the uniquue list undergraduates admitted before cycle_date
            # now we need to join their data from flags & other tables
            df = run(qry, show)
            return df
        return self.get(fcn, 'admissions', overwrite)

####################################
############### main ###############
####################################

self = Term(
    term_code=202408,
    cycle_day=50,
)
# self.get_terms(overwrite=True)
# self.process_flags()
self.get_registrations(
    overwrite=True,
    show=True,
)
self.get_admissions(
    overwrite=True,
    show=True,
)
;