<a href="https://colab.research.google.com/github/ikaru55/DocumentGPT/blob/main/Shift.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random

# 환경 파라미터
NUM_DAYS = 30
NUM_PERSONNEL = 10
MAX_CONSECUTIVE_OFF = 4
REQUIRED_OFF = 12  # 고정 휴무일
MIN_PD_DAYS = 4   # 최소 PD 근무일수

class DutySchedulerEnv:
    def __init__(self, cq_schedule, leave_schedule):
        self.current_day = 0
        self.personnel = {
            i: {
                'off_count': 0,
                'rd_count': 0,    # RD 근무 일수
                'pd_count': 0,    # PD 근무 일수
                'last_off_days': 0,
                'cq_next_day_off': False
            } for i in range(NUM_PERSONNEL)
        }
        self.cq_schedule = cq_schedule
        self.leave_schedule = leave_schedule
        self.daily_duties = []  # 일별 근무 기록

    def get_state(self):
        state = []
        for p in self.personnel.values():
            state.extend([
                p['off_count'],
                p['rd_count'],
                p['pd_count'],
                p['last_off_days'],
                int(p['cq_next_day_off'])
            ])
        return np.array(state + [self.current_day], dtype=np.float32)

    def reset(self):
        self.current_day = 0
        for p in self.personnel.values():
            p.update({
                'off_count': 0,
                'rd_count': 0,
                'pd_count': 0,
                'last_off_days': 0,
                'cq_next_day_off': False
            })
        self.daily_duties = []
        return self.get_state()

    def _action_to_schedule(self, action):
        rd_personnel = []
        pd_personnel = []
        cq_personnel = []

        # 상위 10비트: RD, 중간 10비트: PD, 하위 10비트: CQ
        for i in range(NUM_PERSONNEL):
            if action & (1 << i):
                rd_personnel.append(i)
            if action & (1 << (i + NUM_PERSONNEL)):
                pd_personnel.append(i)
            if action & (1 << (i + 2 * NUM_PERSONNEL)):
                cq_personnel.append(i)

        return self.current_day, rd_personnel, pd_personnel, cq_personnel

    def step(self, action):
        day, rd_personnel, pd_personnel, cq_personnel = self._action_to_schedule(action)
        reward = 0
        done = False

        # Validate action
        valid = True

        # 기본 검증
        if len(rd_personnel) != 4:  # RD는 항상 4명
            valid = False
        if day in self.cq_schedule and set(cq_personnel) != set(self.cq_schedule[day]):
            valid = False

        # 중복 배정 검사
        all_duty_personnel = set(rd_personnel) | set(pd_personnel) | set(cq_personnel)
        if len(all_duty_personnel) != len(rd_personnel) + len(pd_personnel) + len(cq_personnel):
            valid = False

        if not valid:
            return self.get_state(), -100, True

        # 기록 저장
        self.daily_duties.append((rd_personnel, pd_personnel, cq_personnel))

        # Update personnel states
        for p_id in range(NUM_PERSONNEL):
            if p_id in self.leave_schedule.get(day, []):
                self._grant_off(p_id)
                continue

            if p_id in rd_personnel:
                self.personnel[p_id]['rd_count'] += 1
                self._reset_consecutive_off(p_id)
            elif p_id in pd_personnel:
                self.personnel[p_id]['pd_count'] += 1
                self._reset_consecutive_off(p_id)
            elif p_id in cq_personnel:
                self.personnel[p_id]['cq_next_day_off'] = True
            else:
                self._grant_off(p_id)

        # Calculate rewards
        reward += self._calculate_rewards()

        self.current_day += 1
        if self.current_day >= NUM_DAYS:
            done = True
            reward += self._final_validation()

        return self.get_state(), reward, done

    def _grant_off(self, p_id):
        if self.personnel[p_id]['off_count'] < REQUIRED_OFF:
            self.personnel[p_id]['off_count'] += 1
            self.personnel[p_id]['last_off_days'] += 1
        else:
            # 12일 휴무를 채웠다면 PD 근무로 할당
            self.personnel[p_id]['pd_count'] += 1
            self.personnel[p_id]['last_off_days'] = 0

    def _reset_consecutive_off(self, p_id):
        if self.personnel[p_id]['cq_next_day_off']:
            self._grant_off(p_id)
            self.personnel[p_id]['cq_next_day_off'] = False
        else:
            self.personnel[p_id]['last_off_days'] = 0

    def _calculate_rewards(self):
        reward = 0

        # 연속 휴무 관련 보상/패널티
        for p in self.personnel.values():
            if 1 < p['last_off_days'] <= MAX_CONSECUTIVE_OFF:
                reward += p['last_off_days'] * 2
            elif p['last_off_days'] > MAX_CONSECUTIVE_OFF:
                reward -= 20

            # PD 근무 분배 보상
            if p['pd_count'] >= MIN_PD_DAYS:
                reward += 5

        return reward

    def _final_validation(self):
        reward = 0

        for p in self.personnel.values():
            # 정확히 12일 휴무 보상/패널티
            if p['off_count'] == REQUIRED_OFF:
                reward += 50
            else:
                reward -= abs(p['off_count'] - REQUIRED_OFF) * 20

            # 근무일수 합계 검증 (30일 = OFF + RD + PD + CQ후휴무)
            total_days = p['off_count'] + p['rd_count'] + p['pd_count']
            if p['cq_next_day_off']:
                total_days += 1

            if total_days != NUM_DAYS:
                reward -= abs(NUM_DAYS - total_days) * 15

            # 최소 PD 근무일수 검증
            if p['pd_count'] < MIN_PD_DAYS:
                reward -= (MIN_PD_DAYS - p['pd_count']) * 10

        return reward

    def get_schedule_summary(self):
        """현재까지의 스케줄 요약 반환"""
        summary = {
            'personnel_stats': {
                i: {
                    'off_days': p['off_count'],
                    'rd_days': p['rd_count'],
                    'pd_days': p['pd_count'],
                    'cq_next_off': p['cq_next_day_off']
                } for i, p in self.personnel.items()
            },
            'daily_duties': self.daily_duties
        }
        return summary

class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, output_size)
        )

    def forward(self, x):
        return self.fc(x)

class Agent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=10000)
        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = DQN(state_size, action_size).to(self.device)
        self.target_model = DQN(state_size, action_size).to(self.device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)

        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            q_values = self.model(state)
            return torch.argmax(q_values).item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay(self, batch_size):
        if len(self.memory) < batch_size:
            return

        minibatch = random.sample(self.memory, batch_size)
        states = torch.FloatTensor([t[0] for t in minibatch]).to(self.device)
        actions = torch.LongTensor([t[1] for t in minibatch]).to(self.device)
        rewards = torch.FloatTensor([t[2] for t in minibatch]).to(self.device)
        next_states = torch.FloatTensor([t[3] for t in minibatch]).to(self.device)
        dones = torch.FloatTensor([t[4] for t in minibatch]).to(self.device)

        current_q_values = self.model(states).gather(1, actions.unsqueeze(1))

        with torch.no_grad():
            next_q_values = self.target_model(next_states).max(1)[0]

        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

# 실행 예시
if __name__ == "__main__":
    cq_schedule = {5: [3], 15: [7]}
    leave_schedule = {10: [2], 20: [5]}

    env = DutySchedulerEnv(cq_schedule, leave_schedule)
    state_size = len(env.reset())
    action_size = 2 ** (NUM_PERSONNEL * 3)  # RD, PD, CQ 각각 10비트

    agent = Agent(state_size, action_size)
    batch_size = 64
    episodes = 1000
    target_update_frequency = 10

    best_reward = float('-inf')
    best_schedule = None

    for e in range(episodes):
        state = env.reset()
        total_reward = 0

        for day in range(NUM_DAYS):
            action = agent.act(state)
            next_state, reward, done = env.step(action)
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward

            if len(agent.memory) >= batch_size:
                agent.replay(batch_size)

            if done:
                break

        if total_reward > best_reward:
            best_reward = total_reward
            best_schedule = env.get_schedule_summary()

        if e % target_update_frequency == 0:
            agent.update_target_model()

        print(f"Episode: {e+1}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")

    # 최적의 스케줄 출력
    if best_schedule:
        print("\nBest Schedule Found:")
        for pid, stats in best_schedule['personnel_stats'].items():
            print(f"\nPersonnel {pid}:")
            print(f"  Off Days: {stats['off_days']}")
            print(f"  RD Days: {stats['rd_days']}")
            print(f"  PD Days: {stats['pd_days']}")
            print(f"  CQ Next Off: {stats['cq_next_off']}")
            total_days = stats['off_days'] + stats['rd_days'] + stats['pd_days'] + int(stats['cq_next_off'])
            print(f"  Total Days: {total_days}")