In [1]:
import utils
import copy
import random


# Problem representation

## Variables

Since nothing is simple until it becomes complicated, we will use the following classes to represent the three major elements of our problem:

* `Teacher`: contains the subjects a teacher can teach and the preferred days/time slots.
* `Classroom`: contains the subjects that can be held in a classroom and its capacity.
* `Course`: contains the teachers who can teach a subject, the classrooms where it can be held, and the number of hours per week.

In [2]:
class Teacher:
    def __init__(self, name, teacher_info):
        self.name = name
        self.preferred_days = set()
        self.preferred_intervals = set()
        self.courses = set()
        self.__days = set(['Luni', 'Marti', 'Miercuri', 'Joi', 'Vineri'])

        for constraint in teacher_info['Constrangeri']:
            # Skip negative constraints, if a day/interval is not in the
            # positive constraint list, it is a negative constraint by default
            if constraint[0] == '!':
                continue

            if constraint.startswith('!Pauza > '):
                self.pause = int(constraint.split(' ')[2])

            if constraint in self.__days:
                self.preferred_days.add(constraint)
            else:
                two_hour_intervals = self.__get_two_hour_intervals_from_interval(constraint)
                self.preferred_intervals.update(two_hour_intervals)

        self.courses.update(teacher_info[utils.MATERII])

    def __str__(self):
        return f"{self.name} teaches {self.courses} and prefers {self.preferred_days} and {self.preferred_intervals}"

    def __get_two_hour_intervals_from_interval(self, interval):
        """ Evaluate interval as a tuple """
        start, end = eval(interval.replace('-', ','))

        result = []

        while start != end:
            result.append((start, start + 2))
            start += 2

        return result

class Course:
    def __init__(self, name, student_count):
        self.name = name
        self.student_count = student_count
        self.students_not_allocated = student_count
        self.teachers = set()
        self.classrooms = set()

    def add_teacher(self, teacher):
        self.teachers.add(teacher)

    def add_classroom(self, classroom):
        self.classrooms.add(classroom)

    def __str__(self):
        return f"{self.name} with {self.student_count} students, teachers {(list(map(lambda t: t.name, self.teachers)))} and classrooms {(list(map(lambda c: c.name, self.classrooms)))}"

class Classroom:
    def __init__(self, name, capacity, courses):
        self.name = name
        self.capacity = capacity
        self.courses = courses

    def __str__(self):
        return f"{self.name} with capacity {self.capacity} and courses {self.courses}"

## Parsing

The following function parses the input file and creates objects based on the classes mentioned above.

These objects model the conditions of the problem in a way that is easy to use for generating states.

In [3]:
def parse_yaml(yaml_path):
    """ Parse a given yaml file containing the variables and constraints of the problem
        and return dictionaries of teachers, courses, classrooms and a dictionary of courses
        to remaining students that need to be allocated to a classroom """
    yaml_dict = utils.read_yaml_file(yaml_path)
    courses_yaml = yaml_dict[utils.MATERII]
    classrooms_yaml = yaml_dict[utils.SALI]
    teachers_yaml = yaml_dict[utils.PROFESORI]

    teachers = dict()
    courses = dict()
    courses_to_remaining_students = dict()
    classrooms = dict()

    # Initialize courses
    for course_name, student_count in courses_yaml.items():
        course = Course(course_name, student_count)
        courses[course_name] = course
        courses_to_remaining_students[course_name] = student_count

    # Parse classroom info
    for classroom_name, classroom_info in classrooms_yaml.items():
        classroom = Classroom(classroom_name, classroom_info['Capacitate'], classroom_info[utils.MATERII])
        classrooms[classroom_name] = classroom
        for course in classroom_info[utils.MATERII]:
            courses[course].add_classroom(classroom)

    # Parse teacher info
    for teacher_name, teacher_info in teachers_yaml.items():
        teacher = Teacher(teacher_name, teacher_info)
        teachers[teacher_name] = teacher
        for course in teacher_info[utils.MATERII]:
            courses[course].add_teacher(teacher)

    return teachers, courses, courses_to_remaining_students, classrooms

In the following global variables, we'll store the static information of the problem (essentially the constraints that must be respected). We'll see that keeping this information here will save us space in the long run.

In [4]:
teachers, courses, courses_to_remaining_students, classrooms = None, None, None, None

## States

We will represent a timetable using the `State` class, which contains a 3-level dictionary:
- the 1st key is the day,
- the 2nd key is the time slot,
- the 3rd key is the classroom.

The value is a tuple containing the subject and the teacher who is teaching in that slot, or `None` if no subject is scheduled in that slot.

It also contains a dictionary that links courses to the number of unallocated students.

Thus, we have a relatively lightweight representation of the solution's state that doesn't create many issues when cloning (something we will do frequently in the following algorithms).

The class also has methods to generate neighboring states/actions based on the current state of the problem and the constraints from the previously defined data structures, as well as a method to evaluate the quality of the state based on the number of conflicts, the number of occupied time slots scaled by how many teachers could teach at that time, and the number of courses covered.

Neighbouring states can be generated in three ways:
- filling an empty slot with a subject + teacher,
- replacing an occupied slot with a different subject + teacher,
- swapping two occupied slots.

In [21]:
class State:
    """ Class that represents a partial timetable allocation """
    def __init__(self, input_yaml_path = None):
        self.timetable = dict()
        self.courses_to_remaining_students = dict()
        self.conflict_count = 0

        if input_yaml_path is not None:
            self.__parse_yaml_dict(utils.read_yaml_file(input_yaml_path))
            self.courses_to_remaining_students = copy.deepcopy(courses_to_remaining_students)

    def __parse_yaml_dict(self, yaml_dict):
        """ Parses the input yaml file and initializes an initial state """
        # Initialize timetable with empty classrooms
        self.timetable.update({
            day: {
                eval(interval): {
                    classroom: None
                    for classroom in yaml_dict[utils.SALI]
                }
                for interval in yaml_dict[utils.INTERVALE]
            }
            for day in yaml_dict[utils.ZILE]
        })

    def is_final(self):
        """ Checks if the current state is a final state """
        for course_name in self.courses_to_remaining_students:
            if self.courses_to_remaining_students[course_name] > 0:
                return False

        return True

    def clone(self):
        """ Returns a deep copy of the current state """
        state = State()
        state.timetable = copy.deepcopy(self.timetable)
        state.courses_to_remaining_students = copy.deepcopy(self.courses_to_remaining_students)
        state.conflict_count = self.conflict_count

        return state

    def fill_slot(self, day, interval, classroom, teacher_name, course_name):
        """ Fills a given slot with the given course and teacher """
        new_state = self.clone()
        new_state.timetable[day][interval][classroom] = (teacher_name, course_name)

        # Update the conflict count
        new_state.update_conflict_count()

        # Update the course not allocated count
        new_state.courses_to_remaining_students[course_name] -= classrooms[classroom].capacity

        return new_state

    def replace_slot(self, day, interval, classroom, teacher_name, course_name):
        """ Replaces a given slot with the given course and teacher """
        new_state = self.clone()
        new_state.timetable[day][interval][classroom] = (teacher_name, course_name)

        # Update the conflict count
        new_state.update_conflict_count()

        # Update the course not allocated count
        new_state.courses_to_remaining_students[course_name] -= classrooms[classroom].capacity
        new_state.courses_to_remaining_students[self.timetable[day][interval][classroom][1]] += classrooms[classroom].capacity

        return new_state

    def swap_slots(self, day1, interval1, classroom1, day2, interval2, classroom2):
        """ Swaps two slots """
        new_state = self.clone()
        new_state.timetable[day1][interval1][classroom1], new_state.timetable[day2][interval2][classroom2] = new_state.timetable[day2][interval2][classroom2], new_state.timetable[day1][interval1][classroom1]

        # Update the conflict count
        new_state.update_conflict_count()

        # Update course not allocated count
        new_state.courses_to_remaining_students[self.timetable[day1][interval1][classroom1][1]] += classrooms[classroom2].capacity
        new_state.courses_to_remaining_students[self.timetable[day2][interval2][classroom2][1]] += classrooms[classroom1].capacity
        new_state.courses_to_remaining_students[self.timetable[day1][interval1][classroom1][1]] -= classrooms[classroom1].capacity
        new_state.courses_to_remaining_students[self.timetable[day2][interval2][classroom2][1]] -= classrooms[classroom2].capacity

        return new_state

    def check_constraints(self, day, interval, classroom, teacher_name, course_name):
        """ Checks if the given slot can be filled with the given course and teacher """
        # Check if the course can be taught by the teacher
        if course_name not in teachers[teacher_name].courses:
            return False

        # Check if the course can be held in the classroom
        if course_name not in classrooms[classroom].courses:
            return False

        # Check if the teacher has reached the maximum number of intervals
        if self.get_teacher_interval_count(teacher_name) >= 7:
            return False

        # Check if the teacher is available in the given interval
        if not self.teacher_available_in_interval(teacher_name, interval, day):
            return False

        return True

    def get_next_actions_replacements(self):
        return [(day, interval, classroom, teacher_name, course_name)
            for day in self.timetable
            for interval in self.timetable[day]
            for classroom in classrooms
            for teacher_name in teachers
            for course_name in teachers[teacher_name].courses
            if self.timetable[day][interval][classroom] is not None and # slot already filled
            self.courses_to_remaining_students[course_name] != 0 and # course has not been fully allocated
            self.check_constraints(day, interval, classroom, teacher_name, course_name) and
            (self.timetable[day][interval][classroom][1] != course_name or
            self.timetable[day][interval][classroom][0] != teacher_name) # new allocation is different from the old one
        ]

    def get_next_actions_with_soft_constraints(self, course_name = None):
        if course_name is None:
            course_name = self.get_smallest_course()

        return [(day, interval, classroom, teacher_name, course_name)
            for day in self.timetable
            for interval in self.timetable[day]
            for classroom in classrooms
            for teacher_name in teachers
            if self.timetable[day][interval][classroom] is None and # slot not already filled
            self.courses_to_remaining_students[course_name] != 0 and # course has not been fully allocated
            self.check_constraints(day, interval, classroom, teacher_name, course_name) and
            day in teachers[teacher_name].preferred_days and
            interval in teachers[teacher_name].preferred_intervals
        ]

    def get_next_actions_no_soft_constraints(self):
        return [(day, interval, classroom, teacher_name, course_name)
            for day in self.timetable
            for interval in self.timetable[day]
            for classroom in classrooms
            for teacher_name, teacher in teachers.items()
            for course_name in teacher.courses
            if self.timetable[day][interval][classroom] is None and # slot not already filled
            course_name in teacher.courses and # teacher can teach the course
            course_name in classrooms[classroom].courses and # course can be held in the classroom
            self.courses_to_remaining_students[course_name] != 0 and # course has not been fully allocated
            self.get_teacher_interval_count(teacher_name) < 7 and # teacher has not reached the maximum number of intervals
            self.teacher_available_in_interval(teacher_name, interval, day)
        ]

    def get_next_actions_greedy(self):
        """ Returns a greedy list of actions that can be taken in the current state """
        return self.get_next_actions_with_soft_constraints(self.get_smallest_course())

    def get_next_actions_full(self):
        """ Returns all actions that can be taken """
        sorted_courses = sorted(self.courses_to_remaining_students, key=lambda course: courses.get(course).student_count)

        for course_name in sorted_courses:
            if self.courses_to_remaining_students[course_name] == 0:
                continue

            actions = self.get_next_actions_with_soft_constraints(course_name)

            if len(actions) > 0:
                return actions

        if self.get_next_actions_replacements():
            return self.get_next_actions_replacements()

        return self.get_next_actions_no_soft_constraints()

    def get_next_fill_states(self, course_name):
        """ Returns a generator of all possible states that can be reached from the current state
         by filling a slot with the given course """
        return (self.fill_slot(day, interval, classroom, teacher_name, course_name)
            for day in self.timetable
            for interval in self.timetable[day]
            for classroom in classrooms
            for teacher_name in teachers
            if self.timetable[day][interval][classroom] is None and # slot not already filled
            self.courses_to_remaining_students[course_name] != 0 and # course has not been fully allocated
            self.check_constraints(day, interval, classroom, teacher_name, course_name)
        )

    def get_next_replacement_states(self, course_name):
        """ Returns a generator of all possible states that can be reached from the current state
         by replacing an already filled slot with the given course """
        return (self.replace_slot(day, interval, classroom, teacher_name, course_name)
            for day in self.timetable
            for interval in self.timetable[day]
            for classroom in classrooms
            for teacher_name in teachers
            if self.timetable[day][interval][classroom] is not None and # slot already filled
            self.courses_to_remaining_students[course_name] != 0 and # course has not been fully allocated
            self.check_constraints(day, interval, classroom, teacher_name, course_name) and
            (self.timetable[day][interval][classroom][1] != course_name or
            self.timetable[day][interval][classroom][0] != teacher_name) # new allocation is different from the old one
        )

    def get_next_swap_states(self):
        """ Returns a generator of all possible states that can be reached from the current state
         by swapping two slots """
        return (self.swap_slots(day1, interval1, classroom1, day2, interval2, classroom2)
            for day1 in self.timetable
            for interval1 in self.timetable[day1]
            for classroom1 in self.timetable[day1][interval1]
            for day2 in self.timetable
            for interval2 in self.timetable[day2]
            for classroom2 in self.timetable[day2][interval2]
            if self.timetable[day1][interval1][classroom1] is not None and # slot1 is filled
            self.timetable[day2][interval2][classroom2] is not None and # slot2 is filled
            self.check_constraints(day1, interval1, classroom1,
                                    self.timetable[day2][interval2][classroom2][0],
                                    self.timetable[day2][interval2][classroom2][1]) and
            self.check_constraints(day2, interval2, classroom2,
                                    self.timetable[day1][interval1][classroom1][0],
                                    self.timetable[day1][interval1][classroom1][1])
        )

    def get_next_states(self):
        """ Returns a list of all possible states that can be reached from the current state.
            Tries to generate, in order, states by filling slots, replacing slots and swapping slots."""
        smallest_course = self.get_smallest_course()
        states = 0
        next_states = None

        if smallest_course is None:
            # if no allocations can be made, we can only swap slots
            next_states = list(self.get_next_swap_states())
        else:
            # try to fill slots first
            next_states = list(self.get_next_fill_states(smallest_course))
            states += len(next_states)

            # if no allocations can be made, try to replace slots
            if len(next_states) == 0:
                next_states = list(self.get_next_replacement_states(smallest_course))
                states += len(next_states)

        return next_states, states

    def eval(self):
        """ Grades the current state based the quality of the allocated slots """
        total = 0

        for day in self.timetable:
            for interval in self.timetable[day]:
                for classroom in self.timetable[day][interval]:
                    if self.timetable[day][interval][classroom] is not None:
                        teacher_name, _ = self.timetable[day][interval][classroom]

                        # base value for allocated slots
                        total += 500

                        # bonus for assigning teachers with few courses
                        total += 100 / len(teachers[teacher_name].courses)

        # penalty for conflicts
        total -= self.conflict_count * 250

        # bonus for fully allocated courses
        for course in self.courses_to_remaining_students:
            if self.courses_to_remaining_students[course] == 0:
                total += 10000

        # bonus for final state
        if self.is_final():
            total += 1000000

        return total

    def teacher_available_in_interval(self, teacher_name, interval, day):
        """ Checks if the teacher isn't already teaching the given interval """
        for classroom in classrooms:
            timetable_entry = self.timetable[day][interval][classroom]

            if timetable_entry is not None and timetable_entry[0] == teacher_name:
                return False

        return True

    def get_teacher_interval_count(self, teacher_name):
        """ Returns the number of intervals the teacher is already teaching """
        interval_count = 0

        for day in self.timetable:
            for interval in self.timetable[day]:
                for classroom in self.timetable[day][interval]:
                    if self.timetable[day][interval][classroom] is not None and self.timetable[day][interval][classroom][0] == teacher_name:
                        interval_count += 1

        return interval_count

    def update_conflict_count(self):
        """ Recalculates the total number of conflicts """
        self.conflict_count = 0

        for day in self.timetable:
            for interval in self.timetable[day]:
                for classroom in self.timetable[day][interval]:
                    if self.timetable[day][interval][classroom] is not None:
                        teacher_name, _ = self.timetable[day][interval][classroom]
                        teacher = teachers[teacher_name]

                        if day not in teacher.preferred_days:
                            self.conflict_count += 1

                        if interval not in teacher.preferred_intervals:
                            self.conflict_count += 1

    def get_smallest_course(self):
        """ Returns the course with the fewest number of students that has not been fully allocated """

        courses_not_allocated = {course: self.courses_to_remaining_students[course]
                                   for course in self.courses_to_remaining_students
                                if self.courses_to_remaining_students[course] > 0}

        if len(courses_not_allocated) == 0:
            return None

        course_name = min(courses_not_allocated, key=lambda course: courses[course].student_count)

        return course_name

    def get_total_possibilities_for_slot(self, day, interval, classroom):
        total = 0

        for _, teacher in teachers.items():
            for course_name in teacher.courses:
                if course_name in classrooms[classroom].courses and day in teacher.preferred_days and interval in teacher.preferred_intervals:
                    total += 1

        return total

    def get_conflict_count_with_pauses(self):
        result = self.conflict_count

        # check if any teacher has a pause longer than their preffered maximum pause
        for teacher_name, teacher in teachers.items():
            pause_intervals = 0

            for day in self.timetable:
                for interval in self.timetable[day]:
                    for classroom in self.timetable[day][interval]:
                        if self.timetable[day][interval][classroom] is not None and self.timetable[day][interval][classroom][0] == teacher_name:
                            pause_intervals = 0
                        else:
                            pause_intervals += 1

                        if pause_intervals > teacher.pause:
                            result += 1

        return result


# Hill Climbing

We'll apply a generic steepest ascent hill climbing algorithm. The key aspect is how we select the next states and how we evaluate which ones are better.

The focus here is on the `get_next_states` and `eval` methods. To filter some of the possible next states, the `get_next_states` method will try to generate only states that fill an empty slot with the course that has the fewest students (theoretically, the course that should cause the least problems in the long term). If such a course cannot be chosen, it will generate states where an already occupied slot is replaced with the course that has the fewest students. If that also isn't possible, states are generated where time slots are swapped. Additionally, constraints are used to avoid generating any states that violate hard constraints.

Thus, we generate a pool of next states that is always relatively small and of high quality. To select the best state, we use the eval method, which evaluates the quality of a state based on the number of soft constraints violated, the number of occupied time slots, and the number of courses covered. Furthermore, a higher score is given to slots occupied by courses taught by teachers who have fewer overall courses. The idea here is to prioritize teachers who don't have many options, thereby forcing teachers who can teach both the same and other courses to take on those additional ones.

In [6]:
def hill_climbing(initial: State, max_iters: int = 500):
    """ A la carte hill climbing algorithm """
    iters, states = 0, 0
    state = initial.clone()

    while iters < max_iters:
        iters += 1

        if state.is_final():
            break

        next_states, states_generated = state.get_next_states()
        states += states_generated

        if len(next_states) == 0:
            break

        new_state = max(next_states, key=lambda s: s.eval())

        if new_state.eval() <= state.eval():
            break

        state = new_state

    return state.is_final(), iters, states, state


## Testarea algoritmului

In [7]:
def run_hill_climb_test(test_path, overwrite_ref = False):
    global teachers, courses, courses_to_remaining_students, classrooms
    yaml_path = f"inputs/{test_path}.yaml"
    teachers, courses, courses_to_remaining_students, classrooms = parse_yaml(yaml_path)
    initial_state = State(yaml_path)

    is_final, iters, states, final_state = hill_climbing(initial_state)

    print(is_final)
    print(f"{final_state.conflict_count} conflicts")
    print(f"{iters} iterations")
    print(f"{states} states")
    print(utils.pretty_print_timetable(final_state.timetable, f"inputs/{test_path}.yaml"))

    if overwrite_ref:
        # write to file
        with open(f"outputs/{test_path}.txt", "w") as f:
            f.write(utils.pretty_print_timetable(final_state.timetable, f"inputs/{test_path}.yaml"))

In [8]:
run_hill_climb_test("dummy", overwrite_ref=True)

True
0 conflicts
12 iterations
204 states
|           Interval           |             Luni             |             Marti            |           Miercuri           |              Joi             |            Vineri            |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            8 - 10            |      MS : (EG324 - RG)       |      MS : (EG324 - CD)       |      MS : (EG324 - RG)       |
|                              |      DS : (EG390 - EG)       |      EG390 - goala           |      DS : (EG390 - EG)       |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            10 - 12           |      IA : (EG324 - PF)       |      EG324 - goala           |      IA : (EG324 - PF)       |
|               

In [9]:
run_hill_climb_test("orar_mic_exact", overwrite_ref=True)

True
0 conflicts
37 iterations
7053 states
|           Interval           |             Luni             |             Marti            |           Miercuri           |              Joi             |            Vineri            |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            8 - 10            |      ED010 - goala           |      ED010 - goala           |      ED010 - goala           |      ED010 - goala           |      ED010 - goala           |
|                              |      PL : (ED020 - AM)       |      PL : (ED020 - AM)       |      PCom : (ED020 - IS)     |      PCom : (ED020 - IS)     |      ED020 - goala           |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            10 -

In [10]:
run_hill_climb_test("orar_mediu_relaxat", overwrite_ref=True)

True
0 conflicts
75 iterations
41333 states
|           Interval           |             Luni             |             Marti            |           Miercuri           |              Joi             |            Vineri            |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            8 - 10            |      MS : (ED038 - CA)       |      MS : (ED038 - CA)       |      SOC : (ED038 - RA)      |      PL : (ED038 - RE)       |      AA : (ED038 - EV)       |
|                              |      MS : (ED041 - IC)       |      MS : (ED041 - IC)       |      SOC : (ED041 - IG)      |      AA : (ED041 - EV)       |      ED041 - goala           |
|                              |      MS : (ED069 - PD)       |      MS : (ED069 - IG)       |      SOC : (ED069 - MA)      |      SOC : (ED069 - RA)      |      ED069 - goala           |
|               

In [11]:
run_hill_climb_test("orar_mare_relaxat", overwrite_ref=True)

True
0 conflicts
97 iterations
68862 states
|           Interval           |             Luni             |             Marti            |           Miercuri           |              Joi             |            Vineri            |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            8 - 10            |      MS : (ED090 - RA)       |      MS : (ED090 - RA)       |      ASC : (ED090 - EI)      |      MS : (ED090 - RA)       |      PL : (ED090 - AD)       |
|                              |      PCom : (ED091 - CC)     |      PCom : (ED091 - CC)     |      PCom : (ED091 - MA)     |      PCom : (ED091 - MA)     |      ASC : (ED091 - IG)      |
|                              |      MS : (EG346 - RS)       |      EG346 - goala           |      EG346 - goala           |      ASC : (EG346 - EA2)     |      EG346 - goala           |
|               

In [12]:
run_hill_climb_test("orar_constrans_incalcat", overwrite_ref=True)

True
3 conflicts
56 iterations
15107 states
|           Interval           |             Luni             |             Marti            |           Miercuri           |              Joi             |            Vineri            |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            8 - 10            |      PCom : (EC109 - ME)     |      PCom : (EC109 - IV)     |      DS : (EC109 - DD)       |      PM : (EC109 - CP2)      |      SO : (EC109 - CA)       |
|                              |      PM : (ED043 - CP2)      |      PM : (ED043 - CA)       |      ED043 - goala           |      PM : (ED043 - VD)       |      SO : (ED043 - CP)       |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            10 

# Monte Carlo Tree Search

The algorithm itself is not very different from the traditional version, but the differences lie in how we generate actions, how we run simulations, and how we calculate the reward of a simulation.

The `get_next_actions` method is similar to `get_next_states` from hill climbing, and it returns actions that cover the unallocated course with the fewest students. If no such actions exist, the `get_next_actions_full` method is called, which returns actions that replace an occupied slot as well as those that violate soft constraints.

I decided that a simulation should only go down to a certain depth, rather than going all the way to a final state. The reason is that often we could end up with a simulation that takes a very long time to reach a final state, without providing much useful information about how good or bad an action is.

The evaluation is similar to that used in hill climbing: we penalize conflicts and reward slots occupied by teachers who teach few courses.

In [64]:
N = 'N'
Q = 'Q'
STATE = 'state'
PARENT = 'parent'
ACTIONS = 'actions'

def init_node(state, parent = None):
    return {N: 0, Q: 0, STATE: state, PARENT: parent, ACTIONS: {}}

from math import sqrt, log

# constant for balancing exploration and exploitation
CP = 1.0 / sqrt(2.0)

def select_action(node, c = CP):
    """ Select the action with the highest UCB value from the given node """
    N_node = node[N]
    actions = node[ACTIONS]

    best = None
    best_value = None

    for action in actions:
        Q_a = actions[action][Q]
        N_a = actions[action][N]

        action_value = Q_a / N_a + c * sqrt(2 * log(N_node) / N_a)

        if best is None or action_value >= best_value:
            best = action
            best_value = action_value

    return best


In [80]:
def apply_action(state, action):
    day, interval, classroom, _, _ = action
    new_state = None

    if state.timetable[day][interval][classroom] is None:
        new_state = state.fill_slot(*action)
    else:
        new_state = state.replace_slot(*action)

    return new_state


def mcts(state0, budget, tree):
    """ Applies MCTS on the given state0 with the given budget and
        tree and returns the next action, tree and number of states generated """
    if tree is None:
        tree = init_node(state0)

    state_count = 0

    for _ in range(budget):
        # start selection from root
        node = tree

        # traverse the tree until a final state is reached or a node that's not fully explored is found
        while node[ACTIONS] and not node[STATE].is_final() and \
                all([act in node[ACTIONS] for act in node[STATE].get_next_actions_greedy()]):
            action = select_action(node)
            node = node[ACTIONS][action]

        # create a new node
        if not node[STATE].is_final():
            if node[STATE].get_next_actions_greedy():
                action = random.choice([act for act in node[STATE].get_next_actions_greedy() if act not in node[ACTIONS]])
            else:
                action = random.choice([act for act in node[STATE].get_next_actions_full() if act not in node[ACTIONS]])

            state = apply_action(node[STATE], action)

            node_nou = init_node(state, node)
            node[ACTIONS][action] = node_nou
            node = node_nou
            state_count += 1

        # start a simulation from the new node
        state = node[STATE]
        depth = 10
        while not state.is_final() and depth != 0:
            actions = state.get_next_actions_greedy()
            if len(actions) == 0:
                actions = state.get_next_actions_replacements()

            if len(actions) == 0:
                break

            action = random.choice(actions)
            state = apply_action(state, action)

            depth -= 1

            state_count += 1

        # calculate and propagate the reward back to the root
        reward = 0
        if state.is_final():
            reward += 100
        reward += 10 - state.conflict_count

        for day in state.timetable:
            for interval in state.timetable[day]:
                for classroom in state.timetable[day][interval]:
                    if state.timetable[day][interval][classroom] is not None:
                        reward += 50

                        teacher_name, _ = state.timetable[day][interval][classroom]
                        reward += 100 / len(teachers[teacher_name].courses)

        for course in state.courses_to_remaining_students:
            if state.courses_to_remaining_students[course] == 0:
                reward += 1000

        while node:
            node[N] += 1
            node[Q] += reward
            node = node[PARENT]

    # return the best action from the root node
    final_action = select_action(tree, 0.0)
    return final_action, tree[ACTIONS][final_action], state_count

In [67]:
def run_MCTS_test(test_path, budget, overwrite_ref = False):
    """ Run the MCTS algorithm on the given test path with the given budget """
    global teachers, courses, courses_to_remaining_students, classrooms
    yaml_path = f"inputs/{test_path}.yaml"
    teachers, courses, courses_to_remaining_students, classrooms = parse_yaml(yaml_path)
    current_state = State(yaml_path)
    tree_node = None
    states = 0

    # apply MCTS to get the next action at each step
    while current_state and not current_state.is_final():
        action, tree_node, state_count = mcts(current_state, budget, tree_node)
        current_state = apply_action(current_state, action)
        states += state_count

    print(current_state.is_final())
    print(f"{current_state.conflict_count} conflicts")
    print(f"{states} states")

    print(utils.pretty_print_timetable(current_state.timetable, f"inputs/{test_path}.yaml"))

    if overwrite_ref:
        # write to file
        with open(f"outputs/{test_path}.txt", "w") as f:
            f.write(utils.pretty_print_timetable(current_state.timetable, f"inputs/{test_path}.yaml"))


## Testarea algoritmului

In [78]:
run_MCTS_test("dummy", 10, overwrite_ref=True)

True
0 conflicts
562 states
|           Interval           |             Luni             |             Marti            |           Miercuri           |              Joi             |            Vineri            |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            8 - 10            |      EG324 - goala           |      MS : (EG324 - CD)       |      MS : (EG324 - RG)       |
|                              |      DS : (EG390 - EG)       |      EG390 - goala           |      DS : (EG390 - EG)       |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            10 - 12           |      IA : (EG324 - PF)       |      IA : (EG324 - AD)       |      EG324 - goala           |
|                             

In [81]:
run_MCTS_test("orar_mic_exact", 5, overwrite_ref=True)

True
0 conflicts
1699 states
|           Interval           |             Luni             |             Marti            |           Miercuri           |              Joi             |            Vineri            |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            8 - 10            |      ED010 - goala           |      PA : (ED010 - AM)       |      ED010 - goala           |      PA : (ED010 - AM)       |      ED010 - goala           |
|                              |      ED020 - goala           |      ED020 - goala           |      PL : (ED020 - IS)       |      ED020 - goala           |      PCom : (ED020 - IS)     |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            10 - 12           

In [73]:
run_MCTS_test("orar_mediu_relaxat", 3, overwrite_ref=True)

True
0 conflicts
2013 states
|           Interval           |             Luni             |             Marti            |           Miercuri           |              Joi             |            Vineri            |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            8 - 10            |      ED038 - goala           |      ED038 - goala           |      MS : (ED038 - PD)       |      ED038 - goala           |      ED038 - goala           |
|                              |      ED041 - goala           |      ED041 - goala           |      ED041 - goala           |      ED041 - goala           |      ED041 - goala           |
|                              |      MS : (ED069 - IC)       |      PL : (ED069 - MA2)      |      AA : (ED069 - MA2)      |      MS : (ED069 - IG)       |      ED069 - goala           |
|                              

In [54]:
run_MCTS_test("orar_mare_relaxat", 2, overwrite_ref=True)

True
0 conflicts
3076 states
|           Interval           |             Luni             |             Marti            |           Miercuri           |              Joi             |            Vineri            |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            8 - 10            |      MS : (ED090 - RA)       |      MS : (ED090 - RA)       |      ASC : (ED090 - EI)      |      MS : (ED090 - RS)       |      ASC : (ED090 - IG)      |
|                              |      ASC : (ED091 - CC)      |      ED091 - goala           |      ED091 - goala           |      PCom : (ED091 - MA)     |      ED091 - goala           |
|                              |      ASC : (EG346 - MA)      |      ASC : (EG346 - CC)      |      EG346 - goala           |      EG346 - goala           |      EG346 - goala           |
|                              

In [77]:
run_MCTS_test("orar_constrans_incalcat", 2, overwrite_ref=True)

True
63 conflicts
11303 states
|           Interval           |             Luni             |             Marti            |           Miercuri           |              Joi             |            Vineri            |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            8 - 10            |      PM : (EC109 - RA)       |      SO : (EC109 - MP)       |      PM : (EC109 - MP)       |      DS : (EC109 - CA)       |      DS : (EC109 - ME)       |
|                              |      PM : (ED043 - CP2)      |      PM : (ED043 - CI)       |      SO : (ED043 - VA)       |      PM : (ED043 - CP)       |      SO : (ED043 - IV)       |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|            10 - 12         