# Notebook

## Setup

In [2]:
%load_ext kedro.ipython

The kedro.ipython extension is already loaded. To reload it, use:
  %reload_ext kedro.ipython


In [3]:
from itertools import product

import numpy as np
import pandas as pd
import torch
from pomegranate.distributions import Categorical
from pomegranate.hmm import DenseHMM, SparseHMM
from torch.masked import as_masked_tensor
from torch.nn.utils.rnn import pad_sequence

# np.set_printoptions(threshold=np.inf, linewidth=150, precision=3)

## Load Datasets

In [4]:
alarms = catalog.load("clean_alarms")
status = catalog.load("clean_station_overviews")
# sessions = catalog.load("clean_charging_sessions")

sn_map = catalog.load("map_serial_number")
# ss_map = catalog.load("map_station_status")
# ns_map = catalog.load("map_network_status")
# a_map = catalog.load("map_alarm")

domain_alarms = catalog.load("domain_alarms")
domain_status = catalog.load("domain_status")
domain_observations = catalog.load("domain_observations")

## Construct $\mathcal{S}$, $\mathcal{A}$, $R$, , and $\mathcal{O}$

### Define State Space $\mathcal{S}$

In [5]:
# State space S
S_F = pd.Series(["A", "F", "N"], name="f")
S_C = pd.Series(["R", "U"], name="c")
E = pd.Series(["FR", "FC", "PU", "PD", "NR", "NL", "OE", "⊥"], name="e")

f, c, e = zip(*product(S_F, S_C, E))
S = pd.concat([pd.Series(f, name="f"), pd.Series(c, name="c"), pd.Series(e, name="e")], axis=1)
# S.loc[:, "is_dysfunctional"] = False
# S.loc[(S.f == "F") | (S.c == "U"), "is_dysfunctional"] = True
S

Unnamed: 0,f,c,e
0,A,R,FR
1,A,R,FC
2,A,R,PU
3,A,R,PD
4,A,R,NR
5,A,R,NL
6,A,R,OE
7,A,R,⊥
8,A,U,FR
9,A,U,FC


### Define Action Space $\mathcal{A}$

In [6]:
A = pd.Series(["Do Nothing", "Alert Operator"], name="action")
A


[1;36m0[0m        Do Nothing
[1;36m1[0m    Alert Operator
Name: action, dtype: object

### Define Reward Function $R$

In [7]:
# Dysfunctional states
D = S.loc[(S.f == "F") | (S.c == "U")]
D

Unnamed: 0,f,c,e
8,A,U,FR
9,A,U,FC
10,A,U,PU
11,A,U,PD
12,A,U,NR
13,A,U,NL
14,A,U,OE
15,A,U,⊥
16,F,R,FR
17,F,R,FC


In [8]:
# Reward function R
R = np.zeros((len(S), len(A)))
for (s, a), _ in np.ndenumerate(R):
    if s in D.index and a == 1:
        R[s, a] = 10  # Reward for alerting operator in dysfunctional state
    elif s not in D.index and a == 1:
        R[s, a] = -10  # Penalty for alerting operator in functional state
    elif s in D.index and a == 0:
        R[s, a] = -20  # Penalty for not alerting operator in dysfunctional state
R


[1;35marray[0m[1m([0m[1m[[0m[1m[[0m  [1;36m0[0m., [1;36m-10[0m.[1m][0m,
       [1m[[0m  [1;36m0[0m., [1;36m-10[0m.[1m][0m,
       [1m[[0m  [1;36m0[0m., [1;36m-10[0m.[1m][0m,
       [1m[[0m  [1;36m0[0m., [1;36m-10[0m.[1m][0m,
       [1m[[0m  [1;36m0[0m., [1;36m-10[0m.[1m][0m,
       [1m[[0m  [1;36m0[0m., [1;36m-10[0m.[1m][0m,
       [1m[[0m  [1;36m0[0m., [1;36m-10[0m.[1m][0m,
       [1m[[0m  [1;36m0[0m., [1;36m-10[0m.[1m][0m,
       [1m[[0m[1;36m-20[0m.,  [1;36m10[0m.[1m][0m,
       [1m[[0m[1;36m-20[0m.,  [1;36m10[0m.[1m][0m,
       [1m[[0m[1;36m-20[0m.,  [1;36m10[0m.[1m][0m,
       [1m[[0m[1;36m-20[0m.,  [1;36m10[0m.[1m][0m,
       [1m[[0m[1;36m-20[0m.,  [1;36m10[0m.[1m][0m,
       [1m[[0m[1;36m-20[0m.,  [1;36m10[0m.[1m][0m,
       [1m[[0m[1;36m-20[0m.,  [1;36m10[0m.[1m][0m,
       [1m[[0m[1;36m-20[0m.,  [1;36m10[0m.[1m][0m,
       [1m[[0m[1;36m-20[

### Define Observation Space $\mathcal{O}$

In [9]:
#  Status Observations
O_S = sorted(domain_status.state.unique())
display(O_S)
display(len(O_S))


[1m[[0m
    [32m"[0m[32m([0m[32m'Available', 'Reachable'[0m[32m)[0m[32m"[0m,
    [32m"[0m[32m([0m[32m'Faulted', 'Reachable'[0m[32m)[0m[32m"[0m,
    [32m"[0m[32m([0m[32m'Unavailable', 'Reachable'[0m[32m)[0m[32m"[0m,
    [32m"[0m[32m([0m[32m'Unavailable', '⊥'[0m[32m)[0m[32m"[0m,
    [32m"[0m[32m([0m[32m'Unreachable', 'Reachable'[0m[32m)[0m[32m"[0m,
    [32m"[0m[32m([0m[32m'Unreachable', 'Unreachable'[0m[32m)[0m[32m"[0m
[1m][0m

[1;36m6[0m

In [10]:
# Log Observations
O_L = sorted(domain_alarms.alarm.unique())
display(O_L)
display(len(O_L))


[1m[[0m
    [32m'Boot up'[0m,
    [32m'Bootup Due to POWER ON'[0m,
    [32m'Bootup Due to SOFT RESET'[0m,
    [32m'Bootup Due to SWITCH'[0m,
    [32m'Bootup Due to WATCHDOG'[0m,
    [32m'Circuit Sharing Current Reduced'[0m,
    [32m'Circuit Sharing Current Restored'[0m,
    [32m'Data Partition Full'[0m,
    [32m'Earth Fault Station In Service'[0m,
    [32m'Earth Fault Station Out Of Service'[0m,
    [32m'Fault Cleared'[0m,
    [32m'GFCI Hard Trip'[0m,
    [32m'IP Mismatch Detected'[0m,
    [32m'Maintenance Required'[0m,
    [32m'Manual Intervention'[0m,
    [32m'Over Current Hard Trip Detected'[0m,
    [32m'Pilot Unreachable [0m[32m([0m[32m18[0m[32m)[0m[32m'[0m,
    [32m'Pilot current level exceeded'[0m,
    [32m'Powered Off'[0m,
    [32m'RFID Update Failed'[0m,
    [32m'Reachable'[0m,
    [32m'Relay Stuck Close'[0m,
    [32m'Station Not Activated'[0m,
    [32m'Unknown RFID'[0m,
    [32m'Unreachable'[0m
[1m][0m

[1;36m25[0m

In [11]:
# Observation space
O = pd.Series(O_S + O_L, dtype="category", name="O")
O


[1;36m0[0m             [1m([0m[32m'Available'[0m, [32m'Reachable'[0m[1m)[0m
[1;36m1[0m               [1m([0m[32m'Faulted'[0m, [32m'Reachable'[0m[1m)[0m
[1;36m2[0m           [1m([0m[32m'Unavailable'[0m, [32m'Reachable'[0m[1m)[0m
[1;36m3[0m                   [1m([0m[32m'Unavailable'[0m, [32m'⊥'[0m[1m)[0m
[1;36m4[0m           [1m([0m[32m'Unreachable'[0m, [32m'Reachable'[0m[1m)[0m
[1;36m5[0m         [1m([0m[32m'Unreachable'[0m, [32m'Unreachable'[0m[1m)[0m
[1;36m6[0m                                Boot up
[1;36m7[0m                 Bootup Due to POWER ON
[1;36m8[0m               Bootup Due to SOFT RESET
[1;36m9[0m                   Bootup Due to SWITCH
[1;36m10[0m                Bootup Due to WATCHDOG
[1;36m11[0m       Circuit Sharing Current Reduced
[1;36m12[0m      Circuit Sharing Current Restored
[1;36m13[0m                   Data Partition Full
[1;36m14[0m        Earth Fault Station In Service
[1;36m15[0m  

## State Transition and Observation Emission Dynamics

### Initial Estimation

#### Transition Dynamics $T$

In [12]:
# Non-event states
S_bot = S.loc[S.e == "⊥"]
S_bot

Unnamed: 0,f,c,e
7,A,R,⊥
15,A,U,⊥
23,F,R,⊥
31,F,U,⊥
39,N,R,⊥
47,N,U,⊥


In [13]:
# Transition Dynamics T
T_0 = np.ones((len(S), len(S)))
for (s, s_next), _ in np.ndenumerate(T_0):
    if s in S_bot.index and s_next in S_bot.index and s != s_next:
        T_0[s, s_next] = 0.0  # No transitions between non-event states
    elif s not in S_bot.index and s_next not in S_bot.index:
        T_0[s, s_next] = 0.0  # No transitions between event states

T_0 /= T_0.sum(axis=1, keepdims=True)

with np.printoptions(threshold=np.inf, linewidth=150, precision=3):
    display(T_0)


[1;35marray[0m[1m([0m[1m[[0m[1m[[0m[1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0.167[0m, [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0.167[0m, [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   ,
        [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0.167[0m, [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0.167[0m, [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0.167[0m,
        [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.   , [1;36m0.167[0m[1m][0m,
       [1m[[0m[1;36m0[0m.   , [1;36m0[0m.   , [1;36m0[0m.  

#### Observation Emission Dynamics $Z$

In [14]:
# P(s)
P_s = np.zeros(len(S))
for (s,), _ in np.ndenumerate(P_s):
    if s in S_bot.index:
        P_s[s] = 0.9 / len(S_bot)  # Higher probability for non-event states
    elif s not in S_bot.index:
        P_s[s] = 0.1 / (len(S) - len(S_bot))  # Lower probability for event states
P_s


[1;35marray[0m[1m([0m[1m[[0m[1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m,
       [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.15[0m      , [1;36m0.00238095[0m, [1;36m0.00238095[0m,
       [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m,
       [1;36m0.15[0m      , [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m,
       [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.15[0m      , [1;36m0.00238095[0m,
       [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m,
       [1;36m0.00238095[0m, [1;36m0.15[0m      , [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m,
       [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.00238095[0m, [1;36m0.15[0m    

In [15]:
# P(o)
P_o = np.concatenate(
    [0.9 * np.ones(len(O_S)) / len(O_S), 0.1 * np.ones(len(O_L)) / len(O_L)]
)  # Status observations more likely
P_o


[1;35marray[0m[1m([0m[1m[[0m[1;36m0.15[0m , [1;36m0.15[0m , [1;36m0.15[0m , [1;36m0.15[0m , [1;36m0.15[0m , [1;36m0.15[0m , [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m,
       [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m,
       [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m,
       [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m, [1;36m0.004[0m[1m][0m[1m)[0m

In [16]:
# P(s | o)
# Define P(s | o) based on domain knowledge

P_s_o = np.loadtxt("P_s_o.csv", delimiter=",")
P_s_o /= P_s_o.sum(axis=1, keepdims=True)

P_s_o


[1;35marray[0m[1m([0m[1m[[0m[1m[[0m[1;36m0[0m.        , [1;36m0[0m.        , [1;36m0[0m.        , [33m...[0m, [1;36m0[0m.        , [1;36m0[0m.        ,
        [1;36m0[0m.        [1m][0m,
       [1m[[0m[1;36m0[0m.        , [1;36m0[0m.        , [1;36m0[0m.        , [33m...[0m, [1;36m0[0m.        , [1;36m0[0m.        ,
        [1;36m0[0m.        [1m][0m,
       [1m[[0m[1;36m0[0m.        , [1;36m0[0m.        , [1;36m0[0m.        , [33m...[0m, [1;36m0[0m.        , [1;36m0[0m.        ,
        [1;36m0[0m.        [1m][0m,
       [33m...[0m,
       [1m[[0m[1;36m0.02380952[0m, [1;36m0.02380952[0m, [1;36m0.02380952[0m, [33m...[0m, [1;36m0.02380952[0m, [1;36m0.02380952[0m,
        [1;36m0[0m.        [1m][0m,
       [1m[[0m[1;36m0.02380952[0m, [1;36m0.02380952[0m, [1;36m0.02380952[0m, [33m...[0m, [1;36m0.02380952[0m, [1;36m0.02380952[0m,
        [1;36m0[0m.        [1m][0m,
       [1m[[0m[1;36m0[

In [17]:
# P(o | s) = P(s | o) * P(o) / P(s)
# Unnormalized
P_o_s_unnormalized = np.zeros((len(S), len(O)))
for (s, X), _ in np.ndenumerate(P_o_s_unnormalized):
    P_o_s_unnormalized[s, X] = P_s_o[X, s] * P_o[X] / P_s[s]
P_o_s_unnormalized


[1;35marray[0m[1m([0m[1m[[0m[1m[[0m[1;36m0[0m.  , [1;36m0[0m.  , [1;36m0[0m.  , [33m...[0m, [1;36m0.04[0m, [1;36m0.04[0m, [1;36m0[0m.  [1m][0m,
       [1m[[0m[1;36m0[0m.  , [1;36m0[0m.  , [1;36m0[0m.  , [33m...[0m, [1;36m0.04[0m, [1;36m0.04[0m, [1;36m0[0m.  [1m][0m,
       [1m[[0m[1;36m0[0m.  , [1;36m0[0m.  , [1;36m0[0m.  , [33m...[0m, [1;36m0.04[0m, [1;36m0.04[0m, [1;36m0[0m.  [1m][0m,
       [33m...[0m,
       [1m[[0m[1;36m0[0m.  , [1;36m0[0m.  , [1;36m0[0m.  , [33m...[0m, [1;36m0.04[0m, [1;36m0.04[0m, [1;36m0.28[0m[1m][0m,
       [1m[[0m[1;36m0[0m.  , [1;36m0[0m.  , [1;36m0[0m.  , [33m...[0m, [1;36m0.04[0m, [1;36m0.04[0m, [1;36m0[0m.  [1m][0m,
       [1m[[0m[1;36m0[0m.  , [1;36m0[0m.  , [1;36m0[0m.  , [33m...[0m, [1;36m0[0m.  , [1;36m0[0m.  , [1;36m0[0m.  [1m][0m[1m][0m, [33mshape[0m=[1m([0m[1;36m48[0m, [1;36m31[0m[1m)[0m[1m)[0m

In [18]:
# Normalize P(o | s)
P_o_s = P_o_s_unnormalized / P_o_s_unnormalized.sum(axis=1, keepdims=True)
P_o_s


[1;35marray[0m[1m([0m[1m[[0m[1m[[0m[1;36m0[0m.        , [1;36m0[0m.        , [1;36m0[0m.        , [33m...[0m, [1;36m0.01351351[0m, [1;36m0.01351351[0m,
        [1;36m0[0m.        [1m][0m,
       [1m[[0m[1;36m0[0m.        , [1;36m0[0m.        , [1;36m0[0m.        , [33m...[0m, [1;36m0.09090909[0m, [1;36m0.09090909[0m,
        [1;36m0[0m.        [1m][0m,
       [1m[[0m[1;36m0[0m.        , [1;36m0[0m.        , [1;36m0[0m.        , [33m...[0m, [1;36m0.05555556[0m, [1;36m0.05555556[0m,
        [1;36m0[0m.        [1m][0m,
       [33m...[0m,
       [1m[[0m[1;36m0[0m.        , [1;36m0[0m.        , [1;36m0[0m.        , [33m...[0m, [1;36m0.09090909[0m, [1;36m0.09090909[0m,
        [1;36m0.63636364[0m[1m][0m,
       [1m[[0m[1;36m0[0m.        , [1;36m0[0m.        , [1;36m0[0m.        , [33m...[0m, [1;36m0.25[0m      , [1;36m0.25[0m      ,
        [1;36m0[0m.        [1m][0m,
       [1m[[0m[1;36m0[

In [29]:
import pickle
pickle.dump((T_0, P_o_s), open("initial-estimates.pkl", "wb"))

### Data-Driven Refinement

In [22]:
N = len(S)
K = len(O)

torch_device = torch.device("cpu")
# torch_device = torch.device("cuda")

# Step 1: Create categorical distributions from P_o_s
Z = torch.tensor(P_o_s, dtype=torch.float)
# dists = [Categorical(torch.tensor(P_o_s[s].reshape(1, -1), dtype=torch.float)).to(torch_device) for s in S.index]
dists = []
for s in range(N):
    # Get the observation probabilities for state s
    init_probs = Z[s].reshape(1, K)

    # Replace zero probabilities with a small value to avoid issues
    init_probs[init_probs == 0] = 1e-6

    # Normalize so they sum to 1.0
    init_probs = init_probs / init_probs.sum(dim=1, keepdim=True)

    # Create the distribution object
    dists.append(Categorical(probs=init_probs).to(torch_device))
[d.probs.round(decimals=3) for d in dists]


[1m[[0m
    [1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m[1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m,
         [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0950[0m, [1;36m0.0950[0m, [1;36m0.0950[0m, [1;36m0.0000[0m, [1;36m0.0950[0m,
         [1;36m0.0140[0m, [1;36m0.0950[0m, [1;36m0.0140[0m, [1;36m0.0950[0m, [1;36m0.0950[0m, [1;36m0.0950[0m, [1;36m0.0000[0m, [1;36m0.0950[0m, [1;36m0.0000[0m,
         [1;36m0.0950[0m, [1;36m0.0140[0m, [1;36m0.0140[0m, [1;36m0.0000[0m[1m][0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m[1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m,
         [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000[0m, [1;36m0.0000

In [23]:
# Prepare Observation Sequences
domain_observations["o"] = domain_observations.observation.astype(O.dtype).cat.codes
cid_observations = domain_observations[["cid", "o"]].groupby("cid").agg(list).o.values

# Convert observation sequences to torch tensors
cid_observations_torch = [torch.tensor(x, dtype=torch.int) for x in cid_observations]

# Select training data (80% of sequences)
num_batches = int(len(cid_observations_torch) * 0.8)  # Use 80% of data for training
observations_batches = cid_observations_torch[:num_batches]

# Pad sequences to create a batch tensor
padded_o = pad_sequence(observations_batches, batch_first=True, padding_value=-1)

# Create a mask to ignore padding during training
mask = padded_o != -1

# Convert to masked tensor
masked_o = as_masked_tensor(padded_o, mask)

# Add an extra dimension to match expected input shape (batch_size, seq_length, 1)
X_train = masked_o.unsqueeze(-1)
print(f"Shape of training data: {X_train.shape}")

Shape of training data: torch.Size([82, 3583, 1])


In [28]:
for x in observations_batches[0].numpy():
    print(x)

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
30
18
26
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
30
26
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
16
8
16
0
0
0
0
30
5
5
5
26
16
6
16
0
5
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
30
5
5
5
5
30
5
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
30
5
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
30
5
16
6
16
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
30
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
30
18
30
5
5
30
30
5
0
0

In [None]:
# Initialize HMM
T_0_train = torch.tensor(T_0, dtype=torch.float)
T_0_train = T_0_train / T_0_train.sum(dim=1, keepdim=True)

# 'verbose=True' helps you see the improvement during training.
model = DenseHMM(dists, edges=T_0_train, verbose=True).to(torch_device)
# model = DenseHMM(dists, verbose=True).to(torch_device)

In [None]:
# Train
model.fit(X_train.to(torch_device))

In [None]:
# Save model
torch.save(model, "hmm_model.pt")

In [None]:
# Get the final observation emissions matrix
O_final = torch.stack([d.probs for d in model.distributions])
O_final

In [None]:
# Get the final Transition Matrix
T_final = torch.exp(model.edges)
T_final

## Save

In [None]:
export = (S, A, R, O, T_final.cpu().numpy(), O_final.cpu().numpy().squeeze(), cid_observations)
import pickle
pickle.dump(export, open("data.pkl", "wb"))