In [339]:
!pip3 install dataclasses
!pip3 install ortools

from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import List, Dict, Tuple, Iterator, Union

from ortools.sat.python import cp_model
from ortools.sat.python.cp_model import IntVar, CpModel, CpSolver



In [340]:
Monkey = str

monkeys: List[Monkey] = [
    'one',
    'two',
    'three',
    'four',
    'five',
    'six',
    'seven',
    'eight',
    'nine',
    'old1',
    'old2',
    'old3'
]

start_date = datetime(2022, 3, 1)
end_date = datetime(2022, 4, 30)

pools_num = 2

pool_to_min_rest_days = [
    6,
    2
]

rest_days_between_pools = 3

AvailabilityInfoSchema = Dict[Monkey, List[Union[datetime, Tuple[datetime, datetime]]]]

unavailability_info: AvailabilityInfoSchema = {
    'one': [
        (datetime(2022, 3, 1), datetime(2022, 3, 14))
    ]
}

soft_unavailability_info: AvailabilityInfoSchema = {
    'two': [
    ]
}

preference_info: AvailabilityInfoSchema = {
    'three': [
        datetime(2022, 3, 2)
    ]
}

soft_preference_info: AvailabilityInfoSchema = {
    'four': [
        (datetime(2022, 3, 15), datetime(2022, 4, 17))
    ]
}
    
weekday_balance = {}
weekend_balance = {'eight': [-1, 0], 'nine': [-1, 0]}
holiday_balance = {}

monkey_classes = [
    {
        'id': 'regular', 
        'monkeys': [
            'one', 
            'two', 
            'three', 
            'four', 
            'five', 
            'six', 
            'seven', 
            'eight', 
            'nine'
        ], 
        'max_shifts': {
            'Weekend': [None, None],
            'Weekday': [None, None],
            'Holiday': [None, None]
        }
    },
    {
        'id': 'old', 
        'monkeys': [
            'old1', 
            'old2', 
            'old3'
        ],
        'max_shifts': {
            'Weekend': [0, None],
            'Weekday': [2, None],
            'Holiday': [0, None]
        }
    }
]

monkeys_that_cannot_occupy_shift_right_after_holiday = ['six', 'seven']
hard_enforce_after_holiday = True

In [341]:
def date_range(start_date: datetime, end_date: datetime, include_end_date: bool = True) -> Iterator[datetime]:
    """
    Creates an iterator that yields all dates between two given dates
    :param start_date:
    :param end_date:
    :param include_end_date
    :return:
    """
    for n in range(int((end_date - start_date).days) + int(include_end_date)):
        yield start_date + timedelta(n)

In [342]:
class ShiftKind(Enum):
    WEEKDAY = 'Weekday'
    WEEKEND = 'Weekend'
    HOLIDAY = 'Holiday'
    
@dataclass(eq=True, unsafe_hash=True)
class Shift:
    start_date: datetime
    end_date: datetime
    kind: ShiftKind

In [343]:
THURSDAY = 3
FRIDAY = 4
SATURDAY = 5
WEEKEND = [FRIDAY, SATURDAY]

def create_shifts(start_date: datetime, end_date: datetime) -> List[Shift]:
    """
    Creates a list of all shifts between two dates (considers weekdays, weekends and holidays)
    :param start_date:
    :param end_date:
    :return:
    """
    shifts = []

    for date in date_range(start_date, end_date):
        if date.weekday() == THURSDAY:
            shifts.append(Shift(
                start_date=date,
                end_date=date + timedelta(days=3),
                kind=ShiftKind.WEEKEND
            ))
            continue

        if date.weekday() in WEEKEND:
            continue

        shifts.append(Shift(
            start_date=date,
            end_date=date + timedelta(days=1),
            kind=ShiftKind.WEEKDAY
        ))

    return shifts

In [344]:
class ShiftKindInput:
    def __init__(self, shift_kind: ShiftKind, all_shifts: List[Shift], monkeys: List[Monkey],
                 balance: Dict[Monkey, int] = None):
        self.shift_kind = shift_kind
        self.shifts = list(filter(lambda shift: shift.kind is self.shift_kind, all_shifts))
        balance = balance or {}
        self.balance = {monkey: balance.get(monkey, [0] * pools_num) for monkey in monkeys}

    def get_stats(self, pool: int) -> Tuple[int, int, Dict[Monkey, int]]:
        min_shifts = min([value[pool] for value in self.balance.values()])
        max_shifts = max([value[pool] for value in self.balance.values()])
        normalized_balance = {key: value[pool] - min_shifts for key, value in self.balance.items()}

        return min_shifts, max_shifts, normalized_balance

    def print_stats(self, pool: int):
        min_shifts, max_shifts, normalized_balance = self.get_stats(pool)

        rows = [
            self.shift_kind.value,
            f'  - min: {min_shifts}',
            f'  - max: {max_shifts}',
            f'  - balance: {normalized_balance}'
        ]
        print('\n'.join(rows))

In [345]:
all_shifts = create_shifts(start_date=start_date, end_date=end_date)
date_to_shift = {shift.start_date: shift for shift in all_shifts}

In [346]:
def date_to_shifts_by_range(shifts: List[Shift], start_date: datetime, end_date: datetime = None) -> List[Shift]:
    """
    Filters a list of shifts to only the ones that are between two given days
    :param shifts:
    :param start_date:
    :param end_date:
    :return:
    """
    end_date = end_date or start_date
    return list(filter(lambda shift: start_date <= shift.start_date and shift.end_date <= end_date, shifts))

def transform_availability(availability_info: AvailabilityInfoSchema) -> Dict[Monkey, List[Shift]]:   
    new_availability_info = {monkey: [] for monkey in monkeys}
    
    for monkey, availabilities in availability_info.items():
        availabilities = availabilities.copy()
        
        for availability in availabilities:
            if isinstance(availability, tuple):
                new_availability_info[monkey] += date_to_shifts_by_range(
                    shifts=all_shifts,
                    start_date=availability[0],
                    end_date=availability[1]
                )
                
                continue
            
            if availability not in date_to_shift:
                print(f'Warning: {availability} not in range')
                continue
                
            new_availability_info[monkey].append(date_to_shift[availability])
    
    return new_availability_info
            
unavailability = transform_availability(unavailability_info)
soft_unavailability = transform_availability(soft_unavailability_info)
preference = transform_availability(preference_info)
soft_preference = transform_availability(soft_preference_info)


In [347]:
shift_kind_to_input = {
    ShiftKind.WEEKDAY: ShiftKindInput(
        ShiftKind.WEEKDAY,
        all_shifts=all_shifts,
        monkeys=monkeys,
        balance=weekday_balance
    ),
    ShiftKind.WEEKEND: ShiftKindInput(
        ShiftKind.WEEKEND,
        all_shifts=all_shifts,
        monkeys=monkeys,
        balance=weekend_balance,
    ),
    ShiftKind.HOLIDAY: ShiftKindInput(
        ShiftKind.HOLIDAY,
        all_shifts=all_shifts,
        monkeys=monkeys,
    ),
}

In [348]:
model = CpModel()

shift_pools: Dict[Tuple[int, Monkey, Shift], IntVar] = {}

# Set up the shifts dataset. This is a dict that resolves a combination of a monkey and a shift-slot to the boolean
# variable that describes the possibility that the monkey will be assigned to that shift.
for pool_index in range(pools_num):
    for monkey in monkeys:
        for shift in all_shifts:
            shift_pools[pool_index, monkey, shift] = model.NewBoolVar(f'shift_{pool_index}_{monkey}_{shift}')            

In [349]:
# Each shift will be assigned to exactly one monkey.
for shift in all_shifts:
    for pool_index in range(pools_num):
        model.Add(sum(shift_pools[pool_index, monkey, shift] for monkey in monkeys) == 1)

In [350]:
# Each monkey will be assigned to at most one shift each day.
for shift in all_shifts:
    for monkey in monkeys:
        model.Add(sum(shift_pools[pool_index, monkey, shift] for pool_index in range(pools_num)) <= 1)

In [351]:
# Availability / Preference

# Add monkeys_that_cannot_occupy_shift_right_after_holiday to unavailbility 
# (or soft_unavailability if hard_enforce_after_holiday is False)
previous_shift = None
for shift in all_shifts:
    if not previous_shift:
        previous_shift = shift
        continue
    
    if not previous_shift.kind in [ShiftKind.HOLIDAY, ShiftKind.WEEKEND]:
        previous_shift = shift
        continue
    
    for monkey in monkeys_that_cannot_occupy_shift_right_after_holiday:
        unavailability_settings = unavailability if hard_enforce_after_holiday else soft_unavailability
        unavailability_settings[monkey].append(shift)
    
    previous_shift = shift

# unavailability
for monkey, monkey_shifts in unavailability.items():
    for shift in monkey_shifts:
        for pool_index in range(pools_num):
            model.Add(shift_pools[pool_index, monkey, shift] == 0)

# soft-unavailability
num_unavailable = 0
for monkey, monkey_shifts in soft_unavailability.items():
    for shift in monkey_shifts:
        for pool_index in range(pools_num):
            num_unavailable += shift_pools[pool_index, monkey, shift]

model.Minimize(num_unavailable)

# preference (first pool only)
for monkey, monkey_shifts in preference.items():
    for shift in monkey_shifts:
        model.Add(shift_pools[0, monkey, shift] == 1)

# soft-preference (first pool only)
num_prefer = 0
for monkey, monkey_shifts in soft_preference.items():
    for shift in monkey_shifts:
        num_prefer += shift_pools[0, monkey, shift]

model.Maximize(num_prefer)

In [352]:
def get_min_and_max_shifts(
        shifts: List[Shift],
        shift_kind: ShiftKind,
        max_shifts: Union[int, None], 
        pool_index: int, 
        monkey_class: Dict, 
        monkey_classes: List[Dict]
):
    if max_shifts is not None:
        return max_shifts, max_shifts
    
    num_shifts = len(shifts)
    num_monkeys = len(monkey_class['monkeys'])
    
    for other_monkey_class in monkey_classes:
        if other_monkey_class['id'] == monkey_class['id']:
            continue
        
        other_monkey_class_max = other_monkey_class['max_shifts'][shift_kind.value][pool_index]
        if other_monkey_class_max is None:
            num_monkeys += len(other_monkey_class['monkeys'])
        else:
            num_shifts -= other_monkey_class_max * len(other_monkey_class['monkeys'])
            

    min_shifts_per_monkey = num_shifts // num_monkeys
    if num_shifts % num_monkeys == 0:
        max_shifts_per_monkey = min_shifts_per_monkey
    else:
        max_shifts_per_monkey = min_shifts_per_monkey + 1
        
    return min_shifts_per_monkey, max_shifts_per_monkey

def distribute_shifts_evenly(
        model: CpModel,
        shift_pools: Dict[Tuple[int, Monkey, Shift], IntVar],
        pools_num: int,
        shift_kind_to_input: Dict[ShiftKind, ShiftKindInput],
        monkey_class: List[Dict],
        monkey_classes: List[Dict],
):
    """
    Try to distribute day shifts evenly, so that each monkey works
    min_shifts_per_monkey shifts. If this is not possible, because the total
    number of shifts is not divisible by the number of monkeys, some monkeys will
    be assigned one more shift.
    """
    
    for shift_kind, max_shifts_definition in monkey_class['max_shifts'].items():
        shift_kind = ShiftKind(shift_kind)
        shift_kind_input = shift_kind_to_input[shift_kind]
        
        for pool_index, max_shifts in enumerate(max_shifts_definition):
            min_shifts_per_monkey, max_shifts_per_monkey = get_min_and_max_shifts(
                shifts=shift_kind_input.shifts,
                shift_kind=shift_kind,
                max_shifts=max_shifts, 
                pool_index=pool_index, 
                monkey_class=monkey_class, 
                monkey_classes=monkey_classes
            ) 
            
            need_balance_monkey_num = 0
            
            for monkey in monkey_class['monkeys']:
                num_shifts_worked = 0
                
                
                # Improve balance of unbalanced monkeys (at most by one)
                if (
                    shift_kind_input.balance[monkey][pool_index] < 0 
                    and (
                        min_shifts_per_monkey * (len(monkey_class['monkeys']) - need_balance_monkey_num) 
                        # The number of shifts monkeys that are already balanced need to do
                        
                        <= 
                        
                        (len(shift_kind_input.shifts) - (need_balance_monkey_num * max_shifts_per_monkey))
                        # The number of total shifts without the already unbalanced who need to do max_shifts 
                    )
                ):
                    need_balance_monkey_num += 1
                    min_shifts = max_shifts_per_monkey
                else:
                    min_shifts = min_shifts_per_monkey
        
        
                for index, shift in enumerate(shift_kind_input.shifts):
                    num_shifts_worked += shift_pools[pool_index, monkey, shift]

                model.Add(min_shifts <= num_shifts_worked)
                model.Add(num_shifts_worked <= max_shifts_per_monkey)
            
for monkey_class in monkey_classes: 
    # Distribute shifts evenly for each monkey_class

    distribute_shifts_evenly(
        model=model,
        shift_pools=shift_pools,
        pools_num=pools_num,
        shift_kind_to_input=shift_kind_to_input,
        monkey_class=monkey_class,
        monkey_classes=monkey_classes
    )

In [353]:
# There will be at least x rest-days between each shift (determined by pool_to_min_rest_days)
for monkey in monkeys:
    for pool_index in range(pools_num):
        works = [shift_pools[pool_index, monkey, shift] for shift in all_shifts]
        
        other_pool_works = [
            [shift_pools[other_pool_index, monkey, shift] for shift in all_shifts]
            for other_pool_index in range(pools_num)
        ]
        
        for i, work in enumerate(works):
            for other_pool_index, other_works in enumerate(other_pool_works):
                actual_min = min(
                    (
                        pool_to_min_rest_days[pool_index] 
                        if other_pool_index == pool_index 
                        else rest_days_between_pools
                    ), 
                    len(works) - 1 - i
                )
                
                model.AddBoolAnd([
                    sequence_work.Not()
                    for sequence_work in other_works[i + 1: i + 1 + actual_min]
                ]).OnlyEnforceIf(work)
            
            

In [354]:
solver = CpSolver()
status = solver.Solve(model)
if status != cp_model.OPTIMAL:
    print('No optimal solution found !')
else:
    print('Solution:')
    for shift in all_shifts:
        row_parts = []
        
        for pool_index in range(pools_num):
            for monkey in monkeys:
                if not solver.Value(shift_pools[pool_index, monkey, shift]) == 1:
                    continue

                shift_kind_to_input[shift.kind].balance[monkey][pool_index] += 1
                row_parts.append(monkey)
                

        for date in date_range(shift.start_date, shift.end_date, include_end_date=False):
            print(f'{date.strftime("%b-%d")} | {date.strftime("%a")} | {" | ".join(row_parts)}')

    print('\n\nStatistics:')
    
    for pool_index in range(pools_num):
        print()
        print(f'-- Pool {pool_index + 1} --')
        for shift_kind_input in shift_kind_to_input.values():
            shift_kind_input.print_stats(pool_index)

    print('\nSolution Statistics:')
    print('  - conflicts: %i' % solver.NumConflicts())
    print('  - branches : %i' % solver.NumBranches())
    print('  - wall time: %f s' % solver.WallTime())


Solution:
Mar-01 | Tue | nine | seven
Mar-02 | Wed | three | old3
Mar-03 | Thu | eight | six
Mar-04 | Fri | eight | six
Mar-05 | Sat | eight | six
Mar-06 | Sun | four | old1
Mar-07 | Mon | old2 | nine
Mar-08 | Tue | old3 | seven
Mar-09 | Wed | five | eight
Mar-10 | Thu | six | old1
Mar-11 | Fri | six | old1
Mar-12 | Sat | six | old1
Mar-13 | Sun | nine | three
Mar-14 | Mon | two | old2
Mar-15 | Tue | seven | old1
Mar-16 | Wed | four | six
Mar-17 | Thu | one | nine
Mar-18 | Fri | one | nine
Mar-19 | Sat | one | nine
Mar-20 | Sun | eight | five
Mar-21 | Mon | three | old3
Mar-22 | Tue | six | two
Mar-23 | Wed | old2 | seven
Mar-24 | Thu | five | eight
Mar-25 | Fri | five | eight
Mar-26 | Sat | five | eight
Mar-27 | Sun | four | two
Mar-28 | Mon | one | old3
Mar-29 | Tue | seven | old2
Mar-30 | Wed | eight | five
Mar-31 | Thu | two | old3
Apr-01 | Fri | two | old3
Apr-02 | Sat | two | old3
Apr-03 | Sun | nine | one
Apr-04 | Mon | six | old1
Apr-05 | Tue | four | three
Apr-06 | Wed | five 