In [194]:
import findspark
findspark.init()

from pyspark.sql import  SparkSession, Row
from pyspark.sql.functions import \
	lit, col, when, expr, countDistinct

In [223]:
spark = SparkSession.builder.appName('preprocesdid')\
    .config('spark.master', 'local[4]')\
    .config('spark.executor.memory', '1g')\
    .config("spark.sql.shuffle.partitions", 1)\
    .config('spark.driver.memory','1g')\
    .getOrCreate()
# configuration 
# https://spark.apache.org/docs/latest/configuration.html

In [224]:
import os
name = r'D:\Workflow\work\csdid_r\R\5g10t.csv'
# os.path.exists(name)
data = spark.read.csv(name, header=True, inferSchema=True)
data.show(5)

+---+-----------------+---+-------+------+----------------+-----+
|  G|                X| id|cluster|period|               Y|treat|
+---+-----------------+---+-------+------+----------------+-----+
|  2|0.214945979570717|  1|     37|     1|4.39776569093816|    1|
|  2|0.214945979570717|  1|     37|     2|7.66172212986391|    1|
|  2|0.214945979570717|  1|     37|     3|6.83282829318896|    1|
|  2|0.214945979570717|  1|     37|     4|7.07668186230603|    1|
|  2|0.214945979570717|  1|     37|     5|8.99244332186859|    1|
+---+-----------------+---+-------+------+----------------+-----+
only showing top 5 rows



In [226]:
yname, gname, idname, tname = 'Y', 'G', 'id', 'period'
control_group = ['nevertreated', 'notyettreated']
anticipation = 0
panel = True
allow_unbalanced_panel = True
weights_name = None

clustervar = None
xfmla = None

In [163]:
n_pre = data.count()
columns = [idname, tname, yname, gname]

if clustervar is not None:
	columns += [clustervar]
if weights_name is not None:
	columns += [weights_name]
	data = data.withColumn('_w', data[weights_name])
else:
	columns += ['_w']
	data = data.withColumn('_w', lit(1))

def form_to_strings(fmla : str = 'y ~ x + 1'):
	y, x = fmla.split('~')
	x_var = x.strip().split('+')
	x_var = [x.strip() for x in x_var]
	x_var = ['_intercept' if x == '1' else x for x in x_var]
	return x_var
data = data.withColumn('_intercept', lit(1))

# xfmla = 'y ~ X + 1'
if xfmla is None:
	x_var = ['_intercept']
	x_cov = data[x_var]
	n_cov = 1
else:
	x_var = form_to_strings(fmla=xfmla)
	n_cov = len(x_var)

columns += x_var

data = data[columns].na.drop('all')
ndiff = n_pre - data.count()

if ndiff != 0: 
	print(f'Dropped, {ndiff}, rows from original data due to missing data')


In [164]:
def tlist_glist(data, tname, gname, _filter = False):
	tlist = data.select(tname).distinct().orderBy(col(tname))
	glist = data.select(gname).distinct().orderBy(col(gname))
	if _filter:
		value_expr = expr(f'max({gname})')
		value = data.select(value_expr.alias('value')).first()['value']
		glist = glist.filter(glist[gname] < value)
	return tlist, glist

tlist, glist = tlist_glist(data, tname, gname, False)



#   asif_nev_treated = data[gname] > np.max(tlist)
#   asif_nev_treated.fillna(False, inplace=True)
#   data.loc[asif_nev_treated, gname] = 0

# data = data.\
# 	withColumn(
# 		gname, 
# 		when(col(gname) > data.\
# 			selectExpr("max(tlist) as max_tlist")\
# 			.first()["max_tlist"], 0)\
# 			.otherwise(col(gname)
# 		)
# 	)

n_glist0 = glist.filter(glist[gname] == 0).count()
if n_glist0 == 0:
	if control_group == 'nevertreated':
		raise 'There is no avaible never-treated group'
	else:
		value_expr = expr(f'max({gname}) - {anticipation}')
		value = data.select(value_expr.alias('value')).first()['value']
		data =  data.filter(data[tname] < value)
		tlist, glist = tlist_glist(data, tname, gname, True)

In [166]:
glist.show()

+---+
|  G|
+---+
|  0|
|  2|
|  3|
|  4|
|  5|
|  6|
+---+



In [168]:
glist = glist.filter(glist[gname] > 0)
# first_period
fp = tlist.first()[tname]
glist = glist.filter(glist[gname] > fp + anticipation)


In [220]:
data = data.\
	withColumn(
		"treated_fp", (col(gname) <= fp) & ~(col(gname) == 0)
	).\
	fillna({'treated_fp': False})
nfirst_period = \
	data.filter(col('treated_fp') == True).count() if panel\
	else\
		data.filter(col('treated_fp') == True).\
			select(idname).distinct().count()
# 93 - 102
if nfirst_period > 0:
	warning_message = f"Dropped {nfirst_period} units that were already treated in the first period."
	print(warning_message)
	glist_in = glist.collect()
	glist_in = [x[gname] for x in glist_in]
	glist_in = glist_in + [0]
	data = data.filter(col(gname).isin(glist_in))
	tlist, glist = tlist_glist(data, tname=tname, gname=gname)
	glist = glist.filter(col(gname) > 0)
	fp = tlist.first()[tname]
	glist = glist.filter(col(gname) > fp + anticipation)



In [236]:
true_rep_cross_section = False
if not panel:
	true_rep_cross_section = True

# if panel: 
# 	if allow_unbalanced_panel: 
# 		true_rep_cross_section = False
keep = data.na.drop('all')
n_id = data.select(idname).distinct().count()
n_keep = keep.select(idname).distinct().count()
n_old_data = data.count()
data = makeBalancedPanel(data, idname=idname, tname=tname)

# 119 - 121, repetitive code 
# if n_keep < data.count():
# 	print(f"Dropped {n_id-n_keep} observations that had missing data.")

Dropped 0 observations that had missing data.


In [235]:
panel

True