In [14]:
import copy
import collections
import functools
import os
import json
import csv

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

from game.header import *

from data.toy_dataset import ToyDataset
from torch_geometric.loader import DataLoader

In [29]:
def find_last_frame_idx(directory):
    if not os.path.exists(directory):
        print('folder %s does not exist' % directory)
        return None
    highest_number = -float('inf')  # Start with a very low number
    for file_name in os.listdir(directory):
        try:
            # Extract the number from the file name
            number = int(file_name.split('.')[0])  # Assuming numbers are before the file extension
            if number > highest_number:
                highest_number = number
        except ValueError:
            # Skip files that don't have a number as their name
            continue
    return highest_number if highest_number != -float('inf') else None

def find_closest_frame_idx(directory, frame_number):
    if not os.path.exists(directory):
        print('folder %s does not exist' % directory)
        return None
    closest_number = None
    smallest_diff = float('inf')  # Start with a very low number
    for file_name in os.listdir(directory):
        try:
            # Extract the number from the file name
            number = int(file_name.split('.')[0])  # Assuming numbers are before the file extension
            diff = abs(number - frame_number)
            if diff < smallest_diff:
                smallest_diff = diff
                closest_number = number
        except ValueError:
            # Skip files that don't have a number as their name
            continue
    return closest_number

directory_path = "/data/Datasets/ag/frames/C93HZ.mp4"
highest = find_last_frame_idx(directory_path)
print(highest)

closest = find_closest_frame_idx(directory_path, 100)
print(closest)

713
102


In [32]:
def string_to_action_triple(action_string, video_id):
    a = action_string.split(' ')
    if len(a) == 3:
        action_triple = (video_id, int(a[0][1:]), float(a[1]), float(a[2]))
    elif len(a) == 1 and a[0] == '':
        return None
    else:
        print('invalid string')
        return None
    return action_triple

root = '/data/Datasets/ag/'
actions = []

with open(root + 'annotations/Charades/Charades_v1_train.csv') as f:
    reader = csv.DictReader(f)
    for row in reader:
        video_id = row['id']
        action_string = row['actions'].split(';')
        for action in action_string:
            action_tuple = string_to_action_triple(action, video_id)
            if action_tuple:
                actions.append(action_tuple)



In [40]:
def get_frame_from_time(video_id, time, fps=24):
    frame_number = int(time * fps)
    directory = os.path.join(root, 'frames', video_id + '.mp4')
    return find_closest_frame_idx(directory, frame_number)


In [33]:
actions

[('46GP8', 92, 11.9, 21.2),
 ('46GP8', 147, 0.0, 12.6),
 ('N11GT', 98, 8.6, 14.2),
 ('N11GT', 75, 0.0, 11.7),
 ('N11GT', 127, 0.0, 15.2),
 ('N11GT', 153, 6.4, 12.1),
 ('KRF68', 18, 22.6, 27.8),
 ('KRF68', 141, 4.1, 9.6),
 ('KRF68', 148, 10.3, 25.0),
 ('KRF68', 6, 4.2, 10.9),
 ('KRF68', 2, 6.8, 14.1),
 ('KRF68', 150, 0.0, 9.9),
 ('KRF68', 0, 6.2, 15.3),
 ('MJO7C', 15, 0.0, 32.0),
 ('MJO7C', 107, 0.0, 32.0),
 ('S6MPZ', 9, 0.0, 4.3),
 ('S6MPZ', 11, 0.0, 39.0),
 ('S6MPZ', 15, 0.0, 39.0),
 ('S6MPZ', 19, 0.0, 39.0),
 ('S6MPZ', 156, 0.0, 30.7),
 ('S6MPZ', 59, 0.0, 39.0),
 ('S6MPZ', 61, 0.0, 8.0),
 ('S6MPZ', 61, 5.0, 11.5),
 ('S6MPZ', 17, 0.0, 35.5),
 ('S6MPZ', 63, 0.0, 3.7),
 ('7HVU8', 20, 0.0, 5.6),
 ('7HVU8', 4, 12.4, 31.0),
 ('7HVU8', 1, 0.0, 10.0),
 ('7HVU8', 144, 13.5, 18.9),
 ('7HVU8', 127, 0.0, 5.0),
 ('7HVU8', 127, 1.6, 7.8),
 ('7HVU8', 127, 6.3, 12.3),
 ('MCQO5', 148, 4.2, 13.4),
 ('MCQO5', 106, 29.6, 33.0),
 ('MCQO5', 0, 2.1, 23.5),
 ('MCQO5', 107, 21.4, 33.0),
 ('MCQO5', 2, 1.5, 6.

In [1]:
'''
gets all usable frame-action pairs, where the frame should be the very beginning of the action
threshold: the maximum deviation in seconds between the start time of the action and the frame time
'''
def extract_usable_frames(threshold, fps=24, plot=False):
    distribution = []
    data_list = []

    for video_id, action_class, start_time, end_time in actions:
        frame_idx = get_frame_from_time(video_id, start_time)
        if frame_idx:
            deviation = abs((frame_idx / fps) - start_time)
            if plot:
                distribution.append(deviation)
            if deviation < threshold:
                data_list.append(video_id, frame_idx, action_class)
        else:
            print('no frame found')

    if plot:
        plt.hist(distribution, bins=100, range=(0, 5))
    
    return data_list

data_list = extract_usable_frames(1)
print(data_list)


NameError: name 'actions' is not defined

In [65]:
def pyg_to_pred_tensors(data):
    def convert_graph(data):
        nullary = torch.zeros(len(RELS))
        unary = data.x

        binary = torch.zeros(data.num_nodes, data.num_nodes, len(RELS))
        for i,type in enumerate(data.edge_type):
            binary[data.edge_index[0][i], data.edge_index[1][i], type] = 1
        
        tensors = [nullary, unary, binary]
        tensors = [tensor.unsqueeze(0) for tensor in tensors]

        return tensors

    if data.batch is not None:
        datalist = data.to_data_list()
        tensors_list = [convert_graph(d) for d in datalist]
        nullaries = torch.vstack([t[0] for t in tensors_list])
        unaries = torch.vstack([t[1] for t in tensors_list])
        binaries = torch.vstack([t[2] for t in tensors_list])
        return [nullaries, unaries, binaries]
    else:
        pred_tensors = convert_graph(data)
        return pred_tensors

def show_pyg_graph(graph):
    import networkx as nx
    # Create an empty NetworkX directed graph
    G = nx.DiGraph()
    
    # Add nodes with their features
    for i in range(graph.num_nodes):
        G.add_node(i, label=NODES[graph.node_type[i].item()])
    
    # Add edges with their attributes
    edge_index = graph.edge_index
    edge_type = graph.edge_type if 'edge_type' in graph else None
    for i in range(edge_index.size(1)):
        source, target = edge_index[:, i].tolist()
        if edge_type is not None:
            label = RELS[int(edge_type[i].item())]
            G.add_edge(source, target, label=label)
        else:
            G.add_edge(source, target)
    
    # Draw the graph
    pos = nx.circular_layout(G)
    labels = nx.get_node_attributes(G, 'label')
    edge_labels = nx.get_edge_attributes(G, 'label')
    
    nx.draw(G, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=500, font_size=10, font_color='black', font_weight='bold', arrows=True)
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='red')
