In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List

class GrangerMPP(torch.jit.ScriptModule):

    def __init__(self, processes: List[torch.Tensor], memory_dim: int = 10):
        super().__init__()
        self.processes = processes
        self.memory_dim = memory_dim
        self.n_processes = len(self.processes)
        self.GrangerMatrix = nn.Parameter(torch.Tensor(self.n_processes, self.n_processes))
        self.models = nn.ModuleList([ProbRNN(self.memory_dim) for _ in range(self.n_processes)])
        self.sweep_dict = self.make_sweep_dict()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        self.causes = torch.empty()  # This will now be handled outside of TorchScript methods

    @torch.jit.script_method
    def e_step(self, in_weights: torch.Tensor, points_current_pp: torch.Tensor) -> torch.Tensor:
        """
        Return the causes for each event of the current process.

        Args:
            in_weights (torch.Tensor): The input weights tensor.
            points_current_pp (torch.Tensor): The tensor representing current process points.

        Returns:
            torch.Tensor: A tensor containing the causes for each event.
        """
        n = points_current_pp.size(0)
        rv = torch.empty(n, dtype=torch.long)

        for i in range(n):
            cause = F.gumbel_softmax(in_weights, hard=True)
            cause = torch.argmax(cause, dim=0)
            rv[i] = cause  # Directly assign to the preallocated tensor

        return rv

    @torch.jit.script_method
    def compute_causes(self, n_processes: int, GrangerMatrix: torch.Tensor, processes: List[torch.Tensor]) -> List[torch.Tensor]:
        causes = []  # Use a local variable instead of self.causes
        for i_proc in range(n_processes):
            rv = self.e_step(GrangerMatrix[i_proc], processes[i_proc])
            causes.append(rv)
        return causes  # Return causes instead of setting self.causes

    def em_step(self, n_steps: int):
        dic = {i: [] for i in range(self.n_processes)}

        for step in range(n_steps):
            # Use compute_causes and assign to self.causes outside the TorchScript method
            self.causes = self.compute_causes(self.n_processes, self.GrangerMatrix, self.processes)

            for i_proc in range(self.n_processes):
                causes_to_ith = self.causes[i_proc]  # causes of ith_proc

                for j, cause_to_ith in enumerate(causes_to_ith):
                    cause_to_ith = cause_to_ith.item()
                    effect_j_on_i = self.sweep_dict[i_proc][cause_to_ith]

                    if (cause_to_ith == i_proc) and j >= self.memory_dim:
                        X_to_pass = self.processes[i_proc][j - self.memory_dim : j]
                        X_to_pass = X_to_pass.flip(dims=(0,)) - X_to_pass[0]
                        loss = self.m_step(i_proc, X_to_pass.unsqueeze(0))
                        dic[i_proc].append(loss)
                    elif len(effect_j_on_i) > j:
                        X_to_pass = effect_j_on_i[j]
                        loss = self.m_step(i_proc, X_to_pass.unsqueeze(0))
                        dic[i_proc].append(loss)

                if (step + 1) % 25 == 0 or step == 0:
                    print(f'Step: {step + 1}, Model: {i_proc}, Loss: {loss}')

        return dic

    def m_step(self, i_proc_: int, X: torch.Tensor) -> float:
        model = self.models[i_proc_]

        self.optimizer.zero_grad()
        z, loss = model(X)
        loss = -1 * loss
        loss = loss.sum()

        if not (torch.isnan(loss) | torch.isinf(loss)):
            loss.backward()
            self.optimizer.step()
        else:
            print(f'NaN found in epoch')

        return loss.item()

    def make_sweep_dict(self):
        dic = {}
        for i in range(self.n_processes):
            target = self.processes[i]
            dic[i] = {}
            for j in range(self.n_processes):
                cause = self.processes[j]
                dic[i][j] = self.sweep(target, cause)
        return dic

    def sweep(self, pa: torch.Tensor, pc: torch.Tensor) -> torch.Tensor:
        events = []
        for ia in pa:
            events.append((ia, 'a'))
        for ic in pc:
            events.append((ic, 'c'))

        lim = self.memory_dim

        events.sort()
        mem = []
        ret = []
        for t, e in events:
            if e == 'c':
                if len(mem) >= lim:
                    mem.pop(0)
                mem.append(t)

            if e == 'a':
                if len(mem) < lim:
                    continue
                pp = [t - tc for tc in mem]
                ret.append(pp)

        return torch.tensor(ret, dtype=torch.float)

from typing import Tuple

class ProbRNN(torch.jit.ScriptModule):
    def __init__(self, memory_size: int):
        super().__init__()

        self.memory_size = memory_size
        self.linear = nn.Sequential(
            nn.Linear(self.memory_size, 64),
            nn.Tanh()
        )
        self.lstm = nn.LSTM(input_size=64, hidden_size=128, num_layers=2, batch_first=True)
        self.linear_mu = nn.Sequential(nn.Linear(128, self.memory_size))
        self.linear_std = nn.Sequential(nn.Linear(128, self.memory_size))
        # Using nn.Parameter to make sure these weights are correctly registered in TorchScript
        self.gmm_weights = nn.Parameter(torch.softmax(torch.rand(1, self.memory_size), dim=1))
    
    @torch.jit.script_method
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.linear(x)
        x, _ = self.lstm(x)
        mu = self.linear_mu(x)
        std = torch.abs(self.linear_std(x))

        new_X = self.sample(mu, std)
        log_prob = self.compute_log_prob(new_X, mu, std)

        return new_X, log_prob

    
    def compute_log_prob(self, x: torch.Tensor, mus: torch.Tensor, stds: torch.Tensor) -> torch.Tensor:
        # Manually compute the log probability for each component
        component_log_probs = -0.5 * (((x - mus) / stds) ** 2 + 2 * torch.log(stds) + torch.log(torch.tensor(2 * torch.pi)))
        # LogSumExp trick for computing log probabilities of the mixture
        log_weights = torch.log(self.gmm_weights)
        log_prob = torch.logsumexp(component_log_probs + log_weights, dim=-1)
        return log_prob

    def sample(self, mu: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
        return mu + torch.randn_like(std) * std

In [16]:
model = GrangerMPP(processes=[torch.rand(10) for _ in range(5)], memory_dim=2)
scripted_model = torch.jit.script(model)


In [17]:
model.em_step(1)

IndexError: list index out of range