In [1]:
!pip3 install dataclasses
!pip3 install ortools

from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import List, Dict, Tuple, Iterator, Union

from ortools.sat.python import cp_model
from ortools.sat.python.cp_model import IntVar, CpModel, CpSolver



In [85]:
Monkey = str

monkeys: List[Monkey] = [
    'one',
    'two',
    'three',
    'four',
    'five',
    'six',
    'seven',
    'eight',
    'nine'
]

start_date = datetime(2022, 3, 1)
end_date = datetime(2022, 4, 30)

pools_num = 2

pool_to_min_rest_days = [
    6,
    2
]

rest_days_between_pools = 2

AvailabilityInfoSchema = Dict[Monkey, List[Union[datetime, Tuple[datetime, datetime]]]]

unavailability_info: AvailabilityInfoSchema = {
    'one': [
        (datetime(2022, 3, 1), datetime(2022, 3, 14))
    ]
}

soft_unavailability_info: AvailabilityInfoSchema = {
    'two': [
    ]
}

preference_info: AvailabilityInfoSchema = {
    'three': [
        datetime(2022, 3, 2)
    ]
}

soft_preference_info: AvailabilityInfoSchema = {
    'four': [
        (datetime(2022, 3, 15), datetime(2022, 4, 17))
    ]
}
    
weekday_balance = {}
weekend_balance = {'eight': -1, 'nine': -1}
holiday_balance = {}

In [86]:
def date_range(start_date: datetime, end_date: datetime, include_end_date: bool = True) -> Iterator[datetime]:
    """
    Creates an iterator that yields all dates between two given dates
    :param start_date:
    :param end_date:
    :param include_end_date
    :return:
    """
    for n in range(int((end_date - start_date).days) + int(include_end_date)):
        yield start_date + timedelta(n)

In [87]:
class ShiftKind(Enum):
    WEEKDAY = 'Weekday'
    WEEKEND = 'Weekend'
    HOLIDAY = 'Holiday'
    
@dataclass(eq=True, unsafe_hash=True)
class Shift:
    start_date: datetime
    end_date: datetime
    kind: ShiftKind

In [88]:
THURSDAY = 3
FRIDAY = 4
SATURDAY = 5
WEEKEND = [FRIDAY, SATURDAY]

def create_shifts(start_date: datetime, end_date: datetime) -> List[Shift]:
    """
    Creates a list of all shifts between two dates (considers weekdays, weekends and holidays)
    :param start_date:
    :param end_date:
    :return:
    """
    shifts = []

    for date in date_range(start_date, end_date):
        if date.weekday() == THURSDAY:
            shifts.append(Shift(
                start_date=date,
                end_date=date + timedelta(days=3),
                kind=ShiftKind.WEEKEND
            ))
            continue

        if date.weekday() in WEEKEND:
            continue

        shifts.append(Shift(
            start_date=date,
            end_date=date + timedelta(days=1),
            kind=ShiftKind.WEEKDAY
        ))

    return shifts

In [89]:
class ShiftKindInput:
    def __init__(self, shift_kind: ShiftKind, all_shifts: List[Shift], monkeys: List[Monkey],
                 balance: Dict[Monkey, int] = None):
        self.shift_kind = shift_kind
        self.shifts = list(filter(lambda shift: shift.kind is self.shift_kind, all_shifts))
        balance = balance or {}
        self.balance = {monkey: balance.get(monkey, 0) for monkey in monkeys}

    def get_stats(self) -> Tuple[int, int, Dict[Monkey, int]]:
        min_shifts = min(self.balance.values())
        max_shifts = max(self.balance.values())
        normalized_balance = {key: value - min_shifts for key, value in self.balance.items()}

        return min_shifts, max_shifts, normalized_balance

    def print_stats(self):
        min_shifts, max_shifts, normalized_balance = self.get_stats()

        rows = [
            self.shift_kind.value,
            f'  - min: {min_shifts}',
            f'  - max: {max_shifts}',
            f'  - balance: {normalized_balance}'
        ]
        print('\n'.join(rows))

In [90]:
all_shifts = create_shifts(start_date=start_date, end_date=end_date)
date_to_shift = {shift.start_date: shift for shift in all_shifts}

In [91]:
def date_to_shifts_by_range(shifts: List[Shift], start_date: datetime, end_date: datetime = None) -> List[Shift]:
    """
    Filters a list of shifts to only the ones that are between two given days
    :param shifts:
    :param start_date:
    :param end_date:
    :return:
    """
    end_date = end_date or start_date
    return list(filter(lambda shift: start_date <= shift.start_date and shift.end_date <= end_date, shifts))

def transform_availability(availability_info: AvailabilityInfoSchema) -> Dict[Monkey, List[Shift]]:   
    new_availability_info = {monkey: [] for monkey in availability_info}
    for monkey, availabilities in availability_info.items():
        availabilities = availabilities.copy()
        
        for availability in availabilities:
            if isinstance(availability, tuple):
                new_availability_info[monkey] += date_to_shifts_by_range(
                    shifts=all_shifts,
                    start_date=availability[0],
                    end_date=availability[1]
                )
                
                continue
            
            if availability not in date_to_shift:
                print(f'Warning: {availability} not in range')
                continue
                
            new_availability_info[monkey].append(date_to_shift[availability])
    
    return new_availability_info
            
unavailability = transform_availability(unavailability_info)
soft_unavailability = transform_availability(soft_unavailability_info)
preference = transform_availability(preference_info)
soft_preference = transform_availability(soft_preference_info)


In [92]:
shift_kind_to_input = {
    ShiftKind.WEEKDAY: ShiftKindInput(
        ShiftKind.WEEKDAY,
        all_shifts=all_shifts,
        monkeys=monkeys,
        balance=weekday_balance
    ),
    ShiftKind.WEEKEND: ShiftKindInput(
        ShiftKind.WEEKEND,
        all_shifts=all_shifts,
        monkeys=monkeys,
        balance=weekend_balance,
    ),
    ShiftKind.HOLIDAY: ShiftKindInput(
        ShiftKind.HOLIDAY,
        all_shifts=all_shifts,
        monkeys=monkeys,
    ),
}

In [93]:
model = CpModel()

shift_pools: Dict[Tuple[int, Monkey, Shift], IntVar] = {}

# Set up the shifts dataset. This is a dict that resolves a combination of a monkey and a shift-slot to the boolean
# variable that describes the possibility that the monkey will be assigned to that shift.
for pool_index in range(pools_num):
    for monkey in monkeys:
        for shift in all_shifts:
            shift_pools[pool_index, monkey, shift] = model.NewBoolVar(f'shift_{pool_index}_{monkey}_{shift}')            

In [94]:
# Each shift will be assigned to exactly one monkey.
for shift in all_shifts:
    for pool_index in range(pools_num):
        model.Add(sum(shift_pools[pool_index, monkey, shift] for monkey in monkeys) == 1)

In [95]:
# Each monkey will be assigned to at most one shift each day.
for shift in all_shifts:
    for monkey in monkeys:
        model.Add(sum(shift_pools[pool_index, monkey, shift] for pool_index in range(pools_num)) <= 1)

In [96]:
# Availability / Preference

# unavailability
for monkey, monkey_shifts in unavailability.items():
    for shift in monkey_shifts:
        for pool_index in range(pools_num):
            model.Add(shift_pools[pool_index, monkey, shift] == 0)

# soft-unavailability
num_unavailable = 0
for monkey, monkey_shifts in soft_unavailability.items():
    for shift in monkey_shifts:
        for pool_index in range(pools_num):
            num_unavailable += shift_pools[pool_index, monkey, shift]

model.Minimize(num_unavailable)

# preference (first pool only)
for monkey, monkey_shifts in preference.items():
    for shift in monkey_shifts:
        model.Add(shift_pools[0, monkey, shift] == 1)

# soft-preference (first pool only)
num_prefer = 0
for monkey, monkey_shifts in soft_preference.items():
    for shift in monkey_shifts:
        num_prefer += shift_pools[0, monkey, shift]

model.Maximize(num_prefer)

In [97]:
def distribute_shifts_evenly(
        model: CpModel,
        shift_pools: Dict[Tuple[int, Monkey, Shift], IntVar],
        pools_num: int,
        all_shifts: List[Shift],
        monkeys: List[Monkey],
        monkey_balance: Dict[Monkey, int],
):
    """
    Try to distribute day shifts evenly, so that each monkey works
    min_shifts_per_monkey shifts. If this is not possible, because the total
    number of shifts is not divisible by the number of nurses, some monkeys will
    be assigned one more shift.
    :param model:
    :param shifts:
    :param all_shifts:
    :param monkeys:
    :param monkey_balance:
    :return:
    """

    num_shifts = len(all_shifts)
    num_monkeys = len(monkeys)
    min_shifts_per_monkey = num_shifts // num_monkeys
    if num_shifts % num_monkeys == 0:
        max_shifts_per_monkey = min_shifts_per_monkey
    else:
        max_shifts_per_monkey = min_shifts_per_monkey + 1
    
    for pool_index in range(pools_num):
        for monkey in monkeys:
            # TODO: monkey_balance.get(monkey, 0)
            # But what if monkey balance is too low to pass min_shifts_per_monkey
            # And, still better that if 

            num_shifts_worked = 0 

            for index, shift in enumerate(all_shifts):
                num_shifts_worked += shift_pools[pool_index, monkey, shift]

            model.Add(min_shifts_per_monkey <= num_shifts_worked)
            model.Add(num_shifts_worked <= max_shifts_per_monkey)

# Distribute shifts evenly for each shift kind
for shift_kind, shift_kind_input in shift_kind_to_input.items():
    distribute_shifts_evenly(
        model=model,
        shift_pools=shift_pools,
        pools_num=pools_num,
        all_shifts=shift_kind_input.shifts,
        monkeys=monkeys,
        monkey_balance=shift_kind_input.balance
    )

In [98]:
# There will be at least x rest-days between each shift (determined by pool_to_min_rest_days)
for monkey in monkeys:
    for pool_index in range(pools_num):
        works = [shift_pools[pool_index, monkey, shift] for shift in all_shifts]
        
        other_pool_works = [
            [shift_pools[other_pool_index, monkey, shift] for shift in all_shifts]
            for other_pool_index in range(pools_num)
        ]
        
        for i, work in enumerate(works):
            for other_pool_index, other_works in enumerate(other_pool_works):
                actual_min = min(
                    (
                        pool_to_min_rest_days[pool_index] 
                        if other_pool_index == pool_index 
                        else rest_days_between_pools
                    ), 
                    len(works) - 1 - i
                )
                
                model.AddBoolAnd([
                    sequence_work.Not()
                    for sequence_work in other_works[i + 1: i + 1 + actual_min]
                ]).OnlyEnforceIf(work)
            
            

In [99]:
# # There will be at least rest_days_between_pools rest-days between shifts in different pools
# for monkey in monkeys:
#     for pool_index in range(pools_num):
#         works = []
#         for other_pool_index in range(pools_num):
#             works.append(shift_pools[pool_index, monkey, shift] for shift in all_shifts)

#             for i, work in enumerate(works):
#                 actual_min = min(rest_days_between_pools, len(works) - 1 - i)

#                 model.AddBoolAnd([
#                     sequence_work.Not()
#                     for sequence_work in works[i + 1: i + 1 + actual_min]
#                 ]).OnlyEnforceIf(work)

In [100]:
solver = CpSolver()
status = solver.Solve(model)

In [101]:
for i in range(pools_num):
    print(pool_to_min_rest_days[i])

6
2


In [102]:
status

4

In [103]:
solver = CpSolver()
status = solver.Solve(model)
if status != cp_model.OPTIMAL:
    print('No optimal solution found !')
else:
    print('Solution:')
    for shift in all_shifts:
        row_parts = []
        
        for pool_index in range(pools_num):
            for monkey in monkeys:
                if not solver.Value(shift_pools[pool_index, monkey, shift]) == 1:
                    continue

                shift_kind_to_input[shift.kind].balance[monkey] += 1
                row_parts.append(monkey)
                

        for date in date_range(shift.start_date, shift.end_date, include_end_date=False):
            print(f'{date.strftime("%b-%d")} | {date.strftime("%a")} | {" | ".join(row_parts)}')

    print('\n\nStatistics:')

    for shift_kind_input in shift_kind_to_input.values():
        shift_kind_input.print_stats()

    print('\nSolution Statistics:')
    print('  - conflicts: %i' % solver.NumConflicts())
    print('  - branches : %i' % solver.NumBranches())
    print('  - wall time: %f s' % solver.WallTime())


Solution:
Mar-01 | Tue | seven | four
Mar-02 | Wed | three | six
Mar-03 | Thu | nine | five
Mar-04 | Fri | nine | five
Mar-05 | Sat | nine | five
Mar-06 | Sun | four | eight
Mar-07 | Mon | two | three
Mar-08 | Tue | five | six
Mar-09 | Wed | eight | nine
Mar-10 | Thu | seven | two
Mar-11 | Fri | seven | two
Mar-12 | Sat | seven | two
Mar-13 | Sun | three | five
Mar-14 | Mon | one | six
Mar-15 | Tue | nine | seven
Mar-16 | Wed | four | eight
Mar-17 | Thu | five | three
Mar-18 | Fri | five | three
Mar-19 | Sat | five | three
Mar-20 | Sun | six | nine
Mar-21 | Mon | two | one
Mar-22 | Tue | seven | four
Mar-23 | Wed | three | nine
Mar-24 | Thu | eight | one
Mar-25 | Fri | eight | one
Mar-26 | Sat | eight | one
Mar-27 | Sun | four | two
Mar-28 | Mon | nine | seven
Mar-29 | Tue | six | three
Mar-30 | Wed | five | one
Mar-31 | Thu | two | four
Apr-01 | Fri | two | four
Apr-02 | Sat | two | four
Apr-03 | Sun | three | nine
Apr-04 | Mon | eight | one
Apr-05 | Tue | four | two
Apr-06 | Wed | ni