# Case Data Set

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import pickle as pickle
import pandas as pd
import numpy as np

import torch
from torch.nn import functional as F
from torch.utils.data import Dataset

In [3]:
class CaseDataset(Dataset):
    def __init__(self, project_data_path="../../Data/Training", split_pattern="811split", input_data="source",
                 data_version="_train", embedding_version="", earliness_requirement=True):

        file_name = input_data + data_version + ".pkl"
        self.data_path = os.path.join(project_data_path, split_pattern, file_name)
        self.intime_requirement = earliness_requirement

        self.data_all = pd.read_pickle(self.data_path)
        self.data_all = self.data_all.reset_index()
        
        self.data_all["CaseLength"] = self.data_all["Event_Name"].apply(len)
        self.max_case_len = self.data_all["CaseLength"].max()
        
        self.prefix_length = 1
        self.data_pool = self.data_all[["Case_ID", "Event_Name", "CaseLapse", "CaseLength", 'Label', "OverallCriteria"]].copy()
        
        self.act_encoding = pd.read_pickle("../../Embedding/embedding_" + input_data + embedding_version + ".pkl")
        self.data_pool["Feature"] = self.data_pool.apply(self.encode_feature, axis=1)
        
    def encode_feature(self, sample):
        activity_encoded = self.encode_act(sample["Event_Name"])
        feature = torch.from_numpy(activity_encoded)
        # print(feature.shape)
        lapse_time = torch.from_numpy(sample["CaseLapse"])
        feature = torch.cat([feature, lapse_time.reshape((-1, 1))], dim=1)
        return feature.numpy()
    
    def encode_act(self, feature):
        return np.array(self.act_encoding.loc[feature]["Values"].values.tolist())
    
    def shuffle_data(self):
        self.data_pool = self.data_pool.sample(frac=1)
        
    def one_hot_po(self, feature):
        return F.one_hot(torch.from_numpy(feature).long(), num_classes=self.sub_seq_size)
    
    def set_prefix_length(self, length):
        self.prefix_length = length
    
    def update_data_pool(self):
        data_temp = self.data_pool[self.data_pool["CaseLength"] > self.prefix_length]
        if self.intime_requirement:
            data_temp = data_temp[data_temp["CaseLapse"].apply(lambda x: x[self.prefix_length]) < data_temp["OverallCriteria"]]
        return data_temp
    
    def convert_feature_vec_np(self, data):
        data_tmp = np.stack(data[:self.prefix_length])
        data_output = torch.from_numpy(data_tmp)
        return data_output

    def __len__(self):
        data_temp = self.update_data_pool()
        return data_temp.shape[0]

    def __getitem__(self, idx):
        data_temp = self.update_data_pool()
        if torch.is_tensor(idx):
            idx = idx.tolist()

        x = data_temp["Feature"].apply(self.convert_feature_vec_np).values[idx].tolist()
        y = data_temp["Label"].values[idx].tolist()

        if len(x) == 0 :
            return None
        if len(x) == 1:
            return torch.unsqueeze(x[0], 0), torch.tensor(y).reshape((-1,1))
        return torch.stack(x) , torch.tensor(y).reshape((-1,1))
        

In [4]:
t1 = CaseDataset(data_version="_train", embedding_version="_w2v")
t1.data_pool

Unnamed: 0,Case_ID,Event_Name,CaseLapse,CaseLength,Label,OverallCriteria,Feature
0,55134019626153,"[Create PO Item, Create PO Item, Record Goods ...","[0.0, 0.0, 0.06766940728701254]",3,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
1,55134019620954,"[Create PO Item, Create PO Item, Record Goods ...","[0.0, 0.0, 0.0676721463961258]",3,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
2,55134019631756,"[Create PO Item, Record Goods Receipt]","[0.0, 0.06766950173905094]",2,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
3,55134019632858,"[Create PO Item, Create PO Item, Create PO Ite...","[0.0, 0.0, 0.0, 0.06766969064312771]",4,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
4,55134019635756,"[Create PO Item, Record Goods Receipt]","[0.0, 0.0676681794105135]",2,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
...,...,...,...,...,...,...,...
22880,55134010591257,"[Create PO Item, Create PO Item, Create PO Ite...","[0.0, 0.0, 0.0, 0.06767195749204902, 0.0676719...",8,0,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
22881,55134010680841,"[Create PO Item, Record Goods Receipt, Record ...","[0.0, 0.06767327982058646, 0.2403704257822329]",3,0,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
22882,55134010680752,"[Create PO Item, Record Goods Receipt, Record ...","[0.0, 0.06767516886135423, 0.076676164763647]",3,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
22883,55134010680656,"[Create PO Item, Record Goods Receipt]","[0.0, 0.06767346872466323]",2,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."


In [5]:
t1.set_prefix_length(3)
t1.data_pool


Unnamed: 0,Case_ID,Event_Name,CaseLapse,CaseLength,Label,OverallCriteria,Feature
0,55134019626153,"[Create PO Item, Create PO Item, Record Goods ...","[0.0, 0.0, 0.06766940728701254]",3,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
1,55134019620954,"[Create PO Item, Create PO Item, Record Goods ...","[0.0, 0.0, 0.0676721463961258]",3,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
2,55134019631756,"[Create PO Item, Record Goods Receipt]","[0.0, 0.06766950173905094]",2,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
3,55134019632858,"[Create PO Item, Create PO Item, Create PO Ite...","[0.0, 0.0, 0.0, 0.06766969064312771]",4,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
4,55134019635756,"[Create PO Item, Record Goods Receipt]","[0.0, 0.0676681794105135]",2,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
...,...,...,...,...,...,...,...
22880,55134010591257,"[Create PO Item, Create PO Item, Create PO Ite...","[0.0, 0.0, 0.0, 0.06767195749204902, 0.0676719...",8,0,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
22881,55134010680841,"[Create PO Item, Record Goods Receipt, Record ...","[0.0, 0.06767327982058646, 0.2403704257822329]",3,0,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
22882,55134010680752,"[Create PO Item, Record Goods Receipt, Record ...","[0.0, 0.06767516886135423, 0.076676164763647]",3,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."
22883,55134010680656,"[Create PO Item, Record Goods Receipt]","[0.0, 0.06767346872466323]",2,1,0.142549,"[[0.12065998464822769, 0.1271606683731079, 0.0..."


In [6]:
t1[:2]

(tensor([[[ 1.2066e-01,  1.2716e-01,  6.0703e-02,  5.5349e-01,  2.2370e-01,
           -1.3895e-01,  6.0783e-02, -4.8118e-01,  8.5380e-02,  1.1326e-01,
           -4.3482e-01,  6.7169e-02,  4.5126e-01,  2.9466e-01, -4.2997e-02,
            1.6328e-01,  1.5305e-01, -2.0903e-01,  4.5590e-01, -8.9393e-02,
            5.1694e-01, -6.2002e-01,  7.7931e-02,  1.3757e-01,  1.0799e-01,
           -8.5072e-01, -5.8348e-01, -1.0099e-02, -2.9333e-04, -4.5655e-01,
            2.1949e+00,  4.5823e-02, -8.5897e-02,  1.7432e-01, -3.7483e-01,
            7.7810e-02, -2.1400e-02,  2.7822e-01, -1.9335e-01, -3.5139e-01,
            2.7227e-01, -2.0199e-02,  4.5046e-03,  1.7841e-01,  1.0243e-01,
            5.7020e-02, -1.6793e-01, -5.3007e-02,  6.9136e-01,  9.1927e-02,
            0.0000e+00],
          [ 1.2066e-01,  1.2716e-01,  6.0703e-02,  5.5349e-01,  2.2370e-01,
           -1.3895e-01,  6.0783e-02, -4.8118e-01,  8.5380e-02,  1.1326e-01,
           -4.3482e-01,  6.7169e-02,  4.5126e-01,  2.9466e-01, 