In [None]:
import torch
path = r"C:\Users\MainUser\project\cs224w_cb_graph\info_folder\fed_speech_graph.pt"
data = torch.load(path, map_location="cpu", weights_only=False)

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.nn import HeteroConv, SAGEConv, Linear

class HeteroSpeechRegressor(nn.Module):
    def __init__(self, metadata, hidden_dim=128):
        super().__init__()
        edge_types = metadata[1]

        self.conv1 = HeteroConv({et: SAGEConv((-1, -1), hidden_dim)
                                 for et in edge_types}, aggr='sum')
        self.conv2 = HeteroConv({et: SAGEConv((-1, -1), hidden_dim)
                                 for et in edge_types}, aggr='sum')
        self.head = Linear(hidden_dim, 1)

    def _apply_layer(self, layer, x_dict, edge_index_dict):
        out = layer(x_dict, edge_index_dict)
        if isinstance(out, tuple):
            out = out[0]
        for nt, x in x_dict.items():
            if nt not in out:
                out[nt] = x
        out = {k: F.relu(v) for k, v in out.items()}
        return out

    def forward(self, x_dict, edge_index_dict):
        h = self._apply_layer(self.conv1, x_dict, edge_index_dict)
        h = self._apply_layer(self.conv2, h, edge_index_dict)
        return self.head(h['speech']).squeeze(-1)



In [None]:
mask_finite = torch.isfinite(data['speech'].y)
data['speech'].y = data['speech'].y[mask_finite]
data['speech'].x = data['speech'].x[mask_finite]
data['speech'].date = data['speech'].date[mask_finite]

In [None]:
print("data['speech'].num_nodes:", data['speech'].num_nodes)
print("len(speech_dates):", len(data['speech'].date))
print("len(data['speech'].y):", len(data['speech'].y))


data['speech'].num_nodes: 53
len(speech_dates): 53
len(data['speech'].y): 53


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HeteroSpeechRegressor(data.metadata(), hidden_dim=128).to(device)
data = data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = nn.MSELoss()


  self.conv1 = HeteroConv({et: SAGEConv((-1, -1), hidden_dim)
  self.conv2 = HeteroConv({et: SAGEConv((-1, -1), hidden_dim)


In [None]:
from torch_geometric.transforms import ToUndirected
data = ToUndirected()(data)
for nt in data.node_types:
    if not hasattr(data[nt], 'x'):
        N = data[nt].num_nodes
        data[nt].x = torch.zeros((N, 1), dtype=torch.float)
    if not hasattr(data[nt], 'num_nodes') or data[nt].num_nodes is None:
        data[nt].num_nodes = data[nt].x.size(0)


In [None]:
import numpy as np
import torch

N = data['speech'].num_nodes
idx_sorted = np.argsort(data["speech"].date)

train_end = int(0.7 * N)
val_end   = int(0.85 * N)

train_idx = torch.tensor(idx_sorted[:train_end], dtype=torch.long)
val_idx   = torch.tensor(idx_sorted[train_end:val_end], dtype=torch.long)
test_idx  = torch.tensor(idx_sorted[val_end:], dtype=torch.long)

train_mask = torch.zeros(N, dtype=torch.bool)
val_mask   = torch.zeros(N, dtype=torch.bool)
test_mask  = torch.zeros(N, dtype=torch.bool)

train_mask[train_idx] = True
val_mask[val_idx]     = True
test_mask[test_idx]   = True

data['speech'].train_mask = train_mask
data['speech'].val_mask   = val_mask
data['speech'].test_mask  = test_mask


In [None]:
for epoch in range(200):

    model.train()
    optimizer.zero_grad()

    pred = model(data.x_dict, data.edge_index_dict)
    loss = loss_fn(pred[data['speech'].train_mask], data['speech'].y[data['speech'].train_mask])
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        model.eval()
        with torch.no_grad():
            val_pred = model(data.x_dict, data.edge_index_dict)
            val_loss = loss_fn(val_pred[data['speech'].val_mask], data['speech'].y[data['speech'].val_mask]).item()

        print(f"Epoch {epoch:03d} | Train Loss {loss.item():.4f} | Val Loss {val_loss:.4f}")


Epoch 000 | Train Loss 0.0084 | Val Loss 0.0098
Epoch 020 | Train Loss 0.0004 | Val Loss 0.0027
Epoch 040 | Train Loss 0.0000 | Val Loss 0.0030
Epoch 060 | Train Loss 0.0000 | Val Loss 0.0030
Epoch 080 | Train Loss 0.0000 | Val Loss 0.0031
Epoch 100 | Train Loss 0.0000 | Val Loss 0.0031
Epoch 120 | Train Loss 0.0000 | Val Loss 0.0029
Epoch 140 | Train Loss 0.0000 | Val Loss 0.0032
Epoch 160 | Train Loss 0.0000 | Val Loss 0.0030
Epoch 180 | Train Loss 0.0000 | Val Loss 0.0031


In [None]:
model.eval()
with torch.no_grad():
    pred = model(data.x_dict, data.edge_index_dict)
    test_loss = loss_fn(pred[data['speech'].test_mask], data['speech'].y[data['speech'].test_mask]).item()
print(f"Test MSE: {test_loss:.6f}")


Test MSE: 0.002906


In [None]:
from torch.nn import Linear, ModuleDict, ModuleList, Sequential, ReLU, LSTM
from torch_geometric.nn import GATv2Conv

class hetero_attn(torch.nn.Module):
  #uses GATv2Conv for each of the layers before prediction heads
  def __init__ (self, graph_data, hidden, heads, layers):
    super().__init__()

    #dict mapping node types to learnable linear transformations of appropriate
    #dimension

    self.by_type = ModuleDict()
    self.by_type["speech"] = Linear(graph_data["speech"].x.shape[1], hidden)
    self.by_type["topic"] = Linear(graph_data["topic"].x.shape[1], hidden)
    self.by_type["speaker"] = Linear(graph_data["speaker"].x.shape[1], hidden)

    self.convs = torch.nn.ModuleList()
    for i in range(layers):
      layer = HeteroConv({
          ("speech", "discusses", "topic"): GATv2Conv(-1, hidden, heads=heads, edge_dim=1, add_self_loops=False),
          ("speech", "follows", "speech"): GATv2Conv(-1, hidden, heads=heads, edge_dim=1, add_self_loops=False),
          ("speaker", "authored", "speech"): GATv2Conv(-1, hidden, heads=heads, edge_dim=1, add_self_loops=False),
          ("speaker", "influences", "topic"): GATv2Conv(-1, hidden, heads=heads, edge_dim=1, add_self_loops=False),

          ("topic", "rev_discusses", "speech"): GATv2Conv(-1, hidden, heads=heads, edge_dim=1, add_self_loops=False),
          ("speech", "rev_follows", "speech"): GATv2Conv(-1, hidden, heads=heads, edge_dim=1, add_self_loops=False),
          ("speech", "rev_authored", "speaker"): GATv2Conv(-1, hidden, heads=heads, edge_dim=1, add_self_loops=False),
          ("topic", "rev_influences", "speaker"): GATv2Conv(-1, hidden, heads=heads, edge_dim=1, add_self_loops=False),
      }, aggr="mean")

      self.convs.append(layer)

    #output layer is a matrix mapping from vector space of the "hidden" dimension
    #to the number of prediction heads (used for time series later)
    self.dim_final = hidden * heads

  def forward (self, params_map, edge_idx_map, edge_attr_map=None):
    params_map_relu = {key : value.relu() for key, value in params_map.items()}

    for layer in self.convs:
      params_map_relu = layer(params_map, edge_idx_map, edge_attr_map)
      params_map_relu = {key : value.relu() for key, value in params_map_relu.items()}

    return params_map_relu

class time_series_pred (torch.nn.Module):
  def __init__  (self, dim_in, hidden, heads, layers, length):
    super().__init__()
    self.gat = hetero_attn(dim_in, hidden, heads, layers)
    dim_final = self.gat.dim_final

    self.seq = LSTM(input_size=dim_final, hidden_size=dim_final, batch_first=True)

    self.pred_head = Sequential(Linear(dim_final * 2, dim_final), ReLU(), Linear(dim_final, 1))

  def forward (self, data, seqs, topic_idx):
    gat_output_emb = self.gat(data.params_map, data.edge_idx_map, getattr(data, 'edge_attr_map', None))

    for_speeches = gat_output_emb['speech']
    for_topics = gat_output_emb['topic']

    for_seqs = for_speeches[seqs]

    out, (hidden_state, cell) = self.lstm(for_seqs)
    relevant = hidden_state.squeeze(0)

    topics = for_topics[topic_idx]

    concatenated = torch.cat([relevant, topics], dim=-1)
    return self.pred_head(concatenated).squeeze(-1)

In [None]:
def subgraph_up_to_time (data, idx):
  graph_so_far = copy.copy(data)

  graph_so_far['speech'].x = data['speech'].x[:idx]

  for rel in graph_so_far.edge_types: #in class we called these "relation types" where an edge-type was one component of a relation-type,
  #i am using the same convention for naming here
    src, et, dst = rel
    edge_idx = graph_so_far[rel].edge_index

    temporal_mask = torch.ones(edge_idx.size(1), dtype=torch.bool)
    if src == 'speech':
      temporal_mask &= (edge_idx[0] < idx)
    if dst == 'speech':
      temporal_mask &= (edge_idx[1] < idx)

    graph_so_far[rel].edge_index = edge_idx[:, temporal_mask]

    if 'edge_attr' in graph_so_far[rel]:
      graph_so_far[rel].edge_attr = graph_so_far[rel].edge_attr[temporal_mask]

  return graph_so_far

def sliding_window (speech_idx, length):
  seqs = []
  pred_targets = []

  for i in range(length, len(speech_idx)):
    seq = speech_idx[i-length : i]
    pred_target = speech_idx[i]
    seqs.append(seq)
    pred_targets.append(pred_target)

  seqs_out = torch.stack(seqs)
  pred_targets_out = torch.stack(pred_targets)

  return seqs_out, pred_targets_out