In [0]:
from IPython.display import display, HTML, clear_output
try:
    %reload_ext autotime
except:
    %pip install -U ipython-autotime ipywidgets codetiming openpyxl
    %reload_ext autotime
clear_output()
import pathlib, shutil, warnings, dataclasses, numpy as np, pandas as pd
from codetiming import Timer
seed = 42
catalog = 'dev.bronze.saturn'
root = pathlib.Path('/Workspace/Users/scook@tarleton.edu/admitted_matriculation_predictor_2025/')

##########################################
############ helper functions ############
##########################################
tab = '    '
def listify(*args):
    """ensure it is a list"""
    if len(args)==1:
        if args[0] is None or args[0] is np.nan or args[0] is pd.NA:
            return list()
        elif isinstance(args[0], str):
            return [args[0]]
    try:
        return list(*args)
    except Exception as e:
        return list(args)

def rjust(x, width, fillchar=' '):
    return str(x).rjust(width,str(fillchar))

def ljust(x, width, fillchar=' '):
    return str(x).ljust(width,str(fillchar))

def join(lst, sep='\n,', pre='', post=''):
    """flexible way to join list of strings into a single string"""
    return f"{pre}{str(sep).join(map(str,listify(lst)))}{post}"

def alias(dct):
    """convert dict of original column name:new column name into list"""
    return [f'{k} as {v}' for k,v in dct.items()]

def indent(x, lev=1):
    return x.replace('\n','\n'+tab*lev) if lev>=0 else qry

def subqry(qry, lev=1):
    """make qry into subquery"""
    return "(" + indent('\n'+qry.strip(), lev) + indent(")", lev-1)

def run(qry, show=False):
    """run qry and return dataframe"""
    L = qry.split(' ')
    qry = "select * from "+L[0] if len(L) == 1 else qry
    if show:
        print(qry)
    return spark.sql(qry).toPandas().prep()

############ annoying warnings to suppress ############
[warnings.filterwarnings(action='ignore', message=f".*{w}.*") for w in [
    "Could not infer format, so each element will be parsed individually, falling back to `dateutil`",
    "Engine has switched to 'python' because numexpr does not support extension array dtypes",
    "The default of observed=False is deprecated and will be changed to True in a future version of pandas",
    "errors='ignore' is deprecated and will raise in a future version",
    "The behavior of DataFrame concatenation with empty or all-NA entries is deprecated",
    "The behavior of array concatenation with empty entries is deprecated",
    "DataFrame is highly fragmented",
]]

############ pandas functions ############
pd.options.display.max_columns = None
def disp(df, rows=4, head=True):
    """convenient display method"""
    with pd.option_context('display.min_rows', rows, 'display.max_rows', rows):
        X = df.head(rows) if head else df.tails(rows)
        display(HTML(X.to_html()))
    return df

def to_numeric(df, downcast='integer', errors='ignore', **kwargs):
    """convert to numeric dtypes if possible"""
    return (
        df
        .apply(lambda s: s.astype('string').str.lower().str.strip() if s.dtype in ['object','string'] else s)  # prep strings
        .apply(lambda s: s if pd.api.types.is_datetime64_any_dtype(s) else pd.to_numeric(s, downcast=downcast, errors=errors, **kwargs))  # convert to numeric if possible
        .convert_dtypes()  # convert to new nullable dtypes
    )

def prep(df, **kwargs):
    h = lambda x: x.to_numeric(**kwargs).rename(columns=lambda s: s.lower().strip().replace(' ','_').replace('-','_') if isinstance(s, str) else s)
    idx = h(df[[]].reset_index())  # drop columns, reset_index to move index to columns, then apply g
    return h(df).reset_index(drop=True).set_index(pd.MultiIndex.from_frame(idx))  # set idx back to df's index

def wrap(fcn):
    """Make new methods work for Series and DataFrames"""
    def wrapper(X, *args, **kwargs):
        df = fcn(pd.DataFrame(X), *args, **kwargs)
        return None if df is None else df.squeeze() if isinstance(X, pd.Series) else df  # squeeze to series if input was series
    return wrapper

for f in [disp, to_numeric, prep]:
    """monkey-patch my helpers into Pandas Series & DataFrame classees so we can use df.method syntax"""
    setattr(pd.DataFrame, f.__name__, f)
    setattr(pd.Series, f.__name__, wrap(f))


@dataclasses.dataclass
class MyBaseClass():
    """Lets us access object attributes using self.attr or self['attr']"""
    def __contains__(self, key):
        return hasattr(self, key)
    def __getitem__(self, key):
        return getattr(self, key)
    def __setitem__(self, key, val):
        setattr(self, key, val)
    def __delitem__(self, key):
        if key in self:
            delattr(self, key)

    def get(self, fcn, nm, parq, overwrite=False):
        parq = parq.with_suffix('.parquet')
        new = False
        if overwrite:
            del self[nm]
            parq.unlink(missing_ok=True)
        if not nm in self:
            if parq.exists():
                self[nm] = pd.read_parquet(parq)
            else:
                print(f'creating {parq}')
                new = True
                with Timer():
                    parq.parent.mkdir(parents=True, exist_ok=True)
                    self[nm] = fcn().prep()
                    self[nm].to_parquet(parq)
        return self[nm], new

#######################################################
############ process flags reports archive ############
#######################################################

@dataclasses.dataclass
class Flags(MyBaseClass):
    def get_spriden(self):
        # Get id-pidm crosswalk so we can replace id by pidm in flags below
        if 'spriden' not in self:
            qry = f"""
            select distinct
                spriden_id as id,
                spriden_pidm as pidm
            from
                {catalog}spriden
            where
                spriden_change_ind is null
                and spriden_activity_date between '2000-09-01' and '2025-09-01'
                and spriden_id REGEXP '^[0-9]+'
            """
            self.spriden = spark.sql(qry).toPandas().prep()
        return self.spriden


    def process_flags(self, overwrite=False):
        source = pathlib.Path('/Volumes/aiml/scook/scook_files/admitted_flags_raw')
        target = pathlib.Path('/Volumes/aiml/flags/flags_volume')
        counter = 0
        for src in sorted(source.iterdir(), reverse=True):
            counter += 1
            if counter > 5:
                break
            a,b = src.name.lower().split('.')
            if b != 'xlsx' or 'melt' in a or 'admitted' not in a:
                print(a, 'SKIP')
                continue
            # Handles 2 naming conventions that were used at different times
            try:
                dt = pd.to_datetime(a[:10].replace('_','-'))
                multi = True
            except:
                try:
                    dt = pd.to_datetime(a[-6:])
                    multi = False
                except:
                    print(a, 'FAIL')
                    continue
            print(a, dt.date())
            book = pd.ExcelFile(src, engine='openpyxl')
            # Again, handles the 2 different versions with different sheet names
            if multi:
                sheets = {sheet:sheet for sheet in book.sheet_names if sheet.isnumeric() and int(sheet) % 100 in [1,6,8]}
            else:
                sheets = {a[:6]: book.sheet_names[0]}
            for term_code, sheet in sheets.items():
                def fcn():
                    df = (
                        self.get_spriden()
                        .assign(current_date=dt)
                        .merge(book.parse(sheet).prep(), on='id', how='right')
                        .drop(columns=['id','last_name','first_name','mi','pref_fname','street1','street2','primary_phone','call_em_all','email'], errors='ignore')
                    )
                    return df
                if self.get(fcn, 'flags', target / f'term_code/flags_{term_code}_{dt.date()}', overwrite)[1]:
                    counter = 0


@dataclasses.dataclass
class Term(Flags):
    term_code: int = 202408
    cycle_day: int = None
    cycle_date: str = None

    def get_terms(self, overwrite=False, show=False):
        def fcn():
            qry = f"""
select
    A.stvterm_code as term_code
    ,replace(A.stvterm_desc, ' ', '') as term_desc
    ,A.stvterm_start_date as start_date
    ,A.stvterm_end_date as end_date
    ,A.stvterm_fa_proc_yr as fa_proc_yr
    ,A.stvterm_housing_start_date as housing_start_date
    ,A.stvterm_housing_end_date as housing_end_date
    ,B.sobptrm_census_date as census_date
from
    {catalog}stvterm A
    ,{catalog}sobptrm B
where
    A.stvterm_code = B.sobptrm_term_code
    and B.sobptrm_ptrm_code='1'
"""
            df = run(qry, show).set_index('term_code')
            df['stable_date'] = df['census_date'].apply(lambda x: x+pd.Timedelta(days=7+4-x.weekday())) # Friday of week following census
            return df
        df, new = self.get(fcn, 'terms', root / 'terms', overwrite)

        self.stable_date = df.loc[self.term_code,'stable_date']
        if self.cycle_day is None:
            if self.cycle_date is None:
                self.cycle_date = pd.Timestamp.now()
            self.cycle_date = pd.to_datetime(self.cycle_date).normalize()
            self.cycle_day = (self.stable_date - self.cycle_date).days
        self.cycle_date = (self.stable_date - pd.Timedelta(days=self.cycle_day)).date()
        self.stem = f'{self.term_code}/{rjust(self.cycle_day,3,0)}'
        return df, new


    def get_registrations(self, overwrite=False, show=False):
        def fcn():
            dct = {
                'A.sfrstcr_pidm':'pidm',
                'C.sgbstdn_levl_code':'levl_code',
                'C.sgbstdn_styp_code':'styp_code',
                'B.ssbsect_term_code':'term_code',
                'B.ssbsect_crn':'crn',
            }
            qry = f"""
select
    {indent(join(alias(dct)))}
    ,lower(B.ssbsect_subj_code) || B.ssbsect_crse_numb as crse_code
    ,max(B.ssbsect_credit_hrs) as credit_hr
from
    dev.bronze.saturnsfrstcr as A
inner join
    dev.bronze.saturnssbsect as B
on
    A.sfrstcr_term_code = B.ssbsect_term_code
    and A.sfrstcr_crn = B.ssbsect_crn
inner join (
    select
        *
    from
        dev.bronze.saturnsgbstdn
    where
        sgbstdn_term_code_eff <= 202408
    qualify
        row_number() over (partition by sgbstdn_pidm order by sgbstdn_term_code_eff desc) = 1
    ) as C
on
    A.sfrstcr_pidm = C.sgbstdn_pidm
where
    A.sfrstcr_term_code = {self.term_code}
    and A.sfrstcr_ptrm_code not in ('28','R3')
    and A.sfrstcr_add_date <= '{self.cycle_date}' -- added before cycle_day
    and (A.sfrstcr_rsts_date > '{self.cycle_date}' or A.sfrstcr_rsts_code in ('DC','DL','RD','RE','RW','WD','WF')) -- dropped after cycle_day or still enrolled
    and B.ssbsect_subj_code <> 'INST'
group by
    {indent(join(dct.keys()))}
    ,B.ssbsect_subj_code
    ,B.ssbsect_crse_numb
"""

            qry = f"""
with A as {subqry(qry)}
select * from A
union all
select
    {indent(join(dct.values()))}
    ,'_allcrse' as crse_code
    ,sum(credit_hr) as credit_hr
from A
group by
    {indent(join(dct.values()))}
union all
select
    {indent(join(dct.values()))}
    ,'_anycrse' as crse_code
    ,case when sum(A.credit_hr) > 0 then 1 else 0 end as credit_hr
from A
group by
    {indent(join(dct.values()))}
"""
            df = run(qry, show)
            return df
        return self.get(fcn, 'registrations', root / f'registrations/registrations_{self.stem}', overwrite)


self = Term(
    term_code=202408,
    cycle_day=50,
)
self.get_terms(
    overwrite=True,
    show=True,
)
self.get_registrations(
    overwrite=True,
    show=True,
    )
print('done')