# Graph Classification Demo

Demo classifying the work status of day-long schedules using node lables (activity and zone).

In [1]:
from pathlib import Path

from torch import stack, cat, nn
from torch_geometric.loader import DataLoader
from ntsx.nx_to_torch import nx_to_torch_geo
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool, GCNConv

from ntsx import graph_ops, nts_to_nx
from ntsx import read_nts
from ntsx.encoders.trip_encoder import TripEncoder
from ntsx.encoders.table_encoder import TableTokeniser

In [2]:
# load dummy data (synthesised from UK NTS)

dir = Path("data/dummyNTS/")
trips_path = dir / "trips.tab"
attributes_path = dir / "individuals.tab"
hhs_path = dir / "households.tab"

years = [2021]

write_dir = Path("tmp")
write_dir.mkdir(exist_ok=True)

In [3]:
# load data from disk
trips, labels = read_nts.load_nts(trips_path, attributes_path, hhs_path, years=years)

# assign human readable values to the labels
labels = read_nts.label_mapping(labels)

# initaite the encoders
label_encoder = TableTokeniser(labels, verbose=False)
trip_encoder = TripEncoder(trips)

display(labels[["work_status"]].head())
display(trips.head())

HIDs in people and households do not match, attempting to fix...
Fixed: People 6 -> 7, HHs 5 -> 5


Unnamed: 0_level_0,work_status
iid,Unnamed: 1_level_1
1,unemployed
2,employed
3,unemployed
4,unemployed
5,employed


Unnamed: 0,tid,year,day,iid,hid,seq,mode,oact,dact,freq,tst,tet,ozone,dzone,did,pid
0,1,2021,2,1,1,1,car,home,social,0.989618,675,683,7,7,0,1_1
1,2,2021,2,1,1,2,car,social,other,1.002945,720,735,7,7,0,1_1
2,3,2021,2,1,1,3,car,other,social,0.989618,770,780,7,7,0,1_1
3,4,2021,2,1,1,4,taxi,social,home,0.989618,1110,1130,7,7,0,1_1
4,5,2021,3,1,1,1,car,home,social,0.999891,760,770,7,7,1,1_1


In [4]:
# first encode the trips and lables tables
trips_encoded = trip_encoder.encode_trips_table(trips)
print(f"Activity mapping: {trip_encoder.encoders["oact"].mapping}")

labels_encoded = label_encoder.encode_table(labels)

Activity mapping: {'education': 0, 'escort': 1, 'home': 2, 'hotel': 3, 'medical': 4, 'other': 5, 'shop': 6, 'social': 7, 'work': 8}


In [5]:
# then build individuals and then days graphs from the trips table, note that we only merge on home (2)
individuals = nts_to_nx.to_individuals_nx(trips_encoded, attribute_data=labels_encoded)
days = []
for ind in individuals:
    g = graph_ops.anchor_activities(ind, [2])
    g = graph_ops.merge_similar(g, duration_tolerance=0.2)

    # now we can create a graph for each day
    indiv_days = [d for _, d in graph_ops.iter_days(g, stop=None)]
    days.extend(indiv_days)

# now we can create a graph dataset
dataset = nx_to_torch_geo(days)

# finally we can create a dataloader
loader = DataLoader(dataset, batch_size=16, shuffle=True)
for data in loader:
    print(data)

DataBatch(edge_index=[2, 59], act=[53], location=[53], duration=[59], day=[16], tst=[59], tet=[59], travel=[59], iid=[16], age=[16], gender=[16], ethnicity=[16], education=[16], license=[16], car_access=[16], work_status=[16], year=[16], area=[16], income=[16], hh_size=[16], hh_composition=[16], hh_children=[16], hh_cars=[16], hh_bikes=[16], hh_motorcycles=[16], num_nodes=53, batch=[53], ptr=[17])
DataBatch(edge_index=[2, 68], act=[62], location=[62], duration=[68], day=[16], tst=[68], tet=[68], travel=[68], iid=[16], age=[16], gender=[16], ethnicity=[16], education=[16], license=[16], car_access=[16], work_status=[16], year=[16], area=[16], income=[16], hh_size=[16], hh_composition=[16], hh_children=[16], hh_cars=[16], hh_bikes=[16], hh_motorcycles=[16], num_nodes=62, batch=[62], ptr=[17])
DataBatch(edge_index=[2, 22], act=[20], location=[20], duration=[22], day=[7], tst=[22], tet=[22], travel=[22], iid=[7], age=[7], gender=[7], ethnicity=[7], education=[7], license=[7], car_access=[7

In [None]:
class MultiTokenEmbedSum(nn.Module):
    def __init__(self, label_embed_sizes: list[int], hidden_size: int = 32):
        """Embed tokens and add them together."""
        super(MultiTokenEmbedSum, self).__init__()
        self.embeds = nn.ModuleList(
            [nn.Embedding(s, hidden_size) for s in label_embed_sizes]
        )

    def forward(self, x):
        return stack([embed(x[i]) for i, embed in enumerate(self.embeds)], dim=-1).sum(
            dim=-1
        )


class GCNGraphLabeller(torch.nn.Module):
    def __init__(
        self, node_embed_sizes: list[int], target_size: int, hidden_size: int = 32
    ):
        """A simple GNN model for graph classification."""
        super().__init__()
        self.node_embed = MultiTokenEmbedSum(node_embed_sizes, hidden_size)
        self.conv1 = GCNConv(hidden_size, hidden_size)
        self.fc = nn.Linear(hidden_size, target_size)

    def forward(self, data):
        x = [data.act, data.location]
        edge_index, batch = data.edge_index, data.batch
        x = self.node_embed(x)
        x = F.relu(x)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, training=self.training)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)


# train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

node_embed_sizes = [
    trip_encoder.embed_sizes()["oact"],
    trip_encoder.embed_sizes()["ozone"],
]
target_size = label_encoder.embed_sizes()["work_status"]

model = GCNGraphLabeller(
    node_embed_sizes=node_embed_sizes, target_size=target_size, hidden_size=32
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-3)

model.train()
for epoch in range(10):
    for data in loader:
        data.to(device)
        optimizer.zero_grad()
        out = model(data)
        y = data.work_status
        loss = F.nll_loss(out, y)
        preds = out.argmax(dim=1)
        correct = (preds == y).sum().item()
        acc = correct / len(y)
        print(f"Epoch {epoch}: Loss {loss.item():.4f}. Accuracy: {acc:.4f}")
        loss.backward()
        optimizer.step()

Epoch 0: Loss 0.6982. Accuracy: 0.4375
Epoch 0: Loss 0.6808. Accuracy: 0.4375
Epoch 0: Loss 0.7134. Accuracy: 0.4286
Epoch 1: Loss 0.5922. Accuracy: 0.6250
Epoch 1: Loss 0.6982. Accuracy: 0.6250
Epoch 1: Loss 0.6706. Accuracy: 0.4286
Epoch 2: Loss 0.4851. Accuracy: 0.7500
Epoch 2: Loss 0.6352. Accuracy: 0.6250
Epoch 2: Loss 0.6362. Accuracy: 0.5714
Epoch 3: Loss 0.6249. Accuracy: 0.7500
Epoch 3: Loss 0.4433. Accuracy: 0.8750
Epoch 3: Loss 0.7449. Accuracy: 0.4286
Epoch 4: Loss 0.5545. Accuracy: 0.7500
Epoch 4: Loss 0.4544. Accuracy: 0.8750
Epoch 4: Loss 0.5864. Accuracy: 0.5714
Epoch 5: Loss 0.5060. Accuracy: 0.8125
Epoch 5: Loss 0.5349. Accuracy: 0.6875
Epoch 5: Loss 0.5253. Accuracy: 0.7143
Epoch 6: Loss 0.5190. Accuracy: 0.8125
Epoch 6: Loss 0.5193. Accuracy: 0.7500
Epoch 6: Loss 0.4131. Accuracy: 1.0000
Epoch 7: Loss 0.4998. Accuracy: 0.6875
Epoch 7: Loss 0.4739. Accuracy: 0.7500
Epoch 7: Loss 0.4029. Accuracy: 0.8571
Epoch 8: Loss 0.6202. Accuracy: 0.7500
Epoch 8: Loss 0.4100. Acc

In [7]:
class GATGraphLabeller(torch.nn.Module):
    def __init__(
        self,
        node_embed_sizes: list[int],
        edge_embed_sizes: list[int],
        target_size: int,
        hidden_size: int = 32,
    ):
        """A simple GAT model for graph classification with edge and node attributes."""
        super().__init__()
        self.node_embed = MultiTokenEmbedSum(node_embed_sizes, hidden_size)
        self.edge_embed = MultiTokenEmbedSum(edge_embed_sizes, hidden_size)
        self.conv1 = GATConv(hidden_size, hidden_size)
        self.fc = nn.Linear(hidden_size, target_size)

    def forward(self, data):
        x = [data.act, data.location]
        x = self.node_embed(x)

        x_edge_cont = stack([data.duration, data.tst, data.tet], dim=1)
        x_edge_cat = [data.travel]
        x_edge_cat = self.edge_embed(x_edge_cat)
        x_edge = cat([x_edge_cat, x_edge_cont], dim=-1)

        edge_index, batch = data.edge_index, data.batch

        x = F.relu(x)
        x = self.conv1(x, edge_index, x_edge)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, training=self.training)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)


# train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

node_embed_sizes = [
    trip_encoder.embed_sizes()["oact"],
    trip_encoder.embed_sizes()["ozone"],
]
edge_embed_sizes = [
    trip_encoder.embed_sizes()["mode"],
]

target_size = label_encoder.embed_sizes()["work_status"]

model = GATGraphLabeller(
    node_embed_sizes=node_embed_sizes,
    edge_embed_sizes=edge_embed_sizes,
    target_size=target_size,
    hidden_size=32,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-3)

model.train()
for epoch in range(10):
    for data in loader:
        data.to(device)
        optimizer.zero_grad()
        out = model(data)
        y = data.work_status
        loss = F.nll_loss(out, y)
        preds = out.argmax(dim=1)
        correct = (preds == y).sum().item()
        acc = correct / len(y)
        print(f"Epoch {epoch}: Loss {loss.item():.4f}. Accuracy: {acc:.4f}")
        loss.backward()
        optimizer.step()

Epoch 0: Loss 0.7630. Accuracy: 0.3125
Epoch 0: Loss 0.7557. Accuracy: 0.4375
Epoch 0: Loss 0.9625. Accuracy: 0.4286
Epoch 1: Loss 0.6378. Accuracy: 0.6250
Epoch 1: Loss 0.6666. Accuracy: 0.6875
Epoch 1: Loss 0.7317. Accuracy: 0.4286
Epoch 2: Loss 0.5808. Accuracy: 0.6875
Epoch 2: Loss 0.6903. Accuracy: 0.6875
Epoch 2: Loss 0.7519. Accuracy: 0.5714
Epoch 3: Loss 0.6767. Accuracy: 0.5625
Epoch 3: Loss 0.5379. Accuracy: 0.8125
Epoch 3: Loss 0.5676. Accuracy: 0.7143
Epoch 4: Loss 0.5060. Accuracy: 0.8125
Epoch 4: Loss 0.5918. Accuracy: 0.6250
Epoch 4: Loss 0.5716. Accuracy: 0.7143
Epoch 5: Loss 0.5680. Accuracy: 0.7500
Epoch 5: Loss 0.5130. Accuracy: 0.6875
Epoch 5: Loss 0.5985. Accuracy: 0.7143
Epoch 6: Loss 0.4837. Accuracy: 0.7500
Epoch 6: Loss 0.5133. Accuracy: 0.6250
Epoch 6: Loss 0.3809. Accuracy: 0.7143
Epoch 7: Loss 0.6088. Accuracy: 0.6875
Epoch 7: Loss 0.5137. Accuracy: 0.7500
Epoch 7: Loss 0.2472. Accuracy: 0.8571
Epoch 8: Loss 0.3487. Accuracy: 0.8125
Epoch 8: Loss 0.4632. Acc