In [1]:
import glob
import os

import pandas as pd
import sklearn
import tqdm
import numpy as np
from sklearn.metrics import normalized_mutual_info_score as nmi_score
import networkx

import matplotlib
from plotnine import *



In [11]:
class BSCresults(object):

    def __init__(self, data_dir):

        self.data_dir = data_dir
        self.file_list_summary = glob.glob(os.path.join(self.data_dir,'*summary.csv'))
        self.file_list_users = glob.glob(os.path.join(self.data_dir,'*users.csv'))
        self.file_list_hashtags = glob.glob(os.path.join(self.data_dir,'*hashtags.csv'))

        # coi := cluster of interest
        self.coi = []

    def read_data(self, min_user=10):

        self.min_user = min_user

        # subset the values of the file list to only those of min user we want.
        self.eval_list_summary = [file for file in self.file_list_summary if int(file.split('_')[-3])==min_user]
        self.eval_list_users = [file for file in self.file_list_users if int(file.split('_')[-3])==min_user]
        self.eval_list_hashtags = [file for file in self.file_list_hashtags if int(file.split('_')[-3])==min_user]

        # sort in order of cluster number
        self.eval_list_summary.sort(key=lambda x: int(x.split('_')[-2]))
        self.eval_list_users.sort(key=lambda x: int(x.split('_')[-2]))
        self.eval_list_hashtags.sort(key=lambda x: int(x.split('_')[-2]))

        self.data_summary = []
        self.data_users = []
        self.data_hashtags = []

        for file in self.eval_list_summary:
            self.data_summary.append(pd.read_csv(file, usecols = ['cluster','count']))
        for file in self.eval_list_users:
            self.data_users.append(pd.read_csv(file, usecols = ['ID','degree','topic_cluster']))
        for file in self.eval_list_hashtags:
            self.data_hashtags.append(pd.read_csv(file, usecols = ['hashtag','degree','topic_cluster']))


    def eval_nmi(self, min_user=10, shift=10):

        self.min_user=min_user
        self.read_data(self.min_user)
        self.shift=shift

        res_users = []
        for index, value in enumerate(self.data_users[shift:]):
            for i in range(1,shift):
                temp = []
                temp.append(nmi_score(value['topic_cluster'],self.data_users[index+shift-i]['topic_cluster']))
            res_users.append(np.mean(temp))

        self.user_eval_res = res_users
        arr = np.array(res_users)
        self.max_index_best_users = np.where(arr == np.amax(arr))[0][0] + shift
        self.best_cluster_users = np.max(self.data_users[self.max_index_best_users]['topic_cluster'])

        res_hashtags = []
        for index, value in enumerate(self.data_hashtags[shift:]):
            for i in range(1,shift):
                temp = []
                temp.append(nmi_score(value['topic_cluster'],self.data_hashtags[index+shift-i]['topic_cluster']))
            res_hashtags.append(np.mean(temp))

        self.hashtag_eval_res = res_hashtags
        arr = np.array(res_hashtags)
        self.max_index_best_hashtags = np.where(arr == np.amax(arr))[0][0] + shift
        self.best_cluster_hashtags = np.max(self.data_hashtags[self.max_index_best_hashtags]['topic_cluster'])


        # merge summary and user data to get cluster numbers


        # omit one user clusters


        return ((res_users, self.max_index_best_users, self.best_cluster_users),(res_hashtags, self.max_index_best_hashtags, self.best_cluster_hashtags))

    def plot(self):

        start_val = int(self.file_list_users[0].split('_')[-3])+self.shift
        data_userplot = {
            'Cluster Number': list(range(start_val, len(self.data_users[self.shift:])+start_val)),
            'NMI Score': self.user_eval_res
        }
        data_userplot = pd.DataFrame(data_userplot, columns=['Cluster Number', 'NMI Score'])

        start_val = int(self.file_list_users[0].split('_')[-3])+self.shift
        data_hashplot = {
            'Cluster Number': list(range(start_val, len(self.data_hashtags[self.shift:])+start_val)),
            'NMI Score': self.hashtag_eval_res
        }
        data_hashplot = pd.DataFrame(data_hashplot, columns=['Cluster Number', 'NMI Score'])

        self.userplot = ggplot(data_userplot) \
                    + aes(x="Cluster Number", y="NMI Score") \
                    + geom_line() \
                    + labs(title = "User Clusters") 
        # self.userplot
        self.hashplot = ggplot(data_hashplot) \
                    + aes(x="Cluster Number", y="NMI Score") \
                    + geom_line() \
                    + labs(title = "Phrase Clusters")
        # self.hashplot

        self.userplot.save(os.path.join(self.data_dir,"bsc_user_eval.png"), dpi=600)
        self.hashplot.save(os.path.join(self.data_dir,"bsc_hashtags_eval.png"), dpi=600)

        # user number per cluster distribution
        self.best_summary = self.data_summary[self.max_index_best_users]
        self.best_users = self.data_users[self.max_index_best_users]
        self.best_hashtags = self.data_hashtags[self.max_index_best_users]
 
        self.userdistplot = ggplot(self.best_summary) \
            + aes(x="cluster", y="count") \
            + geom_bar(stat='identity') \
            + labs(title = "User Numbers per Cluster for best cluster".format(self.best_cluster_users))
        self.userdistplot.save(os.path.join(self.data_dir,"bsc_best_n_user_dist.png"), dpi=600)


    def add_coi(self, added):

        # coi = cluster of interest. to be manually entered after examination of clustering outputs.
        self.add_coi = 0


In [12]:
bscres = BSCresults('/Users/hubert/Nextcloud/DPhil/DPhil_Studies/2021-04_Study_A_Diffusion/collection_results_2021_05_04_16_22/bsc')

In [13]:
results = bscres.eval_nmi()
print(results[0][1:])
print(results[1][1:])

(59, 69)
(30, 40)


In [14]:
bscres.plot()



In [302]:
best_summary = bscres.data_summary[bscres.best_cluster_users]
best_summary[best_summary['count']==1]
print(np.max(best_summary['cluster']))

157


In [300]:
best_summary.describe()

Unnamed: 0,cluster,count
count,153.0,153.0
mean,79.437908,15.647059
std,45.662064,67.743175
min,1.0,1.0
25%,40.0,2.0
50%,80.0,4.0
75%,119.0,9.0
max,157.0,781.0
