<a href="https://colab.research.google.com/github/marioaloam-00/practica3-cloudrun/blob/main/Patentspy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
Patent Data Analysis on Dataproc Cluster cluster-aa01
Run with: spark-submit --master yarn patent_analysis.py
"""

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend for cluster
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import os
import gc
warnings.filterwarnings('ignore')

# Set up plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)

class PatentDataAnalyzerSpark:
    def __init__(self):
        """
        Initialize Spark session for Dataproc cluster
        """
        self.spark = None
        self.df = None
        self.total_rows = 0

    def create_spark_session(self):
        """Create and configure Spark session for Dataproc"""
        print("Initializing Spark session on cluster-aa01...")

        self.spark = SparkSession.builder \
            .appName("PatentDataAnalysis") \
            .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
            .config("spark.sql.adaptive.enabled", "true") \
            .config("spark.sql.shuffle.partitions", "200") \
            .config("spark.dynamicAllocation.enabled", "true") \
            .config("spark.dynamicAllocation.minExecutors", "2") \
            .config("spark.dynamicAllocation.maxExecutors", "10") \
            .config("spark.executor.memory", "8g") \
            .config("spark.driver.memory", "4g") \
            .config("spark.yarn.queue", "default") \
            .getOrCreate()

        # Set log level
        self.spark.sparkContext.setLogLevel("WARN")

        print(f"Spark session created successfully")
        print(f"Spark version: {self.spark.version}")
        print(f"Master URL: {self.spark.sparkContext.master}")

        return self.spark

    def load_all_data(self):
        """Load ALL CSV files from GCS bucket"""
        bucket_path = "gs://patentbucket-maam/Data"
        file_pattern = "data-*.csv"

        print(f"Loading data from: {bucket_path}/{file_pattern}")

        # Define schema for better performance
        schema = StructType([
            StructField("patent_id", StringType(), True),
            StructField("type", StringType(), True),
            StructField("patent_number", StringType(), True),
            StructField("patent_country", StringType(), True),
            StructField("patent_date", DateType(), True),
            StructField("abstract", StringType(), True),
            StructField("title", StringType(), True),
            StructField("kind", StringType(), True),
            StructField("num_claims", IntegerType(), True),
            StructField("filename", StringType(), True),
            StructField("withdrawn", StringType(), True),
            StructField("classification_uuid", StringType(), True),
            StructField("classification_level", StringType(), True),
            StructField("section", StringType(), True),
            StructField("ipc_class", StringType(), True),
            StructField("subclass", StringType(), True),
            StructField("main_group", StringType(), True),
            StructField("subgroup", StringType(), True),
            StructField("symbol_position", StringType(), True),
            StructField("classification_value", StringType(), True),
            StructField("classification_status", StringType(), True),
            StructField("classification_data_source", StringType(), True),
            StructField("action_date", DateType(), True),
            StructField("ipc_version_indicator", StringType(), True),
            StructField("sequence", IntegerType(), True),
            StructField("assignee_id", StringType(), True),
            StructField("location_id", StringType(), True),
            StructField("city", StringType(), True),
            StructField("state", StringType(), True),
            StructField("assignee_country", StringType(), True),
            StructField("latitude", DoubleType(), True),
            StructField("longitude", DoubleType(), True),
            StructField("county", StringType(), True),
            StructField("state_fips", StringType(), True),
            StructField("county_fips", StringType(), True)
        ])

        # Read all CSV files
        self.df = self.spark.read \
            .option("header", "true") \
            .option("inferSchema", "false") \
            .schema(schema) \
            .csv(f"{bucket_path}/{file_pattern}")

        self.total_rows = self.df.count()

        print(f"✓ Successfully loaded data")
        print(f"✓ Total records: {self.total_rows:,}")
        print(f"✓ Total partitions: {self.df.rdd.getNumPartitions()}")
        print(f"✓ Total columns: {len(self.df.columns)}")

        # Register as temp view for SQL queries
        self.df.createOrReplaceTempView("patents")

        return self.df

    def basic_info(self):
        """Display basic dataset information"""
        print("\n" + "=" * 80)
        print("BASIC DATASET INFORMATION")
        print("=" * 80)

        # Cache for multiple operations
        self.df.cache()

        print(f"\nDataset Dimensions:")
        print(f"Rows: {self.total_rows:,}")
        print(f"Columns: {len(self.df.columns)}")

        # Data types
        print(f"\nData Types:")
        schema_dict = {}
        for field in self.df.schema.fields:
            dtype = str(field.dataType)
            schema_dict[dtype] = schema_dict.get(dtype, 0) + 1

        for dtype, count in schema_dict.items():
            print(f"  {dtype}: {count} columns")

        # Missing values analysis
        print(f"\nMissing Values Analysis (Top 20 columns):")

        # Calculate missing percentages
        missing_expr = [
            (count(when(col(c).isNull(), c)) / self.total_rows * 100).alias(c)
            for c in self.df.columns
        ]

        missing_df = self.df.select(missing_expr).collect()[0]
        missing_dict = missing_df.asDict()

        # Convert to pandas for sorting and display
        missing_pd = pd.DataFrame.from_dict(missing_dict, orient='index', columns=['Missing_Percentage'])
        missing_pd = missing_pd.sort_values('Missing_Percentage', ascending=False)

        print(missing_pd.head(20))

        # Save missing values to CSV
        missing_pd.to_csv('missing_values_summary.csv')
        print(f"\n✓ Missing values summary saved to 'missing_values_summary.csv'")

        # Plot missing values
        self._plot_missing_values(missing_pd.head(20))

    def _plot_missing_values(self, missing_pd):
        """Plot missing values"""
        plt.figure(figsize=(12, 8))
        missing_sorted = missing_pd.sort_values('Missing_Percentage', ascending=True)

        bars = plt.barh(range(len(missing_sorted)), missing_sorted['Missing_Percentage'])
        plt.yticks(range(len(missing_sorted)), missing_sorted.index)
        plt.xlabel('Missing Percentage (%)')
        plt.title('Top 20 Columns with Missing Values')
        plt.xlim(0, 100)

        # Add value labels
        for i, (idx, row) in enumerate(missing_sorted.iterrows()):
            plt.text(row['Missing_Percentage'] + 1, i,
                    f'{row["Missing_Percentage"]:.1f}%', va='center')

        plt.tight_layout()
        plt.savefig('missing_values_plot.png', dpi=300, bbox_inches='tight')
        print(f"✓ Missing values plot saved to 'missing_values_plot.png'")
        plt.close()

    def patent_date_analysis(self):
        """Analyze patent date information"""
        print("\n" + "=" * 80)
        print("PATENT DATE ANALYSIS")
        print("=" * 80)

        # Date range and completeness
        date_stats = self.df.select(
            min("patent_date").alias("min_date"),
            max("patent_date").alias("max_date"),
            count("patent_date").alias("non_null_dates"),
            count("*").alias("total_records")
        ).collect()[0]

        print(f"Patent Date Range:")
        print(f"Min Date: {date_stats['min_date']}")
        print(f"Max Date: {date_stats['max_date']}")

        completeness = (date_stats['non_null_dates'] / date_stats['total_records']) * 100
        print(f"\nPatent Date Completeness:")
        print(f"Non-NULL records: {date_stats['non_null_dates']:,}")
        print(f"NULL records: {date_stats['total_records'] - date_stats['non_null_dates']:,}")
        print(f"Completeness: {completeness:.2f}%")

        # Year distribution
        yearly_stats = self.df.filter(col("patent_date").isNotNull()) \
            .withColumn("year", year("patent_date")) \
            .groupBy("year") \
            .agg(count("*").alias("patent_count")) \
            .orderBy("year")

        # Convert to pandas for plotting
        yearly_pd = yearly_stats.toPandas()

        # Plot year distribution
        self._plot_year_distribution(yearly_pd)

        # Save yearly stats
        yearly_pd.to_csv('yearly_patent_distribution.csv', index=False)
        print(f"\n✓ Yearly distribution saved to 'yearly_patent_distribution.csv'")

    def _plot_year_distribution(self, yearly_pd):
        """Plot year distribution of patents"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))

        # Year distribution bar plot
        axes[0, 0].bar(yearly_pd['year'], yearly_pd['patent_count'], alpha=0.7)
        axes[0, 0].set_xlabel('Year')
        axes[0, 0].set_ylabel('Number of Patents')
        axes[0, 0].set_title('Patent Distribution by Year')
        axes[0, 0].tick_params(axis='x', rotation=45)
        axes[0, 0].grid(True, alpha=0.3)

        # Cumulative distribution
        yearly_pd['cumulative'] = yearly_pd['patent_count'].cumsum()
        axes[0, 1].plot(yearly_pd['year'], yearly_pd['cumulative'], 'b-', linewidth=2)
        axes[0, 1].set_xlabel('Year')
        axes[0, 1].set_ylabel('Cumulative Patents')
        axes[0, 1].set_title('Cumulative Patent Count')
        axes[0, 1].grid(True, alpha=0.3)

        # Decade analysis
        yearly_pd['decade'] = (yearly_pd['year'] // 10) * 10
        decade_counts = yearly_pd.groupby('decade')['patent_count'].sum()
        axes[1, 0].pie(decade_counts.values, labels=decade_counts.index, autopct='%1.1f%%')
        axes[1, 0].set_title('Patent Distribution by Decade')

        # Monthly trend
        monthly_df = self.df.filter(col("patent_date").isNotNull()) \
            .withColumn("month", month("patent_date")) \
            .groupBy("month") \
            .agg(count("*").alias("patent_count")) \
            .orderBy("month")

        monthly_pd = monthly_df.toPandas()
        axes[1, 1].plot(monthly_pd['month'], monthly_pd['patent_count'], 'g-', linewidth=2, marker='o')
        axes[1, 1].set_xlabel('Month')
        axes[1, 1].set_ylabel('Average Patents')
        axes[1, 1].set_title('Average Monthly Patent Count')
        axes[1, 1].set_xticks(range(1, 13))
        axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig('patent_date_analysis.png', dpi=300, bbox_inches='tight')
        print(f"✓ Patent date analysis plots saved to 'patent_date_analysis.png'")
        plt.close()

        # Print top years
        top_years = yearly_pd.nlargest(10, 'patent_count')
        print(f"\nTop 10 Years by Patent Count:")
        for idx, row in top_years.iterrows():
            percentage = (row['patent_count'] / self.total_rows) * 100
            print(f"  {row['year']}: {row['patent_count']:,} patents ({percentage:.2f}%)")

    def geographic_analysis(self):
        """Analyze geographic distribution"""
        print("\n" + "=" * 80)
        print("GEOGRAPHIC ANALYSIS")
        print("=" * 80)

        # Country analysis
        country_stats = self.df.groupBy("assignee_country") \
            .agg(count("*").alias("patent_count")) \
            .orderBy(col("patent_count").desc())

        country_pd = country_stats.limit(20).toPandas()

        # Plot top countries
        self._plot_top_countries(country_pd)

        # Save country stats
        country_stats.coalesce(1).write \
            .mode("overwrite") \
            .option("header", "true") \
            .csv("country_stats")

        print(f"\n✓ Country statistics saved to 'country_stats/' directory")

        # City and state analysis
        geo_columns = ['city', 'state', 'county']
        for col_name in geo_columns:
            if col_name in self.df.columns:
                col_stats = self.df.groupBy(col_name) \
                    .agg(count("*").alias("count")) \
                    .orderBy(col("count").desc()) \
                    .limit(10)

                print(f"\nTop 10 {col_name.upper()} values:")
                for row in col_stats.collect():
                    print(f"  {row[col_name]}: {row['count']:,}")

    def _plot_top_countries(self, country_pd):
        """Plot top countries by patent count"""
        plt.figure(figsize=(14, 8))
        bars = plt.barh(range(len(country_pd)), country_pd['patent_count'], alpha=0.7)
        plt.yticks(range(len(country_pd)), country_pd['assignee_country'])
        plt.xlabel('Number of Patents')
        plt.title('Top 20 Countries by Patent Count')
        plt.gca().invert_yaxis()
        plt.grid(True, alpha=0.3, axis='x')

        # Add value labels
        for i, (idx, row) in enumerate(country_pd.iterrows()):
            plt.text(row['patent_count'] + max(country_pd['patent_count'])*0.01,
                    i, f'{row["patent_count"]:,}', va='center', fontsize=9)

        plt.tight_layout()
        plt.savefig('top_countries.png', dpi=300, bbox_inches='tight')
        print(f"✓ Top countries plot saved to 'top_countries.png'")
        plt.close()

        # Print top countries
        print(f"\nTop 10 Countries by Patent Count:")
        for idx, row in country_pd.head(10).iterrows():
            percentage = (row['patent_count'] / self.total_rows) * 100
            print(f"  {row['assignee_country']}: {row['patent_count']:,} ({percentage:.2f}%)")

    def classification_analysis(self):
        """Analyze patent classifications"""
        print("\n" + "=" * 80)
        print("CLASSIFICATION ANALYSIS")
        print("=" * 80)

        classification_cols = ['section', 'ipc_class', 'subclass', 'main_group', 'subgroup']

        for col_name in classification_cols:
            if col_name in self.df.columns:
                col_stats = self.df.groupBy(col_name) \
                    .agg(count("*").alias("count")) \
                    .orderBy(col("count").desc()) \
                    .limit(10)

                print(f"\nTop 10 {col_name.upper()} values:")
                for row in col_stats.collect():
                    print(f"  {row[col_name]}: {row['count']:,}")

        # IPC section analysis
        if 'section' in self.df.columns:
            section_stats = self.df.groupBy("section") \
                .agg(count("*").alias("patent_count")) \
                .orderBy(col("patent_count").desc())

            section_pd = section_stats.toPandas()

            # Plot section distribution
            plt.figure(figsize=(10, 6))
            bars = plt.bar(range(len(section_pd)), section_pd['patent_count'], alpha=0.7)
            plt.xticks(range(len(section_pd)), section_pd['section'])
            plt.xlabel('IPC Section')
            plt.ylabel('Number of Patents')
            plt.title('IPC Section Distribution')
            plt.grid(True, alpha=0.3, axis='y')

            # Add value labels
            for i, (idx, row) in enumerate(section_pd.iterrows()):
                plt.text(i, row['patent_count'] + max(section_pd['patent_count'])*0.01,
                        f'{row["patent_count"]:,}', ha='center', fontsize=9)

            plt.tight_layout()
            plt.savefig('ipc_section_distribution.png', dpi=300, bbox_inches='tight')
            print(f"\n✓ IPC section distribution saved to 'ipc_section_distribution.png'")
            plt.close()

            # Save section stats
            section_pd.to_csv('ipc_section_stats.csv', index=False)
            print(f"✓ IPC section statistics saved to 'ipc_section_stats.csv'")

    def numerical_analysis(self):
        """Analyze numerical columns"""
        print("\n" + "=" * 80)
        print("NUMERICAL ANALYSIS")
        print("=" * 80)

        numerical_cols = ['num_claims', 'latitude', 'longitude', 'sequence']
        available_cols = [col for col in numerical_cols if col in self.df.columns]

        for col_name in available_cols:
            print(f"\n{col_name.upper()} Statistics:")

            # Calculate statistics using Spark
            stats = self.df.select(
                count(col_name).alias("count"),
                mean(col_name).alias("mean"),
                stddev(col_name).alias("stddev"),
                min(col_name).alias("min"),
                max(col_name).alias("max"),
                percentile_approx(col_name, 0.5).alias("median"),
                percentile_approx(col_name, 0.25).alias("q1"),
                percentile_approx(col_name, 0.75).alias("q3")
            ).collect()[0]

            print(f"  Count: {stats['count']:,}")
            print(f"  Mean: {stats['mean']:.2f}")
            print(f"  Std Dev: {stats['stddev']:.2f}")
            print(f"  Min: {stats['min']}")
            print(f"  Max: {stats['max']}")
            print(f"  Median: {stats['median']}")
            print(f"  Q1: {stats['q1']}")
            print(f"  Q3: {stats['q3']}")

            # Calculate null percentage
            null_count = self.total_rows - stats['count']
            null_percentage = (null_count / self.total_rows) * 100
            print(f"  Missing values: {null_count:,} ({null_percentage:.2f}%)")

    def correlation_analysis(self):
        """Perform correlation analysis on numerical columns"""
        print("\n" + "=" * 80)
        print("CORRELATION ANALYSIS")
        print("=" * 80)

        # Identify numerical columns
        numerical_types = [IntegerType(), LongType(), FloatType(), DoubleType(), DecimalType()]
        numerical_cols = []

        for field in self.df.schema.fields:
            if any(isinstance(field.dataType, num_type) for num_type in numerical_types):
                numerical_cols.append(field.name)

        print(f"Found {len(numerical_cols)} numerical columns:")
        for i, col in enumerate(numerical_cols, 1):
            null_count = self.df.filter(col(col).isNull()).count()
            null_pct = (null_count / self.total_rows) * 100
            print(f"  {i:2d}. {col}: {null_pct:.1f}% missing")

        if len(numerical_cols) < 2:
            print("Not enough numerical columns for correlation analysis.")
            return

        # Sample data for correlation (to avoid memory issues)
        sample_fraction = min(0.1, 100000 / self.total_rows)
        print(f"\nSampling {sample_fraction:.1%} of data for correlation analysis...")

        df_sample = self.df.select(numerical_cols).sample(withReplacement=False,
                                                         fraction=sample_fraction,
                                                         seed=42)

        # Convert to pandas for correlation calculation
        print("Converting to pandas for correlation calculation...")
        df_pandas = df_sample.toPandas()

        # Drop rows with NaN for correlation
        df_clean = df_pandas.dropna()
        print(f"Rows for correlation: {len(df_clean):,}")

        if len(df_clean) < 100:
            print("Not enough data after cleaning for correlation analysis.")
            return

        # Calculate correlation matrix
        print("Calculating correlation matrix...")
        corr_matrix = df_clean.corr()

        # Display correlation matrix
        print(f"\nCorrelation Matrix (showing |r| > 0.1):")
        print("-" * 60)

        # Get pairs with meaningful correlations
        strong_correlations = []
        for i in range(len(corr_matrix.columns)):
            for j in range(i+1, len(corr_matrix.columns)):
                corr_value = corr_matrix.iloc[i, j]
                if abs(corr_value) > 0.1:
                    col1 = corr_matrix.columns[i]
                    col2 = corr_matrix.columns[j]
                    strong_correlations.append((col1, col2, corr_value))

        if strong_correlations:
            # Sort by absolute correlation strength
            strong_correlations.sort(key=lambda x: abs(x[2]), reverse=True)

            print(f"\nStrong Correlations (|r| > 0.1):")
            print(f"{'Column 1':<20} {'Column 2':<20} {'Correlation':<12} {'Strength'}")
            print("-" * 60)

            for col1, col2, corr in strong_correlations[:20]:  # Show top 20
                strength = self._get_correlation_strength(corr)
                print(f"{col1:<20} {col2:<20} {corr:+.3f}       {strength}")
        else:
            print("No strong correlations found (|r| > 0.1)")

        # Plot correlation heatmap
        self._plot_correlation_heatmap(corr_matrix)

        # Save correlation results
        self._save_correlation_results(corr_matrix, strong_correlations)

    def _get_correlation_strength(self, r):
        """Get descriptive strength of correlation"""
        abs_r = abs(r)
        if abs_r >= 0.8:
            return "Very Strong"
        elif abs_r >= 0.6:
            return "Strong"
        elif abs_r >= 0.4:
            return "Moderate"
        elif abs_r >= 0.2:
            return "Weak"
        else:
            return "Very Weak"

    def _plot_correlation_heatmap(self, corr_matrix):
        """Plot correlation heatmap"""
        print(f"\nPlotting correlation heatmap...")

        plt.figure(figsize=(12, 10))

        # Create mask for upper triangle
        mask = np.triu(np.ones_like(corr_matrix, dtype=bool))

        # Create heatmap
        sns.heatmap(corr_matrix,
                   mask=mask,
                   annot=True,
                   fmt='.2f',
                   cmap='RdBu_r',
                   center=0,
                   square=True,
                   linewidths=0.5,
                   cbar_kws={"shrink": 0.8})

        plt.title('Correlation Matrix Heatmap', fontsize=16, pad=20)
        plt.xticks(rotation=45, ha='right', fontsize=10)
        plt.yticks(fontsize=10)
        plt.tight_layout()
        plt.savefig('correlation_matrix.png', dpi=300, bbox_inches='tight')
        print(f"✓ Correlation heatmap saved to 'correlation_matrix.png'")
        plt.close()

    def _save_correlation_results(self, corr_matrix, strong_correlations):
        """Save correlation results to file"""
        # Save full correlation matrix
        corr_matrix.to_csv('correlation_matrix.csv')

        # Save strong correlations
        if strong_correlations:
            strong_corr_df = pd.DataFrame(strong_correlations,
                                         columns=['Column1', 'Column2', 'Correlation'])
            strong_corr_df['Strength'] = strong_corr_df['Correlation'].apply(self._get_correlation_strength)
            strong_corr_df.to_csv('strong_correlations.csv', index=False)

        print(f"\n✓ Correlation results saved:")
        print(f"  - correlation_matrix.csv")
        if strong_correlations:
            print(f"  - strong_correlations.csv")

    def run_full_analysis(self):
        """
        Run complete exploratory analysis
        """
        print("=" * 80)
        print("PATENT DATA ANALYSIS ON DATAPROC CLUSTER-AA01")
        print("=" * 80)

        # Create Spark session
        self.create_spark_session()

        # Load ALL data
        self.load_all_data()

        if self.df is None:
            print("No data loaded!")
            return

        # Run analyses
        self.basic_info()
        self.patent_date_analysis()
        self.geographic_analysis()
        self.classification_analysis()
        self.numerical_analysis()
        self.correlation_analysis()

        print("\n" + "=" * 80)
        print("ANALYSIS COMPLETE")
        print("=" * 80)

        # Stop Spark session
        self.spark.stop()
        print("Spark session stopped.")

# Main execution
if __name__ == "__main__":
    analyzer = PatentDataAnalyzerSpark()
    analyzer.run_full_analysis()