In [None]:
# !rm -r /kaggle/working
# !git clone https://github.com/ce-muzzamil/semiconductor_fabrication_scheduling.git
# !cp -v -r /kaggle/working/semiconductor_fabrication_scheduling/* /kaggle/working/
# !rm -r /kaggle/working/semiconductor_fabrication_scheduling

# import sys
# sys.path.append("/kaggle/working/")

In [None]:
import numpy as np
from simulation.file_instance import FileInstance
from simulation.read import read_all
from simulation.dispatching.dispatcher import dispatcher_map

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.optim import Adam

from logger import Logger
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
class SCFabEnv:
    def __init__(self, dataset, days=1, dispatcher='fifo', seed=0):
        self.files = read_all('datasets/' + dataset)
        self.instance = None
        self.days = days
        self.seed_val = seed
        self.dispatcher = dispatcher_map[dispatcher]

        self.none2zero = lambda x: 0.0 if x is None or x == '' else x 

    def reset(self, hard=True):
        if not hard:
            self.throughput = 0
            self.tardiness = 0
            return
        
        self.throughput = 0
        self.tardiness = 0
        run_to = 3600 * 24 * self.days
        self.eid = np.random.randint(999_999_999)
        self.instance = FileInstance(self.files, run_to, True, [])
        self.num_lots_done = 0
        self.lots_done = []
        self.lots_dispatched = []
        self.groups = {}
        for machine in self.instance.machines:
            if machine.group not in self.groups:
                self.groups[machine.group] = set()
            self.groups[machine.group].add(machine.family)

        self.families = {}
        for machine in self.instance.machines:
            if machine.family not in self.families:
                self.families[machine.family] = machine.group

        self.loc_2_index = {k:e for e, k in enumerate(sorted(set([m.loc for m in self.instance.machines])))}
        self.group_2_index = {k:e for e, k in enumerate(sorted(set(self.groups.keys())))}

        self.instance.next_decision_point()
        self.eid = np.random.randint(999_999_999)
        machines = self.preprocess()
        self.get_state(machines)
        if len(self.conflicting_machines) > 0:
            for family in self.conflicting_machines:
                return self.state_tensor(family)
        else:
            return self.step(None)

    @property
    def get_machines_t(self):
        """usable machines at time t"""
        machines = list(self.instance.usable_machines)
        families = [m.family for m in machines]
        machines = {family: {"counts": families.count(family), "machines":[machine for machine in machines if machine.family == family]} for family in families}

        for family in machines:
            step_names = set([lot.actual_step.step_name for lot in machines[family]["machines"][0].waiting_lots])
            machines[family]["lots_groups"] = {step_name: [lot for lot in machines[family]["machines"][0].waiting_lots if lot.actual_step.step_name == step_name] for step_name in step_names}
        
        for family in machines:
            if machines[family]["counts"] >= len(machines[family]["lots_groups"]):
                machines[family]["conflicting"] = False
            else:
                machines[family]["conflicting"] = True
        return machines
    
    def dispatch_non_conflicting(self, machines):
        for family in machines:
            if not machines[family]["conflicting"]:
                for i, step_name in enumerate(machines[family]["lots_groups"]):
                    self.instance.dispatch(machines[family]["machines"][i], machines[family]["lots_groups"][step_name])

    def preprocess(self):
        self.machines = self.get_machines_t
        self.dispatch_non_conflicting(self.machines)
        return self.machines
    
    def get_state(self, machines):
        mask = {family:machines[family]["conflicting"] for family in machines}
        conflicting_machines = {family:machines[family] for family in machines if mask[family]}
        self.conflicting_machines = conflicting_machines
        return conflicting_machines

    def step(self, action):
        families = self.conflicting_machines.keys()
        for family in list(families):
            if self.conflicting_machines[family]["counts"] == 0:
                self.conflicting_machines.pop(family)

        info = {"time": self.instance.current_time, "done_lots":[], "dispatched_lots": []}
        new_lots_done = self.instance.done_lots[self.num_lots_done:]

        for lot in new_lots_done:
            if lot.idx in self.lots_dispatched:
                self.lots_done.append(lot.idx)
                info['done_lots'].append(lot)
                self.throughput += 1
                lateness_hours = 0 if lot.deadline_at >= lot.done_at else 1
                self.tardiness += lateness_hours

        self.num_lots_done = len(self.instance.done_lots)

        if len(self.conflicting_machines) == 0:
            machines = self.preprocess()
            self.get_state(machines)
            done = self.instance.next_decision_point()
            if done or self.instance.current_time > 3600 * 24 * self.days:
                done = True
            for family in self.conflicting_machines:
                return self.state_tensor(family), 0, done, info
            
        else:
            families = self.conflicting_machines.keys()
            for family in families:
                if family in self.conflicting_machines.keys():
                    lot_groups = self.conflicting_machines[family]["lots_groups"]

                    for i, step_name in enumerate(lot_groups):
                        if i == action:
                            machine = self.conflicting_machines[family]["machines"].pop()
                            lot_group = self.conflicting_machines[family]["lots_groups"].pop(step_name)
                            self.conflicting_machines[family]["counts"] -= 1
                            info["dispatched_lots"].extend([lot.idx for lot in lot_group])
                            self.lots_dispatched.extend([lot.idx for lot in lot_group])
                            self.instance.dispatch(machine, lot_group)
                            break
                    break

            for family in self.conflicting_machines:
                if self.conflicting_machines[family]["counts"] > 0:
                    return self.state_tensor(family), 0, False, info
                
        return None, 0, False, info
                    
    def state_tensor(self, family):
        def foo(**kwargs):
            return np.array(list(kwargs.values()))
            
        machine_features = foo(num_units=len(self.conflicting_machines[family]["machines"]),
                               group_idx=self.group_2_index[self.families[family]],
                               num_machine_families_in_group=len(self.groups[self.families[family]]),
                               load_time_hr=np.mean([m.load_time for m in self.conflicting_machines[family]["machines"]])/3600,
                               unload_time_hr=np.mean([m.unload_time for m in self.conflicting_machines[family]["machines"]])/3600,
                               loc_idx=self.loc_2_index[self.conflicting_machines[family]["machines"][0].loc],
                               num_waiting_lots=len(self.conflicting_machines[family]["machines"][0].waiting_lots),
                               utilized_time=np.mean([m.utilized_time for m in self.conflicting_machines[family]["machines"]]),
                               setuped_time=np.mean([m.setuped_time for m in self.conflicting_machines[family]["machines"]]),
                               pmed_time=np.mean([m.pmed_time for m in self.conflicting_machines[family]["machines"]]),
                               bred_time=np.mean([m.bred_time for m in self.conflicting_machines[family]["machines"]]),
                               min_runs_left_max=np.max([self.none2zero(m.min_runs_left) for m in self.conflicting_machines[family]["machines"]]),
                               min_runs_left_min=np.min([self.none2zero(m.min_runs_left) for m in self.conflicting_machines[family]["machines"]]))

        lot_groups_features = []
        for step_name in self.conflicting_machines[family]["lots_groups"]:
            lot_group: list = self.conflicting_machines[family]["lots_groups"][step_name]
            lot_group_features = [
                len(lot_group),
                np.mean([(lot.deadline_at - self.instance.current_time)/3600 for lot in lot_group]),
                np.max([(lot.deadline_at - self.instance.current_time)/3600 for lot in lot_group]),
                np.mean([(lot.relative_deadline)/3600 for lot in lot_group]),
                np.max([(lot.relative_deadline)/3600 for lot in lot_group]),
                np.mean([(self.instance.current_time - lot.free_since)/3600 for lot in lot_group]),
                np.max([(self.instance.current_time - lot.free_since)/3600 for lot in lot_group]),
                np.mean([len(lot.remaining_steps) for lot in lot_group]),
                np.max([len(lot.remaining_steps) for lot in lot_group]),
                np.mean([lot.cr(self.instance.current_time) for lot in lot_group]),
                np.max([lot.cr(self.instance.current_time) for lot in lot_group]),
                np.mean([lot.priority for lot in lot_group]),
                np.max([lot.priority for lot in lot_group]),
                lot_group[0].actual_step.processing_time.avg(),
                lot_group[0].actual_step.batch_max,
                lot_group[0].actual_step.batch_min,
                0 if lot_group[0].actual_step.setup_needed == '' or lot_group[0].actual_step.setup_needed == self.conflicting_machines[family]["machines"][0].current_setup else 1
            ]
            lot_groups_features.append(np.concatenate([machine_features, np.array(lot_group_features)]))

        return np.stack(lot_groups_features, axis=0).astype("float32")

In [None]:
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, seq_len, d_model):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.randn(1, seq_len, d_model))

    def forward(self, x):
        return x + self.pos_embedding[:, :x.shape[1], :]

class FeatureExtractor(nn.Module):
    def __init__(self, seq_len, input_dim, num_layers=4, nhead=8, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.pos_encoder = LearnablePositionalEncoding(seq_len, input_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=input_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x): 
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)
        return x #N,L,E

class EMB(nn.Module):
    def __init__(self, embed_size, hdim, drp=0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(embed_size, hdim),
            nn.Dropout(drp),
            nn.GELU(),
            nn.Linear(hdim, hdim),
            nn.Dropout(drp),
            nn.GELU(),
            nn.Linear(hdim, embed_size),
        )
    def forward(self, x):
        #N,L,E -> N,L,E
        x = self.mlp(x)
        return x


class Actor(nn.Module):
    def __init__(self, embed_size, hdim, drp=0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(embed_size, hdim),
            nn.Dropout(drp),
            nn.GELU(),
            nn.Linear(hdim, hdim),
            nn.Dropout(drp),
            nn.GELU(),
            nn.Linear(hdim, 1),
        )
    def forward(self, x):
        #N,L,E -> N,L,1
        x = self.mlp(x)
        return x.squeeze(-1)
    
class critic(nn.Module):
    def __init__(self, embed_size, hdim, drp=0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(embed_size, hdim),
            nn.Dropout(drp),
            nn.GELU(),
            nn.Linear(hdim, hdim),
            nn.Dropout(drp),
            nn.GELU(),
            nn.Linear(hdim, 1),
        )

    def forward(self, x):
        #N,L,E -> N,1
        x = self.mlp(x.mean(1))
        return x.squeeze(-1)

class Model(nn.Module):
    def __init__(self, input_dim, embed_size, seq_len, num_enc, num_heads, hdim, drp=0.1):
        super().__init__()

        self.embedding = nn.Linear(input_dim, embed_size)
        self.fe = EMB(embed_size, hdim, drp)
        self.feature_extractor = FeatureExtractor(seq_len=seq_len, 
                                                  input_dim=embed_size, 
                                                  num_layers=num_enc, 
                                                  nhead=num_heads, 
                                                  dim_feedforward=hdim, 
                                                  dropout=drp)
        self.actor = Actor(embed_size, hdim, drp)
        self.critic = critic(embed_size, hdim, drp)
    
    def forward(self, x):
        # x: (N, L, E)
        unsqueezed = False
        if x.ndim == 2:
            x = x.unsqueeze(0)
            unsqueezed = True
            
        x = self.embedding(x)
        x_a = self.fe(x)
        x_c = self.feature_extractor(x_a)  # (N, L, E)
        logits, values = self.actor(x_c), self.critic(x_c)

        if unsqueezed:
            logits = logits.squeeze(0)
            values = values.squeeze(0)

        return logits, values

In [None]:
def collect_rollout(first_obs, env, model, *args, rollout_len=2048):
    obs_buf, action_buf, reward_buf, done_buf, logp_buf, value_buf, info_buf, used_indices = args
    last_used_indices = tuple(used_indices)
    used_indices = []

    obs = first_obs

    counter = 0
    while counter <= rollout_len:
        store = False
        if obs is not None:
            with torch.no_grad():
                logits, value = model(torch.from_numpy(obs))
                
            probs = F.softmax(logits, dim=0)
            dist = Categorical(probs)
            action = dist.sample()
            store = True

        next_obs, _, done, info = env.step(action if isinstance(action, int) else action.item())

        if store:
            action_buf.append(action)
            logp_buf.append(dist.log_prob(action))
            value_buf.append(value.squeeze(-1))
            reward_buf.append(torch.tensor(0, dtype=torch.float32))
        else:
            action_buf.append(None)
            logp_buf.append(None)
            value_buf.append(None)
            reward_buf.append(None)

        obs_buf.append(obs)
        done_buf.append(done)
        info_buf.append(info)
            
        counter += len(info["done_lots"])
            
        obs = next_obs
        if done:
            obs = None
            break
    
    for i in range(len(info_buf)):
        for lot in info_buf[i]["done_lots"]:
            for j in range(i):
                if j in last_used_indices:
                    continue
                if lot.idx in info_buf[j]["dispatched_lots"]:
                    reward_buf[j] += (lot.deadline_at - lot.done_at)/3600
                    used_indices.append(j)
    
    last_obs = obs
    return obs_buf, action_buf, reward_buf, done_buf, logp_buf, value_buf, info_buf, used_indices, last_obs

In [None]:
def ppo_update(model, optimizer, obs_buf, action_buf, reward_buf, done_buf, logp_buf, value_buf,
               gamma=0.95, lam=0.95, clip_ratio=0.2, epochs=1, batch_size=32):

    returns = []
    advs = []
    gae = 0
    last_value = 0

    ploss, vloss = [], []
    for t in reversed(range(len(reward_buf))):
        mask = 1.0 - float(done_buf[t])
        delta = reward_buf[t] + gamma * last_value * mask - value_buf[t]
        gae = delta + gamma * lam * mask * gae
        advs.insert(0, gae)
        last_value = value_buf[t]
        returns.insert(0, gae + value_buf[t])

    advs = torch.tensor(advs, dtype=torch.float32, requires_grad=False)
    returns = torch.tensor(returns, dtype=torch.float32, requires_grad=False)

    for _ in range(epochs):
        for i in range(0, len(obs_buf), batch_size):
            var = [model(torch.from_numpy(i)) for i in obs_buf[i:i+batch_size]]
            logits, new_values = [i[0] for i in var], torch.tensor([i[1] for i in var])
            dists = [Categorical(logits=l) for l in logits]

            act_batch = action_buf[i:i+batch_size]
            old_logp_batch = logp_buf[i:i+batch_size]

            new_logp = []
            for g in range(len(act_batch)):
                dist = dists[g]
                action = act_batch[g]
                log_prob = dist.log_prob(action)
                new_logp.append(log_prob)

            ratio = [torch.exp(new_logp_i - old_logp_batch_i) for new_logp_i, old_logp_batch_i in zip(new_logp, old_logp_batch)]
            adv_batch = advs[i:i+batch_size]
            ret_batch = returns[i:i+batch_size]

            surr1 = [r*a for r, a in zip(ratio, adv_batch)]
            surr2 = [torch.clamp(r, 1.0-clip_ratio, 1.0+clip_ratio) * a for r, a in zip(ratio, adv_batch)]
            policy_loss = -sum([min(s1, s2) for s1, s2 in zip(surr1, surr2)])/len(surr1)

            value_loss = F.mse_loss(new_values.squeeze(-1), ret_batch)
            loss = policy_loss + 0.25 * value_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ploss.append(policy_loss.item())
            vloss.append(value_loss.item())
    return np.mean(ploss), np.mean(vloss)


In [None]:
env = SCFabEnv(days=90, 
               dataset="SMT2020_HVLM", 
               dispatcher="fifo", 
               seed=42)

first_obs = env.reset()

# model = Model(30, 128, 50, 4, 4, 256, 0.1)
model = Model(30, 2, 50, 1, 1, 4, 0.0)
optimizer = Adam(model.parameters(), lr=1e-4)

In [None]:
(obs_buf, 
 action_buf, 
 reward_buf, 
 done_buf, 
 logp_buf, 
 value_buf, 
 info_buf,
 used_indices) = [], [], [], [], [], [], [], []

logger = Logger("logs")

In [None]:
for run in range(100):

    if any(done_buf):
        (obs_buf, 
         action_buf, 
         reward_buf, 
         done_buf, 
         logp_buf, 
         value_buf, 
         info_buf,
         used_indices) = [], [], [], [], [], [], [], []

    (obs_buf, 
    action_buf, 
    reward_buf, 
    done_buf, 
    logp_buf, 
    value_buf, 
    info_buf, 
    used_indices,
    last_obs) = collect_rollout(first_obs, 
                                env, 
                                model, 
                                obs_buf, 
                                action_buf, 
                                reward_buf, 
                                done_buf, 
                                logp_buf, 
                                value_buf, 
                                info_buf,
                                used_indices,
                                rollout_len=10)

    def filter_buf(*args, used_indices):
        rets = []
        for arg in args:
            rets.append([i for e, i in enumerate(arg) if e in used_indices])
        return rets

    (obs_buf_ij, 
    action_buf_ij, 
    reward_buf_ij, 
    done_buf_ij, 
    logp_buf_ij, 
    value_buf_ij, 
    info_buf_ij) = filter_buf(obs_buf, 
                            action_buf, 
                            reward_buf, 
                            done_buf, 
                            logp_buf, 
                            value_buf, 
                            info_buf,
                            used_indices=used_indices)

    pl, vl = ppo_update(model, 
                        optimizer, 
                        obs_buf_ij, 
                        action_buf_ij, 
                        reward_buf_ij, 
                        done_buf_ij, 
                        logp_buf_ij, 
                        value_buf_ij,
                        gamma=0.95, 
                        lam=0.95, 
                        clip_ratio=0.2, 
                        epochs=5, 
                        batch_size=5)
    
    logger.add_to_pool(eid=env.eid,
                       reward=np.mean(reward_buf_ij),
                       throughput=env.throughput,
                       tardiness=env.tardiness,
                       policy_loss=pl,
                       value_loss=vl)
    logger.commit()
    env.reset(hard=False)