In [1]:
from data_utils import DataLoader
from spectral_mix import SpectralMix
import time
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
import pandas as pd

In [3]:
dl = DataLoader()

test_params = [
    {
        'dataset_name': 'acm',
        'd': 9,
        'k': 3
    },
    {
        'dataset_name': 'dblp',
        'd': 2,
        'k': 3
    },
    {
        'dataset_name': 'flickr',
        'd': 11,
        'k': 7
    },
    {
        'dataset_name': 'imdb',
        'd': 2,
        'k': 3
    }
]

test_results = pd.DataFrame(columns=['dataset', 'nmi', 'ari', 'runtime'])
for reprod_test in test_params:
    dataset_name = reprod_test['dataset_name']
    d = reprod_test['d']
    k = reprod_test['k']

    dataset = dl.load_dataset(dataset_name)
    print(f'=== {dataset_name} ===')
    print(dataset['adjacency_matrix'].shape)
    if not dataset['attribute_matrix'] is None:
        print(dataset['attribute_matrix'].shape)
    print(dataset['true_labels'].shape)

    sm = SpectralMix(adjacency_matrix=dataset['adjacency_matrix'], attribute_matrix=dataset['attribute_matrix'], d=d, k=k)
    begin = time.time()
    sm.fit(run_clustering=False)
    end = time.time()

    labels = sm.predict()
    nmi = normalized_mutual_info_score(dataset['true_labels'], labels)
    ari = adjusted_rand_score(dataset['true_labels'], labels)

    result = []
    result.append(dataset_name)
    result.append(nmi)
    result.append(ari)
    result.append(end - begin)
    test_results.loc[len(test_results)] = result

    print(result)

test_results.to_csv('test_results/reproducability_test.csv')

=== acm ===
(3025, 3025, 2)
(3025, 1870)
(3025,)


100%|██████████| 50/50 [01:18<00:00,  1.57s/it]


['acm', 0.40436245227291967, 0.34390319509953093, 78.53478527069092]
=== dblp ===
(8401, 8401, 4)
(8401,)


100%|██████████| 50/50 [01:58<00:00,  2.37s/it]


['dblp', 0.3501433169417014, 0.25891458888567026, 118.45364308357239]
=== flickr ===
(10364, 10364, 2)
(10364,)


100%|██████████| 50/50 [08:30<00:00, 10.20s/it]


['flickr', 0.4919756635370543, 0.35547623426560704, 510.2245271205902]
=== imdb ===
(3550, 3550, 2)
(3550, 2000)
(3550,)


100%|██████████| 50/50 [00:28<00:00,  1.77it/s]

['imdb', 0.0029286854956732415, 0.002837464313851337, 28.270281314849854]



