In [1]:
import torch
from torch.utils.data import DataLoader

In [2]:
import sys
sys.path.append('/home/joao.pires/MPP/neural_mpp/EM/Refact/')
from models import NormalizingFlow
from sweep import HawkesSweep

In [3]:
hk = HawkesSweep([[1, 2, 3], [4, 5, 6]], 2)

In [None]:
hk.make_dict()

In [70]:
dl = hk.get_events(d, 1, 0, 3)

In [72]:
for _, (X, y) in enumerate(dl):
    print (X,y)
    break

tensor([[1., 2.],
        [2., 3.],
        [3., 4.],
        [1., 2.]]) tensor([0, 0, 0, 1])


In [None]:
class GrangerMPP(nn.Module):

    def __init__(self, Process):

        super().__init__()

        self.Process = Process
        self.n_processes = self.Process.n_processes
        self.memory_dim = self.Process.memory_dim
        self.GrangerMatrix = nn.Parameter((torch.empty(self.n_processes, self.n_processes)))
        nn.init.normal_(self.GrangerMatrix, mean=0.5, std=0.1) # very important

        self.models = nn.ModuleList([NormalizingFlow(num_features = 1, memory_size = self.memory_dim, hidden_dim = 32, num_layers = 4) for i in range(self.n_processes)])
        self.optimizers = [torch.optim.Adam(list(self.models[i].parameters()), lr=1e-4, weight_decay = 1e-5) for i in range(self.n_processes)]
        self.g_optimizer = torch.optim.Adam([self.GrangerMatrix], lr = 1e-3, weight_decay=1e-5)
        self.log_GrangerMatrix = []


    def em_step(self, n_steps):
        
        dic = {}
        self.causes = [[], [], []]
        for i in range(self.n_processes):
            dic[i] = []


        taus = torch.linspace(1, 0.5, steps = n_steps)
        
        for self.step in range(n_steps):
          for i_proc in range(self.n_processes):
              self.causes[i_proc] = []
              curr = processes[i_proc]
              len_curr = len(curr)
              idx_start = 0
              while idx_start < len_curr:
                self.num_events = 5
                events = self.get_events(self.num_events, idx_start, i_proc, 1.0) ## the get_events does the e_step!
                if events:
                  DL = DataLoader(events, batch_size = len(events))

                  for X, cause_rank in DL:
                    X = X.unsqueeze(-1)
                    loss = self.m_step(i_proc, X, cause_rank)
                    dic[i_proc].append(loss)

                idx_start += self.num_events

              if (self.step + 1) % 5 == 0 or self.step == 0:
                  print(f'Step: {self.step + 1}, Model: {i_proc}, Loss: {loss}')


        return dic

    def m_step(self, i_proc, X, cause_rank):

        model = self.models[i_proc]
        self.optimizers[i_proc].zero_grad()
        self.g_optimizer.zero_grad()
        z, logp = model.log_prob(X)
        loss = -1*logp

        loss_rnn = (loss * cause_rank).sum()  + -1*(torch.log(cause_rank + 1e-7)).sum() + 0.001*self.GrangerMatrix[i_proc].norm(p=1)


        if not (torch.isnan(loss_rnn) | torch.isinf(loss_rnn)):


            loss_rnn.backward(retain_graph = True)

            self.optimizers[i_proc].step()
            self.g_optimizer.step()
            self.log_GrangerMatrix.append(self.GrangerMatrix.clone().detach())


        else:
            print(f'NaN found in epoch: {self.step}')

        return loss_rnn.item()

    def new_e_step(self, num_events, i_proc, tau):

      in_ = self.GrangerMatrix[i_proc]#.softmax(dim = 0)
      rv = []
      for i in range(num_events):
        cause = F.gumbel_softmax(
            in_,
            tau = tau,
            hard = False
        )
        rv.append(cause)

      self.causes[i_proc].append(rv)

      return rv




In [6]:
class WoldSweep(Sweep):
    def construct_wold_dict(self):
        dict = {}

        events = []
        for id, process in enumerate(self.processes):
            for t in process:
                events.append((t.item(), id))

        events.sort()

        deltas = {}
        last = {}
        cur = -1

        for t, id in events:
            dict[t] = {}
            deltas[id] = Memory(self.memory_dim)
            last[id] = [0, 0]

        for t, id in events:
            if t != cur:
                
                # updating
                cur = t
                for _id, _delta in deltas.items():
                    dict[cur][_id] = _delta.copy()
            
            last[id][1] = last[id][0]
            last[id][0] = t
            if last[id][1] != 0:
                deltas[id].push(last[id][0] - last[id][1])
            
        return dict
        
    # TODO: check idx_start semantics
    def make_dict(self):
        wold = self.construct_wold_dict()
        dic = {}
        for i in range(self.n_processes):
            target = self.processes[i]
            dic[i] = {}
            #for j in range(self.n_processes):
            #    cause = self.processes[j]
            #    dic[i][j] = self.sweep(target, cause)
            ret = {}
            for _t in target:
                t = _t.item()
                for j in range(self.n_processes):
                    if j not in ret:
                        ret[j] = []
                    #print(t, j, wold[t][j])
                    #return None
                    if len(wold[t][j]) < self.memory_dim:
                        ret[j].append([-1] * self.memory_dim)
                    else:
                        ret[j].append(wold[t][j])

            #return None
            for j in range(self.n_processes):
                dic[i][j] = torch.tensor(ret[j], dtype=torch.float)

        return dic


In [7]:
wd = WoldSweep(torch.tensor([[i for i in range(6)], [i for i in range(0, 6*4, 4)]]), 2)

wd.make_dict()

{0: {0: tensor([[-1., -1.],
          [-1., -1.],
          [-1., -1.],
          [-1., -1.],
          [ 1.,  1.],
          [ 1.,  1.]]),
  1: tensor([[-1., -1.],
          [-1., -1.],
          [-1., -1.],
          [-1., -1.],
          [-1., -1.],
          [-1., -1.]])},
 1: {0: tensor([[-1., -1.],
          [ 1.,  1.],
          [ 1.,  1.],
          [ 1.,  1.],
          [ 1.,  1.],
          [ 1.,  1.]]),
  1: tensor([[-1., -1.],
          [-1., -1.],
          [-1., -1.],
          [-1., -1.],
          [ 4.,  4.],
          [ 4.,  4.]])}}

In [8]:
class GetEvents(Sweep):
    pass



In [None]:
GetEvents