## Import Required Packages:

In [None]:
import os
import datetime
from collections import Counter
from pprint import pprint
import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql import DataFrame as spark_DataFrame
from pyspark.sql.types import *
from functools import reduce
import matplotlib.pyplot as plt
import requests
import re
from itertools import product
import time
from io import BytesIO
import requests
import numpy as np
from pyspark.sql import Window
import sys

### Utility Functions

In [None]:
def also_call_as(c):
    """
    Make a list of aliases for a given company using a few predefined
    rules. Try to encompass as many options as possible keeping in mind
    there will be aliases left out.
    """
    c = c.lower()

    # URLs (e.g. ".com")
    if len(c) > 3 and c[-4] == ".":
        a = c
        c = c.rsplit(".", 1)[0].replace(".", " ")
        aliases = set([c, a])
    else:
        aliases = set([c])

    # Single letter endings
    if len(c.split()[-1]) == 1:
        c = c.rsplit(" ", 1)[0]
        aliases.add(c)

    # Company legal endings
    
    endings = ["inc",
               "corp",
               "plc", 
               "reit", 
               "co", 
               "cor", 
               "group", 
               "company",
               "trust",
               "energy",
               "international",
               "of america",
               "pharmaceuticals",
               "clas",
               "in", "nv",
               "sa", 
               "re", 
               "pvt ltd",
               "private limited" ,
               "india private limited"
               "Co",
               "CO.",
               "Companies",
               "Company",
               "Corp",
               "CORP.",
               "Corporation",
               "Inc",
               "INC.",
               "Incorporated",
               "Limited",
               "Ltd",
               "Professional Corporation",
               "Chartered",
               "Limited",
               "Ltd",
               "Ltd.",
               "pa",
               "p.c ",
               "Professional Association",
               "Professional Corporation",
               "Professional Service Corporation",
               "psc",
               "sc.",
               "Service Corporation"]
    
    n_endings = 3  # Can have up to 3 of these endings
    for _ in range(n_endings):
        aliases.update([a.rsplit(" ", 1)[0] for a in aliases if
                        any([a.endswith(" " + e) for e in endings])])
        c = c.rsplit(" ", 1)[0] if any([c.endswith(" " + e) for e in endings]) else c

    # Alias any dashes and replace in company name
    aliases.update([a.replace("-", "") for a in aliases] +
                   [a.replace("-", " ") for a in aliases])
    c = c.replace("-", " ")

    # If '&' stands on its own, add alias of 'and'
    aliases.update([a.replace(" & ", " and ") for a in aliases])

    return {c: list(aliases)}


def generate_common_org_names(companies):
    """
    Download the companies and loop through them to find their aliases
    """
    #companies = [org_name]
    comp_dict = dict()
    for c in companies:
        comp_dict.update(also_call_as(c))
    return comp_dict


combine_spark_dfs = lambda sdf_list: reduce(spark_DataFrame.union, sdf_list)

class Common_MetaData:
    """ Variables to use across many functions. """
    keep = ["DATE",
            "SourceCommonName", 
            "DocumentIdentifier", 
            "Themes",
            "Organizations",
            "V2Tone"]
    tone = ["Tone",
            "PositiveTone",
            "NegativeTone", 
            "Polarity",
            "ActivityDensity", 
            "SelfDensity",
            "WordCount"]
    organizations = None



@udf(ArrayType(StringType(), True))
def simple_expand_spark(x):
    """ Expand a semicolon separated strint to a list (ignoring empties)"""
    if not x:
        return []
    return list(filter(None, x.split(";")))



@udf(MapType(StringType(), DoubleType()))
def tone_expand_spark(x):
    """ Expand the tone field. """
    if not x:
        return {t: None for t in Common_MetaData.tone}
    return {Common_MetaData.tone[i]: float(v) for i, v in enumerate(x.split(","))}



@udf(BooleanType())
def has_theme_spark(x, theme):
    """ Is the given theme included in any of the listed themes? """
    return any([theme in lst.split("_") for lst in x])


@udf(StringType())
def clean_organization(s):
    """ Standardize the organization names. """
    for k, v in Common_MetaData.organizations.items():
        if v[0] in s.split() :
            return k
    return s.lower()


def also_call_create(c):
    """
    Make a list of aliases for a given company using a few predefined
    rules. Try to encompass as many options as possible keeping in mind
    there will be aliases left out.
    """
    c = c.lower()

    # URLs (e.g. ".com")
    if len(c) > 3 and c[-4] == ".":
        a = c
        c = c.rsplit(".", 1)[0].replace(".", " ")
        aliases = set([c, a])
    else:
        aliases = set([c])

    return aliases

## Downloading , Filtering and Storing Gdelt Data:

In [0]:
class Redesign_data_format:
    def redesign_sdf(self, sdf, file_path_refined, org_name):
        """
        Given a spark data frame of the downloaded data, reformat it
        into human-readable Common_MetaData.
        Add a few more Common_MetaData for our purposes.
        """
        setattr(Common_MetaData, "organizations",generate_common_org_names(org_name))
        
        sdf = sdf.select(*Common_MetaData.keep)
        
        if not os.path.exists(file_path_refined.replace("dbfs:/", "/dbfs/")):
            # Reformat existing columns
            sdf = (sdf.withColumnRenamed("DocumentIdentifier", "URL")
                      .withColumn("Themes", simple_expand_spark("Themes"))
                      .withColumn("Organizations", simple_expand_spark("Organizations"))
                      .withColumn("V2Tone", tone_expand_spark("V2Tone"))
                   )

            # Create ESG columns & explode organization column
            sdf = (sdf.withColumn("E", has_theme_spark("Themes", F.lit("ENV")))
                      .withColumn("S", has_theme_spark("Themes", F.lit("UNGP")))
                      .withColumn("G", has_theme_spark("Themes", F.lit("ECON")))                      
                   )
        
            sdf.write.format("delta").option("header", "true").mode("overwrite").save(file_path_refined)
            print("sdf created")
        
        else:
            sdf = spark.read.format("delta").option("header", "true").load(file_path_refined)    
        
        sdf = (sdf.withColumn("Organization", F.explode("Organizations"))
                  .withColumn("Organization", clean_organization("organization"))
                  .filter(F.col("organization").isin(list(Common_MetaData.organizations.keys())))
              )

        # Expand tone columns
        exprs = [F.col("V2Tone").getItem(k).alias(k) for k in Common_MetaData.tone]

        sdf = sdf.select(*sdf.columns, *exprs).drop("V2Tone")
        #print(sdf.column)
        return sdf


    def download_and_generate_gdelt_table1(self, date, gd):
        """
        Download the GDELT table as a pandas dataframe using the gdelt package.
        Return a spark data frame.
        """
        pdf = gd.Search([date], table="gkg",coverage=True, output="df")
        pdf["DATE"] = pd.to_datetime(pdf["DATE"], format="%Y%m%d%H%M%S")

        sdf = spark.createDataFrame(pdf)
        print("   * loaded *  ", date)
        return sdf
    
    def download_and_generate_gdelt_table(self, date, file_path):
        file_path_os =  file_path.replace("dbfs:/", "/dbfs/")
        s = []
        for i in range(24):
            for j in list(range(0, 60, 15)):
                if i<10:
                    if j<10:
                        s.append('0' + str(i) + '0' + str(j))
                    else:
                        s.append('0' + str(i) + str(j))
                else:
                    if j<10:
                        s.append(str(i) + '0' + str(j))
                    else:
                         s.append(str(i) + str(j))
                            
        if not os.path.exists(file_path_os):
            li = ''.join(date.split('-'))
            c=0
            df1 = None
            for elem in s:
                try:
                    print(li, elem)
                    response = requests.get('http://data.gdeltproject.org/gdeltv2/'+li+ elem + '00.gkg.csv.zip')
                    #'http://data.gdeltproject.org/gkg/20220204.gkg.csv.zip')
                    buffer = BytesIO(response.content)
                   
                    frame = pd.read_csv(buffer, compression='zip', sep='\t',header=None, warn_bad_lines=True,encoding='latin')
                    frame[1] = pd.to_datetime(frame[1], format="%Y%m%d%H%M%S")
        
                    frame.columns = ['GKGRECORDID', 
                                     'DATE', 
                                     'SourceCollectionIdentifier', 
                                     'SourceCommonName',
                                     'DocumentIdentifier',
                                     'Counts',
                                     'V2Counts', 
                                     'Themes', 
                                     'V2Themes',
                                     'Locations',
                                     'V2Locations', 
                                     'Persons', 
                                     'V2Persons', 
                                     'Organizations',
                                     'V2Organizations', 
                                     'V2Tone',
                                     'Dates',
                                     'GCAM',
                                     'SharingImage',
                                     'RelatedImages', 
                                     'SocialImageEmbeds', 
                                     'SocialVideoEmbeds', 
                                     'Quotations',
                                     'AllNames', 
                                     'Amounts', 
                                     'TranslationInfo', 
                                     'Extras'] 
                    columns1 = ['DATE', 
                                'SourceCollectionIdentifier', 
                                'SourceCommonName',
                                'DocumentIdentifier',
                                'Counts', 
                                'V2Counts',
                                'Themes',
                                'V2Themes',
                                'Locations',
                                'V2Locations',
                                'Organizations',
                                'V2Organizations',
                                'V2Tone',
                                'Dates'] 
                    frame = frame[columns1]
                    print(frame.shape)
                    if c==0:
                        df1 = frame
                        c=1
                    else:
                        df1 = df1.append(frame, ignore_index=True)
                except:
                    pass
            
            sdf = spark.createDataFrame(df1)
            sdf.write.format("delta").option("header", "true").mode("overwrite").save(file_path)
            print("   * loaded *  ", date)
        else:
            sdf = spark.read.format("delta").option("header", "true").load(file_path)
        return sdf


    def getting_all_org_data(self, start_date, end_date, main_dir , refined_dir,organisation_name = None, override_save=False):
        """
        For each date between start_date and end_date, either download
        and clean the data or load the pre-saved data. Save the day's data
        in case of future use (so it doesn't have to be downloaded and cleaned again)
        """
        print("Loading and cleaning all data")
        data_list = []

        # Download and format the daily data
        for i, date in enumerate(pd.date_range(start_date, end_date).astype(str)):
            if i % 7 == 1:
                # Prevent it hanging like it does sometimes
                time.sleep(60)

            try:
                file_path = os.path.join(main_dir, date)
                file_path_refined = os.path.join(refined_dir, date)
                df = self.redesign_sdf(self.download_and_generate_gdelt_table(date, file_path), file_path_refined, organisation_name)      
                data_list.append(df)
                del df
                spark.catalog.clearCache()

            except Exception as e:
              print(f"!!! Failed to complete {date}!")
              print("  ****   Reason:\n" + str(e) + "\n\n")

        return combine_spark_dfs(data_list)
    def reformat_sdf(self, sdf, file_path_refined, org_name):
        """
        Given a spark data frame of the downloaded data, reformat it
        into human-readable Common_MetaData.
        Add a few more Common_MetaData for our purposes.
        """
        sdf = sdf.select(*Common_MetaData.keep)
        
        if not os.path.exists(file_path_refined.replace("dbfs:/", "/dbfs/")):
            # Reformat existing columns
            sdf = (sdf.withColumnRenamed("DocumentIdentifier", "URL")
                      .withColumn("Themes", simple_expand_spark("Themes"))
                      .withColumn("Organizations", simple_expand_spark("Organizations"))
                      .withColumn("V2Tone", tone_expand_spark("V2Tone"))
                   )

            # Create ESG columns & explode organization column
            sdf = (sdf.withColumn("E", has_theme_spark("Themes", F.lit("ENV")))
                      .withColumn("S", has_theme_spark("Themes", F.lit("UNGP")))
                      .withColumn("G", has_theme_spark("Themes", F.lit("ECON")))                      
                   )
        
            sdf.write.format("delta").option("header", "true").mode("overwrite").save(file_path_refined)
            print("sdf created")
        
        else:
            sdf = spark.read.format("delta").option("header", "true").load(file_path_refined)    
        
        sdf = (sdf.withColumn("Organization", F.explode("Organizations"))
                  .withColumn("Organization", clean_organization("organization"))
                  .filter(F.col("organization").isin(list(Common_MetaData.organizations.keys())))
              )

        return sdf
    def preprocess_gdelt_data(self ,start_date, end_date, org_name = None, save_csv=True):
        """
        """   
        dbutils.fs.rm('/mnt/esg/financial_report_data/GDELT_data_single_org/data_single_org', True)    

        base_dir = f"dbfs:/mnt/esg/financial_report_data/GDELT_data_single_org"
        if not os.path.exists(base_dir.replace("dbfs:/", "/dbfs/")):
            dbutils.fs.mkdirs(base_dir)

        org_dir = 'dbfs:/mnt/esg/financial_report_data/GDELT_data_single_org/data_single_org'
        if not os.path.exists(org_dir.replace("dbfs:/", "/dbfs/")):
            dbutils.fs.mkdirs(org_dir)

        base_data_dir = f"dbfs:/mnt/esg/financial_report_data/GDELT_data_single_org/data"
        base_data_dir_refined = f"dbfs:/mnt/esg/financial_report_data/GDELT_data_single_org/data_refined"

        if not os.path.exists(base_data_dir.replace("dbfs:/", "/dbfs/")):
            dbutils.fs.mkdirs(base_data_dir)
        if not os.path.exists(base_data_dir_refined.replace("dbfs:/", "/dbfs/")):     
            dbutils.fs.mkdirs(base_data_dir_refined)
        # Download and reformat the data
        print('')
        data = Redesign_data_format().getting_all_org_data(start_date, end_date, base_data_dir, base_data_dir_refined, org_name)
        print(f"There are {data.count():,d} data points for {len(Common_MetaData.organizations)} "
              f"organizations from {start_date} to {end_date}")

        # Save the data
        print("Saving Data...")
        data_save_path = f"dbfs:/mnt/esg/financial_report_data/GDELT_data_single_org/data_single_org" ## we are changing
        if not os.path.exists(base_data_dir.replace("dbfs:/", "/dbfs/")):
            dbutils.fs.mkdirs(base_data_dir)
        data.write.format("delta").mode("overwrite").save(data_save_path) ## need to remove this..
        print(f"Saved to {data_save_path}")

        return data



    # COMMAND ----------

def download_gdelt_data(start_date, end_date, org_name = None, save_csv=True):
    """
    """   
    dbutils.fs.rm('/mnt/esg/financial_report_data/GDELT_data_single_org/data_single_org', True)    

    base_dir = f"dbfs:/mnt/esg/financial_report_data/GDELT_data_single_org"
    if not os.path.exists(base_dir.replace("dbfs:/", "/dbfs/")):
        dbutils.fs.mkdirs(base_dir)

    org_dir = 'dbfs:/mnt/esg/financial_report_data/GDELT_data_single_org/data_single_org'
    if not os.path.exists(org_dir.replace("dbfs:/", "/dbfs/")):
        dbutils.fs.mkdirs(org_dir)

    base_data_dir = f"dbfs:/mnt/esg/financial_report_data/GDELT_data_single_org/data"
    base_data_dir_refined = f"dbfs:/mnt/esg/financial_report_data/GDELT_data_single_org/data_refined"

    if not os.path.exists(base_data_dir.replace("dbfs:/", "/dbfs/")):
        dbutils.fs.mkdirs(base_data_dir)
    if not os.path.exists(base_data_dir_refined.replace("dbfs:/", "/dbfs/")):     
        dbutils.fs.mkdirs(base_data_dir_refined)
    # Download and reformat the data
    print('')
    data = Redesign_data_format().getting_all_org_data(start_date, end_date, base_data_dir, base_data_dir_refined, org_name)
    print(f"There are {data.count():,d} data points for {len(Common_MetaData.organizations)} "
          f"organizations from {start_date} to {end_date}")

    # Save the data
    print("Saving Data...")
    data_save_path = f"dbfs:/mnt/esg/financial_report_data/GDELT_data_single_org/data_single_org" ## we are changing
    if not os.path.exists(base_data_dir.replace("dbfs:/", "/dbfs/")):
        dbutils.fs.mkdirs(base_data_dir)
    data.write.format("delta").mode("overwrite").save(data_save_path) ## need to remove this..
    print(f"Saved to {data_save_path}")

    return data

# COMMAND ----------


## ESG Computating Functions:

In [0]:
def singleSource(filtered_df):
    #Tone	PositiveTone	NegativeTone	Polarity	ActivityDensity	SelfDensity	WordCount
    filtered_df_result = (filtered_df.groupby(F.date_format("DATE", "y-MM-dd").alias("date"), 'SourceCommonName')
                                      .agg(F.mean("Tone"), F.mean("PositiveTone"), F.mean("NegativeTone"), 
                                           F.mean("Polarity"), F.mean("ActivityDensity"), 
                                           F.mean("SelfDensity"), F.mean("WordCount"))
                                      .withColumn("date", F.to_date("date", format="y-MM-dd"))
                                      .withColumn("date", F.col("date").cast("date"))
                                      .orderBy(F.col("date").asc())
                          )
    filtered_df_result = (filtered_df_result.select("date", "SourceCommonName", F.col("avg(Tone)").alias("Tone"), F.col("avg(PositiveTone)").alias("PositiveTone"), 
                                             F.col("avg(NegativeTone)").alias("NegativeTone"),
                                             F.col("avg(Polarity)").alias("Polarity"), F.col("avg(ActivityDensity)").alias("ActivityDensity"), 
                                             F.col("avg(SelfDensity)").alias("SelfDensity"), F.col("avg(WordCount)").alias("WordCount"))
                                           
                          )
    return filtered_df_result


def avg_day_tone(filtered_df, name):
    """ """
    colname = f"{name.replace(' ', '_')}_tone"
    tone_df = (filtered_df.groupby(F.date_format("DATE", "y-MM-dd").alias("date"))
                          .agg((F.sum("Tone") / F.sum("WordCount")).alias(colname))
                          .select("date", f"{colname}")
                          .withColumn("date", F.to_date("date", format="y-MM-dd"))
                          .withColumn("date", F.col("date").cast("date"))
                          .orderBy(F.col("date").asc())
              )
    return tone_df

def subtract_cols(df, col1, col2):
    df = (df.withColumn(col1, df[f"{col1}"] - df[f"{col2}"])
            .withColumnRenamed(col1, col1.replace("_tone", "_diff")))
    return df


def get_col_avgs(df):
    exclude = [k for k, v in df.dtypes if v in ["date", "timestamp", "string", "SourceCommonName"]]
    avgs = df.select([F.avg(c).alias(c) for c in df.columns if c not in exclude]).collect()[0]
    return {c: avgs[c] for c in df.columns if c not in exclude}

# COMMAND ----------

def show_df_rounded(df, places=4, rows=20):
    dtypes = {k: v for k, v in df.dtypes}
    date_cols = [k for k, v in dtypes.items() if v in ["date", "timestamp"]]
    str_cols = [k for k, v in dtypes.items() if v == "string"]
    int_cols = [k for k, v in dtypes.items() if "int" in v]
    
    show_cols = [F.date_format(c, "y-MM-dd").alias(c) if c in date_cols
                 else (F.col(c).alias(c) if c in str_cols
                 else (F.format_number(c, 0).alias(c) if c in int_cols
                 else (F.format_number(c, places).alias(c))))
                 for c in df.columns]
    show_cols = [c for c in show_cols]
    df.select(*show_cols).limit(rows).show()

# DBTITLE 1,Load Data from Delta Table
def load_data(save_path, file_name): 
  df = (spark.read.format("delta")
                      .option("header", "true")
                      .option("inferSchema", "true")
                      .load(os.path.join(save_path, file_name))
           )
  return df.toPandas()


def filter_non_esg(df): 
    return df[(df['E']==True) | (df['S'] == True) | (df['G'] == True)]

# COMMAND ----------

class graph_creator:
    def __init__(self, df):
        self.df = df

    def create_graph(self):
        # Find Edges
        df_edge = pd.DataFrame(self.df.groupby("URL").Organization.apply(list)
                               ).reset_index()

        get_tpls = lambda r: (list(itertools.combinations(r, 2)) if
                              len(r) > 1 else None)
        df_edge["SourceDest"] = df_edge.Organization.apply(get_tpls)
        df_edge = df_edge.explode("SourceDest").dropna(subset=["SourceDest"])

        # Get Weights
        source_dest = pd.DataFrame(df_edge.SourceDest.tolist(),
                                   columns=["Source", "Dest"])
        sd_mapping = source_dest.groupby(["Source", "Dest"]).size()
        get_weight = lambda r: sd_mapping[r.Source, r.Dest]
        source_dest["weight"] = source_dest.apply(get_weight, axis=1)

        # Get
        self.organizations = set(source_dest.Source.unique()).union(
                             set(source_dest.Dest.unique()))
        self.G = nx.from_pandas_edgelist(source_dest, source="Source",
            target="Dest", edge_attr="weight", create_using=nx.Graph)
        return self.G

# COMMAND ----------

def get_embeddings(G, organizations):
    # Fit graph
    g2v = NVVV()
    g2v.fit(G)
    
    # Embeddings
    embeddings = g2v.model.wv.vectors
    pca = PCA(n_components=3)
    principalComponents = pca.fit_transform(embeddings)
    d_e = pd.DataFrame(principalComponents)
    d_e["company"] = organizations
    return d_e, g2v

# COMMAND ----------

def get_connections(organizations, topn=25):
    l = [g2v.model.wv.most_similar(org, topn=topn)
         for org in organizations]
    df_sim = pd.DataFrame(l, columns=[f"n{i}" for i in range(topn)])
    for col in df_sim.columns:
        new_cols = [f"{col}_rec", f"{col}_conf"]
        df_sim[new_cols] = pd.DataFrame(df_sim[col].tolist(), 
                                        index=df_sim.index)
    df_sim = df_sim.drop(columns=[f"n{i}" for i in range(topn)])
    df_sim.insert(0, "company", list(organizations))
    return df_sim

# COMMAND ----------

def make_embeddings_and_connections(start, end):
    base_dir = f"dbfs:/mnt/esg/financial_report_data/GDELT_data_Russell_top_300"
    save_dir = os.path.join(base_dir, f"{start}__to__{end}")
    csv_file = "data_as_csv.csv"

    # Load data
    print("Loading Data")
    df = pd.read_csv(os.path.join(save_dir, csv_file).replace("dbfs:/", "/dbfs/"))
    df = filter_non_esg(df)

    # Create graph
    print("Creating Graph")
    creator = graph_creator(df)
    G = creator.create_graph()
    organizations = list(creator.organizations)

    # Save graph as pkl
    fp = os.path.join(save_dir, "organization_graph.pkl").replace("dbfs:/", "/dbfs/")
    with open(fp, "wb") as f:
        pickle.dump(G, f)
        
    # Create embeddings
    print("Creating embeddings")
    emb_path = os.path.join(save_dir, "pca_embeddings.csv").replace("dbfs:/", "/dbfs/")
    d_e, g2v = get_embeddings(G, organizations)
    d_e.to_csv(emb_path, index=False)
    
    # Create connections
    print("Creating connections")
    df_sim = get_connections(organizations)
    sim_path = os.path.join(save_dir, "connections.csv")
    df_sim.to_csv(sim_path.replace("dbfs:/", "/dbfs/"))
    
    # Save organizations as delta
    conn_path = os.path.join(save_dir, "CONNECTIONS")
    conn_data = spark.createDataFrame(df_sim)
    conn_data.write.format("delta").mode("overwrite").option("overwriteSchema", "true").save(conn_path)
    
def make_tables(start_date, end_date):
    """
    """
    # Directories
    org_types = f"Russell_top_{Fields.n_orgs}"
    base_dir = f"dbfs:/mnt/esg/financial_report_data/GDELT_data_{org_types}"
    range_save_dir = os.path.join(base_dir, f"{start_date}__to__{end_date}")
    esg_dir = os.path.join(range_save_dir, "esg_scores")
    dbutils.fs.mkdirs(esg_dir)
    
    # Load data
    data_path = os.path.join(range_save_dir, "data_as_delta")
    try:
        data = (spark.read.format("delta")
                     .option("header", "true")
                     .option("inferSchema", "true")
                     .load(data_path)
               )
        print("Data Loaded!")
    except:
        print("Data for these dates hasn't been generated!!!")
        return

    # Get all organizations
    print("Finding all Organizations")
    organizations = [x.Organization for x in data.select(
                     "Organization").distinct().collect()]
    
    # Get the overall tone
    print("Calculating Tones Over Time")
    overall_tone = daily_tone(data, "industry")
    esg_tones = {L: daily_tone(data.filter(f"{L} == True"), "industry")
                 for L in ["E", "S", "G"]}
    
    # Loop through the organizations to get the average daily tone for each company
    pct_idxs = range(0, len(organizations), len(organizations) // 10)
    for i, org in enumerate(organizations):
        if i in pct_idxs:
            print(f"{pct_idxs.index(i) * 10}%")
        tone_label = f"{org.replace(' ', '_')}_tone"
        
        overall_org_df = data.filter(f"Organization == '{org}'")
        org_tone = daily_tone(overall_org_df, org)
        overall_tone = subtract_cols(overall_tone.join(org_tone, on="date", how="left"),
                                     tone_label, "industry_tone")
      
        for L, tdf in esg_tones.items():
            esg_org_df = overall_org_df.filter(f"{L} == True")
            esg_org_tone = daily_tone(esg_org_df, org)
            esg_tones[L] = subtract_cols(tdf.join(esg_org_tone, on="date", how="left"), 
                                         tone_label, "industry_tone")            
    del data   
    
    # Average to get overall scores
    print("Computing Overall Scores")
    scores = {}
    overall_scores = get_col_avgs(overall_tone)
    esg_scores = {L: get_col_avgs(tdf) for L, tdf in esg_tones.items()}

    for org in organizations:
        diff_label = f"{org.replace(' ', '_')}_diff"
        scores[org] = {L: tdf[diff_label] for L, tdf in esg_scores.items()}
        scores[org]["T"] = overall_scores[diff_label]
      
      
    # Save all the tables
    print("Saving Tables")  
    
    # Overall ESG
    print("    Daily Overall ESG")
    path = os.path.join(esg_dir, "overall_daily_esg_scores.csv").replace("dbfs:/", "/dbfs/")
    pd_df = overall_tone.toPandas().set_index("date").sort_index().asfreq(freq="D", method="ffill")
    pd_df.to_csv(path, index=True)
    
    # E, S, and G
    for L, tdf in esg_tones.items():
        print("    Daily " + L)
        path = os.path.join(esg_dir, f"daily_{L}_score.csv").replace("dbfs:/", "/dbfs/")
        pd_df = tdf.toPandas().set_index("date").sort_index().asfreq(freq="D", method="ffill")
        pd_df.to_csv(path, index=True)

    # Averaged scores
    print("    Average Scores")
    score_path = path = os.path.join(esg_dir, "average_esg_scores.csv").replace("dbfs:/", "/dbfs/")
    pd.DataFrame(scores).to_csv(score_path, index=True)
    print("DONE!")

In [None]:
def calculating_esg_values(start_date, end_date):
    """
    """
    print("calulation started")
    data_path = f"dbfs:/mnt/esg/financial_report_data/GDELT_data_single_org/data_single_org"
    try:
        data = (spark.read.format("delta")
                     .option("header", "true")
                     .option("inferSchema", "true")
                     .load(data_path)
               )
        print("Data Loaded!")
    except:
        print("Data for these dates hasn't been generated!!!")
        return

    # Get all organizations
    print("Finding all Organizations")
    organizations = [x.Organization for x in data.select(
                     "Organization").distinct().collect()]
    
    # Get the overall tone
    print(organizations)
    print("Calculating Tones Over Time")
    overall_tone = avg_day_tone(singleSource(data), "industry")
    #overall_tone = avg_day_tone(data, "industry")
    esg_tones = {L: avg_day_tone(singleSource(data.filter(f"{L} == True")), "industry")
                 for L in ["E", "S", "G"]}
    
    pct_idxs = range(0, len(organizations))
    
    for i, org in enumerate(organizations):
        tone_label = f"{org.replace(' ', '_')}_tone"

        overall_org_df = data.filter(f"Organization == '{org}'")
        org_tone = avg_day_tone(singleSource(overall_org_df), org)

        overall_tone = overall_tone.join(org_tone, on="date", how="left")
      
        for L, tdf in esg_tones.items():
            esg_org_df = overall_org_df.filter(f"{L} == True")
            esg_org_tone = avg_day_tone(singleSource(esg_org_df), org)

            
            esg_tones[L] = tdf.join(esg_org_tone, on="date", how="left")
    
    
         
    del data   
    
    
    # Average to get overall scores
    print(organizations)
    print("Computing Overall Scores")
    scores = {}
    print("    Calculating overall tone")
    overall_scores = get_col_avgs(overall_tone)
    print("    Calculating esg tone")
    esg_scores = {L: get_col_avgs(tdf) for L, tdf in esg_tones.items()}
    
    #print(overall_scores)
    print(esg_scores)
    print("DONE!")
    return esg_scores, overall_tone




In [0]:
def download_company_data(base_dir, start_date, end_date, org_name):
    #print(org_name)
    _ = download_gdelt_data(start_date, end_date,org_name, save_csv=True)
    #print(org_name)
    esg_score1, _ = calculating_esg_values(start_date, end_date)
    return esg_score1

### Calculating ESG Scores for input Companies:

In [None]:
base_dir = "dbfs:/mnt/esg/financial_report_data" 
dbutils.widgets.text("myinput","microsoft;apple")
dbutils.widgets.text("startdate","2022-05-01")
dbutils.widgets.text("enddate","2022-05-02")

var_a = dbutils.widgets.get("myinput")
var_a = var_a.split(';')
print(var_a)
start_date = dbutils.widgets.get("startdate")
end_date = dbutils.widgets.get("enddate")

In [0]:
output = download_company_data(base_dir, start_date, end_date, var_a)
dbutils.notebook.exit(output)
