In [1]:
import sankey as sk
import pandas as pd
from collections import Counter
import json

# read in the data and convert to json
df_art = pd.read_csv('clean_articles.csv')
df_art.word_counts = df_art.word_counts.apply(lambda x: x[8:-1].replace("'", "\""))
df_art.word_counts = df_art.word_counts.apply(json.loads)
df_art.head()


Unnamed: 0,title,subreddit,word_counts
0,Meta's threat to close down Facebook and Insta...,nottheonion,"{'data': 14, 'transfer': 7, 'european': 6, 'wi..."
1,Pregnant Texas woman driving in HOV lane told ...,nottheonion,"{'said': 7, 'officer': 6, 'according': 5, 'cit..."
2,Mark Zuckerberg Says Meta Employees “Lovingly”...,nottheonion,"{'given': 2, 'employee': 2, 'energy': 2, 'news..."
3,Police didn't immediately confront the gunman ...,nottheonion,"{'school': 5, 'shooting': 5, 'gunman': 5, 'off..."
4,Shaquille O'Neal says gorillas freak out when ...,nottheonion,"{'gorilla': 9, 'zoo': 5, 'look': 5, 'story': 4..."


In [2]:
def get_total_counts(df, subreddit):
    """ Get the total counts for a given subreddit """
    total_count = Counter()
    for d in df.query(f'subreddit == "{subreddit}"').word_counts:
        total_count = total_count + Counter(d)
    return total_count

# get the total counts for each subreddit
counts = {
'nottheonion': get_total_counts(df_art, 'nottheonion'),
'TheOnion': get_total_counts(df_art, 'TheOnion')
}

In [3]:
# use sankey library from DS3500

file_dict = {}

# covert wordcount dictionary into dataframe with top 15 words from each subreddit
for sub in counts.keys():
    word_count = counts[sub]
    word_count = dict(word_count.most_common(15))
    file_dict[sub] = word_count

    df_wordcount = pd.DataFrame(
        [(k, i, j) for k, v in file_dict.items() for i, j in v.items()],
        columns=['files', 'words', 'word_count']
    )
    
fig = sk.make_sankey(df_wordcount, 'files', 'words', vals='word_count', pad=10, color=[
            'skyblue', 'slateblue', 'slategray', 'snow',
            'springgreen', 'steelblue', 'tan', 'teal', 'thistle', 'tomato',
            'turquoise', 'violet', 'wheat', 'whitesmoke', 'yellow']*8)

        
fig.update_layout(
    autosize=False,
    width=800,
    height=800,)
fig.show()
fig.write_image("topwords-sankey.png")