# Graph Classification Demo

Demo classifying the work status of day-long schedules using node lables (activity and zone).

In [1]:
from pathlib import Path

from torch import stack, nn
from torch_geometric.loader import DataLoader
from ntsx.data_loader import nx_to_torch_geo
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

from ntsx import core, ops
from ntsx import read_nts
from ntsx.encoders.trip_encoder import TripEncoder
from ntsx.encoders.table_encoder import TableTokeniser

In [2]:
# load dummy data (synthesised from UK NTS)

dir = Path("data/dummyNTS/")
trips_path = dir / "trips.tab"
attributes_path = dir / "individuals.tab"
hhs_path = dir / "households.tab"

years = [2021]

write_dir = Path("tmp")
write_dir.mkdir(exist_ok=True)

In [3]:
# load data from disk
trips, labels = read_nts.load_nts(trips_path, attributes_path, hhs_path, years=years)

# assign human readable values to the labels
labels = read_nts.label_mapping(labels)

# initaite the encoders
label_encoder = TableTokeniser(labels, verbose=False)
trip_encoder = TripEncoder(trips)

display(labels[["work_status"]].head())
display(trips.head())

HIDs in people and households do not match, attempting to fix...
Fixed: People 6 -> 7, HHs 5 -> 5


Unnamed: 0_level_0,work_status
iid,Unnamed: 1_level_1
1,unemployed
2,employed
3,unemployed
4,unemployed
5,employed


Unnamed: 0,tid,year,day,iid,hid,seq,mode,oact,dact,freq,tst,tet,ozone,dzone,did,pid
0,1,2021,2,1,1,1,car,home,social,0.989618,675,683,7,7,0,1_1
1,2,2021,2,1,1,2,car,social,other,1.002945,720,735,7,7,0,1_1
2,3,2021,2,1,1,3,car,other,social,0.989618,770,780,7,7,0,1_1
3,4,2021,2,1,1,4,taxi,social,home,0.989618,1110,1130,7,7,0,1_1
4,5,2021,3,1,1,1,car,home,social,0.999891,760,770,7,7,1,1_1


In [4]:
# first encode the trips table
trips_encoded = trip_encoder.encode_trips_table(trips)
print(f"Activity mapping: {trip_encoder.encoders["oact"].mapping}")

# then build a graph from the trips table, note that we only merge on home (1)
gs = core.to_nx(trips_encoded)
gs = ops.anchor_activities(gs, [2])
gs = ops.merge_similar(gs, duration_tolerance=0.2)

# now we can create a graph for each day
days = [g for _, g in ops.iter_days(gs, stop=None)]
print(f"Node Labels: {next(iter(days[0].nodes(data=True)))[1].keys()}")
print(f"Node label sizes: {trip_encoder.embed_sizes()}")
print(f"Number of days: {len(days)}")

# here we retrieve iids for the days
iids = [list(d.edges.data())[0][2]["iid"] for d in days]

# so that we can retrieve the work status labels from the labels table (also encoded)
work_status_tokens = label_encoder.encode_series(labels.work_status[iids])
print(f"Work status tokens: {label_encoder.embed_sizes()["work_status"]}")
print(f"Number of work status tokens: {len(work_status_tokens)}")

# now we can create a graph dataset
dataset = nx_to_torch_geo(days)

# add graph level labels manually to dataset -> yuck?
for data, label in zip(dataset, work_status_tokens):
    data.y = label
print(f"Number of graphs: {len(dataset)}")

# finally we can create a dataloader
loader = DataLoader(dataset, batch_size=16, shuffle=True)
for data in loader:
    print(data)

Activity mapping: {'education': 0, 'escort': 1, 'home': 2, 'hotel': 3, 'medical': 4, 'other': 5, 'shop': 6, 'social': 7, 'work': 8}
Node Labels: dict_keys(['act', 'location'])
Node label sizes: {'mode': 5, 'oact': 9, 'dact': 9, 'day': 39, 'tst': 1, 'tet': 1, 'ozone': 2, 'dzone': 2}
Number of days: 39
Work status tokens: 2
Number of work status tokens: 39
Number of graphs: 39
DataBatch(edge_index=[2, 71], x=[59, 2], edge_attr=[71, 6], y=[16], batch=[59], ptr=[17])
DataBatch(edge_index=[2, 60], x=[53, 2], edge_attr=[60, 6], y=[16], batch=[53], ptr=[17])
DataBatch(edge_index=[2, 18], x=[17, 2], edge_attr=[18, 6], y=[7], batch=[17], ptr=[8])


In [5]:
class MultiTokenEmbedSum(nn.Module):
    def __init__(self, label_embed_sizes: list[int], hidden_size: int = 32):
        """Embed tokens and add them together."""
        super(MultiTokenEmbedSum, self).__init__()
        self.embeds = nn.ModuleList(
            [nn.Embedding(s, hidden_size) for s in label_embed_sizes]
        )

    def forward(self, x):
        return stack(
            [embed(x[:, i]) for i, embed in enumerate(self.embeds)], dim=-1
        ).sum(dim=-1)


class BasicGraphLabeller(torch.nn.Module):
    def __init__(
        self, node_embed_sizes: list[int], target_size: int, hidden_size: int = 32
    ):
        """A simple GNN model for graph classification."""
        super().__init__()
        self.node_embed = MultiTokenEmbedSum(node_embed_sizes, hidden_size)
        self.conv1 = GCNConv(hidden_size, hidden_size)
        self.fc = nn.Linear(hidden_size, target_size)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.node_embed(x)
        x = F.relu(x)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, training=self.training)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)


# train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

node_embed_sizes = [
    trip_encoder.embed_sizes()["oact"],
    trip_encoder.embed_sizes()["ozone"],
]
target_size = label_encoder.embed_sizes()["work_status"]

model = BasicGraphLabeller(
    node_embed_sizes=node_embed_sizes, target_size=target_size, hidden_size=32
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-3)

model.train()
for epoch in range(10):
    for data in loader:
        data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y.long())
        preds = out.argmax(dim=1)
        correct = (preds == data.y).sum().item()
        acc = correct / len(data.y)
        print(f"Epoch {epoch}: Loss {loss.item():.4f}. Accuracy: {acc:.4f}")
        loss.backward()
        optimizer.step()

Epoch 0: Loss 0.8339. Accuracy: 0.4375
Epoch 0: Loss 0.8766. Accuracy: 0.3125
Epoch 0: Loss 0.6177. Accuracy: 0.7143
Epoch 1: Loss 0.7462. Accuracy: 0.6250
Epoch 1: Loss 0.4892. Accuracy: 0.7500
Epoch 1: Loss 0.6420. Accuracy: 0.7143
Epoch 2: Loss 0.6991. Accuracy: 0.6250
Epoch 2: Loss 0.5797. Accuracy: 0.6250
Epoch 2: Loss 0.6674. Accuracy: 0.2857
Epoch 3: Loss 0.4837. Accuracy: 0.8125
Epoch 3: Loss 0.6789. Accuracy: 0.6875
Epoch 3: Loss 0.5268. Accuracy: 0.7143
Epoch 4: Loss 0.5807. Accuracy: 0.7500
Epoch 4: Loss 0.5212. Accuracy: 0.7500
Epoch 4: Loss 0.7620. Accuracy: 0.4286
Epoch 5: Loss 0.6091. Accuracy: 0.6875
Epoch 5: Loss 0.6587. Accuracy: 0.5000
Epoch 5: Loss 0.7839. Accuracy: 0.4286
Epoch 6: Loss 0.6035. Accuracy: 0.6250
Epoch 6: Loss 0.4850. Accuracy: 0.8750
Epoch 6: Loss 0.6264. Accuracy: 0.5714
Epoch 7: Loss 0.3894. Accuracy: 0.8750
Epoch 7: Loss 0.5156. Accuracy: 0.8125
Epoch 7: Loss 0.6636. Accuracy: 0.5714
Epoch 8: Loss 0.5175. Accuracy: 0.7500
Epoch 8: Loss 0.4305. Acc