In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import egg.core as core
from options import Options
opts = Options()

from typing import Any, List, Optional, Sequence, Union

import torch.utils.data
from torch_geometric.data import Batch, Dataset, Data
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.data.on_disk_dataset import OnDiskDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
import torch
import random
from torch_geometric.data import Dataset
from graph.build import create_family_tree, create_data_object

class FamilyGraphDataset(Dataset):
    """
    Dataset class for generating family graph data.

    Args:
        root (str): Root directory path.
        number_of_graphs (int): Number of graphs to generate.
        generations (int): Number of generations in each family tree.

    Returns:
        Data(x=[8, 2], edge_index=[2, 20], edge_attr=[20], labels=[8])
    """
    def __init__(self, root, number_of_graphs, generations, transform=None, pre_transform=None):
        self.number_of_graphs = number_of_graphs
        self.generations = generations
        super(FamilyGraphDataset, self).__init__(root, transform, pre_transform)
        self.data = None
        self.process()

    @property
    def processed_file_names(self):
        return ['family_graphs.pt']

    def len(self):
        return len(self.data)

    def get(self, idx):
        return self.data[idx]
    
    def generate_labels(self, num_nodes):
        target_node_idx = random.randint(0, num_nodes - 1)
        return target_node_idx
    
    def generate_root(self, num_nodes, target_node_idx):
        node_indices = list(range(num_nodes))
        node_indices.remove(target_node_idx)
        root_idx = random.choice(node_indices)
        return root_idx
    
    def process(self):
        if not os.path.isfile(self.processed_paths[0]):
            self.data = []
            for _ in range(self.number_of_graphs):
                family_tree = create_family_tree(self.generations)
                graph_data = create_data_object(family_tree)

                # Generate random labels for each node
                target_node_idx = self.generate_labels(graph_data.num_nodes)

                # Store the labels as an attribute of the graph_data
                graph_data.target_node_idx = target_node_idx

                root_idx = self.generate_root(graph_data.num_nodes, target_node_idx)

                graph_data.root_idx = root_idx

                self.data.append(graph_data)

            torch.save(self.data, self.processed_paths[0])
        else:
            self.data = torch.load(self.processed_paths[0])

In [3]:
dataset = FamilyGraphDataset(root='/Users/meeslindeman/Library/Mobile Documents/com~apple~CloudDocs/Thesis/Code/data', number_of_graphs=5000, generations=2)
print(dataset[0])
print(dataset[1])

Processing...


Data(x=[4, 2], edge_index=[2, 10], edge_attr=[10, 3], target_node_idx=3, root_idx=1)
Data(x=[4, 2], edge_index=[2, 10], edge_attr=[10, 3], target_node_idx=2, root_idx=3)


Done!


In [4]:
from graph.dataset import FamilyGraphDataset
dataset = FamilyGraphDataset(root=f'data/gens={opts.generations}')

In [5]:
print(len(dataset))
print(dataset[0].num_nodes)

5000
4


In [6]:
from graph.draw import draw_graph

# plot = draw_graph(dataset[1])

In [7]:
total_nodes = 0
for i in range(0, len(dataset)):
    total_nodes += dataset[i].x.shape[0]  # shape[0] gives the number of nodes in each graph

average_nodes = total_nodes / len(dataset)
print("Average number of nodes:", average_nodes)

Average number of nodes: 4.0


In [8]:
print(dataset[0].edge_attr.shape)

torch.Size([10, 3])


In [9]:
graph = dataset[0]
print(*graph, sep="\n")

('x', tensor([[ 0., 92.],
        [ 1., 96.],
        [ 0., 69.],
        [ 0., 64.]]))
('edge_index', tensor([[0, 2, 0, 3, 0, 1, 2, 1, 3, 1],
        [1, 0, 2, 0, 3, 0, 1, 2, 1, 3]]))
('edge_attr', tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.]]))
('target_node_idx', 1)
('root_idx', 2)


In [10]:
from torch_geometric.nn import GATv2Conv

class GAT(torch.nn.Module):
    def __init__(self, num_node_features, embedding_size, heads):
        super().__init__()
        self.out_heads = 1

        self.conv1 = GATv2Conv(num_node_features, embedding_size, edge_dim=3, heads=heads, concat=True)
        self.conv2 = GATv2Conv(-1, embedding_size, edge_dim=3, heads=self.out_heads, concat=True)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        h = self.conv1(x=x, edge_index=edge_index, edge_attr=edge_attr)     
        h = F.leaky_relu(h)

        h = self.conv2(x=h, edge_index=edge_index, edge_attr=edge_attr)     

        return h
    
from torch_geometric.nn import TransformerConv

class Transform(torch.nn.Module):
    def __init__(self, num_node_features, embedding_size, heads):
        super().__init__()
        self.out_heads = 1

        self.conv1 = TransformerConv(num_node_features, embedding_size, edge_dim=3, heads=heads, concat=True)
        self.conv2 = TransformerConv(-1, embedding_size, edge_dim=3, heads=self.out_heads, concat=True)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        h = self.conv1(x=x, edge_index=edge_index, edge_attr=edge_attr)     
        h = F.leaky_relu(h)

        h = self.conv2(x=h, edge_index=edge_index, edge_attr=edge_attr)     

        return h

In [11]:
class SenderDual(nn.Module):
    def __init__(self, embedding_size, heads, hidden_size, temperature):
        super(SenderDual, self).__init__()
        self.num_node_features = dataset.num_node_features
        self.heads = heads
        self.hidden_size = hidden_size
        self.temp = temperature

        self.transform = Transform(self.num_node_features, embedding_size, heads)
        self.gat = GAT(self.num_node_features, embedding_size, heads) 
        self.fc = nn.Linear(embedding_size, hidden_size) 

    def forward(self, x, _aux_input):
        data = _aux_input

        batch_ptr = data.ptr
        target_node_idx = data.target_node_idx

        h_t = self.transform(data)
        h_g = self.gat(data)
        h = h_t + h_g

        adjusted_target_node_idx = target_node_idx + batch_ptr[:-1]

        target_embedding = h[adjusted_target_node_idx]       

        output = self.fc(target_embedding)                           

        return output.view(-1, self.hidden_size)

class ReceiverDual(nn.Module):
    def __init__(self, embedding_size, heads, hidden_size):
        super(ReceiverDual, self).__init__()
        self.num_node_features = dataset.num_node_features
        self.heads = heads
        self.hidden_size = hidden_size

        self.transform = Transform(self.num_node_features, embedding_size, heads)
        self.gat = GAT(self.num_node_features, embedding_size, heads) 
        self.fc = nn.Linear(hidden_size, embedding_size)

    def forward(self, message, _input, _aux_input):
        data = _aux_input

        h_t = self.transform(data)
        h_g = self.gat(data)
        h = h_t + h_g   

        # Reshape h for batched operation
        batch_size = data.num_graphs  # Assuming this attribute is available
        num_nodes_per_graph = data.num_nodes // batch_size  # Assuming equal number of nodes in each graph
        h = h.view(batch_size, num_nodes_per_graph, -1)

        message_embedding = self.fc(message)  
        message_embedding = message_embedding.view(batch_size, -1, 1)

        dot_products = torch.bmm(h, message_embedding).squeeze(-1)   

        probabilities = F.log_softmax(dot_products, dim=1)                      

        return probabilities

In [12]:
class SenderGAT(nn.Module):
    def __init__(self, embedding_size, heads, hidden_size, temperature):
        super(SenderGAT, self).__init__()
        self.num_node_features = dataset.num_node_features
        self.heads = heads
        self.hidden_size = hidden_size
        self.temp = temperature

        self.gat = GAT(self.num_node_features, embedding_size, heads) 
        self.fc = nn.Linear(embedding_size, hidden_size) 

    def forward(self, x, _aux_input):
        data = _aux_input

        target_node_idx = data.target_node_idx

        h = self.gat(data)

        target_embedding = h[target_node_idx]           

        output = self.fc(target_embedding)                           

        return output.view(-1, self.hidden_size)

class ReceiverGAT(nn.Module):
    def __init__(self, embedding_size, heads, hidden_size):
        super(ReceiverGAT, self).__init__()
        self.num_node_features = dataset.num_node_features
        self.heads = heads
        self.hidden_size = hidden_size

        self.gat = GAT(self.num_node_features, embedding_size, heads)
        self.fc = nn.Linear(hidden_size, embedding_size)

    def forward(self, message, _input, _aux_input):
        data = _aux_input

        h = self.gat(data)   

        message_embedding = self.fc(message)                 

        dot_products = torch.matmul(h, message_embedding.t()).t()   

        probabilities = F.log_softmax(dot_products, dim=1)                      

        return probabilities

In [13]:
class SenderTransform(nn.Module):
    def __init__(self, embedding_size, heads, hidden_size, temperature):
        super(SenderTransform, self).__init__()
        self.num_node_features = dataset.num_node_features
        self.heads = heads
        self.hidden_size = hidden_size
        self.temp = temperature
          
        self.transform = Transform(self.num_node_features, embedding_size, heads) 
        self.fc = nn.Linear(embedding_size, hidden_size) 

    def forward(self, x, _aux_input):
        data = _aux_input

        target_node_idx = data.target_node_idx

        h = self.transform(data)

        target_embedding = h[target_node_idx]           

        output = self.fc(target_embedding)                           

        return output.view(-1, self.hidden_size)

class ReceiverTransform(nn.Module):
    def __init__(self, embedding_size, heads, hidden_size):
        super(ReceiverTransform, self).__init__()
        self.num_node_features = dataset.num_node_features
        self.heads = heads
        
        self.transform = Transform(self.num_node_features, embedding_size, heads)
        self.fc = nn.Linear(hidden_size, embedding_size)

    def forward(self, message, _input, _aux_input):
        data = _aux_input

        h = self.transform(data)   

        message_embedding = self.fc(message)                 

        dot_products = torch.matmul(h, message_embedding.t()).t()   

        probabilities = F.log_softmax(dot_products, dim=1)                      

        return probabilities

In [14]:
class SenderRel(nn.Module):
    def __init__(self, embedding_size, heads, hidden_size, temperature):
        super(SenderRel, self).__init__()
        self.num_node_features = 4
        self.heads = heads
        self.hidden_size = hidden_size
        self.temp = temperature

        self.transform = Transform(self.num_node_features, embedding_size, heads)
        self.gat = GAT(self.num_node_features, embedding_size, heads) 
        self.fc = nn.Linear(embedding_size * 2, hidden_size) 

    def forward(self, x, _aux_input):
        data = _aux_input

        target_node_idx, root_idx = data.target_node_idx, data.root_idx

        h_t = self.transform(data)

        h_g = self.gat(data)

        h = h_t + h_g                       

        target = h[target_node_idx].squeeze()
        root = h[root_idx].squeeze()

        target_embedding = torch.cat((target, root))

        output = self.fc(target_embedding)                           

        return output.view(-1, self.hidden_size)

class ReceiverRel(nn.Module):
    def __init__(self, embedding_size, heads, hidden_size):
        super(ReceiverRel, self).__init__()
        self.num_node_features = 4
        self.heads = heads
        self.hidden_size = hidden_size

        self.transform = Transform(self.num_node_features, embedding_size, heads)
        self.gat = GAT(self.num_node_features, embedding_size, heads) 
        self.fc = nn.Linear(hidden_size, embedding_size)

    def forward(self, message, _input, _aux_input):
        data = _aux_input

        h_t = self.transform(data)

        h_g = self.gat(data)

        h = h_t + h_g   

        message_embedding = self.fc(message)        

        dot_products = torch.matmul(h, message_embedding.t()).t()   

        probabilities = F.log_softmax(dot_products, dim=1)                      

        return probabilities

## Ouput for a single graph

In [15]:
agents = "dual"

In [16]:
if agents == "dual":
    sender = SenderDual(embedding_size=opts.embedding_size, heads=opts.heads, hidden_size=opts.hidden_size, temperature=opts.gs_tau) 
    receiver = ReceiverDual(embedding_size=opts.embedding_size, heads=opts.heads, hidden_size=opts.hidden_size)
elif agents == "transform":
    sender = SenderTransform(embedding_size=opts.embedding_size, heads=opts.heads, hidden_size=opts.hidden_size, temperature=opts.gs_tau) 
    receiver = ReceiverTransform(embedding_size=opts.embedding_size, heads=opts.heads, hidden_size=opts.hidden_size) 
elif agents == "gat":
    sender = SenderGAT(embedding_size=opts.embedding_size, heads=opts.heads, hidden_size=opts.hidden_size, temperature=opts.gs_tau) 
    receiver = ReceiverGAT(embedding_size=opts.embedding_size, heads=opts.heads, hidden_size=opts.hidden_size) 
elif agents == "rel":
    sender = SenderRel(embedding_size=opts.embedding_size, heads=opts.heads, hidden_size=opts.hidden_size, temperature=opts.gs_tau) 
    receiver = ReceiverRel(embedding_size=opts.embedding_size, heads=opts.heads, hidden_size=opts.hidden_size) 
else:
    print("Invalid agent type")

In [17]:
def process_graph_to_sequence(graph):
    # Vocabulary
    vocab = {'(': 0, ')': 1, 'male': 2, 'female': 3, 'married-to': 4, 'child-of': 5, 'gave-birth-to': 6}

    # Function to get node information
    def get_node_info(graph):
        node_info = {}
        relationship_types = {tuple([1., 0., 0.]): 'married-to', tuple([0., 1., 0.]): 'child-of', tuple([0., 0., 1.]): 'gave-birth-to'}

        for index, features_tensor in enumerate(graph.x):
            features = features_tensor.tolist()
            gender = 'male' if features[0] == 0 else 'female'
            features_dict = {'gender': gender, 'features': features}

            relationships = []
            for i in range(graph.edge_index.shape[1]):
                if graph.edge_index[0, i] == index:
                    target_node = graph.edge_index[1, i]
                    rel_type = relationship_types[tuple(graph.edge_attr[i].tolist())]
                    relationships.append({'node': target_node.item(), 'relationship': rel_type})

            features_dict['relationships'] = relationships
            node_info[index] = features_dict

        return node_info

    # Function to sort tree
    def sort_tree(node_data, start_node):
        visited = {start_node}

        def get_sorted_children(node_index):
            relationships = [r for r in node_data[node_index]['relationships'] if r['node'] not in visited]
            married_child = [child['node'] for child in relationships if child['relationship'] == 'married-to']
            other_children = sorted([child['node'] for child in relationships if child['node'] not in married_child], 
                                    key=lambda x: node_data[x]['gender'] == 'male')
            visited.update(married_child + other_children)
            return married_child + other_children

        def build_tree(node_index):
            children_indices = get_sorted_children(node_index)
            children = [build_tree(child_index) for child_index in children_indices]
            return {"index": node_index, "children": children}

        return build_tree(start_node)

    # Function to build sequence
    def build_sequence(tree, node_data, vocab):
        def node_sequence(node):
            sequence = [vocab[node_data[node['index']]['gender']]]
            for child in node['children']:
                relationship = next(r['relationship'] for r in node_data[node['index']]['relationships'] if r['node'] == child['index'])
                sequence.append(vocab[relationship])

                if child['children']:
                    sequence.append(vocab['('])  # Opening parenthesis
                    sequence.extend(node_sequence(child))
                    sequence.append(vocab[')'])  # Closing parenthesis
                else:
                    sequence.extend(node_sequence(child))

            return sequence

        sequence = node_sequence(tree)
        return torch.tensor([sequence], dtype=torch.float32)

    # Process graph to sequence
    node_info = get_node_info(graph)
    tree = sort_tree(node_info, graph.target_node_idx)
    sequence_tensor = build_sequence(tree, node_info, vocab)

    return sequence_tensor

# Example usage (assuming graph and graph.target_node_idx are defined)
sequence_tensor = process_graph_to_sequence(graph)
print(sequence_tensor)


tensor([[3., 4., 2., 6., 2., 6., 2.]])


## Batching

In [27]:
class Collater:
    def __init__(
        self,
        game_size: int,  # the number of graphs for a game
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
    ):
        self.game_size = game_size
        self.dataset = dataset
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def __call__(self, batch: List[Any]) -> Any:
        elem = batch[0]
        if isinstance(elem, BaseData):
            batch = batch[:((len(batch) // self.game_size) * self.game_size)]  # we throw away the last batch_size % game_size
            batch = Batch.from_data_list(
                batch,
                follow_batch=self.follow_batch,
                exclude_keys=self.exclude_keys,
            )
            # we return a tuple (sender_input, labels, receiver_input, aux_input)
            # we use aux_input to store minibatch of graphs
            return (
                batch.x.view(1, -1),  # we don't need sender_input --> create a fake one
                batch.target_node_idx,  # the target is aways the first graph among game_size graphs
                None,  # we don't care about receiver_input
                batch  # this is a compact data for batch_size graphs 
            )

        raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")

    def collate_fn(self, batch: List[Any]) -> Any:
        if isinstance(self.dataset, OnDiskDataset):
            return self(self.dataset.multi_get(batch))
        return self(batch)


class DataLoader(torch.utils.data.DataLoader):
    def __init__(
        self,
        game_size: int,  # the number of graphs for a game
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        batch_size: int = 1,
        shuffle: bool = False,
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        **kwargs,
    ):
        self.game_size = game_size
        # Remove for PyTorch Lightning:
        kwargs.pop('collate_fn', None)

        # Save for PyTorch Lightning < 1.6:
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        self.collator = Collater(game_size, dataset, follow_batch, exclude_keys)

        if isinstance(dataset, OnDiskDataset):
            dataset = range(len(dataset))

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=self.collator.collate_fn,
            **kwargs,
        )

In [28]:
from sklearn.model_selection import train_test_split

# Split the dataset into training and validation sets
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)

# Print the lengths of the training and validation sets
print("Training set length:", len(train_data))
print("Validation set length:", len(val_data))

Training set length: 4000
Validation set length: 1000


In [29]:
train_loader = DataLoader(game_size=1, dataset=train_data, batch_size=50, shuffle=True)
val_loader = DataLoader(game_size=1, dataset=val_data, batch_size=50, shuffle=True)

## Outputs for batch

In [41]:
batch = next(iter(train_loader))
print(batch[3])

DataBatch(x=[200, 2], edge_index=[2, 500], edge_attr=[500, 3], target_node_idx=[50], root_idx=[50], batch=[200], ptr=[51])


In [42]:
graphs = batch[3].to_data_list()
for graph in range(0, len(graphs)):
    sequence_tensor = process_graph_to_sequence(graph)
    print(sequence_tensor)

AttributeError: 'int' object has no attribute 'x'

In [31]:
print(batch[3])
print(batch[3].target_node_idx)

# Sender produces a message
sender_output = sender(None, batch[3])
# print("Sender's message:", sender_output)
print("Sender's shape: ", sender_output.shape)

# Receiver tries to identify the target node
receiver_output = receiver(sender_output, None, batch[3])
# print("Receiver's output:", receiver_output)
print("Receiver's shape: ", receiver_output.shape)

DataBatch(x=[200, 2], edge_index=[2, 500], edge_attr=[500, 3], target_node_idx=[50], root_idx=[50], batch=[200], ptr=[51])
tensor([3, 1, 0, 2, 0, 1, 3, 3, 0, 1, 3, 2, 3, 3, 3, 3, 3, 3, 3, 2, 1, 3, 3, 0,
        2, 3, 2, 1, 3, 0, 1, 2, 2, 1, 0, 1, 0, 1, 2, 2, 2, 3, 2, 2, 0, 0, 2, 2,
        0, 1])
Sender's shape:  torch.Size([50, 20])
Receiver's shape:  torch.Size([50, 4])


In [32]:
sender_gs = core.RnnSenderGS(sender, opts.vocab_size, opts.embedding_size, opts.hidden_size, max_len=opts.max_len, temperature=opts.gs_tau, cell=opts.sender_cell)
receiver_gs = core.RnnReceiverGS(receiver, opts.vocab_size, opts.embedding_size, opts.hidden_size, cell=opts.sender_cell)

# Sender produces a message
sender_output = sender_gs(None, batch[3])
# print("Sender's message:", sender_output)
print("Sender's shape:", sender_output.shape) # batch size x max_len+1 x vocab size

# Receiver tries to identify the target node
receiver_output = receiver_gs(sender_output, None, batch[3])
# print("Receiver's output:", receiver_output)
print("Receiver's output shape:", receiver_output.shape)

Sender's shape: torch.Size([50, 5, 100])
Receiver's output shape: torch.Size([50, 5, 4])


In [34]:
def loss_nll(
    _sender_input, _message, _receiver_input, receiver_output, labels, _aux_input):
    """
    NLL loss - differentiable and can be used with both GS and Reinforce
    """
    # print(receiver_output)
    # print(labels)
    nll = F.nll_loss(receiver_output, labels, reduction="none")
    acc = (labels == receiver_output.argmax(dim=1)).float().mean()
    return nll, {"acc": acc}

In [37]:
print(batch[0].shape)

torch.Size([1, 400])


In [35]:
game = core.SenderReceiverRnnGS(sender_gs, receiver_gs, loss_nll)

core.init(params=['--random_seed=7', 
                '--lr=1e-2',   
                f'--batch_size={opts.batch_size}',
                '--optimizer=adam'])

optimizer = torch.optim.Adam(game.parameters())

topographic_similarity = core.TopographicSimilarity(
    sender_input_distance_fn="cosine", 
    message_distance_fn="euclidean", 
    compute_topsim_train_set=True, 
    compute_topsim_test_set=True, 
    is_gumbel=True
)

trainer = core.Trainer(
    game=game, 
    optimizer=optimizer, 
    train_data=train_loader,
    validation_data=val_loader, 
    callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True), topographic_similarity]
)

trainer.train(n_epochs=50)

{"loss": 1.3164544105529785, "acc": 0.41999974846839905, "length": 4.999938488006592, "mode": "train", "epoch": 1}


ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 3160 and the array at index 1 has size 7998000