In [2]:
import copy
import random
import time
import matplotlib.pyplot as plt
from dataclasses import dataclass
from enum import Enum

@dataclass
class Node:
    student_info: list
    calendar: list

class Student(Enum):
    MARIA = 0
    GUS = 1
    DIEGO = 2
    JOSE = 3

class Subject(Enum):
    MATH = 0
    HISTORY = 1
    PHYSICS = 2
    CHEMISTRY = 3
    SPANISH = 4
    ENGLISH = 5
    BIOLOGY = 6

class Day(Enum):
    MONDAY = 0
    TUESDAY = 1
    WEDNESDAY = 2

MAX_CLASSES_PER_STUDENT = 3
MAX_CLASSES_PER_DAY = len(Student)

def initial_node_initializer():
    student_info = {}

    for student in Student:
        classes = pick_subjects()
        days = random.sample(list(Day), MAX_CLASSES_PER_STUDENT)
        student_info[student] = list(zip(classes, days))

    calendar = do_calendar(student_info)

    return Node(student_info, calendar)

def pick_subjects():
    global global_subjects

    if len(global_subjects) == 0:
        global_subjects = list(Subject)

    if len(global_subjects) < MAX_CLASSES_PER_STUDENT:
        classes = [global_subjects.pop() for _ in range(len(global_subjects))]
        len_classes = len(classes)
        global_subjects = list(Subject)
        random.shuffle(global_subjects)

        for _ in range(MAX_CLASSES_PER_STUDENT - len_classes):
            classes.append(global_subjects.pop())
    else:
        random.shuffle(global_subjects)
        classes = [global_subjects.pop() for _ in range(MAX_CLASSES_PER_STUDENT)]

    return classes

def do_calendar(student_info):
    calendar = {day: [] for day in Day}

    for student in Student:
        for subject, day in student_info[student]:
            calendar[day].append(subject)

    return calendar

def scoring_function(node: Node):
    calendar = node.calendar
    student_info = node.student_info
    repeated_subjects_count = 0

    for day in calendar.keys():
        days_set = set(calendar[day])
        if len(days_set) < MAX_CLASSES_PER_DAY:
            repeated_subjects_count +=  MAX_CLASSES_PER_DAY - len(days_set)

    for student in student_info.keys():
        days_set = set(day for _, day in student_info[student])
        if len(days_set) < MAX_CLASSES_PER_STUDENT:
            repeated_subjects_count += MAX_CLASSES_PER_STUDENT - len(days_set)

    return repeated_subjects_count

def backtracking_search(initial_node):
    def backtrack(node):
        if is_complete(node):
            return node

        student, subject, day = select_unassigned_variable(node)

        for value in order_domain_values(node, student):
            new_node = assign(node, student, subject, day, value)
            if is_consistent(new_node):
                result = backtrack(new_node)
                if result is not None:
                    return result
        return None

    return backtrack(initial_node)

def is_complete(node):
    for student_info in node.student_info.values():
        if len(student_info) < MAX_CLASSES_PER_STUDENT:
            return False
    return True

def select_unassigned_variable(node):
    for student, student_info in node.student_info.items():
        if len(student_info) < MAX_CLASSES_PER_STUDENT:
            for subject, day in student_info:
                return student, subject, day

def order_domain_values(node, student):
    days_assigned = set(day for _, day in node.student_info[student])
    available_days = set(Day) - days_assigned
    return sorted(available_days)

def assign(node, student, subject, day, value):
    new_student_info = copy.deepcopy(node.student_info)
    new_student_info[student].append((subject, value))
    return Node(new_student_info, do_calendar(new_student_info))

def is_consistent(node):
    calendar = node.calendar
    student_info = node.student_info

    for day in calendar.keys():
        days_set = set(calendar[day])
        if len(days_set) < MAX_CLASSES_PER_DAY:
            return False

    for student in student_info.keys():
        days_set = set(day for _, day in student_info[student])
        if len(days_set) < MAX_CLASSES_PER_STUDENT:
            return False

    return True

# Testing the backtracking search
global_subjects = list(Subject)
init = initial_node_initializer()

start_time = time.time()
result = backtracking_search(init)
end_time = time.time()

if result is not None:
    print("Solution found:")
    for student, subjects in result.student_info.items():
        print(f"{student.name}:")
        for subject, day in subjects:
            print(f"  {subject.name} on {day.name}")
        print()
else:
    print("No solution found.")

print("Execution time:", end_time - start_time)


Solution found:
MARIA:
  ENGLISH on WEDNESDAY
  BIOLOGY on TUESDAY
  HISTORY on MONDAY

GUS:
  SPANISH on TUESDAY
  PHYSICS on MONDAY
  CHEMISTRY on WEDNESDAY

DIEGO:
  MATH on WEDNESDAY
  ENGLISH on MONDAY
  HISTORY on TUESDAY

JOSE:
  CHEMISTRY on MONDAY
  PHYSICS on TUESDAY
  SPANISH on WEDNESDAY

Execution time: 0.0005233287811279297
