In [None]:
import copy
import collections
import functools
import os
import json
import csv

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from game.header import *

from data.toy_dataset import ToyDataset
from torch_geometric.loader import DataLoader

In [21]:
def string_to_action_triple(action_string, video_id):
    a = action_string.split(' ')
    if len(a) == 3:
        action_triple = (video_id, int(a[0][1:]), float(a[1]), float(a[2]))
    elif len(a) == 1 and a[0] == '':
        return None
    else:
        print('invalid string')
        return None
    return action_triple

root = '/data/Datasets/ag/'
actions = []

with open(root + 'annotations/Charades/Charades_v1_train.csv') as f:
    reader = csv.DictReader(f)
    for row in reader:
        video_id = row['id']
        action_string = row['actions'].split(';')
        video_actions = [string_to_action_triple(action, video_id) for action in action_string]
        actions.append(video_actions)

        print(row['length'])

In [20]:
actions

In [None]:
for action in actions:
    pass

In [65]:
def pyg_to_pred_tensors(data):
    def convert_graph(data):
        nullary = torch.zeros(len(RELS))
        unary = data.x

        binary = torch.zeros(data.num_nodes, data.num_nodes, len(RELS))
        for i,type in enumerate(data.edge_type):
            binary[data.edge_index[0][i], data.edge_index[1][i], type] = 1
        
        tensors = [nullary, unary, binary]
        tensors = [tensor.unsqueeze(0) for tensor in tensors]

        return tensors

    if data.batch is not None:
        datalist = data.to_data_list()
        tensors_list = [convert_graph(d) for d in datalist]
        nullaries = torch.vstack([t[0] for t in tensors_list])
        unaries = torch.vstack([t[1] for t in tensors_list])
        binaries = torch.vstack([t[2] for t in tensors_list])
        return [nullaries, unaries, binaries]
    else:
        pred_tensors = convert_graph(data)
        return pred_tensors

def show_pyg_graph(graph):
    import networkx as nx
    # Create an empty NetworkX directed graph
    G = nx.DiGraph()
    
    # Add nodes with their features
    for i in range(graph.num_nodes):
        G.add_node(i, label=NODES[graph.node_type[i].item()])
    
    # Add edges with their attributes
    edge_index = graph.edge_index
    edge_type = graph.edge_type if 'edge_type' in graph else None
    for i in range(edge_index.size(1)):
        source, target = edge_index[:, i].tolist()
        if edge_type is not None:
            label = RELS[int(edge_type[i].item())]
            G.add_edge(source, target, label=label)
        else:
            G.add_edge(source, target)
    
    # Draw the graph
    pos = nx.circular_layout(G)
    labels = nx.get_node_attributes(G, 'label')
    edge_labels = nx.get_edge_attributes(G, 'label')
    
    nx.draw(G, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=500, font_size=10, font_color='black', font_weight='bold', arrows=True)
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='red')
