In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import math

from torch.nn import LayerNorm, BatchNorm1d

In [2]:
def drop_edge(edge_index, edge_attr, drop_prob=0.2):
    if not self.training or drop_prob == 0.0:
        return edge_index, edge_attr

    mask = torch.rand(edge_index.size(1)) > drop_prob
    return edge_index[:, mask], edge_attr[mask]

In [3]:
class EGATLayer(nn.Module):
    def __init__(self, node_dim, edge_dim, out_dim):
        super().__init__()
        self.node_proj = nn.Linear(node_dim, out_dim)
        self.edge_proj = nn.Linear(edge_dim, out_dim)
        self.attn = nn.Linear(3 * out_dim, 1)
        self.norm = LayerNorm(out_dim)

    def forward(self, x, edge_index, edge_attr):
        src, dst = edge_index
        h_src = self.node_proj(x[src])
        h_dst = self.node_proj(x[dst])
        e = self.edge_proj(edge_attr)

        attn_input = torch.cat([h_src, h_dst, e], dim=-1)
        alpha = torch.sigmoid(self.attn(attn_input))

        messages = alpha * (h_src + e)
        out = torch.zeros_like(self.node_proj(x))
        out.index_add_(0, dst, messages)

        return self.norm(out)

## Path Encoder

In [4]:
class CausalPathEncoder(nn.Module):
    def __init__(self, dim, dropout=0.3):
        super().__init__()
        self.fc = nn.Linear(dim, dim)
        self.dropout = dropout

    def forward(self, h):
        if self.training:
            mask = torch.rand(h.size(0)) > self.dropout
            h = h * mask.unsqueeze(-1)

        return F.relu(self.fc(h)) + h   # residual

## Temporal Dropout prevents overfitting

In [5]:
class TemporalDropout(nn.Module):
    def __init__(self, p=0.2):
        super().__init__()
        self.p = p

    def forward(self, x):
        if not self.training:
            return x
        mask = torch.rand(x.size(1)) > self.p
        return x[:, mask, :]


##Multi-Scale Temporal Encoder

In [6]:
class MultiScaleTemporal(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.cnn = nn.Conv1d(dim, dim, kernel_size=2)
        self.lstm = nn.LSTM(dim, dim, batch_first=True)
        self.gru = nn.GRU(dim, dim, batch_first=True)

        self.alpha = nn.Parameter(torch.ones(3))

    def forward(self, x):
        # x: [B, T, D]
        cnn_out = self.cnn(x.transpose(1, 2)).mean(-1)
        lstm_out, _ = self.lstm(x)
        lstm_out = lstm_out[:, -1]

        gru_out, _ = self.gru(x)
        attn = torch.softmax(gru_out.mean(-1), dim=1)
        gru_out = (gru_out * attn.unsqueeze(-1)).sum(1)

        weights = torch.softmax(self.alpha, dim=0)
        return (
            weights[0] * cnn_out +
            weights[1] * lstm_out +
            weights[2] * gru_out
        )


## Attention with confidence cap

In [7]:
class MobilityAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(1, dim)

    def forward(self, h, speed):
        gate = torch.sigmoid(self.fc(speed.unsqueeze(-1)))
        gate = torch.clamp(gate, 0.3, 1.0)
        return h * gate


In [8]:
class Bottleneck(nn.Module):
    def __init__(self, dim, bottleneck_dim):
        super().__init__()
        self.proj = nn.Linear(dim, bottleneck_dim)
        self.bn = BatchNorm1d(bottleneck_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, h):
        return self.dropout(self.bn(self.proj(h)))


In [9]:
class FailureHead(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1)

    def forward(self, h):
        return torch.sigmoid(self.fc(h))


class TypeHead(nn.Module):
    def __init__(self, dim, num_classes=4):
        super().__init__()
        self.fc = nn.Linear(dim, num_classes)

    def forward(self, h):
        return self.fc(h)


class TimeHead(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1)

    def forward(self, h):
        return self.fc(h)


class UncertaintyHead(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1)

    def forward(self, h):
        return torch.clamp(F.softplus(self.fc(h)), min=1e-3)


In [10]:
class TGNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.egat = EGATLayer(15, 3, 64)
        self.causal = CausalPathEncoder(64)
        self.temporal = MultiScaleTemporal(64)
        self.mobility = MobilityAttention(64)
        self.bottleneck = Bottleneck(64, 32)

        self.fail_head = FailureHead(32)
        self.type_head = TypeHead(32)
        self.time_head = TimeHead(32)
        self.uncertainty_head = UncertaintyHead(32)

    def forward(self, x_seq, edge_index, edge_attr, speed):
        h_seq = []
        for x in x_seq:
            h = self.egat(x, edge_index, edge_attr)
            h = self.causal(h)
            h_seq.append(h)

        h_seq = torch.stack(h_seq, dim=1)
        h = self.temporal(h_seq)
        h = self.mobility(h, speed)
        h = self.bottleneck(h)

        return {
            "failure": self.fail_head(h),
            "type": self.type_head(h),
            "time": self.time_head(h),
            "uncertainty": self.uncertainty_head(h)
        }


In [11]:
def uncertainty_loss(y, mu, sigma):
    return ((y - mu) ** 2) / (2 * sigma ** 2) + torch.log(sigma)


In [12]:
def total_loss(out, y_fail, y_type, y_time,
               λ1=1.0, λ2=0.5, λ3=0.5, λ4=0.2, λ5=0.2):

    L_fail = F.binary_cross_entropy(out["failure"], y_fail)
    L_type = F.cross_entropy(out["type"], y_type)
    L_time = F.smooth_l1_loss(out["time"], y_time)
    L_unc = uncertainty_loss(y_time, out["time"], out["uncertainty"]).mean()

    return λ1*L_fail + λ2*L_type + λ3*L_time + λ5*L_unc


In [13]:
def train_step(model, batch, optimizer):
    model.train()
    optimizer.zero_grad()

    out = model(**batch)
    loss = total_loss(out,
                      batch["y_fail"],
                      batch["y_type"],
                      batch["y_time"])

    loss.backward()
    optimizer.step()
    return loss.item()
