# Version 2 (sparse matrices...)
### I created the last notebook naively and completely missed that the discusses method operates on sparse matrices, so I'm starting again from scratch

# Generating a sparse matrix representation of a graph...

In [46]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_nodes = 100
average_out_degree = 9
total_edge_classes = 36
embedding_size = 128
output_dim = 4



# First we generate some arbitrary graph
def generate_graph(num_nodes, average_out_degree, total_edge_classes):
    num_edges = average_out_degree*num_nodes
    indices = torch.randint(0, num_nodes, (2, num_edges))
    print(indices.size())
    edges = torch.ones(num_edges)
    print(edges.size())
    A = torch.sparse_coo_tensor(indices, edges, (num_nodes, num_nodes))
    return A
# A = generate_graph(15,3,5).to_dense()

# Now maybe it's useful to ensure it's connected...
def generate_connected_graph(num_nodes, average_out_degree, total_edge_classes, embedding_size):
    edges_set = set()
    num_edges = average_out_degree * num_nodes

    def sample_random_edge_class():
        return torch.randint(0, total_edge_classes, (1,)).item()
    
    def generate_random_edge():
        u, v = torch.randint(0, num_nodes, (2,)).tolist()
        #no self loops
        while u == v:
            u, v = torch.randint(0, num_nodes, (2,)).tolist()
        return (u, v)
    
    def generate_random_node():
        return torch.randint(0, num_nodes, (1,)).item()
    
    def sample_random_node(A):
        connected_nodes = A._indices()[1].tolist() #select from incoming to...
        if not connected_nodes:
            return generate_random_node()
        return connected_nodes[torch.randint(0, len(connected_nodes), (1,)).item()]#torch.nonzero(A)
    
    initial_edge = generate_random_edge()
    edges_set.add(initial_edge)
    indices = torch.tensor([[initial_edge[0]], [initial_edge[1]]], dtype=torch.long, device=device)
    edges = torch.ones(1, device=device)
    A = torch.sparse_coo_tensor(indices, edges, (num_nodes, num_nodes), device=device)
    A_list = []
    for edge_class in range(total_edge_classes):
        for _ in range(num_edges - 1):
            u = sample_random_node(A)
            v = generate_random_node()
            while u == v:
                v = generate_random_node()
            new_edge = (u, v)#torch.cat([u, v], dim=0).unsqueeze(1)
            if new_edge in edges_set:
                continue
            edges_set.add(new_edge)
            edge_tensor = torch.tensor([[u], [v]], dtype=torch.long, device=device)
            indices = torch.cat([indices, edge_tensor], dim=1)
            edges = torch.ones(indices.size(1), device=device)
            A = torch.sparse_coo_tensor(indices, edges, (num_nodes, num_nodes), device=device)
            # indices = torch.stack((indices,torch.cat([u,v]).unsqueeze(1)),dim=0)
        # print(edges_set)
        A_list.append(A)
    random_labels = torch.rand((num_nodes, embedding_size), device=device)
    return A_list, random_labels 

A, X = generate_connected_graph(num_nodes,average_out_degree,total_edge_classes, embedding_size)

# Now per edge type? or randomly assign edges...? maybe multinomial_sample(1/k)^k?

In [24]:
X

tensor([[0.6704, 0.5895, 0.8493,  ..., 0.2298, 0.1610, 0.7519],
        [0.9928, 0.9682, 0.4486,  ..., 0.8425, 0.8177, 0.9624],
        [0.1649, 0.8949, 0.3657,  ..., 0.5884, 0.1401, 0.0713],
        ...,
        [0.4362, 0.5824, 0.0432,  ..., 0.9224, 0.6386, 0.7155],
        [0.6440, 0.7035, 0.2164,  ..., 0.4567, 0.7283, 0.8311],
        [0.2540, 0.7553, 0.0965,  ..., 0.8924, 0.4883, 0.6456]])

# GCN

In [25]:
import torch.nn as nn
import torch.functional as f

k_hop = 3

def normalize_adjacency(A, self=True): # Self-loop doesn't work with R-GCN
    size = A.size()[0]
    A = torch.add(torch.eye(size,device=device).to_sparse(), A)
    degree = torch.sparse.sum(A, dim=1).to_dense()
    # print(degree.size())
    d_inv_sqrt = degree.pow(-0.5)
    D_inv_sqrt = torch.diag(d_inv_sqrt)
    A = A.to_dense()
    normalized_A = D_inv_sqrt @ A @ D_inv_sqrt
    return normalized_A
A_prime = normalize_adjacency(A)


class GCNLayer(nn.Module):
    def __init__(self, x_dim, y_dim):
        super(GCNLayer, self).__init__()
        self.x_dim = x_dim
        self.y_dim = y_dim
        # self.A_norm = A_norm

        self.lin = nn.Linear(x_dim, y_dim)

    def forward(self, A_norm, X):
        device = next(self.parameters()).device
        # print(A_norm, X)
        transformed = self.lin(X)
        aggregated = torch.matmul(A_norm, transformed)
        return aggregated


class GCN(nn.Module):
    def __init__(self, x_dim, h_dim, y_dim, max_k_hop):
        super(GCN, self).__init__()
        self.max_k_hop =max_k_hop
        self.x_dim = x_dim
        self.gcns = nn.ModuleList([GCNLayer(x_dim, h_dim) for _ in range(max_k_hop)])
        self.final_gcn = GCNLayer(x_dim, y_dim) 
        self.act = nn.ReLU()
        self.drop = nn.Dropout(p=0.2)
        # self.norm = nn.LayerNorm(y_dim)
    
    def forward(self, A, X):
        # outer layers
        device = next(self.parameters()).device
        for i in range(self.max_k_hop):
            H = self.gcns[i](A, X)
            H = self.act(H)
            H = self.drop(H)
        Y = self.final_gcn(A, H)
        out = self.act(Y)
        return out#nn.LogSoftmax(Y)
model = GCN(X.size(1), embedding_size, output_dim, k_hop).to(device)
out = model(A_prime, X)


AttributeError: 'list' object has no attribute 'size'

# R-GCN

### Block diagonal weight matrix (one is held in memory for each relational weight per layer, so block diagonal sparse matrices save some memory at the cost of some layer-level information flow)

In [None]:
block_size = embedding_size//4 #some partition... ensure it's a round number
relation_weights = [torch.randn(block_size, block_size), torch.randn(block_size, block_size)]

block_diag_matrix = torch.block_diag(*relation_weights).to_sparse()
block_diag_matrix.to_dense()

In [26]:
print( embedding_size//4)

32


# R-GCN model:

In [47]:
def normalize_adjacency(A, self=True): # Self-loop doesn't work with R-GCN
    size = A.size()[0]
    A = torch.add(torch.eye(size,device=device).to_sparse(), A)
    degree = torch.sparse.sum(A, dim=1).to_dense()
    # print(degree.size())
    d_inv_sqrt = degree.pow(-0.5)
    D_inv_sqrt = torch.diag(d_inv_sqrt)
    A = A.to_dense()
    normalized_A = D_inv_sqrt @ A @ D_inv_sqrt
    return normalized_A
# A_prime = normalize_adjacency(A)

A_list = [normalize_adjacency(a, False) for a in A]


##save non-error version
def block_diag (weights):
    block_diag_matrix = torch.block_diag(*weights).to_sparse()
    return block_diag_matrix
    

class R_GCNLayer(nn.Module):
    def __init__(self, x_dim, y_dim, num_relations, block_split):
        super(R_GCNLayer, self).__init__()
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.num_relations = num_relations
        gain = nn.init.calculate_gain('relu')
        max_block_size = 1
        for i in range(1,11):
            if x_dim % i == 0 and y_dim % i == 0:
                max_block_size +=1 
        x_block_size = x_dim // block_split
        y_block_size = y_dim // block_split
        self.W = nn.ParameterList()
        for _ in range(num_relations):
            wr = nn.ParameterList()
            # relation_weights = []
            for i in range(block_split):
                w = nn.Parameter(torch.randn(x_block_size, y_block_size)) 
                nn.init.kaiming_uniform_(w, a=gain)
                # relation_weights.append(w)
             #move blocks to params, not block diag
            # print(block_diag_matrix)
                wr.append(w)
                # self.Wr.append(wr)
            self.W.append(wr)

        self.bias = nn.Parameter(torch.zeros(y_dim))
        nn.init.zeros_(self.bias)

#dense version change forward loop such that weigted is matmul, etc of blockdiag(w)
# def block_diag (weights):
#     block_diag_matrix = torch.block_diag(*weights)
#     return block_diag_matrix
    

# class R_GCNLayer(nn.Module):
#     def __init__(self, x_dim, y_dim, num_relations, block_split):
#         super(R_GCNLayer, self).__init__()
#         self.x_dim = x_dim
#         self.y_dim = y_dim
#         self.num_relations = num_relations
#         gain = nn.init.calculate_gain('relu')
#         max_block_size = 1
#         for i in range(1,11):
#             if x_dim % i == 0 and y_dim % i == 0:
#                 max_block_size +=1 
#         x_block_size = x_dim // block_split
#         y_block_size = y_dim // block_split
#         self.W = nn.ParameterList()
#         for _ in range(num_relations):
#             # wr = nn.ParameterList()
#             relation_weights = []
#             for i in range(block_split):
#                 w = torch.randn(x_block_size, y_block_size)
#                 relation_weights.append(w)
#                 nn.init.kaiming_uniform_(block_diag_dense, a=gain)
#             block_diag_dense = block_diag(relation_weights)
#             # print(block_diag_dense)
#              #move blocks to params, not block diag
#             # print(block_diag_matrix)
#                 # wr.append(w)
#             self.W.append(block_diag_dense)
#             # self.W.append(wr)

#         self.bias = nn.Parameter(torch.zeros(y_dim))
#         nn.init.zeros_(self.bias)
        

    

    def forward(self, A, X):
        device = next(self.parameters()).device
        aggregated = torch.zeros((X.size(0), self.y_dim), device=device)
        # print(self.W[1], self.W[1].size())
        for r in range(self.num_relations):
            # print(block_diag(self.W[r]).to_dense())
            # print(X.size(), self.Wr[r])
            weighted = torch.matmul(X, block_diag(self.W[r]))  # (num_nodes, out_dim)
            # print(weighted.size())
            transformed = torch.sparse.mm(A[r], weighted)
            # aggregated_r = torch.matmul(A_norm, transformed)
            # print(aggregated.size(), transformed.size())
            aggregated += transformed
        aggregated += self.bias
        return aggregated


class R_GCN(nn.Module):
    def __init__(self, x_dim, h_dim, y_dim, max_k_hop, num_relations):
        super(R_GCN, self).__init__()
        self.max_k_hop =max_k_hop
        self.num_relations = num_relations
        self.x_dim = x_dim
        self.block_split = 2
        self.gcns = nn.ModuleList([R_GCNLayer(x_dim, h_dim, num_relations, self.block_split) for _ in range(max_k_hop)])
        self.final_r_gcn = R_GCNLayer(x_dim, y_dim,num_relations, self.block_split) 
        self.act = nn.ReLU()
        self.drop = nn.Dropout(p=0.2)
        # self.norm = nn.LayerNorm(y_dim)
        self.norm = nn.BatchNorm1d(y_dim)


    
        
    def forward(self, A, X):
        # outer layers
        device = next(self.parameters()).device
        for i in range(self.max_k_hop):
            H = self.gcns[i](A, X)
            H = self.act(H)
            H = self.drop(H)
        Y = self.final_r_gcn(A, H)
        out = self.act(Y)
        out = self.norm(out)
        return out#nn.LogSoftmax(Y)
model = R_GCN(X.size(1), embedding_size, output_dim, k_hop, total_edge_classes).to(device)
out = model(A_list, X)
out

tensor([[-0.2029,  0.4855,  0.0000,  0.4661],
        [-0.2029,  0.5779,  0.0000,  0.8182],
        [-0.2029, -1.4555,  0.0000,  0.7315],
        [-0.2029,  0.3165,  0.0000,  0.4782],
        [-0.2029, -0.2421,  0.0000,  0.1120],
        [-0.2029,  0.5388,  0.0000, -0.4432],
        [-0.2029, -1.8544,  0.0000, -1.3735],
        [-0.2029,  1.0408,  0.0000,  0.8749],
        [-0.2029, -2.0593,  0.0000, -0.5386],
        [-0.2029,  1.2324,  0.0000, -0.8722],
        [-0.2029, -1.1084,  0.0000,  0.4878],
        [-0.2029,  0.3290,  0.0000,  1.5753],
        [-0.2029, -0.4637,  0.0000, -0.2371],
        [-0.2029, -0.5747,  0.0000,  0.3715],
        [-0.2029, -0.9568,  0.0000, -0.2955],
        [-0.2029,  1.8275,  0.0000, -0.7516],
        [-0.2029,  1.2218,  0.0000,  0.3498],
        [-0.2029, -0.2814,  0.0000,  1.6360],
        [-0.2029, -0.4658,  0.0000,  0.3716],
        [-0.2029, -0.3599,  0.0000, -0.1129],
        [ 7.3236,  0.2993,  0.0000, -1.1330],
        [ 0.7067, -1.5002,  0.0000

In [31]:
output_dim

3

# Train

In [48]:
from torchviz import make_dot
from tqdm import tqdm
criterion = nn.CrossEntropyLoss()#nn.NLLLoss()
params = list(model.parameters())
# params.extend(list(classifier.parameters()))
optimizer = torch.optim.Adam(params,lr=0.01)

num_epochs = 600

true_labels = torch.randint(0, output_dim, (num_nodes,), dtype=torch.long).to(device)

# print(true_labels)

def save_gradient_hook(grad):
    gradients.append(grad)


for epoch in tqdm(range(num_epochs)):#, desc = f"epoch {epoch}/{num_epochs}"):
    model.train()
    epoch_loss = 0.0
    correct = 0
    total = 0
    i = 0

    predicted_labels = model(A_list, X)
    
    loss = criterion(predicted_labels, true_labels)

    # for param in params:
    #     param.register_hook(save_gradient_hook)
    
    # if i % batch_size == 0:
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    i += 1
    epoch_loss += loss.item()

    _, predicted = torch.max(predicted_labels.data, 1)
    total += true_labels.size(0)
    correct += (predicted == true_labels).sum().item()
    avg_loss = epoch_loss
    accuracy = 100 * correct / total
    if epoch % 5 == 0:
        print(f"epoch: {epoch}\navg loss: {avg_loss} accuracy: {accuracy}")
        # make_dot(predicted_labels, params=dict(list(model.named_parameters()))).render("r_gcn_torchviz", format="png")


           

  0%|▏                                          | 2/600 [00:00<02:14,  4.44it/s]

epoch: 0
avg loss: 1.524507999420166 accuracy: 22.0


  1%|▌                                          | 7/600 [00:01<01:32,  6.43it/s]

epoch: 5
avg loss: 1.4015671014785767 accuracy: 34.0


  2%|▊                                         | 12/600 [00:02<02:22,  4.13it/s]

epoch: 10
avg loss: 1.3542121648788452 accuracy: 43.0


  3%|█▏                                        | 17/600 [00:03<01:31,  6.35it/s]

epoch: 15
avg loss: 1.3514102697372437 accuracy: 33.0


  4%|█▍                                        | 21/600 [00:04<02:40,  3.61it/s]

epoch: 20
avg loss: 1.3495310544967651 accuracy: 40.0


  4%|█▉                                        | 27/600 [00:05<01:43,  5.55it/s]

epoch: 25
avg loss: 1.3472623825073242 accuracy: 41.0


  5%|██▏                                       | 32/600 [00:06<01:25,  6.62it/s]

epoch: 30
avg loss: 1.3439117670059204 accuracy: 35.0


  6%|██▌                                       | 37/600 [00:07<01:21,  6.90it/s]

epoch: 35
avg loss: 1.309199333190918 accuracy: 36.0


  7%|██▉                                       | 42/600 [00:07<01:22,  6.78it/s]

epoch: 40
avg loss: 1.3285068273544312 accuracy: 34.0


  8%|███▎                                      | 47/600 [00:08<01:19,  6.92it/s]

epoch: 45
avg loss: 1.2921274900436401 accuracy: 38.0


  9%|███▋                                      | 52/600 [00:09<01:18,  7.02it/s]

epoch: 50
avg loss: 1.3122432231903076 accuracy: 39.0


 10%|███▉                                      | 57/600 [00:10<01:21,  6.66it/s]

epoch: 55
avg loss: 1.2917721271514893 accuracy: 41.0


 10%|████▎                                     | 62/600 [00:10<01:17,  6.93it/s]

epoch: 60
avg loss: 1.3459113836288452 accuracy: 37.0


 11%|████▋                                     | 67/600 [00:11<01:16,  6.96it/s]

epoch: 65
avg loss: 1.2794016599655151 accuracy: 47.0


 12%|█████                                     | 72/600 [00:12<01:15,  6.98it/s]

epoch: 70
avg loss: 1.2659642696380615 accuracy: 41.0


 13%|█████▍                                    | 77/600 [00:13<01:14,  6.98it/s]

epoch: 75
avg loss: 1.281798243522644 accuracy: 38.0


 14%|█████▋                                    | 82/600 [00:13<01:13,  7.04it/s]

epoch: 80
avg loss: 1.2759076356887817 accuracy: 38.0


 14%|██████                                    | 86/600 [00:15<02:26,  3.51it/s]

epoch: 85
avg loss: 1.279822826385498 accuracy: 38.0


 15%|██████▍                                   | 92/600 [00:16<01:25,  5.94it/s]

epoch: 90
avg loss: 1.2701709270477295 accuracy: 41.0


 16%|██████▋                                   | 96/600 [00:17<02:03,  4.09it/s]

epoch: 95
avg loss: 1.2736297845840454 accuracy: 43.0


 17%|██████▉                                  | 101/600 [00:18<02:36,  3.19it/s]

epoch: 100
avg loss: 1.2548794746398926 accuracy: 44.0


 18%|███████▏                                 | 106/600 [00:20<02:38,  3.11it/s]

epoch: 105
avg loss: 1.2558616399765015 accuracy: 41.0


 18%|███████▌                                 | 111/600 [00:21<02:38,  3.08it/s]

epoch: 110
avg loss: 1.2488245964050293 accuracy: 40.0


 19%|███████▉                                 | 116/600 [00:23<02:19,  3.47it/s]

epoch: 115
avg loss: 1.2044368982315063 accuracy: 42.0


 20%|████████▎                                | 122/600 [00:24<01:20,  5.92it/s]

epoch: 120
avg loss: 1.1971265077590942 accuracy: 50.0


 21%|████████▋                                | 127/600 [00:25<01:09,  6.85it/s]

epoch: 125
avg loss: 1.23415207862854 accuracy: 46.0


 22%|████████▉                                | 131/600 [00:25<01:33,  5.04it/s]

epoch: 130
avg loss: 1.1870958805084229 accuracy: 41.0


 23%|█████████▎                               | 137/600 [00:27<01:46,  4.34it/s]

epoch: 135
avg loss: 1.1799806356430054 accuracy: 46.0


 24%|█████████▋                               | 141/600 [00:28<01:15,  6.07it/s]

epoch: 140
avg loss: 1.1739649772644043 accuracy: 49.0


 24%|██████████                               | 147/600 [00:29<01:24,  5.33it/s]

epoch: 145
avg loss: 1.1880673170089722 accuracy: 45.0


 25%|██████████▍                              | 152/600 [00:30<01:06,  6.70it/s]

epoch: 150
avg loss: 1.1831810474395752 accuracy: 53.0


 26%|██████████▋                              | 157/600 [00:31<01:29,  4.94it/s]

epoch: 155
avg loss: 1.1270793676376343 accuracy: 50.0


 27%|███████████                              | 162/600 [00:32<01:05,  6.64it/s]

epoch: 160
avg loss: 1.1462008953094482 accuracy: 48.0


 28%|███████████▎                             | 166/600 [00:32<01:02,  6.95it/s]

epoch: 165
avg loss: 1.1707470417022705 accuracy: 45.0


 28%|███████████▋                             | 171/600 [00:34<02:11,  3.27it/s]

epoch: 170
avg loss: 1.077985405921936 accuracy: 48.0


 30%|████████████                             | 177/600 [00:36<01:36,  4.36it/s]

epoch: 175
avg loss: 1.1419662237167358 accuracy: 50.0


 30%|████████████▍                            | 182/600 [00:36<01:04,  6.47it/s]

epoch: 180
avg loss: 1.144460678100586 accuracy: 51.0


 31%|████████████▊                            | 187/600 [00:37<00:58,  7.03it/s]

epoch: 185
avg loss: 1.1688612699508667 accuracy: 44.0


 32%|█████████████                            | 192/600 [00:38<00:56,  7.17it/s]

epoch: 190
avg loss: 1.103809118270874 accuracy: 50.0


 33%|█████████████▍                           | 196/600 [00:38<01:23,  4.86it/s]

epoch: 195
avg loss: 1.115091323852539 accuracy: 49.0


 34%|█████████████▊                           | 202/600 [00:40<01:14,  5.32it/s]

epoch: 200
avg loss: 1.0746572017669678 accuracy: 57.0


 34%|██████████████▏                          | 207/600 [00:41<00:58,  6.70it/s]

epoch: 205
avg loss: 1.1027384996414185 accuracy: 48.0


 35%|██████████████▍                          | 212/600 [00:41<00:54,  7.06it/s]

epoch: 210
avg loss: 1.0325841903686523 accuracy: 59.0


 36%|██████████████▊                          | 217/600 [00:42<00:53,  7.16it/s]

epoch: 215
avg loss: 1.093115210533142 accuracy: 52.0


 37%|███████████████                          | 221/600 [00:43<00:52,  7.16it/s]

epoch: 220
avg loss: 1.1790440082550049 accuracy: 43.0


 38%|███████████████▍                         | 226/600 [00:44<01:51,  3.36it/s]

epoch: 225
avg loss: 1.1730895042419434 accuracy: 49.0


 39%|███████████████▊                         | 232/600 [00:46<01:39,  3.69it/s]

epoch: 230
avg loss: 1.1265952587127686 accuracy: 52.0


 40%|████████████████▏                        | 237/600 [00:47<00:59,  6.12it/s]

epoch: 235
avg loss: 1.1131348609924316 accuracy: 53.0


 40%|████████████████▌                        | 242/600 [00:47<00:51,  6.93it/s]

epoch: 240
avg loss: 1.1782243251800537 accuracy: 48.0


 41%|████████████████▊                        | 246/600 [00:48<01:10,  4.99it/s]

epoch: 245
avg loss: 1.0937623977661133 accuracy: 52.0


 42%|█████████████████▏                       | 251/600 [00:50<01:47,  3.25it/s]

epoch: 250
avg loss: 1.0877245664596558 accuracy: 50.0


 43%|█████████████████▍                       | 256/600 [00:51<01:50,  3.12it/s]

epoch: 255
avg loss: 1.164635419845581 accuracy: 51.0


 44%|█████████████████▊                       | 261/600 [00:53<01:50,  3.08it/s]

epoch: 260
avg loss: 1.092973232269287 accuracy: 53.0


 44%|██████████████████▏                      | 266/600 [00:55<01:48,  3.08it/s]

epoch: 265
avg loss: 1.1347428560256958 accuracy: 51.0


 45%|██████████████████▌                      | 271/600 [00:56<01:43,  3.17it/s]

epoch: 270
avg loss: 1.0954760313034058 accuracy: 52.0


 46%|██████████████████▊                      | 276/600 [00:57<01:09,  4.65it/s]

epoch: 275
avg loss: 1.051997423171997 accuracy: 54.0


 47%|███████████████████▎                     | 282/600 [00:58<00:55,  5.75it/s]

epoch: 280
avg loss: 1.054595708847046 accuracy: 52.0


 48%|███████████████████▌                     | 287/600 [00:59<00:45,  6.83it/s]

epoch: 285
avg loss: 1.0293818712234497 accuracy: 52.0


 49%|███████████████████▉                     | 292/600 [01:00<00:43,  7.04it/s]

epoch: 290
avg loss: 1.010118842124939 accuracy: 56.0


 50%|████████████████████▎                    | 297/600 [01:01<00:42,  7.07it/s]

epoch: 295
avg loss: 1.054264783859253 accuracy: 54.0


 50%|████████████████████▌                    | 301/600 [01:01<00:41,  7.16it/s]

epoch: 300
avg loss: 1.005678653717041 accuracy: 58.0


 51%|████████████████████▉                    | 307/600 [01:02<00:53,  5.45it/s]

epoch: 305
avg loss: 1.0025187730789185 accuracy: 61.0


 52%|█████████████████████▎                   | 312/600 [01:03<00:42,  6.78it/s]

epoch: 310
avg loss: 1.0324794054031372 accuracy: 55.0


 53%|█████████████████████▋                   | 317/600 [01:04<00:39,  7.19it/s]

epoch: 315
avg loss: 1.027179479598999 accuracy: 51.0


 54%|██████████████████████                   | 322/600 [01:05<00:40,  6.95it/s]

epoch: 320
avg loss: 0.9974989295005798 accuracy: 61.0


 55%|██████████████████████▎                  | 327/600 [01:05<00:38,  7.03it/s]

epoch: 325
avg loss: 0.9491453766822815 accuracy: 58.0


 55%|██████████████████████▋                  | 332/600 [01:06<00:43,  6.15it/s]

epoch: 330
avg loss: 1.0478346347808838 accuracy: 52.0


 56%|██████████████████████▉                  | 336/600 [01:07<00:53,  4.92it/s]

epoch: 335
avg loss: 0.9456179141998291 accuracy: 56.0


 57%|███████████████████████▎                 | 342/600 [01:08<00:57,  4.49it/s]

epoch: 340
avg loss: 0.8786041736602783 accuracy: 67.0


 58%|███████████████████████▋                 | 347/600 [01:10<00:54,  4.66it/s]

epoch: 345
avg loss: 0.9452252984046936 accuracy: 63.0


 58%|███████████████████████▉                 | 351/600 [01:11<01:15,  3.31it/s]

epoch: 350
avg loss: 0.8861430287361145 accuracy: 65.0


 59%|████████████████████████▎                | 356/600 [01:13<01:18,  3.11it/s]

epoch: 355
avg loss: 0.9281129240989685 accuracy: 60.0


 60%|████████████████████████▋                | 361/600 [01:14<01:17,  3.07it/s]

epoch: 360
avg loss: 0.8788821697235107 accuracy: 61.0


 61%|█████████████████████████                | 366/600 [01:16<01:09,  3.36it/s]

epoch: 365
avg loss: 0.8767220973968506 accuracy: 65.0


 62%|█████████████████████████▎               | 371/600 [01:17<01:13,  3.10it/s]

epoch: 370
avg loss: 0.8571144938468933 accuracy: 63.0


 63%|█████████████████████████▋               | 376/600 [01:19<01:12,  3.09it/s]

epoch: 375
avg loss: 0.8612250685691833 accuracy: 62.0


 64%|██████████████████████████               | 382/600 [01:21<00:59,  3.66it/s]

epoch: 380
avg loss: 0.9050018787384033 accuracy: 62.0


 64%|██████████████████████████▍              | 387/600 [01:21<00:34,  6.18it/s]

epoch: 385
avg loss: 0.8824501633644104 accuracy: 62.0


 65%|██████████████████████████▊              | 392/600 [01:22<00:29,  6.97it/s]

epoch: 390
avg loss: 0.8231688141822815 accuracy: 60.0


 66%|███████████████████████████              | 396/600 [01:23<00:51,  3.94it/s]

epoch: 395
avg loss: 0.9229439496994019 accuracy: 57.0


 67%|███████████████████████████▍             | 401/600 [01:25<01:01,  3.23it/s]

epoch: 400
avg loss: 0.8714409470558167 accuracy: 64.0


 68%|███████████████████████████▋             | 406/600 [01:26<01:04,  3.03it/s]

epoch: 405
avg loss: 0.8388002514839172 accuracy: 67.0


 68%|████████████████████████████             | 411/600 [01:28<01:01,  3.09it/s]

epoch: 410
avg loss: 0.858518660068512 accuracy: 62.0


 69%|████████████████████████████▍            | 416/600 [01:30<00:59,  3.08it/s]

epoch: 415
avg loss: 0.8606392741203308 accuracy: 65.0


 70%|████████████████████████████▊            | 422/600 [01:32<00:49,  3.62it/s]

epoch: 420
avg loss: 0.9074292182922363 accuracy: 66.0


 71%|█████████████████████████████▏           | 427/600 [01:32<00:28,  6.15it/s]

epoch: 425
avg loss: 0.7904000878334045 accuracy: 65.0


 72%|█████████████████████████████▌           | 432/600 [01:33<00:24,  6.99it/s]

epoch: 430
avg loss: 0.9044851660728455 accuracy: 65.0


 73%|█████████████████████████████▊           | 437/600 [01:34<00:23,  7.08it/s]

epoch: 435
avg loss: 0.8339964151382446 accuracy: 66.0


 74%|██████████████████████████████▏          | 442/600 [01:34<00:22,  7.15it/s]

epoch: 440
avg loss: 0.7939798831939697 accuracy: 63.0


 74%|██████████████████████████████▍          | 446/600 [01:35<00:21,  7.09it/s]

epoch: 445
avg loss: 0.8780575394630432 accuracy: 55.0


 75%|██████████████████████████████▊          | 451/600 [01:37<00:44,  3.37it/s]

epoch: 450
avg loss: 0.9870428442955017 accuracy: 54.0


 76%|███████████████████████████████▏         | 456/600 [01:38<00:46,  3.12it/s]

epoch: 455
avg loss: 0.9165701270103455 accuracy: 52.0


 77%|███████████████████████████████▌         | 461/600 [01:40<00:44,  3.14it/s]

epoch: 460
avg loss: 0.8253042697906494 accuracy: 59.0


 78%|███████████████████████████████▊         | 466/600 [01:41<00:42,  3.13it/s]

epoch: 465
avg loss: 0.8652363419532776 accuracy: 58.0


 78%|████████████████████████████████▏        | 471/600 [01:43<00:41,  3.09it/s]

epoch: 470
avg loss: 0.8790467977523804 accuracy: 56.0


 80%|████████████████████████████████▌        | 477/600 [01:44<00:21,  5.83it/s]

epoch: 475
avg loss: 0.906903088092804 accuracy: 56.0


 80%|████████████████████████████████▊        | 481/600 [01:45<00:28,  4.14it/s]

epoch: 480
avg loss: 0.8562385439872742 accuracy: 61.0


 81%|█████████████████████████████████▏       | 486/600 [01:47<00:34,  3.33it/s]

epoch: 485
avg loss: 0.8423547148704529 accuracy: 59.0


 82%|█████████████████████████████████▌       | 491/600 [01:48<00:35,  3.10it/s]

epoch: 490
avg loss: 0.8060154914855957 accuracy: 64.0


 83%|█████████████████████████████████▉       | 496/600 [01:50<00:33,  3.08it/s]

epoch: 495
avg loss: 0.8087769150733948 accuracy: 62.0


 84%|██████████████████████████████████▏      | 501/600 [01:51<00:32,  3.05it/s]

epoch: 500
avg loss: 0.8371729850769043 accuracy: 60.0


 84%|██████████████████████████████████▌      | 506/600 [01:53<00:30,  3.09it/s]

epoch: 505
avg loss: 0.8683082461357117 accuracy: 60.0


 85%|██████████████████████████████████▉      | 512/600 [01:54<00:19,  4.54it/s]

epoch: 510
avg loss: 0.8993200659751892 accuracy: 58.0


 86%|███████████████████████████████████▎     | 517/600 [01:55<00:12,  6.57it/s]

epoch: 515
avg loss: 0.9263912439346313 accuracy: 60.0


 87%|███████████████████████████████████▋     | 522/600 [01:56<00:10,  7.16it/s]

epoch: 520
avg loss: 0.8035109043121338 accuracy: 62.0


 88%|████████████████████████████████████     | 527/600 [01:57<00:10,  7.27it/s]

epoch: 525
avg loss: 0.7605675458908081 accuracy: 64.0


 89%|████████████████████████████████████▎    | 532/600 [01:57<00:09,  7.28it/s]

epoch: 530
avg loss: 0.8301758766174316 accuracy: 61.0


 90%|████████████████████████████████████▋    | 537/600 [01:58<00:09,  6.69it/s]

epoch: 535
avg loss: 0.8696222901344299 accuracy: 59.0


 90%|█████████████████████████████████████    | 542/600 [01:59<00:08,  7.03it/s]

epoch: 540
avg loss: 0.8617804050445557 accuracy: 57.0


 91%|█████████████████████████████████████▍   | 547/600 [01:59<00:07,  7.11it/s]

epoch: 545
avg loss: 0.8542235493659973 accuracy: 55.0


 92%|█████████████████████████████████████▋   | 552/600 [02:00<00:06,  7.22it/s]

epoch: 550
avg loss: 0.8299148678779602 accuracy: 55.0


 93%|██████████████████████████████████████   | 557/600 [02:01<00:05,  7.17it/s]

epoch: 555
avg loss: 0.8399718999862671 accuracy: 62.0


 94%|██████████████████████████████████████▍  | 562/600 [02:02<00:05,  7.24it/s]

epoch: 560
avg loss: 0.763938844203949 accuracy: 64.0


 94%|██████████████████████████████████████▋  | 567/600 [02:02<00:04,  7.14it/s]

epoch: 565
avg loss: 0.7647873759269714 accuracy: 64.0


 95%|███████████████████████████████████████  | 571/600 [02:03<00:04,  7.03it/s]

epoch: 570
avg loss: 0.7708765268325806 accuracy: 64.0


 96%|███████████████████████████████████████▍ | 577/600 [02:04<00:03,  6.92it/s]

epoch: 575
avg loss: 0.7578765153884888 accuracy: 65.0


 97%|███████████████████████████████████████▊ | 582/600 [02:04<00:02,  6.48it/s]

epoch: 580
avg loss: 0.7962588667869568 accuracy: 62.0


 98%|████████████████████████████████████████ | 587/600 [02:05<00:01,  6.61it/s]

epoch: 585
avg loss: 0.7253901958465576 accuracy: 67.0


 99%|████████████████████████████████████████▍| 592/600 [02:06<00:01,  7.01it/s]

epoch: 590
avg loss: 0.7674355506896973 accuracy: 61.0


 99%|████████████████████████████████████████▋| 596/600 [02:07<00:00,  6.92it/s]

epoch: 595
avg loss: 0.8034726977348328 accuracy: 62.0


100%|█████████████████████████████████████████| 600/600 [02:07<00:00,  4.69it/s]


# I'm confused about the rdf stuff. I don't see any multimodal data.

In [53]:
true_labels

tensor([3, 3, 0, 3, 2, 3, 2, 0, 1, 1, 2, 3, 1, 1, 2, 0, 1, 1, 3, 0, 2, 2, 1, 3,
        1, 1, 3, 2, 0, 3, 1, 1, 3, 0, 2, 2, 0, 2, 2, 1, 2, 2, 3, 1, 3, 1, 3, 2,
        1, 2, 1, 2, 2, 3, 2, 1, 0, 1, 3, 0, 1, 3, 1, 2, 3, 0, 2, 2, 0, 1, 1, 0,
        3, 2, 1, 2, 0, 0, 0, 1, 2, 0, 0, 2, 3, 1, 1, 2, 3, 1, 3, 3, 3, 3, 3, 0,
        0, 1, 1, 0])

In [56]:
predicted

tensor([3, 3, 0, 3, 0, 3, 1, 0, 1, 1, 1, 3, 1, 3, 1, 0, 1, 3, 3, 1, 1, 1, 1, 3,
        1, 1, 3, 1, 0, 3, 3, 1, 3, 1, 1, 1, 0, 0, 1, 1, 1, 1, 3, 1, 1, 1, 3, 3,
        1, 1, 0, 1, 1, 3, 1, 1, 0, 1, 3, 1, 1, 3, 0, 1, 3, 1, 3, 0, 1, 1, 0, 0,
        3, 3, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 3, 1, 0, 0, 3, 0, 3, 3, 1, 3, 3, 0,
        0, 1, 1, 0])

# Create graph from n-triple file:

In [1]:
import rdflib
from rdflib import Graph
import logging
import os
from IPython.display import clear_output


data_loc = './Downloads/ml4g/dmg/mmkg/dmg/scripts/dmg777k_stripped.nt'
folder = './Downloads/ml4g'

def create_new_graph(path, batch_size = 1000, test=True):
    logging.basicConfig(
    filename='rdf_parsing_errors.log',
    filemode='w')
    
    graph = Graph()
    # graph = graph.parse(, format='nt')
    batch_num = 0
    i = 0
    with open(path, 'r', encoding='utf-8') as f:
        while True:
            batch = []
            try:
                [batch.append(next(f)) for j in range(batch_size)]
                i += j
            except:
                pass
            if not batch:
                break
            batch_num += 1
            nt_string = ''.join(batch)
            try:
                graph.parse(data=nt_string, format='nt')
                if test:
                    graph = Graph()
            except ParseError as e:
                logging.error(f"in batch: {batch_num}:\npproblematic data:\n\n{batch}\n\n")
                check(batch, batch_num, test=test)

            if batch_num == 5:
                # print(batch)
                pass
            if batch_num % 10 == 0:
                # clear_output()
                
                pass
                # print(f"{batch_num}/:o?")

    return graph

def check(batch, batch_num, test = True):
    graph = Graph()
    for i, line in enumerate(batch):
        try:
            graph.parse(line)
        except Exception as e:
            logging.error(f'in line: {i}:\n{line}\n{e}')
graph=create_new_graph(data_loc, test=False)
# graph = Graph()
# graph.parse(data_loc)

In [155]:
from datetime import datetime
node_set = set()
edge_set = set()
string_set = set()
image_set = set()
num_set = set()
poly_set = set()
date_set = set()
point_set = set()
i = 0

def is_date(date_string):
    try:
        datetime.strptime(date_string, '%Y-%m-%d')
        return True
    except ValueError:
        return False

for s,p,o in graph:
    i+=1
    pi = p.identifier
    for node in [s,o]:
        ni = node.identifier
        # if 'http' in ni[:200] and 'geonames' not in ni: #just add geonames to node set I think...
        if 'http' in ni[:200]: #200, because images sometimes have kgbench url attached
            node_set.add(''.join(ni.split(':')[1].split('/')[:-1]))
        else:
            if node.isalnum():
                if node.isnumeric():
                    if node.isdigit():
                        num_set.add(int(node.identifier))
                    else:
                        num_set.add(float(node.identifier))
            elif ni.startswith('POINT') or ni.startswith('Point'): #didn't see any points, but according to the paper they can be included.
                point_set.add(ni)
            elif node.isalpha(): #maybe elif maybe not dunno if it filters out strings with numbers
                string_set.add(ni)
            elif ni.startswith('_9j_'):
                image_set.add(ni) #might want to load this to hard drive if memory becomes an issue.
            elif ni.startswith('POLYGON') or ni.startswith('Polygon'):
                poly_set.add(ni)
            elif is_date(ni):
                date_set.add(ni)
            elif ni.isascii():
                string_set.add(ni)
            elif ni.isprintable():
                string_set.add(ascii(ni)) #don't know if it's necessary, but it probably can't hurt
            else: #all that's left seems to be monument stories and property description related text. 
                            #If there's an error later it's probably from here
                string_set.add(ascii(ni))
                
                # print(ni.isalpha(),ni)
                


    edge_set.add(p)

    
i

777124

In [189]:
from string import printable

character_map = {i:char for i,char in enumerate(printable[:95])} #exclude \t and such specific characters


# Convert raw values to consistent feature vectors for the encoders:

In [254]:
import math
import torch
from string import printable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

node_map = {i:node for i,node in enumerate(node_set)}
inv_node_map = {node:i for i,node in enumerate(node_set)}

edge_map = {i:edge for i,edge in enumerate(edge_set)}
inv_edge_map = {edge:i for i,edge in enumerate(edge_set)}

def tokenize_string(s, character_map):
    #preprocess for text encoder[1](? layer temporal cnn -->embedding_dim)

    #character level embeddings
    tokens = [character_map[char] for char in s]
    return tokens #

def encode_image(img_data):
    #preprocess for encoder[2](2 layer cnn -->embedding_dim)
    return #maybe save tokens & image representation to hard drive?

def encode_num(n, max_num):
    # read how it's done exactly.. cat(e(num),e(bool))
    # maybe add log scale if max_num is really high...
    v = n/max_num
    return torch.tensor(v,dtype=torch.float32, device=device)

def encode_polygon(poly, global_mean_x, global_mean_y, x_max, y_max):
    #preprocess for spatial encoder[3](? layer temporal cnn -->embedding_dim) 
    poly_tensor_x = (global_mean_x-torch.tensor([ point[0] for point in poly], dtype=torch.float32, device=device))/x_max
    poly_tensor_y = (global_mean_y-torch.tensor([ point[1] for point in poly], dtype=torch.float32, device=device))/y_max
    # print(poly_tensor_x)
    return torch.stack((poly_tensor_x,poly_tensor_y),dim=0)

def encode_point(point, max_x, max_y):
    #preprocess for spatial encoder[3](? layer temporal cnn -->embedding_dim)
    
    point = torch.tensor(point,dtype=torch.float32, device=device)
    div = torch.tensor((max_x,max_y),dtype=torch.float32, device=device)
    return point/div

def encode_date(date):
    #preprocess for temporal encoder[4](? layer ffnn -->embedding_dim)
    def cyclical(num, max_num, epsilon = 1e-8):
        # cyclical: [sine((2pi * X)/max_num_of_cycle) cos((2pi * X)/max_num_of_cycle)]
        return torch.tensor([math.sin((2 * math.pi * num)/max_num)+epsilon, math.cos((2 * math.pi * num)/max_num)+epsilon],dtype=torch.float32, device=device)
        
    def norm_cent(num):
        # non-cyclical only centuries: normalized from -99 to 99 (-9999 bc to 9999 ac)
        return torch.tensor((num + 99)/198,dtype=torch.float32, device=device).unsqueeze(0)
        
    split_str = date.split('-')
    years_str = split_str[0]
    month_str = split_str[1]
    day_str = split_str[2]
    centuries = norm_cent(int(years_str[:-2]))
    decades = cyclical(int(years_str[-2]), 10)
    years = cyclical(int(years_str[-1]), 10)
    months = cyclical(int(month_str), 12)
    days = cyclical(int(day_str), 31)
    # print((centuries, decades, years, months, days))
    # print(decades,years)
    return torch.cat((centuries, decades, years, months, days), dim=0)



def process_point(point_str, highest_x, highest_y):
    if 'POINT' in point_str:
        point_str = point_str.split('POINT(')[1].split('))')[0]
    elif 'Point' in point_str:
        point_str = point_str.split('Point(')[1].split('))')[0]
    point_list = point_str.strip(')').strip('(').split()
    point = tuple([float(coord) for coord in point_list])
    point_x,point_y = point
    highest_x = point_x if point_x > highest_x else highest_x
    highest_y = point_y if point_y > highest_y else highest_y
    return point, highest_x, highest_y
    

def get_num_data(poly_str, max_x, y_max):
    if 'POLYGON' in poly_str:
        poly_str = poly_str.split('POLYGON ((')[1].split('))')[0]
    elif 'Polygon' in poly_str:
        poly_str = poly_str.split('Polygon ((')[1].split('))')[0]
    poly_combi_str_list = [poly for poly in poly_str.split(',')]# refactor this later. someone is talking rlly loud and I can't think
    try:
        poly_tupled = [(float(poly.split()[0]),float(poly.split()[1])) for poly in poly_combi_str_list]
    except ValueError:
        poly_tupled = [(float(poly.split()[0].strip(')').strip('(')),
                     float(poly.split()[1].strip(')').strip('('))) 
                    for poly in poly_combi_str_list]
    x_max, y_max = max([x for x,y in poly_tupled]), max([y for x, y in poly_tupled])
    x_max =  x_max if x_max > max_x else max_x
    y_max =  y_max if y_max > max_y else max_y
    x_mean = sum([tup[0] for tup in poly_tupled])/len(poly_tupled)
    y_mean = sum([tup[1] for tup in poly_tupled])/len(poly_tupled)
    return poly_tupled, x_mean, y_mean, x_max, y_max


global_mean_x = 0
global_mean_y = 0

print(point)
max_x = 0
max_y = 0

points_tupled = []
for i, point in enumerate(point_set):
    point_tupled,max_x,max_y = process_point(point,max_x,max_y)
    points_tupled.append(point_tupled)


polys_tupled = []
x_max, y_max = 0,0
for i, poly in enumerate(poly_set):
    # if i < 100:
    poly_tupled, x_mean, y_mean, x_max, y_max = get_num_data(poly, x_max, y_max)
    global_mean_x += 1
    global_mean_y += 1
    polys_tupled.append(poly_tupled)
i += 1
print(poly_tupled[0])
global_mean_x, global_mean_y = global_mean_x/i, global_mean_y/i


character_map = {char:i for i,char in enumerate(printable)}
character_map['\x7f'] = 101
strings = []
for s in string_set:
    str_feature = tokenize_string(s, character_map)
    strings.append(str_feature)

imgs = []
for img in image_set:
    img_feature = encode_image(img)
    imgs.append(img_feature)

nums = []
for n in num_set:
    norm_fac = max(num_set)
    num_feature = encode_num(n, norm_fac)
    nums.append(num_feature)

poly_features = []
for i,poly in enumerate(polys_tupled):
    poly_feature = encode_polygon(poly, global_mean_x, global_mean_y, x_max, y_max)
    poly_features.append(poly_feature)
    if i<10:
        # print(poly)
        # print(embedding)
        pass

point_features = []
for point_tup in points_tupled:
    # print(point)
    point_feature = encode_point(point_tup, max_x, max_y)
    point_features.append(point_feature)
        # print(point)
        # print(embedding)
print(point)

date_features = []
for d in date_set:
    date_feature = encode_date(d)
    date_features.append(date_feature)



Point(5.948392853085229 52.659852423418606)
(7.161238424712663, 53.11030897180887)
Point(5.948392853085229 52.659852423418606)


# 

images.memorix.nlrcedownloadfullsize
bag.basisregistraties.overheid.nlbagidgeometry

In [100]:
import base64
from PIL import Image

def decode_base64_jpg(encoded_str,log_note='pass values to decode_base_64_jpg'): # - to +, _ to /
    """
    encoded_str: url safe base 64 jpg string --> image bytes string
    """

    try:
        image_bytes = base64.urlsafe_b64decode(encoded_str)
        return image_bytes
    except Exception as e:
        logging.error(f"{e} error encoding image at {log_note}")
        return None

def save_bytes_to_jpg(image_bytes, item_num=0,folder='Downloads/ml4g/image_data/',name='decoded'):
    filename = f'{folder}{name}_{item_num}.jpg'
    with open(filename, 'wb') as img_file:
        img_file.write(image_bytes)


image_bytes = decode_base64_jpg(image_data)
save_bytes_to_jpg(image_bytes)

 'isalnum',
 'isalpha',
 'isascii',
 'isdecimal',
 'isdigit',
 'isidentifier',
 'islower',
 'isnumeric',
 'isprintable',
 'isspace',
 'istitle',
 'isupper',