In [1]:
# TODO: somehow pull all of the data from the database and populate ta_metadata and shift_metadata

In [2]:
# Import dummy data
from dummy_data_big import ta_metadata, shift_metadata, preference_matrix

In [3]:
# ============================================================
# DATA ACCESS HELPERS
# ============================================================

def get_pref(ta_id, shift_id):
    """Get a TA's preference score for a shift. 0 means unavailable."""
    ta_idx    = ta_id - 1      # ta_ids are 1-indexed
    shift_idx = shift_id - 1   # shift_ids are 1-indexed
    return preference_matrix[ta_idx][shift_idx]

def is_available(ta_id, shift_id):
    """A TA is available for a shift if their preference score is > 0."""
    return get_pref(ta_id, shift_id) > 0

def get_ta(ta_id):
    """Get TA metadata by ta_id."""
    return ta_metadata[ta_id - 1]

def get_shift(shift_id):
    """Get shift metadata by shift_id."""
    return shift_metadata[shift_id - 1]

def get_available_tas(shift_id):
    """Get all TAs who marked themselves available for a shift."""
    return [ta["ta_id"] for ta in ta_metadata if is_available(ta["ta_id"], shift_id)]

def get_eligible_leads(shift_id):
    """Get all TAs available for a shift who can be a lab lead."""
    return [ta_id for ta_id in get_available_tas(shift_id) if (get_ta(ta_id)["lab_admin_status"] == 3)]

def get_eligible_lab_tas(shift_id):
    """Get all TAs available for a shift who can be a lab TA."""
    return [ta_id for ta_id in get_available_tas(shift_id) if (get_ta(ta_id)["lab_admin_status"] >= 2)]

def get_shifts_for_ta(ta_id):
    """Get all shifts a TA marked themselves available for."""
    return [s["shift_id"] for s in shift_metadata if is_available(ta_id, s["shift_id"])]

def ta_id_to_name(ta_id):
    return get_ta(ta_id)["name"]

def shift_id_to_name(shift_id):
    return get_shift(shift_id)["name"]

In [4]:
# # Helper function tests
# print(list(map(ta_id_to_name, get_available_tas(5))))      # Tue B3 Lab — should be Alice, Bob, Eve, Grace, Hank
# print(list(map(ta_id_to_name, get_eligible_leads(5))))     # should be Alice, Bob, Eve, Grace (can_be_lab_lead = True)
# print(list(map(ta_id_to_name, get_eligible_lab_tas(5))))   # should be Alice, Bob, Eve, Grace, Hank
# print(list(map(shift_id_to_name, get_shifts_for_ta(6))))   # Frank — should be shifts 1, 3, 8, 9, 10, 11 (no labs)

In [5]:
# ============================================================
# CONSTRAINT LOGIC HELPERS
# ============================================================

def shift_duration_hours(shift):
    """Returns duration of a shift in hours as a float."""
    start_mins = shift["start"].hour * 60 + shift["start"].minute
    end_mins   = shift["end"].hour   * 60 + shift["end"].minute
    return (end_mins - start_mins) / 60

def shifts_overlap(shift_a, shift_b):
    """Two shifts conflict if they are on the same day and their times intersect."""
    if shift_a["day"] != shift_b["day"]:
        return False
    return shift_a["start"] < shift_b["end"] and shift_b["start"] < shift_a["end"]

def has_time_conflict(ta_id, shift_id, current_assignments):
    """Check if a TA is already assigned to any shift that overlaps with the candidate shift."""
    candidate = get_shift(shift_id)
    for assigned_shift_id in current_assignments[ta_id]:
        if shifts_overlap(candidate, get_shift(assigned_shift_id)):
            return True
    return False

def is_in_any_lab(ta_id, current_assignments):
    """Check if a TA is already assigned to any lab shift."""
    return any(get_shift(s_id)["is_lab"] for s_id in current_assignments[ta_id])

def would_exceed_max_hours(ta_id, shift_id, hours_assigned):
    """Check if assigning this shift would push the TA over their max hours."""
    ta = get_ta(ta_id)
    shift = get_shift(shift_id)
    return hours_assigned[ta_id] + shift_duration_hours(shift) > ta["max_hours"]

def has_min_status_for_role(ta_id, role):
    """Check if a TA's lab_admin_status meets the minimum required for a role."""
    min_status = {"oh_ta": 1, "lab_ta": 2, "lead": 3}
    return get_ta(ta_id)["lab_admin_status"] >= min_status[role]

# ============================================================
# COMBINED ELIGIBILITY
# ============================================================

def get_eligible_tas_for_role(shift_id, role, current_assignments, hours_assigned):
    """
    Returns list of ta_ids who are eligible for a given role on a given shift.
    Filters on: availability, role status, max hours, time conflicts, one-lab rule.
    """
    eligible = []
    for ta in ta_metadata:
        ta_id = ta["ta_id"]

        if not is_available(ta_id, shift_id):
            continue
        if not has_min_status_for_role(ta_id, role):
            continue
        if would_exceed_max_hours(ta_id, shift_id, hours_assigned):
            continue
        if has_time_conflict(ta_id, shift_id, current_assignments):
            continue
        if role in ("lead", "lab_ta") and is_in_any_lab(ta_id, current_assignments):
            continue

        eligible.append(ta_id)

    return eligible

In [6]:
# # Eligibility tests

# # Dictionaries where ta_id is the key
# current_assignments = {ta["ta_id"]: [] for ta in ta_metadata} # List of shift_ids assigned to this ta
# hours_assigned      = {ta["ta_id"]: 0  for ta in ta_metadata} # Current number of hours for this ta

# # Shift 5 = Tue 12:00 Lab
# print(list(map(ta_id_to_name, get_eligible_tas_for_role(5, "lead",   current_assignments, hours_assigned))))
# # Expect: Alice(1), Bob(2), Eve(5), Grace(7) — status 3 only, available for shift 5

# print(list(map(ta_id_to_name, get_eligible_tas_for_role(5, "lab_ta", current_assignments, hours_assigned))))
# # Expect: Alice(1), Bob(2), Eve(5), Grace(7), Hank(8) — status >= 2, available for shift 5

# print(list(map(ta_id_to_name, get_eligible_tas_for_role(6, "oh_ta",  current_assignments, hours_assigned))))
# # Expect: Alice(1), Bob(2), Carol(3), Dave(4), Grace(7) — available for shift 6 (Tue 12:00 OH)

# # Now assign Alice to the Tue 12:00 Lab as a lead
# current_assignments[1] = [5]
# hours_assigned[1] = shift_duration_hours(get_shift(5))  # 1.25 hours

# print(list(map(ta_id_to_name, get_eligible_tas_for_role(6, "oh_ta", current_assignments, hours_assigned))))
# # Expect: Bob(2), Carol(3), Dave(4), Grace(7) — Alice drops out due to time conflict

# print(list(map(ta_id_to_name, get_eligible_tas_for_role(7, "lead", current_assignments, hours_assigned))))
# # Expect: Bob(2), Eve(5), Grace(7) — Alice drops out due to one-lab rule

In [7]:
# ============================================================
# ENFORCE FAIRNESS FLOOR TO EVEN OUT SHIFTS
# ============================================================

def apply_fairness_floor(threshold=0.7):
    """
    Calculates a fair share of hours for each TA and raises their
    min_hours to threshold * fair_share if it's currently lower.
    Caps the floor at each TA's actual available hours so we don't
    create an impossible constraint.
    Mutates ta_metadata in place.
    """

    # Total TA-hours needed across all shifts
    total_hours_needed = sum(
        shift_duration_hours(shift) * sum(shift["staffing"])
        for shift in shift_metadata
    )

    fair_share = total_hours_needed / len(ta_metadata)
    fairness_floor = fair_share * threshold

    for i, ta in enumerate(ta_metadata):
        # How many hours is this TA actually available for?
        available_hours = sum(
            shift_duration_hours(shift)
            for j, shift in enumerate(shift_metadata)
            if preference_matrix[i][j] > 0
        )

        # Don't set a floor higher than what they can actually work
        adjusted_floor = min(fairness_floor, available_hours)

        # Only raise the floor, never lower an existing min
        new_min = max(ta["min_hours"], adjusted_floor)

        if new_min != ta["min_hours"]:
            ta["min_hours"] = new_min

# Apply fairness floor before running anything else
apply_fairness_floor()

In [8]:
# ============================================================
# ASSIGNMENT HELPERS
# ============================================================

def assign_ta(ta_id, shift_id, current_assignments, hours_assigned):
    """Assign TA to shift."""
    current_assignments[ta_id].append(shift_id)
    hours_assigned[ta_id] += shift_duration_hours(get_shift(shift_id))

def remove_ta(ta_id, shift_id, current_assignments, hours_assigned):
    """Remove TA from shift."""
    current_assignments[ta_id].remove(shift_id)
    hours_assigned[ta_id] -= shift_duration_hours(get_shift(shift_id))

In [9]:
# ============================================================
# INITIAL SCHEDULE - GREEDY APPROACH
# ============================================================

def greedy_assign():
    """
    Produce an initial valid schedule by filling shifts greedily.
    Labs first (most constrained), then OH shifts.
    Within each shift, fills roles most constrained first: leads, lab_tas, oh_tas.
    Returns schedule, current_assignments, hours_assigned.
    """

    # Schedule output: shift_id -> assigned TAs by role + error info
    schedule = {
        s["shift_id"]: {
            "leads":   [],
            "lab_tas": [],
            "oh_tas":  [],
            "unschedulable": False,
            "error": None
        }
        for s in shift_metadata
    }

    current_assignments = {ta["ta_id"]: [] for ta in ta_metadata}
    hours_assigned      = {ta["ta_id"]: 0  for ta in ta_metadata}

    # Give preference to better TAs with better "fit"
    def candidate_score(ta_id, shift_id):
        ta = get_ta(ta_id)
        pref = get_pref(ta_id, shift_id)
        hours_below_min = max(0, ta["min_hours"] - hours_assigned[ta_id])
        balance_boost   = hours_below_min / ta["min_hours"] if ta["min_hours"] > 0 else 0
        
        return pref + balance_boost 

    # Assign correct number of TAs to the shift, role
    def fill_role(shift, role, num_needed):
        shift_id = shift["shift_id"]
        
        eligible = get_eligible_tas_for_role(
            shift_id, role, current_assignments, hours_assigned
        )
        # Sort by candidate score descending
        ranked = sorted(eligible, key=lambda ta_id: candidate_score(ta_id, shift_id), reverse=True)
        selected = ranked[:num_needed]

        role_key = role + "s"  # "leads", "lab_tas", "oh_tas"
        for ta_id in selected:
            schedule[shift_id][role_key].append(ta_id)
            assign_ta(ta_id, shift_id, current_assignments, hours_assigned)

        # Unable to schedule enough TAs
        if len(selected) < num_needed:
            schedule[shift_id]["unschedulable"] = True
            schedule[shift_id]["error"] = (
                f"Could only fill {len(selected)}/{num_needed} {role} slots"
            )

    # Labs first, then OH
    sorted_shifts = sorted(shift_metadata, key=lambda s: not s["is_lab"])

    for shift in sorted_shifts:
        if shift["is_lab"]:
            fill_role(shift, "lead",   shift["staffing"][2])
            fill_role(shift, "lab_ta", shift["staffing"][1])
        else:
            fill_role(shift, "oh_ta",  shift["staffing"][0])

    return schedule, current_assignments, hours_assigned

In [10]:
# ============================================================
# SCHEDULE SCORING
# ============================================================

# Better to schedule as lab lead if possible
ROLE_WEIGHTS = {
    "lead":   1.0,
    "lab_ta": 0.8,
    "oh_ta":  0.6,
}

def score_schedule(schedule):
    total = 0
    for shift_id, assignment in schedule.items():
        for ta_id in assignment["leads"]:
            total += ROLE_WEIGHTS["lead"]   * get_pref(ta_id, shift_id)
        for ta_id in assignment["lab_tas"]:
            total += ROLE_WEIGHTS["lab_ta"] * get_pref(ta_id, shift_id)
        for ta_id in assignment["oh_tas"]:
            total += ROLE_WEIGHTS["oh_ta"]  * get_pref(ta_id, shift_id)
    return total

In [11]:
# ============================================================
# SIMULATED ANNEALING
# ============================================================

import math
import random
from copy import deepcopy

def get_random_filled_shift(schedule):
    # Pick a random shift that has at least one assignment
    filled_shifts = [
        shift_id for shift_id, assignment in schedule.items()
        if assignment["leads"] or assignment["lab_tas"] or assignment["oh_tas"]
    ]
    shift_id = random.choice(filled_shifts)
    assignment = schedule[shift_id]

    # Pick a random role pool that is nonempty, then a random TA from it
    nonempty_roles = []
    if assignment["leads"]:
        nonempty_roles.append(("lead", "leads"))
    if assignment["lab_tas"]:
        nonempty_roles.append(("lab_ta", "lab_tas"))
    if assignment["oh_tas"]:
        nonempty_roles.append(("oh_ta", "oh_tas"))

    role, role_key = random.choice(nonempty_roles)
    ta_out = random.choice(assignment[role_key])

    return(ta_out, shift_id, role)

def simulated_annealing(initial_temp  = 10.0, cooling_rate  = 0.995, num_iterations = 10000):
    
    # Start from a greedy schedule
    schedule, current_assignments, hours_assigned = greedy_assign()
    current_score = score_schedule(schedule)

    best_schedule            = deepcopy(schedule)
    best_hours_assigned      = deepcopy(hours_assigned)
    best_score               = current_score

    temperature = initial_temp

    for iteration in range(num_iterations):

        # --------------------------------------------------------
        # PICK A RANDOM FILLED (shift, role) SLOT TO MUTATE
        # --------------------------------------------------------
        ta_out, shift_id, role = get_random_filled_shift(schedule)
        role_key = role + "s"

        # --------------------------------------------------------
        # TEMPORARILY REMOVE THE TA
        # --------------------------------------------------------
        schedule[shift_id][role_key].remove(ta_out)
        remove_ta(ta_out, shift_id, current_assignments, hours_assigned)

        # --------------------------------------------------------
        # FIND AN ELIGIBLE REPLACEMENT (excluding already assigned TAs on this shift)
        # --------------------------------------------------------
        already_on_shift = (
            schedule[shift_id]["leads"] +
            schedule[shift_id]["lab_tas"] +
            schedule[shift_id]["oh_tas"]
        )
        eligible = [
            ta_id for ta_id in get_eligible_tas_for_role(
                shift_id, role, current_assignments, hours_assigned
            )
            if ta_id not in already_on_shift
        ]

        if not eligible:
            # No replacement found — restore and move on
            schedule[shift_id][role_key].append(ta_out)
            assign_ta(ta_out, shift_id, current_assignments, hours_assigned)
            temperature *= cooling_rate
            continue

        ta_in = random.choice(eligible)  # random, not greedy — important for annealing

        # --------------------------------------------------------
        # APPLY THE SWAP AND SCORE
        # --------------------------------------------------------
        schedule[shift_id][role_key].append(ta_in)
        assign_ta(ta_in, shift_id, current_assignments, hours_assigned)

        new_score = score_schedule(schedule)
        delta     = new_score - current_score

        # --------------------------------------------------------
        # ACCEPT OR REJECT
        # --------------------------------------------------------
        accept = False
        if delta > 0:
            accept = True
        else:
            probability = math.exp(delta / temperature)
            accept = random.random() < probability

        if accept:
            current_score = new_score
            if current_score > best_score:
                best_score               = current_score
                best_schedule            = deepcopy(schedule)
                best_hours_assigned      = deepcopy(hours_assigned)
        else:
            # Revert the swap
            schedule[shift_id][role_key].remove(ta_in)
            remove_ta(ta_in,  shift_id, current_assignments, hours_assigned)
            schedule[shift_id][role_key].append(ta_out)
            assign_ta(ta_out, shift_id, current_assignments, hours_assigned)

        temperature *= cooling_rate

    return best_schedule, best_hours_assigned, best_score

In [12]:
# ============================================================
# TEST ON DUMMMY DATA
# ============================================================

def display_results(schedule, hours_assigned, score):
    print(f"Total Score: {score:.2f}\n")
    for shift_id, assignment in schedule.items():
        shift = get_shift(shift_id)
        print(f"{shift['name']}")
        if assignment["unschedulable"]:
            print(f"  ⚠ UNSCHEDULABLE: {assignment['error']}")
        if assignment["leads"]:
            print(f"  Leads:   {[get_ta(t)['name'] for t in assignment['leads']]}")
        if assignment["lab_tas"]:
            print(f"  Lab TAs: {[get_ta(t)['name'] for t in assignment['lab_tas']]}")
        if assignment["oh_tas"]:
            print(f"  OH TAs:  {[get_ta(t)['name'] for t in assignment['oh_tas']]}")
        print()

    print("Hours per TA:")
    for ta in ta_metadata:
        h = hours_assigned[ta["ta_id"]]
        print(f"  {ta['name']:8}: {h:.2f}h  (min: {ta['min_hours']}, max: {ta['max_hours']})")

# Run
schedule, hours_assigned, score = simulated_annealing()
display_results(schedule, hours_assigned, score)

Total Score: 100.00

Mon 9:00 OH
  OH TAs:  ['Iris', 'Tina']

Mon 10:30 OH
  OH TAs:  ['Jack', 'Olivia']

Mon 12:00 OH
  OH TAs:  ['Karen', 'Pete']

Mon 13:30 OH
  OH TAs:  ['Leo', 'Victor']

Mon 15:00 OH
  OH TAs:  ['Mia', 'Quinn']

Mon 16:30 OH
  OH TAs:  ['Nathan', 'Rachel']

Mon 19:00 OH
  OH TAs:  ['Sam']

Mon 20:30 OH
  OH TAs:  ['Xavier']

Tue 9:00 OH
  OH TAs:  ['Alice', 'Grace']

Tue 10:30 OH
  OH TAs:  ['Uma', 'Bob']

Tue 12:00 Lab
  Leads:   ['Alice', 'Bob']
  Lab TAs: ['Dave', 'Grace']

Tue 12:00 OH
  OH TAs:  ['Carol', 'Iris']

Tue 13:30 OH
  OH TAs:  ['Dave', 'Jack']

Tue 15:00 Lab
  Leads:   ['Carol', 'Eve']
  Lab TAs: ['Hank', 'Frank']

Tue 15:00 OH
  OH TAs:  ['Wendy', 'Mia']

Tue 16:30 OH
  OH TAs:  ['Rachel', 'Eve']

Tue 19:00 OH
  OH TAs:  ['Yara']

Tue 20:30 OH
  OH TAs:  ['Sam']

Wed 9:00 OH
  OH TAs:  ['Tina', 'Bob']

Wed 10:30 OH
  OH TAs:  ['Olivia', 'Uma']

Wed 12:00 OH
  OH TAs:  ['Karen', 'Pete']

Wed 13:30 OH
  OH TAs:  ['Hank', 'Leo']

Wed 15:00 OH
  OH TA

In [13]:
# TODO: somehow send these results to backend