
Process Prediction
==
 - Load Data
 - Categorize / Normalize / Fillmissing
 - Create Datastructure for language model

# Imports

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
from exp.eventlog import *

In [3]:
from exp.dl_utils import *

In [4]:
import editdistance as ed

# Load Data

In [5]:
log=import_xes(untar_data(URLs.BPIC_2012))

# Data Processing

1. Merge Trace Attributes and Event Attributes first in one df. It is easier to copy over the trace attributes
3. Split into Train and Test
2. Create Traces from DF

In [6]:
df=pd.merge(log.events,log.traceAttributes,left_on='trace_id',right_index=True)
df

Unnamed: 0,trace_id,event_id,org:resource,lifecycle:transition,concept:name,time:timestamp,REG_DATE,AMOUNT_REQ
0,173688,0,112,COMPLETE,A_SUBMITTED,2011-09-30 22:38:44.546000+00:00,2011-10-01 00:38:44.546000+02:00,20000
1,173688,1,112,COMPLETE,A_PARTLYSUBMITTED,2011-09-30 22:38:44.880000+00:00,2011-10-01 00:38:44.546000+02:00,20000
2,173688,2,112,COMPLETE,A_PREACCEPTED,2011-09-30 22:39:37.906000+00:00,2011-10-01 00:38:44.546000+02:00,20000
3,173688,3,112,SCHEDULE,W_Completeren aanvraag,2011-09-30 22:39:38.875000+00:00,2011-10-01 00:38:44.546000+02:00,20000
4,173688,4,,START,W_Completeren aanvraag,2011-10-01 09:36:46.437000+00:00,2011-10-01 00:38:44.546000+02:00,20000
5,173688,5,10862,COMPLETE,A_ACCEPTED,2011-10-01 09:42:43.308000+00:00,2011-10-01 00:38:44.546000+02:00,20000
6,173688,6,10862,COMPLETE,O_SELECTED,2011-10-01 09:45:09.243000+00:00,2011-10-01 00:38:44.546000+02:00,20000
7,173688,7,10862,COMPLETE,A_FINALIZED,2011-10-01 09:45:09.243000+00:00,2011-10-01 00:38:44.546000+02:00,20000
8,173688,8,10862,COMPLETE,O_CREATED,2011-10-01 09:45:11.197000+00:00,2011-10-01 00:38:44.546000+02:00,20000
9,173688,9,10862,COMPLETE,O_SENT,2011-10-01 09:45:11.380000+00:00,2011-10-01 00:38:44.546000+02:00,20000


# Split in Train, Test and Validation

Split first only in train set and test set. The train set is used to train the model. The test set is used to test the model later on. Let the model create the validation set on its own.

In [7]:
trace_id='trace_id'

In [8]:
def random_split_traces(d,split=0.8,trace_id='trace_id'):
    traces=d[trace_id].drop_duplicates()
    shuffled=traces.iloc[np.random.permutation(len(traces))].values
    split=int(len(traces)*split)
    return shuffled[:split],shuffled[split:]

In [9]:
train_trace_ids,test_trace_ids=random_split_traces(df,0.8)

In [10]:
def get_df(t,df): return df[df[trace_id].isin(t)]
train_df=get_df(train_trace_ids,df)

Split train into train and validation
--

In [11]:
train_traces,validation_trace_ids=random_split_traces(train_df,0.9)


In [12]:
train_traces

array(['184351', '211718', '208259', '212983', ..., '202866', '180544', '177773', '182723'], dtype=object)

In [13]:
train_df=get_df(train_traces,df)
test_df=get_df(test_trace_ids,df)
valid_df=get_df(validation_trace_ids,df)

# Process Data

In [14]:
def normalize_cont_column(x, mean, std,eps=1e-7): return (x-mean)/(eps + std)

In [40]:
UNK, PAD, BOS, EOS, TK_REP, TK_WREP, TK_UP, TK_MAJ = "xxunk xxpad xxbos xxeos xxrep xxwrep xxup xxmaj".split()
default_spec_tok = [UNK, PAD, BOS, EOS, TK_REP, TK_WREP, TK_UP, TK_MAJ]

from collections import OrderedDict

def uniqueify(x, sort=False):
    res = list(OrderedDict.fromkeys(x).keys())
    if sort: res.sort()
    return res

class Processor():
    def process(self, items): return items

class CategoryProcessor(Processor):
    def __init__(self,default_token=None): 
        self.vocab=None
        self.default_token=default_token

    def __call__(self, items):
        #The vocab is defined on the first use.
        if self.vocab is None:
            self.vocab = uniqueify(items)
            if self.default_token is not None:
                for o in reversed(self.default_token):
                    if o in self.vocab: self.vocab.remove(o)
                    self.vocab.insert(0, o)
            self.otoi  = {v:k for k,v in enumerate(self.vocab)}
        return [self.proc1(o) for o in items]
    def proc1(self, item):  return self.otoi.get(item,0)

    def deprocess(self, idxs):
        assert self.vocab is not None
        return [self.deproc1(idx) for idx in idxs]
    
    def deproc1(self, idx): return self.vocab[idx]

In [41]:
class TraceProcessor(Processor):
    def __init__(self,cat_names,cont_names,date_names,vocabs={}):
        self.vocabs=vocabs
        self.cat_names,self.cont_names,self.date_names=cat_names,cont_names,date_names
    def __call__(self,df):
        cat_names,cont_names=self.cat_names[:],self.cont_names[:]
        for d in self.date_names:
            df,cat, cont = add_datepart(df,d,utc=True)
            cat_names+=listify(cat)    
            cont_names+=listify(cont)

        for c in cat_names:
            if not c in self.vocabs.keys(): 
                self.vocabs[c] = CategoryProcessor(default_spec_tok)
         

            df[c]=self.vocabs[c](df[c])
            
        for c in cont_names:
            df[c]=df[c].astype(float)

            if not c in self.vocabs.keys(): 
                self.vocabs[c]=df[c].mean(),df[c].std()
            df[c]=normalize_cont_column(df[c], *self.vocabs[c])
    
        return df
        
        
    
    def deprocess(self,items,columns):
        pass

In [42]:
def add_datepart(df, fldname, drop=True, time=False,utc=False):
    "Helper function that adds columns relevant to a date."
    df=df.copy()
    fld = df[fldname]
    fld_dtype = fld.dtype
    if isinstance(fld_dtype, pd.core.dtypes.dtypes.DatetimeTZDtype):
        fld_dtype = np.datetime64

    if not np.issubdtype(fld_dtype, np.datetime64):
        df[fldname] = fld = pd.to_datetime(fld, utc=utc,infer_datetime_format=True)
    targ_pre = re.sub('[Dd]ate$', '', fldname)
    attr = ['Year', 'Month', 'Week', 'Day', 'Dayofweek', 'Dayofyear',
            'Is_month_end', 'Is_month_start', 'Is_quarter_end', 'Is_quarter_start', 'Is_year_end', 'Is_year_start']
    if time: attr = attr + ['Hour', 'Minute', 'Second']
    cols=[]
    for n in attr: 
        col_name=targ_pre +"_"+ n
        df[col_name] = getattr(fld.dt, n.lower())
        cols.append(col_name)
    df[targ_pre + '_Elapsed'] = fld.astype(np.int64) // 10 ** 9
    if drop: df.drop(fldname, axis=1, inplace=True)
    return df,cols,targ_pre + '_Elapsed'

In [43]:
# encode data and create vocab
cat_names=['event_id','org:resource','lifecycle:transition','concept:name',]
date_names=['time:timestamp','REG_DATE']
cont_names=['AMOUNT_REQ']

In [44]:
tp=TraceProcessor(cat_names,cont_names,date_names)

In [45]:
train_df

Unnamed: 0,trace_id,event_id,org:resource,lifecycle:transition,concept:name,time:timestamp,REG_DATE,AMOUNT_REQ
0,173688,0,112,COMPLETE,A_SUBMITTED,2011-09-30 22:38:44.546000+00:00,2011-10-01 00:38:44.546000+02:00,20000
1,173688,1,112,COMPLETE,A_PARTLYSUBMITTED,2011-09-30 22:38:44.880000+00:00,2011-10-01 00:38:44.546000+02:00,20000
2,173688,2,112,COMPLETE,A_PREACCEPTED,2011-09-30 22:39:37.906000+00:00,2011-10-01 00:38:44.546000+02:00,20000
3,173688,3,112,SCHEDULE,W_Completeren aanvraag,2011-09-30 22:39:38.875000+00:00,2011-10-01 00:38:44.546000+02:00,20000
4,173688,4,,START,W_Completeren aanvraag,2011-10-01 09:36:46.437000+00:00,2011-10-01 00:38:44.546000+02:00,20000
5,173688,5,10862,COMPLETE,A_ACCEPTED,2011-10-01 09:42:43.308000+00:00,2011-10-01 00:38:44.546000+02:00,20000
6,173688,6,10862,COMPLETE,O_SELECTED,2011-10-01 09:45:09.243000+00:00,2011-10-01 00:38:44.546000+02:00,20000
7,173688,7,10862,COMPLETE,A_FINALIZED,2011-10-01 09:45:09.243000+00:00,2011-10-01 00:38:44.546000+02:00,20000
8,173688,8,10862,COMPLETE,O_CREATED,2011-10-01 09:45:11.197000+00:00,2011-10-01 00:38:44.546000+02:00,20000
9,173688,9,10862,COMPLETE,O_SENT,2011-10-01 09:45:11.380000+00:00,2011-10-01 00:38:44.546000+02:00,20000


In [46]:
train_proc=tp(train_df)

In [47]:
valid_proc=tp(valid_df)

In [48]:
test_proc=tp(test_df) # unknown token einfügen

# Create Traces

In [49]:
def create_traces(event_df,trace_id='trace_id'):
    ll=[]
    trace_ids=[]
    cols=list(event_df)
    cols.remove(trace_id)
    for n, g in event_df.groupby(trace_id):
        l=[]
        
        for c in cols:
            l.append(list(g[c]))
        ll.append(l)
        trace_ids.append(n)  
        

    df=pd.DataFrame(ll,columns=cols)
    df.index=trace_ids
    return df


In [50]:
train_traces=create_traces(train_proc)
valid_traces=create_traces(valid_proc)

# LanguageModel Dataloader

In [51]:
bs,bptt=128,70

In [52]:
class LMDataSet():
    def __init__(self, df, bs=64, bptt=70, shuffle=False):
        self.bs,self.bptt,self.shuffle = bs,bptt,shuffle
        self.cols=list(df)

        total_len = sum(df.apply(lambda x: max([len(listify(x[k])) for k in self.cols]),axis=1))
        self.n_batch = total_len // self.bs

        self.batched=self.batchify(df)
        #print(self.bs,self.bptt,self.shuffle,total_len, self.n_batch)
        #print(self.batched)
    
    def __len__(self): return ((self.n_batch-1) // self.bptt) * self.bs
    
    def __getitem__(self, idx):
        source = self.batched[:,idx % self.bs]
        seq_idx = (idx // self.bs) * self.bptt
        x,y=source[:,seq_idx:seq_idx+self.bptt],source[:,seq_idx+1:seq_idx+self.bptt+1]
        return x,y
    
    def batchify(self,df):
        if self.shuffle: df=df.sample(frac=1)
        
        dd={}
        for c in self.cols:
            dd[c]=[]
        for i, row in df.iterrows():
            l=max([len(listify(row[c])) for c in self.cols])
            for c in self.cols:
                dd[c].append(tensor(row[c]).expand(l))
        for c in self.cols:
            s= torch.cat([torch.cat((tensor([2.0]),t.float(),tensor([3.0]))) for t in dd[c]])
            dd[c]=s[:self.n_batch * self.bs].view(self.bs, self.n_batch)
        return torch.stack([dd[c] for c in self.cols])

In [53]:
def get_dls(train_ds, valid_ds,  **kwargs):
    return (DataLoader(LMDataSet(train_ds, shuffle=True,bptt=bptt), batch_size=bs),
            DataLoader(LMDataSet(valid_ds, shuffle=False,bptt=bptt), batch_size=bs))

In [54]:
valid_traces

Unnamed: 0,event_id,org:resource,lifecycle:transition,concept:name,AMOUNT_REQ,time:timestamp_Year,time:timestamp_Month,time:timestamp_Week,time:timestamp_Day,time:timestamp_Dayofweek,...,REG_DATE_Day,REG_DATE_Dayofweek,REG_DATE_Dayofyear,REG_DATE_Is_month_end,REG_DATE_Is_month_start,REG_DATE_Is_quarter_end,REG_DATE_Is_quarter_start,REG_DATE_Is_year_end,REG_DATE_Is_year_start,REG_DATE_Elapsed
173721,"[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,...","[8, 8, 8, 8, 33, 33, 33, 33, 33, 33, 33, 33, 3...","[8, 8, 8, 9, 10, 8, 8, 8, 8, 8, 9, 8, 10, 8, 1...","[8, 9, 10, 11, 11, 12, 13, 14, 15, 16, 17, 11,...","[-0.5323385977449078, -0.5323385977449078, -0....","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...",...,"[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[-1.71618525009305, -1.71618525009305, -1.7161..."
173793,"[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,...","[8, 8, 8, 8, 36, 36, 9, 14, 14, 14, 14, 14, 9,...","[8, 8, 8, 9, 10, 8, 10, 8, 8, 8, 8, 8, 9, 8, 1...","[8, 9, 10, 11, 11, 11, 11, 12, 13, 14, 15, 16,...","[-0.7782020970763474, -0.7782020970763474, -0....","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 1...",...,"[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[-1.7083370104084947, -1.7083370104084947, -1...."
173883,"[8, 9, 10]","[8, 8, 8]","[8, 8, 8]","[8, 9, 26]","[-1.0650428462963601, -1.0650428462963601, -1....","[8, 8, 8]","[9, 9, 9]","[8, 8, 8]","[19, 19, 19]","[14, 14, 14]",...,"[10, 10, 10]","[10, 10, 10]","[10, 10, 10]","[9, 9, 9]","[8, 8, 8]","[9, 9, 9]","[8, 8, 8]","[8, 8, 8]","[8, 8, 8]","[-1.6913446601096802, -1.6913446601096802, -1...."
173982,"[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,...","[8, 8, 8, 8, 36, 36, 36, 36, 36, 36, 36, 36, 3...","[8, 8, 8, 9, 10, 8, 10, 8, 8, 8, 8, 8, 9, 8, 1...","[8, 9, 10, 11, 11, 11, 11, 12, 13, 14, 15, 16,...","[-0.6142930975220543, -0.6142930975220543, -0....","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10,...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1...",...,"[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[-1.6729471766369386, -1.6729471766369386, -1...."
174012,"[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,...","[8, 8, 8, 8, 17, 17, 9, 14, 14, 14, 14, 14, 9,...","[8, 8, 8, 9, 10, 8, 10, 8, 8, 8, 8, 8, 9, 8, 1...","[8, 9, 10, 11, 11, 11, 11, 12, 14, 13, 15, 16,...","[-0.6552703474106276, -0.6552703474106276, -0....","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1...",...,"[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[-1.671458413830585, -1.671458413830585, -1.67..."
174036,"[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,...","[8, 8, 8, 8, 17, 17, 17, 17, 18, 18, 18, 18, 1...","[8, 8, 8, 9, 10, 8, 10, 8, 10, 8, 10, 8, 8, 8,...","[8, 9, 10, 11, 11, 11, 11, 11, 11, 11, 11, 12,...","[-0.04061159908202868, -0.04061159908202868, -...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1...",...,"[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[-1.6705858693684263, -1.6705858693684263, -1...."
174099,"[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,...","[8, 8, 8, 8, 9, 9, 9, 9, 38, 38, 38, 38, 38, 3...","[8, 8, 8, 9, 10, 8, 10, 8, 10, 8, 8, 8, 8, 8, ...","[8, 9, 10, 11, 11, 11, 11, 11, 11, 12, 14, 13,...","[-0.04061159908202868, -0.04061159908202868, -...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[13, 13, 13, 13, 13, 13, 20, 20, 20, 20, 20, 2...","[10, 10, 10, 10, 10, 10, 12, 12, 12, 12, 12, 1...",...,"[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[-1.6684984944562833, -1.6684984944562833, -1...."
174150,"[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,...","[8, 8, 8, 18, 18, 18, 18, 18, 18, 9, 9, 40, 40...","[8, 8, 9, 10, 8, 9, 8, 10, 8, 10, 8, 10, 8, 10...","[8, 9, 28, 28, 10, 11, 28, 11, 11, 11, 11, 11,...","[0.12329740047226437, 0.12329740047226437, 0.1...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 20, 20, 2...","[10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 12, 1...",...,"[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[-1.6655188975214106, -1.6655188975214106, -1...."
174159,"[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,...","[8, 8, 8, 8, 18, 18, 9, 9, 38, 38, 40, 40, 40,...","[8, 8, 8, 9, 10, 8, 10, 8, 10, 8, 10, 8, 10, 8...","[8, 9, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11,...","[2.0082508953466345, 2.0082508953466345, 2.008...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[13, 13, 13, 13, 13, 13, 20, 20, 20, 20, 20, 2...","[10, 10, 10, 10, 10, 10, 12, 12, 12, 12, 12, 1...",...,"[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[-1.6650619120686776, -1.6650619120686776, -1...."
174180,"[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,...","[8, 8, 8, 8, 18, 18, 18, 18, 18, 18, 18, 18, 1...","[8, 8, 8, 9, 10, 8, 8, 8, 8, 8, 9, 8, 10, 8, 8...","[8, 9, 10, 11, 11, 12, 14, 13, 15, 16, 17, 11,...","[0.3691608998037039, 0.3691608998037039, 0.369...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1...",...,"[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[-1.6642354545246985, -1.6642354545246985, -1...."


In [55]:
data = DataBunch(*get_dls(train_traces, valid_traces))

In [56]:
iter_dl = iter(data.train_dl)
xb,yb = next(iter_dl)
xb.size()

torch.Size([128, 31, 70])

# Basic Model

In [57]:
class BasicModel(nn.Module):
    def __init__(self, n_in,n_out,n_emb,nh):
        super().__init__()
        self.emb=nn.Embedding(n_in, 7, padding_idx=1)
        self.lin1=nn.Linear(7,nh)
        self.relu=nn.ReLU()
        self.lin2=nn.Linear(nh,n_out)
        
    def __call__(self, x):
        x=x[:,3] # magic number for 'concept:name'
        x=x.long()
        x=self.emb(x)
        x=self.lin1(x)
        x=self.relu(x)
        x=self.lin2(x)
        return x.float()

In [58]:
def getBasicModel():
    vocab=len((tp.vocabs['concept:name']).vocab) # Stupid 'concept:name' model
    n_emb,nh=int(vocab/2),10
    model=BasicModel(bs*bptt,vocab,n_emb,nh)
    return model

In [59]:
len((tp.vocabs['concept:name']).vocab)

32

In [60]:
xb.shape

torch.Size([128, 31, 70])

In [61]:
xb[None,0,:,1:10].shape

torch.Size([1, 31, 9])

In [62]:
model=getBasicModel()
pred = model(xb)
pred.shape,yb[:,0].shape

(torch.Size([128, 70, 32]), torch.Size([128, 70]))

In [63]:
def cross_entropy_activity(input, target):
    target=target[:,3] # magic number for 'concept:name'
    bs,sl =target.size()
    return F.cross_entropy(input.view(bs * sl, -1), target.flatten().long())
cross_entropy_activity(pred,yb)

tensor(3.4857, grad_fn=<NllLossBackward>)

In [64]:
def accuracy_activity(input, target): 
    target=target[:,3] # magic number for 'concept:name'
    bs,sl =target.size()
    return (torch.argmax(input.view(bs * sl, -1), dim=1)==target.flatten().long()).float().mean()
accuracy_activity(pred,yb)

tensor(0.0104)

# Training Loop

**Callbacks**

In [65]:
sched = combine_scheds([0.3, 0.7], [sched_cos(0.3, 0.6), sched_cos(0.6, 0.2)]) 

In [66]:
class CudaCallback(Callback):
    def begin_fit(self): self.model.cuda()
    def begin_batch(self): self.run.xb,self.run.yb = self.xb.cuda(),self.yb.cuda()

In [67]:
cbfs = [partial(AvgStatsCallback,accuracy_activity),
        CudaCallback, 
        Recorder,
        partial(ParamScheduler, 'lr', sched),
        ProgressBarCallback,
       ]

**Model**

In [68]:
opt_func = partial(Optimizer, steppers=[sgd_step])

In [69]:
model=getBasicModel()

In [70]:
opt = opt_func(model.parameters(), lr=0.5)


**Learner**

In [71]:
learn = Learner(model,data,cross_entropy_activity,cb_funcs=cbfs,opt_func=opt_func)


In [72]:
learn.fit(20) 

epoch,train_loss,train_accuracy_activity,valid_loss,valid_accuracy_activity,time
0,3.132911,0.251617,2.784872,0.266127,00:00
1,2.571219,0.29841,2.354194,0.403348,00:00
2,2.184912,0.469997,2.008629,0.511217,00:00
3,1.849331,0.538219,1.694882,0.562109,00:00
4,1.576334,0.577052,1.469645,0.580246,00:00
5,1.399247,0.603604,1.337631,0.614118,00:00
6,1.2924,0.615168,1.250861,0.615402,00:00
7,1.21932,0.64063,1.191905,0.648605,00:00
8,1.168093,0.649042,1.149328,0.648605,00:00
9,1.131562,0.65043,1.118194,0.650446,00:00


In [73]:
basic_model=learn.model
basic_model(xb.cuda()).shape

torch.Size([128, 70, 32])

# Testing

In [74]:
test_proc

Unnamed: 0,trace_id,event_id,org:resource,lifecycle:transition,concept:name,AMOUNT_REQ,time:timestamp_Year,time:timestamp_Month,time:timestamp_Week,time:timestamp_Day,...,REG_DATE_Day,REG_DATE_Dayofweek,REG_DATE_Dayofyear,REG_DATE_Is_month_end,REG_DATE_Is_month_start,REG_DATE_Is_quarter_end,REG_DATE_Is_quarter_start,REG_DATE_Is_year_end,REG_DATE_Is_year_start,REG_DATE_Elapsed
165,173712,8,8,8,8,1.188706,8,9,8,9,...,9,9,9,9,9,9,9,8,8,-1.717258
166,173712,9,8,8,9,1.188706,8,9,8,9,...,9,9,9,9,9,9,9,8,8,-1.717258
167,173712,10,8,9,28,1.188706,8,9,8,9,...,9,9,9,9,9,9,9,8,8,-1.717258
168,173712,11,16,10,28,1.188706,8,9,8,9,...,9,9,9,9,9,9,9,8,8,-1.717258
169,173712,12,16,8,10,1.188706,8,9,8,9,...,9,9,9,9,9,9,9,8,8,-1.717258
170,173712,13,16,9,11,1.188706,8,9,8,9,...,9,9,9,9,9,9,9,8,8,-1.717258
171,173712,14,16,8,28,1.188706,8,9,8,9,...,9,9,9,9,9,9,9,8,8,-1.717258
172,173712,15,39,10,11,1.188706,8,9,8,9,...,9,9,9,9,9,9,9,8,8,-1.717258
173,173712,16,39,8,11,1.188706,8,9,8,9,...,9,9,9,9,9,9,9,8,8,-1.717258
174,173712,17,36,10,11,1.188706,8,9,9,13,...,9,9,9,9,9,9,9,8,8,-1.717258


In [75]:
test_traces=create_traces(test_proc)


In [76]:
test_traces.iloc[0:1]

Unnamed: 0,event_id,org:resource,lifecycle:transition,concept:name,AMOUNT_REQ,time:timestamp_Year,time:timestamp_Month,time:timestamp_Week,time:timestamp_Day,time:timestamp_Dayofweek,...,REG_DATE_Day,REG_DATE_Dayofweek,REG_DATE_Dayofyear,REG_DATE_Is_month_end,REG_DATE_Is_month_start,REG_DATE_Is_quarter_end,REG_DATE_Is_quarter_start,REG_DATE_Is_year_end,REG_DATE_Is_year_start,REG_DATE_Elapsed
173712,"[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,...","[8, 8, 8, 16, 16, 16, 16, 39, 39, 36, 36, 36, ...","[8, 8, 9, 10, 8, 9, 8, 10, 8, 10, 8, 10, 8, 8]","[8, 9, 28, 28, 10, 11, 28, 11, 11, 11, 11, 11,...","[1.1887058975751692, 1.1887058975751692, 1.188...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 13, 13, 13, 13, 13]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10]",...,"[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]","[-1.7172581949747072, -1.7172581949747072, -1...."


## Next Step Prediction

In [77]:
def process_data_for_next_step_prediction(test,input_cols=None,output_col=3,startIndex=1):
    xs,ys=[],[]
    if input_cols == None: input_cols=list(test)
    input_cols=listify(input_cols)
    for trace in test.values:
        for i in range(startIndex,len(listify(trace[0]))):
            x,y=[],[]
            for c in range(len(input_cols)):
                x.append(trace[c][:i])
            
            xs.append(x)
            ys.append(trace[output_col][i])
    return pd.DataFrame(xs,columns=input_cols),ys


In [81]:
x

Unnamed: 0,event_id,org:resource,lifecycle:transition,concept:name,AMOUNT_REQ,time:timestamp_Year,time:timestamp_Month,time:timestamp_Week,time:timestamp_Day,time:timestamp_Dayofweek,...,REG_DATE_Day,REG_DATE_Dayofweek,REG_DATE_Dayofyear,REG_DATE_Is_month_end,REG_DATE_Is_month_start,REG_DATE_Is_quarter_end,REG_DATE_Is_quarter_start,REG_DATE_Is_year_end,REG_DATE_Is_year_start,REG_DATE_Elapsed
0,[8],[8],[8],[8],[1.1887058975751692],[8],[9],[8],[9],[9],...,[9],[9],[9],[9],[9],[9],[9],[8],[8],[-1.7172581949747072]
1,"[8, 9]","[8, 8]","[8, 8]","[8, 9]","[1.1887058975751692, 1.1887058975751692]","[8, 8]","[9, 9]","[8, 8]","[9, 9]","[9, 9]",...,"[9, 9]","[9, 9]","[9, 9]","[9, 9]","[9, 9]","[9, 9]","[9, 9]","[8, 8]","[8, 8]","[-1.7172581949747072, -1.7172581949747072]"
2,"[8, 9, 10]","[8, 8, 8]","[8, 8, 9]","[8, 9, 28]","[1.1887058975751692, 1.1887058975751692, 1.188...","[8, 8, 8]","[9, 9, 9]","[8, 8, 8]","[9, 9, 9]","[9, 9, 9]",...,"[9, 9, 9]","[9, 9, 9]","[9, 9, 9]","[9, 9, 9]","[9, 9, 9]","[9, 9, 9]","[9, 9, 9]","[8, 8, 8]","[8, 8, 8]","[-1.7172581949747072, -1.7172581949747072, -1...."
3,"[8, 9, 10, 11]","[8, 8, 8, 16]","[8, 8, 9, 10]","[8, 9, 28, 28]","[1.1887058975751692, 1.1887058975751692, 1.188...","[8, 8, 8, 8]","[9, 9, 9, 9]","[8, 8, 8, 8]","[9, 9, 9, 9]","[9, 9, 9, 9]",...,"[9, 9, 9, 9]","[9, 9, 9, 9]","[9, 9, 9, 9]","[9, 9, 9, 9]","[9, 9, 9, 9]","[9, 9, 9, 9]","[9, 9, 9, 9]","[8, 8, 8, 8]","[8, 8, 8, 8]","[-1.7172581949747072, -1.7172581949747072, -1...."
4,"[8, 9, 10, 11, 12]","[8, 8, 8, 16, 16]","[8, 8, 9, 10, 8]","[8, 9, 28, 28, 10]","[1.1887058975751692, 1.1887058975751692, 1.188...","[8, 8, 8, 8, 8]","[9, 9, 9, 9, 9]","[8, 8, 8, 8, 8]","[9, 9, 9, 9, 9]","[9, 9, 9, 9, 9]",...,"[9, 9, 9, 9, 9]","[9, 9, 9, 9, 9]","[9, 9, 9, 9, 9]","[9, 9, 9, 9, 9]","[9, 9, 9, 9, 9]","[9, 9, 9, 9, 9]","[9, 9, 9, 9, 9]","[8, 8, 8, 8, 8]","[8, 8, 8, 8, 8]","[-1.7172581949747072, -1.7172581949747072, -1...."
5,"[8, 9, 10, 11, 12, 13]","[8, 8, 8, 16, 16, 16]","[8, 8, 9, 10, 8, 9]","[8, 9, 28, 28, 10, 11]","[1.1887058975751692, 1.1887058975751692, 1.188...","[8, 8, 8, 8, 8, 8]","[9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8]","[9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9]",...,"[9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8]","[8, 8, 8, 8, 8, 8]","[-1.7172581949747072, -1.7172581949747072, -1...."
6,"[8, 9, 10, 11, 12, 13, 14]","[8, 8, 8, 16, 16, 16, 16]","[8, 8, 9, 10, 8, 9, 8]","[8, 9, 28, 28, 10, 11, 28]","[1.1887058975751692, 1.1887058975751692, 1.188...","[8, 8, 8, 8, 8, 8, 8]","[9, 9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8, 8]","[9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9]",...,"[9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8, 8]","[8, 8, 8, 8, 8, 8, 8]","[-1.7172581949747072, -1.7172581949747072, -1...."
7,"[8, 9, 10, 11, 12, 13, 14, 15]","[8, 8, 8, 16, 16, 16, 16, 39]","[8, 8, 9, 10, 8, 9, 8, 10]","[8, 9, 28, 28, 10, 11, 28, 11]","[1.1887058975751692, 1.1887058975751692, 1.188...","[8, 8, 8, 8, 8, 8, 8, 8]","[9, 9, 9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8, 8, 8]","[9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9]",...,"[9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8, 8, 8]","[8, 8, 8, 8, 8, 8, 8, 8]","[-1.7172581949747072, -1.7172581949747072, -1...."
8,"[8, 9, 10, 11, 12, 13, 14, 15, 16]","[8, 8, 8, 16, 16, 16, 16, 39, 39]","[8, 8, 9, 10, 8, 9, 8, 10, 8]","[8, 9, 28, 28, 10, 11, 28, 11, 11]","[1.1887058975751692, 1.1887058975751692, 1.188...","[8, 8, 8, 8, 8, 8, 8, 8, 8]","[9, 9, 9, 9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8, 8, 8, 8]","[9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9]",...,"[9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8, 8, 8, 8]","[8, 8, 8, 8, 8, 8, 8, 8, 8]","[-1.7172581949747072, -1.7172581949747072, -1...."
9,"[8, 9, 10, 11, 12, 13, 14, 15, 16, 17]","[8, 8, 8, 16, 16, 16, 16, 39, 39, 36]","[8, 8, 9, 10, 8, 9, 8, 10, 8, 10]","[8, 9, 28, 28, 10, 11, 28, 11, 11, 11]","[1.1887058975751692, 1.1887058975751692, 1.188...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8, 8, 8, 8, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 13]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 10]",...,"[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8]","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8]","[-1.7172581949747072, -1.7172581949747072, -1...."


In [78]:
def predict_next_step(model,df):
    model.eval()
    model.cpu()
    preds=[]
    for e in df.values:
        t=torch.stack([tensor(e[c]).float() for c in range(len(list(df)))])
        pred=model(t[None])
        preds.append(pred[0][-1].tolist())
    return np.argmax(np.array(preds),axis=1)



In [79]:
def next_step_measure(preds,ys):
    # Simple accuracy measure
    # Do I have to weight it? Check Paper!
    return (np.array(preds)==np.array(ys)).mean()

In [80]:
x,y=process_data_for_next_step_prediction(test_traces)
preds=predict_next_step(basic_model,x)
next_step_measure(preds,y)

0.6476051902493473

## Suffix Prediction

In [None]:
def process_data_for_suffix_prediction(test,input_cols=None,output_col=3,startIndex=1):
    xs,ys=[],[]
    if input_cols == None: input_cols=list(test)
    input_cols=listify(input_cols)
    for trace in test.values:
        for i in range(startIndex,len(listify(trace[0]))):
            x,y=[],[]
            for c in range(len(input_cols)):
                x.append(trace[c][:i])
            
            xs.append(x)
            ys.append(trace[output_col][i:])
    return pd.DataFrame(xs,columns=input_cols),ys

In [None]:
x,y=process_data_for_suffix_prediction(test_traces)

In [None]:
def predict_suffix(model,df):
    rl=[]

    for x in progress_bar(df.values):
        t=torch.stack([tensor(x[c]).float() for c in range(len(list(df)))])
        p=tensor(-1)
        res=[]
        while p.int()!=3: # 3: eos token
            pred=model(t[None])
            pred=pred[0][-1]
            p=torch.multinomial(torch.softmax(pred,0),1).float()
           # p=torch.argmax(pred,0).float()[None]
            if p.int()!=3 or len(res)==0: res.append(p)
            k=torch.cat((t[3],p))
            t=torch.stack([k for c in range(len(list(df)))])


        res=torch.cat(res,0).int().tolist()
        rl.append(res)
    return rl

In [None]:
def suffix_measure(preds,ys):
    sim=[]
    edits=[]
    for p,y in zip(preds,ys):
        l=max(len(p),len(y))
        d=ed.eval(p,y)
        edits.append(abs(d))
        sim.append(1-(abs(d)/l))
    return np.array(edits).mean(),np.array(edits).min(),np.array(edits).max(),np.array(sim).mean()


In [None]:
preds=predict_suffix(basic_model,x)

In [None]:
mean_edit,min_edit,max_edit,sim=suffix_measure(preds,y)
mean_edit,min_edit,max_edit,sim