In [44]:
import pandas as pd
import numpy as np
import yaml
from functools import partial
import gc
gc.enable()

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import download_url, extract_zip, HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.nn.conv import MessagePassing
import torch_geometric.transforms as T
from torch_geometric import EdgeIndex
from torch_geometric.utils import add_self_loops, spmm, is_sparse
from torch_geometric.typing import Adj, OptPairTensor, SparseTensor

%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../scripts')
from utils import generate_slot_data, generate_heist_data, generate_thief_data, conflict_interval, conflict_schedule, is_unqualified

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [45]:
config_problemdef_PATH = '../configs/v1_problemdef.yaml'
with open(config_problemdef_PATH, 'r') as f:
    config = yaml.load(f, Loader=yaml.SafeLoader)

schedule_size = config['schedule_size']
qual_size = len(config['qualifications'])
sat_size = len(config['job_satisfaction'])
heist_size = len(config['heist_features'])
slot_size = len(config['slot_features']) + qual_size + sat_size
thief_size = schedule_size + qual_size + sat_size

In [46]:
# Load generation parameters
config_params_PATH = '../configs/v1_params.yaml'
with open(config_params_PATH, 'r') as f:
    config = yaml.load(f, Loader=yaml.SafeLoader)

sdchedule_size = config['schedule_size']
heist_dur_max = config['heist_dur_max']
n_slots_max = config['n_slots_max']
n_slots_min = config['n_slots_min']
n_heists_max = config['n_heists_max']
n_heists_min = config['n_heists_min']
qual_max = config['qual_max']
qual_min = config['qual_min']
sat_max = config['sat_max']
sat_min = config['sat_min']

In [47]:
# Set partial functions using problem definition sizes
generate_heist_data = partial(generate_heist_data, schedule_size, heist_dur_max, n_slots_max, n_slots_min)
generate_thief_data = partial(generate_thief_data, schedule_size, n_heists_max, n_heists_min, qual_max, qual_min, qual_size, sat_max, sat_min, sat_size)
generate_slot_data = partial(generate_slot_data, qual_max, qual_min, qual_size, sat_size)

In [48]:
# Hard code features to idx mappings
featidx_h_start = 0
featidx_h_end = 1
featidx_h_n_slots = 2
featidx_h_n_slots_req = 3
featidx_t_schedule = (0,schedule_size)
featidx_t_qual = (schedule_size, schedule_size + qual_size)
featidx_t_sat = (schedule_size + qual_size, thief_size)
featidx_s_req = 0
featidx_s_qual = (1, qual_size+1)
featidx_s_sat = (qual_size+1, slot_size)

In [53]:
env = CustomEnvironment()
env.reset()

  index_d = torch.stack([torch.tensor(self.index_d_heist1), torch.tensor(self.index_d_heist2)])


In [52]:
class CustomEnvironment():
    def __init__(self,**config):
        pass

    def reset(self, **kwargs):

        n_thieves = 30
        n_heists = 50

        self.init_indices()
        self.init_dictionaries()
        self.heist_df = self.gen_n_heists(n_heists) # will update index_d_heist1, index_d_heist2 in place
        self.slot_df = self.gen_n_slots() # will update index_c_slot, index_c_heist, heist2slot, slot2heist in place
        self.thief_df = self.gen_n_thieves(n_thieves) # will update index_a_thief, index_a_slot, slot2thief in place
    
        self.data = self.reset_data()

    def action(self, edge_idx): # NOTE: editing self.data inplace
        # Get thief, slot from index_a
        thief, slot = self.data['thief','possible','slot'].edge_index[:,edge_idx]

        # 1. Remove (thief, slot) from index_a
        self.data['thief','possible','slot'].edge_index = self.data['thief','possible','slot'].edge_index[(self.data['thief','possible','slot'].edge_index[0]==thief) &
                                                                                                          (self.data['thief','possible','slot'].edge_idnex[1]==slot)]
        pass

    def init_indices(self):
        self.index_a_thief, self.index_a_slot = [], []
        self.index_b_thief, self.index_b_slot = [], []
        self.index_c_slot, self.index_c_heist = [], []
        self.index_d_heist1, self.index_d_heist2 = [], []
    
    def init_dictionaries(self):
        self.heist2slot = {}
        self.slot2heist = {}
        self.heist2assigned = {}
        self.slot2thief = {}

    def gen_n_heists(self, n_heists): # TODO: add optional config file of size n_heists x dims
        heist_df = pd.DataFrame()
        for h in range(n_heists):
            # Get heist data
            heist_data = generate_heist_data()

            # Add to heist_df
            tmp_df = pd.DataFrame(heist_data).T
            tmp_df.index = [h]
            heist_df = pd.concat([heist_df, tmp_df])

            # Update index_d
            newInterval = (heist_data[featidx_h_start], heist_data[featidx_h_end])
            conflictMask = [conflict_interval(newInterval, (x,y)) for x,y in zip(heist_df[featidx_h_start], heist_df[featidx_h_end])]
            indices = list(np.compress(heist_df.index, conflictMask))
            for h2 in indices:
                self.index_d_heist1.append(h)
                self.index_d_heist2.append(h2)

        heist_df.index.rename('heistId', inplace=True)
        return heist_df
    
    def gen_n_slots(self): 
        i = 0 # counter of total slots so far
        slot_df = pd.DataFrame()
        # For each heist
        for h, heist in self.heist_df.iterrows():
            # Number of slots per heist
            n_slots = int(heist[featidx_h_n_slots])
            n_slots_req = heist[featidx_h_n_slots_req]
            
            # For each slots
            for s in range(n_slots):
                # Get slot data based on required status
                slot_data = generate_slot_data(required=(s < n_slots_req))

                # Update index_c
                self.index_c_slot.append(i)
                self.index_c_heist.append(h)

                # Update heist2slot and slot2heist dictionaries
                if self.heist2slot.get(h) is not None:
                    self.heist2slot[h].append(i)
                else: self.heist2slot[h] = [i]
                self.slot2heist[i] = h

                # Add to slot_df
                tmp_df = pd.DataFrame(slot_data).T
                tmp_df.index = [i]
                slot_df = pd.concat([slot_df, tmp_df])

                i += 1 # update slot counter
        slot_df.index.rename('slotId', inplace=True)
        return slot_df
    
    def gen_n_thieves(self, n_thieves): 
        thief_df = pd.DataFrame()
        # For each thief
        for t in range(n_thieves):
            # Get thief data
            thief_data = generate_thief_data()
            schedule = thief_data[featidx_t_schedule[0]: featidx_t_schedule[1]]

            # Create thief-slot edges
            # For each heist
            for h, heist in self.heist_df.iterrows():
                # If schedule conflict: no edge
                if conflict_schedule(schedule, (heist[featidx_h_start], heist[featidx_h_end])):
                    continue
                # If no schedule conflict: get indices of slot on heist
                slot_idx = self.heist2slot[h]
                # For each eligible slot
                for s in slot_idx:
                    # If thief unqualified: no edge
                    thief_qual = thief_data[featidx_t_qual[0]: featidx_t_qual[1]]
                    slot_qual = self.slot_df.iloc[s, featidx_s_qual[0]: featidx_s_qual[1]]
                    if is_unqualified(thief_qual, slot_qual):
                        continue
                    # If thief qualified: update index_a
                    self.index_a_thief.append(t)
                    self.index_a_slot.append(s)
                    # Update slot2thief dictionary
                    self.slot2thief[s] = t
        
            # Add to thief_df
            tmp_df = pd.DataFrame(thief_data).T
            tmp_df.index = [t]
            thief_df = pd.concat([thief_df, tmp_df])

        thief_df.index.rename('thiefId', inplace=True)
        return thief_df

    def reset_data(self):
        index_a = torch.stack([torch.tensor(self.index_a_thief), torch.tensor(self.index_a_slot)])
        index_b = torch.stack([torch.tensor(self.index_b_thief), torch.tensor(self.index_b_slot)])
        index_c = torch.stack([torch.tensor(self.index_c_slot), torch.tensor(self.index_c_heist)])
        index_d = torch.stack([torch.tensor(self.index_d_heist1), torch.tensor(self.index_d_heist2)])

        data = HeteroData()

        # Add node indices
        data['thief'].node_id = torch.tensor(self.thief_df.index)
        data['slot'].node_id = torch.tensor(self.slot_df.index)
        data['heist'].node_id = torch.tensor(self.heist_df.index)

        # Add node features
        data["thief"].x = torch.tensor(self.thief_df.values).to(torch.float)
        data["slot"].x = torch.tensor(self.slot_df.values).to(torch.float)
        data["heist"].x = torch.tensor(self.heist_df.values).to(torch.float)

        # Add edge indices
        data["thief","possible","slot"].edge_index = index_a # has shape (2, num_edges)
        data["thief","assigned","slot"].edge_index = index_b # has shape (2, num_edges)
        data["slot","on","heist"].edge_index = index_c # has shape (2, num_edges)
        data["heist","conflicts","heist"].edge_index = index_d # has shape (2, num_edges)

        # Add reverse edge
        data = T.ToUndirected()(data)
        return data