In [17]:
from pygraphblas import *
from pygraphblas.demo.gviz import draw, draw_op
import pygraphblas.descriptor
import csv
import sys



In [24]:
#Load data from CSV format
class DataLoader:
    
    def __init__(self, path):
        self.path = path
        
    def load_node(self, filename):
        filename = self.path + filename
        with open(filename, newline='') as csvfile:
            reader = csv.DictReader(csvfile, delimiter=',', quotechar='"')
            original_ids = [row['id:ID'] for row in reader]
            id_mapping = {}
            for index in range(len(original_ids)):
                id_mapping[original_ids[index]] = index
            
        return original_ids, id_mapping

    def load_edge(self, filename, start_mapping, end_mapping, typ=BOOL, drop_dangling_edges=False):
        filename = self.path + filename
        with open(filename, newline='') as csvfile:
            reader = csv.DictReader(csvfile, delimiter=',', quotechar='"')
            row_ids = []
            col_ids = []
            values = []
            for row in reader:
                start_id = row['id:START_ID']
                end_id = row['id:END_ID']
                if not drop_dangling_edges or (start_id in start_mapping and end_id in end_mapping):
                    row_ids.append(start_mapping[start_id])
                    col_ids.append(end_mapping[end_id])
                    values.append(1)
        
            edge_matrix = Matrix.from_lists(
            row_ids,
            col_ids,
            values,
            nrows=len(start_mapping), 
            ncols=len(end_mapping), 
            typ=typ)
            return edge_matrix


In [105]:
'''
Check results here:
https://github.com/ftsrg/trainbenchmark/blob/master/trainbenchmark-tool/src/main/java/hu/bme/mit/trainbenchmark/benchmark/test/TrainBenchmarkTest.java#L139 
'''

path = 'trainbenchmark-repair-models-csv/'
loader = DataLoader(path)
data_size = 2

vertices = {}
mapping = {}
vertices['Route'], mapping['Route'] = loader.load_node(f'railway-repair-{data_size}-Route.csv')
vertices['SwitchPosition'], mapping['SwitchPosition'] = loader.load_node(f'railway-repair-{data_size}-SwitchPosition.csv')
vertices['Switch'], mapping['Switch'] = loader.load_node(f'railway-repair-{data_size}-Switch.csv')
vertices['Sensor'], mapping['Sensor'] = loader.load_node(f'railway-repair-{data_size}-Sensor.csv')
vertices['Segment'], mapping['Segment'] = loader.load_node(f'railway-repair-{data_size}-Segment.csv')


for vertex in vertices:
    print(f"dimension of {vertex} is {len(vertices[vertex])}")

matrices = {}
matrices['follows'] = loader.load_edge(f'railway-repair-{data_size}-follows.csv', mapping['Route'], mapping['SwitchPosition'])
matrices['target'] = loader.load_edge(f'railway-repair-{data_size}-target.csv', mapping['SwitchPosition'], mapping['Switch'])
matrices['monitoredBySwitch'] = loader.load_edge(f'railway-repair-{data_size}-monitoredBy.csv', mapping['Switch'], mapping['Sensor'], drop_dangling_edges=True)
matrices['monitoredBySegment'] = loader.load_edge(f'railway-repair-{data_size}-monitoredBy.csv', mapping['Segment'], mapping['Sensor'], drop_dangling_edges=True)
matrices['requires'] = loader.load_edge(f'railway-repair-{data_size}-requires.csv', mapping['Route'], mapping['Sensor'])
matrices['connectsTo'] = loader.load_edge(f'railway-repair-{data_size}-connectsTo.csv', mapping['Segment'], mapping['Segment'], drop_dangling_edges=True)

selected_sensor_id = mapping['Sensor']['1692']

def route_sensor_violation_query(matrices):
    route_to_switch = matrices['follows'] @ matrices['target']
    route_to_sensor = route_to_switch @ matrices['monitoredBySwitch']
    return route_to_sensor.eadd(matrices['requires'], add_op=MINUS)
    # Zeros are included in the result as well

def connected_segments_query(matrices, vertices):
    monitoredBySegmentTransposed = matrices['monitoredBySegment'].transpose()
    res = monitoredBySegmentTransposed.dup()
    for _ in range(5):
        I, J, V = res[[selected_sensor_id],:].to_lists()
#         print([vertices['Segment'][j] for j in J])
        res = res.mxm(matrices['connectsTo'], mask=monitoredBySegmentTransposed)
        
        
    
    I, J, V = res.to_lists()
    print([vertices['Sensor'][i] for i in I])

    #print(res.to_string(format_string='{:>4}'))
# def switch_monitored_query(matrices):
#     return matrices['monitoredBy']

connected_segments_query(matrices, vertices)
# route_sensor_violations_result = route_sensor_violation_query(matrices)
# route_sensor_violations_result.to_string()
# result.to_string()



dimension of Route is 10
dimension of SwitchPosition is 67
dimension of Switch is 67
dimension of Sensor is 310
dimension of Segment is 1564
['6', '121', '128', '689', '934', '1095', '1208', '1319', '1338', '1692', '1705', '1775', '1922', '2019']


In [9]:
follows_mx = load_edge(small_data_set['follows'], route_mapping, swp_mapping)
target_mx = load_edge(small_data_set['target'], swp_mapping, sw_mapping)
monitored_by_mx = load_edge(small_data_set['monitoredBy'], sw_mapping, sen_mapping, drop_dangling_edges=True)
requires_mx = load_edge(small_data_set['requires'], route_mapping, sen_mapping)
monitored_by_mx.to_string()
#print_matrix(monitored_by_mx)

    0 1 2 3 4 5 6 7 8 9101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
 0| 1 1 1 1   1 1                                                                                                                                                     