In [1]:
import numpy as np
import pandas as pd
import plotly.express as px

In [19]:
# set default parameters
np.random.seed(42)

b0_treat = 10
b0_control = 40
b1_treat = 4
b1_control = 4
treatment_effect = 7
noise = 3
N = 500
R = 100

In [20]:
# simulate data

# units and treatment assignment
units = np.arange(N)
treat = np.repeat([0, 1], np.floor(N/2)) 
np.random.shuffle(treat)

df = pd.DataFrame({
    'unit' : units,
    'treat' : treat
})

# add time (each unit is observed at time 0 and time 1)
df = df.merge(pd.DataFrame({'time_period' : np.arange(9)}), how = 'cross')

# add time indicator (1 if post-treatment else 0)
df['time_indicator'] = (df['time_period'] >= 4).astype(int)

# add outcomes

# baseline outcomes
df['outcome'] = df['treat'].apply(lambda x : b0_treat if x == 1 else b0_control) 

# apply trends
treat_mask = df['treat'] == 1 
control_mask = df['treat'] == 0

df.loc[treat_mask, 'outcome'] += df.loc[treat_mask, 'time_period'] * b1_treat
df.loc[control_mask, 'outcome'] += df.loc[control_mask, 'time_period'] * b1_control

# apply treatment effect
df['outcome'] = df['outcome'] + treatment_effect * df['treat'] * df['time_indicator']

# add noise
df['outcome'] = df['outcome'] + np.random.normal(0, noise, 9*N)


In [21]:
# It's good practice in Plotly to use a string/categorical column for colors
# to ensure it creates a discrete color scale and legend.
df_plot = df.copy()
df_plot['treat'] = df_plot['treat'].apply(lambda x : 'Treated' if x == 1 else 'Control')

# 1. Create the interactive scatter plot with trendlines
fig = px.scatter(
    df_plot,
    x='time_period',
    y='outcome',
    color='treat',  # This replicates `c=df['treat']`
    title="Simulated Difference-in-Differences Data",
    opacity = 0.4,
    labels={
        "time_period": "Time Period",
        "outcome": "Outcome",
        "treat": "Group"
    },
    color_discrete_map={ # Optional: set custom colors
        'Treated': 'red',
        'Control': 'blue'
    },
    template='plotly_white'
)

# 2. Add the vertical line, which corresponds to your plt.vlines()
fig.add_vline(
    x=4,
    line_dash="dash",
    line_color="black",
    annotation_text="Treatment Start",
    annotation_position="top right"
)

# 3. Show the figure
fig.show()

To do:
- turn simulate data into function
- turn plotting into functions
- use doWhy to do a DiD https://www.pywhy.org/dowhy/v0.8/example_notebooks/dowhy_simple_example.html
