In [1]:
using Statistics, CSV, DataFrames

In [2]:
include("data.jl")

data (generic function with 1 method)

In [3]:
df = CSV.read("C:/Users/User/Documents/GitHub/synthdid/data/CPS.csv", DataFrame);

In [4]:
function contract3(X, v=Nothing)
    # if(length(size(X)) == 3 && size(X, 3) == length(v)) return "error"
    out = zeros(size(X, 1), size(X, 2))
    if isnothing(v)
      return out
    end
    for ii in eachindex(v)
      out .+= v[ii] * X[:, :, ii]
    end
    return out
  end;

In [187]:
function fw_step(A, x, b, eta, alpha=nothing)
    Ax = A * x
    half_grad = transpose(Ax .- b) * A + (eta * x)'
    i = findmin(vec(half_grad))[2]
    if !isnothing(alpha)
      x *= (1 - alpha)
      x[i] += alpha
      return x
    else
      d_x = -x
      d_x[i] = 1 - x[i]
      if all(d_x .== 0)
        return x
      end
      d_err = A[:, i] - Ax
      step_upper = -half_grad * d_x
      step_bot = sum(d_err .^ 2) + eta * sum(d_x .^ 2)
      step = step_upper[1] / step_bot
      constrained_step = min(1, max(0, step))
      return x + constrained_step * d_x
    end
  end;

In [None]:
function sc_weight_fw_covariates(Y, X = array(0, dim = c(dim(Y), 0)), zeta.lambda = 0, zeta.omega = 0,
                                lambda.intercept =   TRUE, omega.intercept = TRUE,
                                min.decrease = 1e-3, max.iter = 1000,
                                lambda = NULL, omega = NULL, beta = NULL, update.lambda = TRUE, update.omega = TRUE)


In [460]:
df = CSV.read("C:/Users/User/Documents/GitHub/synthdid/data/CPS.csv", DataFrame);

In [461]:
Y = df[:,3:5]
X = cat(zeros(size(Y)); dims = ndims(Y)+1);
zeta_lambda = 0; zeta_omega = 0;
lambda_intercept = true; omega_intercept = true;
min_decrease = 1e-3; max_iter = 1000;
lambda = nothing; omega = nothing; beta = nothing; update_lambda = true; update_omega = true;

T0 = size(Y)[2] - 1
N0 = size(Y)[1] - 1
if ndims(X) == 2
    cat(X; dims = ndims(X)+1)
end
if isnothing(lambda)
    lambda = repeat([1/T0], T0)
end
if isnothing(omega)
    omega = repeat([1/N0], N0)
end;
if isnothing(beta)
    beta = repeat([0.0], size(X)[3]-1)
end

mutable struct update_weights1
    val
    lambda
    omega
    err_lambda
    err_omega
end

function update_weights(Y, lambda, omega)

    Y_lambda = zeros(N0, T0+1);
    if lambda_intercept
        for i in 1:size(Y,2)
            Y_lambda[:,i] = Y[1:N0,i] .- mean(Y[1:N0,i])
        end
    else
        Y_lambda = Y[1:N0,:]
    end
    if update_lambda 
        lambda = fw_step(Y_lambda[:, 1:T0], lambda, Y_lambda[:,T0+1], N0 * real(zeta_lambda^2))
    end
    err_lambda = Y_lambda * vcat(lambda, -1);

    Y_omega = zeros(size(Matrix(Y[:, 1:T0])', 1), size(Matrix(Y[:, 1:T0])',2));
    if omega_intercept
        for i in 1:size(Matrix(Y[:, 1:T0])', 2)
            Y_omega[:,i] = Matrix(Y[:, 1:T0])'[1:T0, i] .- mean(Matrix(Y[:, 1:T0])'[1:T0, i])
        end
    else
        Y_omega = Matrix(Y[:, 1:T0])'
    end
    if update_omega
        omega = fw_step(Y_omega[:, 1:N0], omega, Y_omega[:,N0+1], N0 * real(zeta_omega^2))
    end
    err_omega = Y_omega * vcat(omega, -1)
    val = real(zeta_omega.^2) * sum(omega.^2) + real(zeta_lambda.^2) * sum(lambda.^2) + sum(err_omega.^2) / T0 .+ sum(err_lambda.^2) ./ N0
    # return Dict("val" => val, "lambda" => lambda, "omega" => omega, "err_lambda" => err_lambda, "err_omega" => err_omega)
    res1 = update_weights1(val, lambda, omega, err_lambda, err_omega);
    return res1;
    
end

vals = repeat([0], max_iter);
t = 0
Y_beta = Y .- contract3(X, beta);
weights = update_weights(Y_beta, lambda, omega);

In [462]:
t < max_iter && (t < 2 || vals[t - 1] - vals[t] > min_decrease^2)

true

In [463]:
X

2000×3×1 Array{Float64, 3}:
[:, :, 1] =
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 ⋮         
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0