In [35]:
import pandas as pd
import networkx as nx

In [36]:
df = pd.read_csv('news.csv')

In [37]:
df.rename(columns={'Unnamed: 0': 'ID', 'class': 'label'}, inplace=True)
df.drop(['date'], axis=1, inplace=True)

In [38]:
df.columns

Index(['ID', 'news', 'subject', 'month', 'day', 'year', 'label'], dtype='object')

In [39]:
sample = df[0:50]

In [40]:
sample.columns

Index(['ID', 'news', 'subject', 'month', 'day', 'year', 'label'], dtype='object')

In [41]:
sample

Unnamed: 0,ID,news,subject,month,day,year,label
0,0,white house presses congress on bill allowing ...,politics,9,21,2016,1
1,1,china urges cooperation after us brands it a c...,politics,12,19,2017,1
2,2,cleveland school officer placed on leave afte...,politics,1,8,2016,0
3,3,syrian democratic forces say reach deir alzor ...,politics,9,10,2017,1
4,4,senator talks all night as democrats fight tru...,politics,4,5,2017,1
5,5,san francisco just told trump in no uncertain...,politics,12,3,2016,0
6,6,this is clinton’s supreme court plan and it c...,politics,3,28,2016,0
7,7,uk's johnson raises hackles over ww2 'punishme...,politics,1,18,2017,1
8,8,senate blocks democratic plan to expand gun ba...,politics,6,20,2016,1
9,9,mike rowe a lesson on liberty that everyone sh...,politics,7,4,2016,0


In [42]:
sample['label'].value_counts()

label
1    28
0    22
Name: count, dtype: int64

In [43]:
graph = nx.Graph()

In [44]:
from sklearn.feature_extraction.text import TfidfVectorizer

# Initialize TF-IDF vectorizer
tfidf_vectorizer = TfidfVectorizer(stop_words="english")
tfidf_features = tfidf_vectorizer.fit_transform(sample['news'])

for idx, row in sample.iterrows():
    feature_tfidf_scores = tfidf_features[idx].toarray()[0]
    
    # Pair feature indices with their TF-IDF scores
    features_with_scores = list(enumerate(feature_tfidf_scores))
    
    # Sort the features based on their TF-IDF scores in descending order
    top_features = sorted(features_with_scores, key=lambda x: x[1], reverse=True)
    
    # Select the top 100 most prominent features
    necessary_features = top_features[:200]
    
    node_attrs = {
        'content': necessary_features,
        'month': row['month'],
        'day': row['day'],
        'year': row['year'],
    }
    graph.add_node(row['ID'], label=row['label'], **node_attrs)

In [45]:
for node, attrs in graph.nodes(data=True):
    print(f"Node {node}:")
    for attr_name, attr_value in attrs.items():
        print(f"{attr_name} : {attr_value}")

Node 0:
label : 1
content : [(4353, 0.32049218887911557), (1398, 0.24036914165933665), (980, 0.23742463348459217), (2045, 0.23126855153717768), (4465, 0.18501484122974216), (391, 0.16024609443955778), (3158, 0.16024609443955778), (3600, 0.16024609443955778), (2047, 0.1449172128697211), (2150, 0.1449172128697211), (3971, 0.1449172128697211), (3584, 0.14384272329010772), (2830, 0.1387611309223066), (2598, 0.1256051060251096), (1612, 0.11871231674229608), (2856, 0.11871231674229608), (3693, 0.11871231674229608), (4543, 0.10783630218470776), (2608, 0.1033834351724594), (3679, 0.09250742061487108), (3154, 0.09071213128019688), (5, 0.08012304721977889), (60, 0.08012304721977889), (145, 0.08012304721977889), (364, 0.08012304721977889), (1124, 0.08012304721977889), (1248, 0.08012304721977889), (1421, 0.08012304721977889), (1466, 0.08012304721977889), (2293, 0.08012304721977889), (2299, 0.08012304721977889), (2423, 0.08012304721977889), (2536, 0.08012304721977889), (2919, 0.08012304721977889), 

In [46]:
import matplotlib.pyplot as plt
import numpy as np

In [73]:
G1 = graph.copy()
G2 = graph.copy()
G3 = graph.copy() 
G4 = graph.copy()

In [None]:
def display_graph(graph):
    plt.figure(figsize=(8, 6))  # Set the figure size
    
    # Draw nodes
    pos = nx.spring_layout(graph)  # Compute node positions
    nx.draw_networkx_nodes(graph, pos, node_size=500, node_color='skyblue')
    
    # Draw edges
    nx.draw_networkx_edges(graph, pos)
    
    # Draw edge labels
    edge_labels = nx.get_edge_attributes(graph, 'weight')
    nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels)
    
    # Draw node labels
    nx.draw_networkx_labels(graph, pos, font_size=10)
    
    plt.title("Graph Visualization")  # Set the title
    plt.axis('off')  # Turn off axis
    plt.show()