In [1]:
from igraph import *
import pickle as pkl
import pandas as pd
import numpy as np
import glob
import re
import databricks.koalas as ks
from scipy.spatial.distance import euclidean
from collections import defaultdict

In [2]:
train_features = "./train_features/"
test_features = "./test_features/"

In [3]:
class BasicGraph(Graph):
    
    def __init__(self, *args, **kwds):
        super(BasicGraph, self).__init__(*args, **kwds)
        
    def concat(self, graph2):
        vs_attributes = self.vs.attributes()
        es_attributes = self.es.attributes()

        vs_attrs = {}
        es_attrs = {}
        
        for attribute in vs_attributes:
            vs_attrs[attribute] = self.vs[attribute] + graph2.vs[attribute]
        
        for attribute in es_attributes:
            es_attrs[attribute] = self.es[attribute] + graph2.es[attribute]
        
        aux = self + graph2
        
        aux.set_vs_attributes(vs_attrs)
        aux.set_es_attributes(es_attrs)
        
        return aux
        
    def set_vs_attributes(self, attr):
        for key, value in attr.items():
            self.vs[key] = value
    
    def set_es_attributes(self, attr):
        for key, value in attr.items():
            self.es[key] = value
    
    def mst(self, weights = None):
        if weights is not None:
            return self.spanning_tree(weights=self.es["weight"])
        return self.spanning_tree(weights=weights)

In [4]:
class ReadDataframe():
    
    def __init__(self, features_file = "./train_features/"):
        self.features_file = features_file
    
    def _sortKeyFunc(self, s):
        out = int(re.findall("[0-9]+", s)[0])
        return out
    
    def _print_while_reads(self, indice, total):
        sys.stdout.write("Dataframes lidos: %f %%  \r" % ((indice/total)*100) )
        sys.stdout.flush()
    
    def _get_batch_dfs(self):
        files = sorted(glob.glob(self.features_file+"*"), key=self._sortKeyFunc)
        for file in files:
            yield pd.read_pickle(file)
        
    def classes_counts(self):
        unique_counts = defaultdict(int)
        for df in self._get_batch_dfs():
            keys, values = np.unique(df.global_class, return_counts = True)
            for key, value in zip(keys, values):
                unique_counts[key] += value
        
        return dict(unique_counts)
    
    def bboxes_counts(self):
        lengths = np.empty((0)).astype(int)
        for df in self._get_batch_dfs():
            lengths = np.append(lengths, df.applymap(self._count_features).sum(axis=1).to_numpy().astype(int))
        
        return lengths
    
    def global_classes(self):
        global_classes = np.empty((0))
        for df in self._get_batch_dfs():
            global_classes = np.append(global_classes, df.global_class.to_numpy())
        
        return global_classes
    
    def get_global_names(self):
        self.classes = self.classes_counts()
        classes_list = []
        for key, value in self.classes.items():
            for i in range(value):
                classes_list.append(key + "_" + str(i))
        
        return classes_list
    
    def _count_features(self, data):
        if isinstance(data, np.ndarray):
            return 1
        return 0
    
    def get_bbox_names(self):
        classes_list = []
        
        lengths = self.bboxes_counts()
        global_classes = self.global_classes()
        
        relative_ids = np.empty((0)).astype(int)
        for count in self.classes.values():
            relative_ids =  np.append(relative_ids, np.arange(count).astype(int))
        
        for global_class, length, relative_id in zip(global_classes, lengths, relative_ids):
            for i in range(length):
                classes_list.append(global_class + "_" + str(relative_id) + "_" + str(i))
                
        return classes_list
                
    def get_names(self):
        names = []
        names.extend(self.get_global_names())
        names.extend(self.get_bbox_names())
        
        return np.asarray(names)
    
    
    def _return_only_features(self, data):
        if isinstance(data, np.ndarray):
            return data
        return None

    def get_global_features(self):
        features = []
        for df in self._get_batch_dfs():
            for feature in df.global_feature.to_numpy():
                features.extend(feature)
            
        return features
        
    def get_bbox_features(self):
        features = []
        for df in self._get_batch_dfs():
            ff = df.applymap(self._return_only_features).to_numpy()
            for f in ff:
                for feature in f:
                    if feature is not None:
                        features.extend(feature)
        
        return features
    
    def get_features(self):
        features = []
        features.extend(self.get_global_features())
        features.extend(self.get_bbox_features())

        return np.asarray(features)

In [5]:
class CreateGraph(ReadDataframe, BasicGraph):
    
    def __init__(self, features_file = "./train_features/"):
        self.features_file = features_file
        ReadDataframe.__init__(self, features_file)
        BasicGraph.__init__(self)
        
    def build(self):
        g = BasicGraph.Full(0)
        self.classes = self.classes_counts()
        for classe_count in self.classes.values():
            g = g + BasicGraph.Full(classe_count)
        
        self.lengths = self.bboxes_counts()
        for length in self.lengths:
            g = g + BasicGraph.Full(length)
        
        g.vs["name"] = self.get_names()
        g.vs["feature"] = self.get_features()
        
        self.g = g
        
    def auto_connect(self):
        names = self.get_global_names()
    
        for name, length in zip(names, self.lengths):
            global_indice = self.g.vs.find(name=name).index
            
            bbox_indice = self.g.vs.find(name=name + "_0").index
            self.g.add_edges([(global_indice, actual_indice) for actual_indice in range(bbox_indice, bbox_indice+length)])
        
    def mst(self, weights = None):
        return self.g.mst(weights)
    
    def set_weights(self):
        for indice, pair in enumerate(self.g.get_edgelist()):
            features = self.g.vs[pair]["feature"]
            self.g.es[indice]["weight"] = euclidean(features[0].reshape(-1), features[1].reshape(-1))
    

In [6]:
train_graph = CreateGraph()
train_graph.build()
train_graph.auto_connect()
train_graph.set_weights()
train_graph = train_graph.mst()

In [11]:
train_graph.summary()

'IGRAPH UNW- 45275 45222 -- \n+ attr: feature (v), name (v), weight (e)'

In [7]:
test_graph = CreateGraph(test_features)
test_graph.build()
test_graph.auto_connect()
test_graph.set_weights()
test_graph = test_graph.mst()

In [8]:
test_graph.summary()

'IGRAPH UNW- 10807 10754 -- \n+ attr: feature (v), name (v), weight (e)'

In [9]:
graph = train_graph.concat(test_graph)

In [10]:
graph.summary()

'IGRAPH UNW- 56082 55976 -- \n+ attr: feature (v), name (v), weight (e)'

In [12]:
with open("graph.pkl", "wb") as f:
    graph.write_pickle(f)