In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import matplotlib.pyplot as plt
import seaborn as sns

# Initialize Spark session
spark = SparkSession.builder.appName("ArticleCount").getOrCreate()

# Load CSV into a Spark DataFrame (use the correct file path)
file_path = r"C:\Users\96659\Desktop\python_ws\article_count_per_publisher_per_year.csv"
sample_df = spark.read.option("header", "true").option("delimiter", ";").csv(file_path)

# Assuming the publisher is the 2nd column and the year is the 1st column
publisher_column_name = "publisher0"
year_column_name = "year"
article_count_column = "article_count"

# Convert article count to integer for proper aggregation
sample_df = sample_df.withColumn(article_count_column, F.col(article_count_column).cast("int"))

# Group by publisher and year, then count articles
articles_per_publisher = sample_df.groupBy(year_column_name, publisher_column_name).agg(
    F.sum(article_count_column).alias("article_count")
)

# Find the publisher with the highest number of articles each year
window_spec = pyspark.sql.Window.partitionBy(year_column_name).orderBy(F.desc("article_count"))
highest_articles_per_year = articles_per_publisher.withColumn(
    "rank", F.rank().over(window_spec)
).filter(F.col("rank") == 1).drop("rank")

# Convert to Pandas for visualization
highest_articles_per_year_pd = highest_articles_per_year.toPandas()

# Plot the results
plt.figure(figsize=(12, 6))
sns.barplot(data=highest_articles_per_year_pd, x=year_column_name, y='article_count', hue=publisher_column_name)
plt.title('Publisher with Highest Number of Articles Per Year')
plt.xlabel('Year')
plt.ylabel('Number of Articles')
plt.xticks(rotation=45)
plt.legend(title=publisher_column_name)
plt.show()
