In [None]:
# Stratified Sampling

In [131]:
import sys
import os
from pathlib import Path
import importlib
sys.path.append('..')

import pandas as pd

import data.dataframe_preparation as preparation
from data.custom_widgets import ReportsLabeler
import data
importlib.reload(data.custom_widgets)
importlib.reload(data.dataframe_preparation)

############### CONFIG ###############
FIRM_METADATA = os.path.abspath("../input_files/Firm_Metadata.csv")
DATA_INPUT_PATH = os.path.abspath("../input_files/annual_reports/")
MASTER_DATA_PATH = os.path.abspath("../input_files/annual_reports/Firm_AnnualReport_DF.csv")
LABEL_OUTPUT_FN = 'Firm_AnnualReport_Labels_DF.pkl'

HOLD_OUT_YEAR = 2019
HOLD_OUT_COMPANY = 'gb_unilever_plc'

SEED = 99
OVERRIDE_FILE = True
######################################

# Create master output file if not exits
master_file = Path(MASTER_DATA_PATH)
if not master_file.is_file() or OVERRIDE_FILE:
    df = preparation.get_df(input_path=DATA_INPUT_PATH, report_type_mappings={"20F": "AR"}, selected_report_types={"AR"}, include_text=False, include_page_no=False, include_toc=False)
    df = df.set_index("id")
    # Add additional labelling columns
    df['should_label'] = False
    df['is_labelled'] = False
    
    # Load the meta data
    df_meta = pd.read_csv(FIRM_METADATA)
    df_meta = df_meta.set_index('id')
    df['company_id'] = df['country'] + "_" + df['company']
    df = df.drop(columns=['country'])
    df = df.merge(df_meta, left_on='company_id', right_index=True)
    
    # TODO: Raise/Alert or remove if all outputs are there!
    df = df[df['output_file'].notna()]
    
    # Select all reports from year and company
    df_to_label = df[(df.year == HOLD_OUT_YEAR) | (df.company_id == HOLD_OUT_COMPANY)].copy()
    df_to_label['should_label'] = True
    nr_held_out_reports = len(df_to_label)
    df.update(df_to_label)
    
    # Remaining reports: Take sample with stratification along variables. n = amount of reports in each year AND industry...
    df_to_label = df[df.should_label == False]
    # Note: Do not use apply here, as otherwise the random state is equal for each group, possibly leading to non-random sampling!
    df_to_label = df_to_label.groupby(['year', 'icb_industry']).sample(n=1, random_state=SEED)
    # df_to_label = df_to_label.reset_index(2).reset_index(drop=True).set_index("id")
    
    df_to_label['should_label'] = True
    df.update(df_to_label)
    df.to_csv(MASTER_DATA_PATH)
    
df_to_label.groupby(['company']).count()

HBox(children=(FloatProgress(value=0.0, max=49.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=19.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=19.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=19.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))




In [132]:
df_to_label.groupby(['company']).count()

Unnamed: 0_level_0,orig_report_type,report_type,year,input_file,output_file,should_label,is_labelled,company_id,firm_name,ticker,country,icb_industry,icb_supersector
company,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
abb,1,1,1,1,1,1,1,1,1,0,1,1,1
air_liquide,3,3,3,3,3,3,3,3,3,0,3,3,3
airbus,7,7,7,7,7,7,7,7,7,0,7,7,7
anheuser_busch_inbev,4,4,4,4,4,4,4,4,4,0,4,4,4
asml_hldg,6,6,6,6,6,6,6,6,6,0,6,6,6
astrazeneca,2,2,2,2,2,2,2,2,2,0,2,2,2
axa,3,3,3,3,3,3,3,3,3,0,3,3,3
barclays,3,3,3,3,3,3,3,3,3,0,3,3,3
basf,9,9,9,9,9,9,9,9,9,0,9,9,9
bayer,4,4,4,4,4,4,4,4,4,0,4,4,4


In [133]:
df_to_label.groupby(['icb_industry']).count()

Unnamed: 0_level_0,company,orig_report_type,report_type,year,input_file,output_file,should_label,is_labelled,company_id,firm_name,ticker,country,icb_supersector
icb_industry,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
10 Technology,20,20,20,20,20,20,20,20,20,20,0,20,20
15 Telecommunications,20,20,20,20,20,20,20,20,20,20,0,20,20
20 Health Care,20,20,20,20,20,20,20,20,20,20,0,20,20
30 Financials,20,20,20,20,20,20,20,20,20,20,0,20,20
40 Consumer Discretionary,20,20,20,20,20,20,20,20,20,20,0,20,20
45 Consumer Staples,18,18,18,18,18,18,18,18,18,18,0,18,18
50 Industrials,20,20,20,20,20,20,20,20,20,20,0,20,20
55 Basic Materials,16,16,16,16,16,16,16,16,16,16,0,16,16
60 Energy,16,16,16,16,16,16,16,16,16,16,0,16,16
65 Utilities,15,15,15,15,15,15,15,15,15,15,0,15,15


In [134]:
df_to_label.groupby(['year']).count()

Unnamed: 0_level_0,company,orig_report_type,report_type,input_file,output_file,should_label,is_labelled,company_id,firm_name,ticker,country,icb_industry,icb_supersector
year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
1999.0,6,6,6,6,6,6,6,6,6,0,6,6,6
2000.0,6,6,6,6,6,6,6,6,6,0,6,6,6
2001.0,7,7,7,7,7,7,7,7,7,0,7,7,7
2002.0,7,7,7,7,7,7,7,7,7,0,7,7,7
2003.0,9,9,9,9,9,9,9,9,9,0,9,9,9
2004.0,10,10,10,10,10,10,10,10,10,0,10,10,10
2005.0,10,10,10,10,10,10,10,10,10,0,10,10,10
2006.0,10,10,10,10,10,10,10,10,10,0,10,10,10
2007.0,10,10,10,10,10,10,10,10,10,0,10,10,10
2008.0,10,10,10,10,10,10,10,10,10,0,10,10,10
