In [164]:
import numpy as np
import plotly.graph_objects as go
import pandas as pd
import polars as pl
from scipy.optimize import minimize
from sklearn.linear_model import LinearRegression

In [47]:
N = 10 # control units
T = 10 # observed periods
T0 = 5 # intervention time

# Y(0) ~ N(0, 1)
# Y(1) ~ N(4, 1)
np.random.seed(20240209)
weight = np.array(
    [0.4, 0.0, 0.3, 0.0, 0.0, 0.0, 0.3, 0.0, 0.0, 0.0]
)
Y_control = np.random.normal(0, 1, size=(N, T))
Y0 = weight @ Y_control
Y0_observed = np.concatenate(
    [
        Y0[:T0], np.random.normal(5, 1, T-T0)
    ]
)
Y0_counterfactual = Y0


fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=np.arange(1, T+1),
        y=Y0_observed,
        mode='lines',
        name='observed'
    )
)
fig.add_trace(
    go.Scatter(
        x=np.arange(T0, T+1),
        y=Y0_counterfactual[T0-1:],
        name='counterfactual',
        line=dict(
            dash='dash'
        ),
        mode='lines'
    )
)
for i in range(N):
    if i==1:
        fig.add_trace(
            go.Scatter(
                x=np.arange(1, T+1),
                y=Y_control[i, :],
                name='control units',
                opacity=0.3,
                line=dict(
                    color='gray'
                ),
                mode='lines',
            )
        )
    else:
        fig.add_trace(
            go.Scatter(
                x=np.arange(1, T+1),
                y=Y_control[i, :],
                opacity=0.3,
                line=dict(
                    color='gray'
                ),
                mode='lines',
                showlegend=False
            )
        )
fig.add_vline(
    x=5
)
fig.update_layout(
    xaxis_title='time',
    yaxis_title='outcome'
)
fig.show()

In [192]:
# stataファイルを読み込む
df = pd.read_stata('./data/smoking.dta')
df = pl.from_pandas(df)
# print(df.head().to_pandas().to_markdown())
df = df.select(
    "state", "year", "cigsale", "retprice"
).pivot(
    values=['cigsale', 'retprice'],
    index='year',
    columns='state'
)
# 処置群とコントロール群に分割
df

year,cigsale_state_Alabama,cigsale_state_Arkansas,cigsale_state_California,cigsale_state_Colorado,cigsale_state_Connecticut,cigsale_state_Delaware,cigsale_state_Georgia,cigsale_state_Idaho,cigsale_state_Illinois,cigsale_state_Indiana,cigsale_state_Iowa,cigsale_state_Kansas,cigsale_state_Kentucky,cigsale_state_Louisiana,cigsale_state_Maine,cigsale_state_Minnesota,cigsale_state_Mississippi,cigsale_state_Missouri,cigsale_state_Montana,cigsale_state_Nebraska,cigsale_state_Nevada,cigsale_state_New Hampshire,cigsale_state_New Mexico,cigsale_state_North Carolina,cigsale_state_North Dakota,cigsale_state_Ohio,cigsale_state_Oklahoma,cigsale_state_Pennsylvania,cigsale_state_Rhode Island,cigsale_state_South Carolina,cigsale_state_South Dakota,cigsale_state_Tennessee,cigsale_state_Texas,cigsale_state_Utah,cigsale_state_Vermont,cigsale_state_Virginia,…,retprice_state_California,retprice_state_Colorado,retprice_state_Connecticut,retprice_state_Delaware,retprice_state_Georgia,retprice_state_Idaho,retprice_state_Illinois,retprice_state_Indiana,retprice_state_Iowa,retprice_state_Kansas,retprice_state_Kentucky,retprice_state_Louisiana,retprice_state_Maine,retprice_state_Minnesota,retprice_state_Mississippi,retprice_state_Missouri,retprice_state_Montana,retprice_state_Nebraska,retprice_state_Nevada,retprice_state_New Hampshire,retprice_state_New Mexico,retprice_state_North Carolina,retprice_state_North Dakota,retprice_state_Ohio,retprice_state_Oklahoma,retprice_state_Pennsylvania,retprice_state_Rhode Island,retprice_state_South Carolina,retprice_state_South Dakota,retprice_state_Tennessee,retprice_state_Texas,retprice_state_Utah,retprice_state_Vermont,retprice_state_Virginia,retprice_state_West Virginia,retprice_state_Wisconsin,retprice_state_Wyoming
f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
1970.0,89.800003,100.300003,123.0,124.800003,120.0,155.0,109.900002,102.400002,124.800003,134.600006,108.5,114.0,155.800003,115.900002,128.5,104.300003,93.400002,121.300003,111.199997,108.099998,189.5,265.700012,90.0,172.399994,93.800003,121.599998,108.400002,107.300003,123.900002,103.599998,92.699997,99.800003,106.400002,65.5,122.599998,124.300003,…,38.799999,29.4,42.200001,39.0,34.299999,33.799999,41.400002,30.6,37.700001,34.200001,28.299999,34.299999,38.0,39.099998,36.200001,36.0,34.0,33.900002,38.900002,31.4,39.700001,27.299999,37.299999,36.599998,38.400002,38.400002,39.299999,32.5,38.5,39.900002,40.400002,34.599998,37.700001,28.799999,33.700001,38.5,34.099998
1971.0,95.400002,104.099998,121.0,125.5,117.599998,161.100006,115.699997,108.5,125.599998,139.300003,108.400002,102.800003,163.5,119.800003,133.199997,116.400002,105.400002,127.599998,115.599998,108.599998,190.5,278.0,92.599998,187.600006,98.5,124.599998,115.400002,106.300003,123.199997,115.0,96.699997,106.300003,108.900002,67.699997,124.400002,128.399994,…,39.700001,31.1,45.5,41.299999,35.799999,33.599998,41.400002,32.200001,38.5,38.900002,30.1,39.299999,38.799999,40.099998,37.5,36.799999,34.700001,34.700001,44.0,34.099998,41.700001,29.4,38.900002,38.099998,39.799999,44.700001,40.200001,34.299999,38.5,41.599998,42.0,36.599998,39.5,30.200001,41.599998,40.200001,34.400002
1972.0,101.099998,103.900002,123.5,134.300003,110.800003,156.300003,117.0,126.099998,126.599998,149.199997,109.400002,111.0,179.399994,125.300003,136.5,96.800003,112.099998,130.0,122.199997,104.900002,198.600006,296.200012,99.300003,214.100006,103.800003,124.400002,121.699997,109.0,134.399994,118.699997,103.0,111.5,108.599998,71.300003,138.0,137.0,…,39.900002,31.200001,51.299999,44.700001,40.900002,33.700001,41.900002,32.5,41.900002,38.799999,30.6,40.0,41.5,45.200001,37.400002,37.700001,40.099998,41.099998,40.599998,36.099998,41.099998,28.700001,38.900002,38.400002,39.799999,44.700001,41.599998,34.099998,39.099998,41.599998,46.900002,37.200001,40.0,29.9,41.299999,40.299999,34.400002
1973.0,102.900002,108.0,124.400002,137.899994,109.300003,154.699997,119.800003,121.800003,124.400002,156.0,110.599998,115.199997,201.899994,126.699997,138.0,106.800003,115.0,132.100006,119.900002,106.599998,201.5,279.0,98.900002,226.5,108.699997,120.5,124.099998,110.699997,142.0,125.5,103.5,109.699997,110.400002,72.699997,146.800003,143.100006,…,39.900002,32.700001,50.599998,44.0,42.400002,36.299999,41.0,32.900002,41.900002,39.299999,30.6,39.900002,41.0,45.599998,37.299999,37.700001,40.900002,41.200001,40.299999,36.900002,41.799999,28.9,39.400002,42.0,40.400002,44.900002,40.599998,33.5,39.599998,40.799999,46.400002,36.5,39.799999,30.1,39.900002,42.599998,34.400002
1974.0,108.199997,109.699997,126.699997,132.800003,112.400002,151.300003,123.699997,125.599998,131.899994,159.600006,116.099998,118.599998,212.399994,129.899994,142.100006,110.599998,117.099998,135.399994,121.900002,110.5,204.699997,269.799988,100.300003,227.300003,110.5,122.099998,130.5,114.199997,146.100006,129.699997,108.400002,114.800003,114.699997,75.599998,151.800003,149.600006,…,41.900002,38.099998,52.5,44.200001,42.400002,38.0,41.900002,34.5,43.200001,40.200001,31.5,41.599998,41.799999,47.0,41.400002,38.0,41.799999,42.0,41.900002,37.900002,43.700001,30.1,39.900002,42.900002,41.0,46.599998,41.299999,35.200001,40.400002,42.5,47.5,37.799999,41.299999,31.299999,42.0,43.900002,35.799999
1975.0,111.699997,114.800003,127.099998,131.0,110.199997,147.600006,122.900002,123.300003,131.800003,162.399994,120.5,123.400002,223.0,133.600006,140.699997,111.5,116.800003,135.600006,123.699997,114.099998,205.199997,269.100006,103.099998,226.0,117.900002,122.5,132.899994,114.599998,154.699997,130.5,113.5,117.400002,116.0,75.800003,155.5,152.699997,…,45.0,41.700001,54.5,45.900002,44.5,40.299999,45.200001,36.700001,45.400002,42.700001,33.299999,44.299999,46.700001,49.400002,43.0,43.5,43.700001,44.599998,44.5,40.799999,46.299999,32.900002,42.599998,46.0,43.599998,49.799999,44.299999,38.099998,42.799999,45.299999,50.599998,40.5,41.799999,33.599998,45.200001,46.599998,38.599998
1976.0,116.199997,119.099998,128.0,134.199997,113.400002,153.0,125.900002,125.099998,134.399994,166.600006,124.400002,127.699997,230.899994,139.600006,144.899994,116.699997,120.900002,139.5,124.900002,118.099998,201.399994,290.5,102.400002,230.199997,125.400002,124.599998,138.600006,118.800003,150.199997,136.800003,116.699997,121.699997,121.400002,77.900002,171.100006,158.100006,…,48.299999,44.799999,57.599998,50.099998,47.900002,42.5,48.400002,38.700001,47.799999,46.599998,36.0,48.099998,49.900002,52.099998,46.400002,44.700001,45.299999,46.799999,44.900002,43.900002,49.5,35.799999,45.900002,48.5,46.400002,52.299999,52.200001,41.0,45.0,48.299999,53.299999,43.400002,47.099998,37.900002,48.400002,51.299999,42.599998
1977.0,117.099998,122.599998,126.400002,132.0,117.300003,153.300003,127.900002,125.0,134.0,173.0,125.5,127.900002,229.399994,140.0,145.600006,117.199997,122.099998,140.800003,127.0,117.699997,190.800003,278.799988,102.400002,217.0,122.199997,127.300003,140.399994,120.099998,148.800003,137.199997,115.599998,124.599998,124.199997,78.0,169.399994,157.699997,…,49.0,44.700001,58.400002,51.700001,49.5,45.599998,49.400002,40.599998,49.400002,48.099998,36.900002,48.900002,50.900002,53.099998,48.799999,45.900002,47.599998,48.099998,49.299999,45.0,51.599998,36.599998,47.400002,49.799999,47.900002,53.299999,52.299999,42.200001,46.400002,49.599998,53.299999,44.700001,47.0,38.400002,48.900002,52.099998,43.400002
1978.0,123.0,127.300003,126.099998,129.199997,117.5,155.5,130.600006,122.800003,136.699997,150.899994,127.099998,127.099998,224.699997,142.699997,143.899994,118.900002,124.900002,141.800003,127.199997,117.400002,187.0,269.600006,103.099998,205.5,121.900002,131.300003,143.600006,122.300003,146.800003,140.399994,116.900002,127.300003,126.599998,79.599998,162.399994,155.899994,…,58.700001,57.400002,61.700001,58.700001,54.700001,51.5,54.599998,50.0,54.599998,52.599998,41.400002,54.200001,55.0,57.900002,53.599998,49.900002,51.900002,53.599998,54.299999,49.700001,56.0,41.799999,53.200001,53.900002,53.099998,57.400002,56.299999,49.200001,53.200001,54.799999,59.099998,49.5,52.5,42.799999,53.900002,57.099998,49.799999
1979.0,121.400002,126.5,121.900002,131.5,117.400002,150.199997,131.0,117.5,135.300003,148.899994,124.199997,126.400002,214.899994,140.100006,138.5,118.300003,123.900002,140.199997,120.300003,116.099998,183.300003,254.600006,101.0,197.300003,121.300003,130.899994,141.600006,122.599998,145.800003,135.699997,117.400002,127.199997,126.400002,79.099998,160.899994,151.800003,…,60.099998,52.799999,64.400002,60.0,56.599998,55.400002,56.799999,52.5,56.400002,54.799999,43.400002,57.099998,54.5,60.900002,56.5,52.200001,53.700001,55.400002,57.099998,53.200001,57.599998,43.700001,55.0,56.299999,55.5,60.599998,58.700001,50.200001,54.099998,57.299999,62.200001,53.700001,54.799999,45.799999,62.400002,58.700001,51.700001


In [307]:
df = pd.read_stata('./data/smoking.dta')
df = pl.from_pandas(df)

T = len(df['year'].unique()) # 観測された時間の長さ
N = len(df['state'].unique()) - 1 # control unitの数
T0 = 1988 - 1970 + 1 # 処置が始まるまでの時間

# print(df.head().to_pandas().to_markdown())
df = df.select(
    "state", "year", "cigsale", "retprice"
)

# X: N×T0*2 matrix
X = df.filter(
(pl.col("state") != "California")
& (pl.col("year") <= 1988)
).sort('year')[['cigsale', 'retprice']].to_numpy().T.reshape(N, -1)

Y = df.filter(
    (pl.col("state") == "California")
    & (pl.col("year") <= 1988)
)[['cigsale', 'retprice']].to_numpy().T.reshape(2*T0)

print(X.shape)

(38, 38)


In [367]:
# OLS
alpha_OLS = np.linalg.inv(X.T @ X) @ X.T @ Y
print(alpha_OLS.round(3))
print(alpha_OLS.sum())

[-0.618 -0.918  0.633  0.276  0.263  0.958  0.168  0.303 -0.331 -0.099
 -0.578  0.957 -0.356  0.278 -0.03  -0.455 -0.164 -0.621 -0.175 -0.077
  0.079  0.382 -0.326  0.419  0.216 -0.125 -0.1    0.21  -0.398  0.734
  0.315 -0.058  0.504 -0.21  -0.767  0.737  0.026 -0.099]
0.9510519


In [336]:
treatment_outcome = df.filter(
    (pl.col("state") == "California")
)[['cigsale']].to_numpy().T.reshape(T)

control_outcome = df.filter(
    (pl.col("state") != "California")
)[['cigsale']].to_numpy().reshape(N, -1).T

print(control_outcome.shape)

pred_OLS = control_outcome @ alpha_OLS

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=year,
        y=treatment_outcome,
        name='California'
    )
)

fig.add_trace(
    go.Scatter(
        x=year,
        y=pred_OLS,
        name='Synthetic California'
    )
)


fig.add_vline(
    x=1988
)

fig.update_layout(
    template='plotly_white',
    xaxis_title='year',
    yaxis_title='cigarette sales (in packs)',
    legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
    ),
    width=1200,
    height=600,
)

(31, 38)


In [368]:
# SCM
def loss_function(w, X, y):
    return np.mean((y - X @ w)**2)

cons = (
    {'type': 'eq', 'fun': lambda x: np.sum(x) - 1}
)
bounds = [(0, None) for i in range(N)]

alpha0 = np.ones(N) / N
scm = minimize(
    loss_function,
    x0=alpha0,
    args=(X, Y),
    constraints=cons,
    bounds=bounds,
    method="SLSQP"
)
alpha_scm = scm.x
print(alpha_scm.round(4))
print(alpha_scm.sum())

[0.     0.     0.     0.0853 0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.113
 0.1051 0.4566 0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.2401 0.     0.     0.     0.     0.    ]
1.000000000008864


In [335]:
pred_scm = control_outcome @ alpha_scm

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=year,
        y=treatment_outcome,
        name='California'
    )
)

fig.add_trace(
    go.Scatter(
        x=year,
        y=pred_scm,
        name='Synthetic California'
    )
)

fig.add_trace(
    go.Scatter(
        x=year,
        y=pred_OLS,
        name='OLS',
        opacity=0.5
    )
)

fig.add_vline(
    x=1988
)

fig.update_layout(
    template='plotly_white',
    xaxis_title='year',
    yaxis_title='cigarette sales (in packs)',
    legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
    ),
    width=1200,
    height=600,
)

In [339]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=year,
        y=treatment_outcome - pred_scm,
        name='SCM'
    )
)

fig.add_trace(
    go.Scatter(
        x=year,
        y=treatment_outcome - pred_OLS,
        opacity=0.5,
        name='OLS'
    )
)

fig.add_vline(
    x=1988
)

fig.add_hline(
    y=0,
    opacity=0.3
)

fig.update_layout(
    template='plotly_white',
    xaxis_title='year',
    yaxis_title='Treatment Effect',
    legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
    ),
    width=1200,
    height=600,
    title='Estimatied Result of Treatment Effect'
)

fig.show()

In [350]:
state = df.filter(pl.col("state")!="California")['state'].unique()
fig = go.Figure()

fig.add_trace(
    go.Bar(
        x=state,
        y=alpha_scm,
        name="SCM"
    )
)

fig.add_trace(
    go.Bar(
        x=state,
        y=alpha_OLS,
        name="OLS"
    )
)
fig.update_layout(
    template='plotly_white',
    legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
    ),
    xaxis_title="state",
    yaxis_title="weight",
    xaxis_tickangle=270,
    width=1200,
    height=600,
)

In [364]:
state_used = ["Connecticut", "Nevada", "New Hampshire", "New Mexico", "Utah"]
# df.filter(pl.col('state').is_in(state_used))
fig = go.Figure()
fig.add_trace(
        go.Scatter(
            x=year,
            y=df.filter(pl.col("state")=="California")['cigsale'],
            name=c
        )
    )
for c in state_used:
    fig.add_trace(
        go.Scatter(
            x=year,
            y=df.filter(pl.col("state")==c)['cigsale'],
            name=c,
            # line=dict({"color":'gray'}),
            opacity=0.5
        )
    )

fig.update_layout(
    template='plotly_white',
    xaxis_title='year',
    yaxis_title='cigarette sales (in packs)',
    legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
    ),
    width=1200,
    height=600,
)
fig.show()