In [None]:
import random

In [7]:
class ExamTimetableEnv:
    def __init__(self, courses, rooms, conflicts):
        self.courses = courses  # {course: size}
        self.rooms = rooms      # [(room,capacity)]
        self.conflicts = conflicts  # {(c1,c2): overlap}
        self.timeslots = ["T1","T2"]

    def initial_state(self):
        return {c: (random.choice(self.timeslots), random.choice(self.rooms)[0]) 
                for c in self.courses}

    def neighbours(self, state):
        neighs = []
        for c in self.courses:
            for t in self.timeslots:
                for r,_ in self.rooms:
                    if (t,r) != state[c]:
                        new = dict(state); new[c] = (t,r)
                        neighs.append(new)
        return neighs

    def random_neighbour(self, state):
        c = random.choice(list(self.courses))
        t = random.choice(self.timeslots)
        r = random.choice(self.rooms)[0]
        new = dict(state); new[c] = (t,r)
        return new

    def cost(self, state):
        clash_penalty, overflow_penalty = 10, 5
        cost = 0
        for (c1,c2), ov in self.conflicts.items():
            if state[c1][0] == state[c2][0]:
                cost += clash_penalty * ov
        for (t,r) in {(ts,rm) for ts,rm in state.values()}:
            assigned = [c for c,v in state.items() if v==(t,r)]
            cap = next(cap for rr,cap in self.rooms if rr==r)
            used = sum(self.courses[c] for c in assigned)
            if used > cap:
                cost += overflow_penalty * (used-cap)
        return cost

    @staticmethod
    def random_instance(n_courses=5, n_rooms=2, max_size=50, seed=None):
        random.seed(seed)
        courses = {f"C{i+1}": random.randint(10,max_size) for i in range(n_courses)}
        rooms = [(f"R{j+1}", random.randint(20,max_size)) for j in range(n_rooms)]
        conflicts = {}
        cl = list(courses)
        for _ in range(n_courses):
            c1,c2 = random.sample(cl,2)
            conflicts[(c1,c2)] = random.randint(1, min(courses[c1],courses[c2])//2)
        return ExamTimetableEnv(courses, rooms, conflicts)
