In [96]:
import torch
path = r"C:\Users\MainUser\project\cs224w_cb_graph\info_folder\fed_speech_graph.pt"
data = torch.load(path, map_location="cpu", weights_only=False)

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.nn import HeteroConv, SAGEConv, Linear

class HeteroSpeechRegressor(nn.Module):
    def __init__(self, metadata, hidden_dim=128):
        super().__init__()
        edge_types = metadata[1]

        self.conv1 = HeteroConv({et: SAGEConv((-1, -1), hidden_dim)
                                 for et in edge_types}, aggr='sum')
        self.conv2 = HeteroConv({et: SAGEConv((-1, -1), hidden_dim)
                                 for et in edge_types}, aggr='sum')
        self.head = Linear(hidden_dim, 1)

    def _apply_layer(self, layer, x_dict, edge_index_dict):
        out = layer(x_dict, edge_index_dict)
        if isinstance(out, tuple):
            out = out[0]
        for nt, x in x_dict.items():
            if nt not in out:
                out[nt] = x
        out = {k: F.relu(v) for k, v in out.items()}
        return out

    def forward(self, x_dict, edge_index_dict):
        h = self._apply_layer(self.conv1, x_dict, edge_index_dict)
        h = self._apply_layer(self.conv2, h, edge_index_dict)
        return self.head(h['speech']).squeeze(-1)



In [98]:
mask_finite = torch.isfinite(data['speech'].y)
data['speech'].y = data['speech'].y[mask_finite]
data['speech'].x = data['speech'].x[mask_finite]
data['speech'].date = data['speech'].date[mask_finite]

In [99]:
print("data['speech'].num_nodes:", data['speech'].num_nodes)
print("len(speech_dates):", len(data['speech'].date))
print("len(data['speech'].y):", len(data['speech'].y))


data['speech'].num_nodes: 53
len(speech_dates): 53
len(data['speech'].y): 53


In [100]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HeteroSpeechRegressor(data.metadata(), hidden_dim=128).to(device)
data = data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = nn.MSELoss()


  self.conv1 = HeteroConv({et: SAGEConv((-1, -1), hidden_dim)
  self.conv2 = HeteroConv({et: SAGEConv((-1, -1), hidden_dim)


In [101]:
from torch_geometric.transforms import ToUndirected
data = ToUndirected()(data)
for nt in data.node_types:
    if not hasattr(data[nt], 'x'):
        N = data[nt].num_nodes
        data[nt].x = torch.zeros((N, 1), dtype=torch.float)
    if not hasattr(data[nt], 'num_nodes') or data[nt].num_nodes is None:
        data[nt].num_nodes = data[nt].x.size(0)


In [102]:
import numpy as np
import torch

N = data['speech'].num_nodes
idx_sorted = np.argsort(data["speech"].date)

train_end = int(0.7 * N)
val_end   = int(0.85 * N)

train_idx = torch.tensor(idx_sorted[:train_end], dtype=torch.long)
val_idx   = torch.tensor(idx_sorted[train_end:val_end], dtype=torch.long)
test_idx  = torch.tensor(idx_sorted[val_end:], dtype=torch.long)

train_mask = torch.zeros(N, dtype=torch.bool)
val_mask   = torch.zeros(N, dtype=torch.bool)
test_mask  = torch.zeros(N, dtype=torch.bool)

train_mask[train_idx] = True
val_mask[val_idx]     = True
test_mask[test_idx]   = True

data['speech'].train_mask = train_mask
data['speech'].val_mask   = val_mask
data['speech'].test_mask  = test_mask


In [104]:
for epoch in range(200):

    model.train()
    optimizer.zero_grad()

    pred = model(data.x_dict, data.edge_index_dict)
    loss = loss_fn(pred[data['speech'].train_mask], data['speech'].y[data['speech'].train_mask])
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        model.eval()
        with torch.no_grad():
            val_pred = model(data.x_dict, data.edge_index_dict)
            val_loss = loss_fn(val_pred[data['speech'].val_mask], data['speech'].y[data['speech'].val_mask]).item()

        print(f"Epoch {epoch:03d} | Train Loss {loss.item():.4f} | Val Loss {val_loss:.4f}")


Epoch 000 | Train Loss 0.0084 | Val Loss 0.0098
Epoch 020 | Train Loss 0.0004 | Val Loss 0.0027
Epoch 040 | Train Loss 0.0000 | Val Loss 0.0030
Epoch 060 | Train Loss 0.0000 | Val Loss 0.0030
Epoch 080 | Train Loss 0.0000 | Val Loss 0.0031
Epoch 100 | Train Loss 0.0000 | Val Loss 0.0031
Epoch 120 | Train Loss 0.0000 | Val Loss 0.0029
Epoch 140 | Train Loss 0.0000 | Val Loss 0.0032
Epoch 160 | Train Loss 0.0000 | Val Loss 0.0030
Epoch 180 | Train Loss 0.0000 | Val Loss 0.0031


In [105]:
model.eval()
with torch.no_grad():
    pred = model(data.x_dict, data.edge_index_dict)
    test_loss = loss_fn(pred[data['speech'].test_mask], data['speech'].y[data['speech'].test_mask]).item()
print(f"Test MSE: {test_loss:.6f}")


Test MSE: 0.002906
