In [150]:
from pre_process_did import pre_process_did
import pandas as pd, numpy as np
import patsy 
fml = patsy.dmatrices

In [151]:
def makeBalancedPanel(data, idname, tname):
  data = data.sort_values([idname, tname]).reset_index(drop = True)
  nt = len(data[tname].unique())
  data = data.groupby(idname)\
    .filter(lambda x: len(x) == nt)
  return data

In [152]:
data = pd.read_csv("../../data/sim_data.csv")
data.head()

Unnamed: 0.1,Unnamed: 0,G,X,id,period,Y
0,1,4,-0.876233,1,1,3.248701
1,2,4,-0.876233,1,2,2.266837
2,3,4,-0.876233,1,3,3.990885
3,4,4,-0.876233,1,4,4.653489
4,5,4,0.974974,2,1,5.880457


In [153]:
yname = "Y"
tname = "period"
idname = "id"
gname = "G"
data = data
xformla = "Y~X"

### Adicional params
clustervar = None
weights_name = None


In [154]:
control_group = ['nevertreated', 'notyettreated']
anticipation = 0
panel = True
allow_unbalanced_panel = True

In [155]:
n, t = data.shape
control_group = control_group[0]

In [156]:

columns = [idname, tname, yname, gname]
# Columns
if clustervar is not None:
  columns += [clustervar]
if weights_name is not None:
  columns += [weights_name]
  w = data[weights_name]
else:
  w = np.ones(n)


if xformla is None:
  xformla = f'{yname} ~ 1'
_, x_cov = fml(xformla, data = data, return_type='dataframe')
_, n_cov = x_cov.shape

data = pd.concat([data[columns], x_cov], axis=1)
data = data.assign(w = w)
data = data.dropna()
ndiff = n - len(data) 
if ndiff != 0: 
  print(f'dropped, {ndiff}, rows from original data due to missing data')

tlist = np.sort(data[tname].unique())
glist = np.sort(data[gname].unique())

asif_nev_treated = data[gname] > np.max(tlist)
asif_nev_treated.fillna(False, inplace=True)
data.loc[asif_nev_treated, gname] = 0

In [157]:
if len(glist[glist == 0]) == 0:
  if control_group == "nevertreated":
    raise ValueError("There is no available never-treated group")
  else:
    value = np.max(glist) - anticipation
    data = data.query(f'{tname} < @value')
    tlist = np.sort(data_nona[tname].unique())
    glist = np.sort(data_nona[gname].unique())
    glist = glist[glist < np.max(glist)]

In [158]:
glist = glist[glist > 0]
# first prerios 
fp = tlist[0]
glist = glist[glist > fp + anticipation]

treated_fp = (data[gname] <= fp) & ~(data[gname] == 0)
treated_fp.fillna(False, inplace=True)

nfirst_period = np.sum(treated_fp) if panel \
  else len(data.loc[treated_fp, idname].unique())

if nfirst_period > 0:
  warning_message = f"Dropped {nfirstperiod} units that were already treated in the first period."
  print(warning_message)
  glist_in = np.append(glist, [0])
  data = data.query(f'{gname} in @glist_in')
  tlist = np.sort(data_nona[tname].unique())
  glist = np.sort(data_nona[gname].unique())
  glist = glist[glist > 0]
  fp = tlist[0]
  glist = glist[glist > fp + anticipation]

#todo: idname must be numeric
true_rep_cross_section = False
if not panel:
  true_rep_cross_section = True

In [159]:
if panel:
  if allow_unbalanced_panel: 
    panel = False
    true_rep_cross_section = False
  else:
    keepers = data.dropna().index
    n = len(data[idname].unique)
    print(n)
    n_keep = len(data.iloc[keepers, idname].unique())

    if len(data.loc[keepers] < len(data)):
      print(f"Dropped {n-n_keep} observations that had missing data.")
      data = data.loc[keepers]
    # make balanced data set
    n_old = len(data[idname].unique())
    data = makeBalancedPanel(data, idname=idname, tname=tname)
    n = len(data[idname].unique())
    if len(data) == 0:
      raise ValueError("All observations dropped to convert data to balanced panel. Consider setting `panel=False` and/or revisit 'idname'.")
    if n < n_old:
      warnings.warn(f"Dropped {n_old-n} observations while converting to balanced panel.")
    tn = tlist[0]
    n = len(data.query(f'{tname} == @tn'))


In [160]:
# add rowid
if not panel:
  keepers = data.dropna().index
  ndiff = len(data.loc[keepers]) - len(data)
  if len(keepers) == 0:
    raise "All observations dropped due to missing data problems."
  if ndiff < 0:
    mssg = f"Dropped {ndiff} observations that had missing data."
    data = data.loc[keepers]
  if true_rep_cross_section: 
    data = data.assign(rowid = range(len(data)))
    idname = 'rowid'
  else:
    r_id = np.array(data[idname])
    data = data.assign(rowid = r_id)
  
  n = len(data[idname].unique())

data = data.sort_values([idname, tname])


In [161]:
if len(glist) == 0:
  raise f"No valid groups. The variable in '{gname}' should be expressed as the time a unit is first treated (0 if never-treated)."
if len(tlist) == 2:
  cband = False

In [162]:
gsize = data.groupby(data[gname]).size().reset_index(name="count")
gsize["count"] /= len(tlist)

reqsize = n_cov + 5
gsize = gsize[gsize["count"] < reqsize]

if len(gsize) > 0:
  gpaste = ",".join(map(str, gsize[gname]))
  warnings.warn(f"Be aware that there are some small groups in your dataset.\n  Check groups: {gpaste}.")

  if 0 in gsize[gname].tolist() and control_group == "nevertreated":
    raise "Never-treated group is too small, try setting control_group='notyettreated'."
nT, nG = map(len, [tlist, glist])

In [163]:
did_params = {
  'yname' : yname, 'tname': tname,
  'idname' : idname, 'gname': gname,
  'xformla': xformla, 'data': data,
  'tlist': tlist, 'glist': glist,
  'n': n, 'nG': nG, 'nT': nT,
  'control_group': control_group, 'anticipation': anticipation,
  'weights_name': weights_name, 'panel': panel,
  'true_rep_cross_section': true_rep_cross_section
}
did_params

{'yname': 'Y',
 'tname': 'period',
 'idname': 'id',
 'gname': 'G',
 'xformla': 'Y~X',
 'data':       id  period         Y  G  Intercept         X    w  rowid
 0      1       1  3.248701  4        1.0 -0.876233  1.0      1
 1      1       2  2.266837  4        1.0 -0.876233  1.0      1
 2      1       3  3.990885  4        1.0 -0.876233  1.0      1
 3      1       4  4.653489  4        1.0 -0.876233  1.0      1
 4      2       1  5.880457  4        1.0  0.974974  1.0      2
 ..   ...     ...       ... ..        ...       ...  ...    ...
 315   99       4  4.843891  0        1.0  0.055808  1.0     99
 316  100       1 -1.069496  0        1.0  0.666871  1.0    100
 317  100       2  1.359492  0        1.0  0.666871  1.0    100
 318  100       3  3.421405  0        1.0  0.666871  1.0    100
 319  100       4  6.059621  0        1.0  0.666871  1.0    100
 
 [320 rows x 8 columns],
 'tlist': array([1, 2, 3, 4], dtype=int64),
 'glist': array([2, 3, 4], dtype=int64),
 'n': 80,
 'nG': 3,
 'nT':