In [8]:
import wandb
from condgen.models.score_matching import ConditionalScoreMatcher
from condgen.data_utils.data_utils_MNIST import MNISTDataModule
from condgen.data_utils.data_utils_cf_traj import SimpleTrajDataModule
from condgen.models import samplers
from condgen.models.CFGAN import CFGAN
import pytorch_lightning as pl
import torch
import numpy as np
import os

from torchvision.utils import make_grid
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import tqdm 
import pandas as pd
import torch.nn as nn

In [9]:
#gpu = 0
#api = wandb.Api()
#run = api.run(f"edebrouwer/counterfactuals/ce1wo23n")


## Functions

In [19]:
class SC(nn.Module):
    def __init__(self,Ntrain):
        super().__init__()
        self.weights = nn.Parameter(torch.zeros(Ntrain,1))
    def forward(self,X):
        weights_1 = torch.nn.functional.softmax(self.weights,0)
        return torch.matmul(X,weights_1)
    
def permute_last(X):
    if len(X.shape)==3:
        return X.permute(1,2,0)
    elif len(X.shape)==4:
        return X.permute(1,2,3,0)
    
def train(Xtrain, xtest, mod, epochs = 50):
    optimizer = torch.optim.Adam(mod.parameters(), lr = 0.1)
    loss_history = []
    for epoch in range(epochs):
        optimizer.zero_grad()
        x_pred = mod(permute_last(Xtrain))[...,0]
        loss = (x_pred - xtest).pow(2).mean()
        loss.backward()
        optimizer.step()
        loss_history.append(loss.detach().cpu())
    return mod, loss_history

def pred(mod,Ytrain):
    y_pred = mod(permute_last(Ytrain))[...,0]
    return y_pred

def evaluate_sc(model_cls,dataset_cls, run, config = None):

    fname = [f.name for f in run.files() if "ckpt" in f.name][0]
    run.file(fname).download(replace = True, root = ".")
    model = model_cls.load_from_checkpoint(fname)
    os.remove(fname)

    hparams = dict(model.hparams)
    dataset = dataset_cls(**hparams)
    
    if config["ite_mode"]:
        treatment2 = config["treatment2"]
        treatment3 = config["treatment3"]
        ite_mode = True
    else:
        ite_mode = False
        treatment2 = None
        treatment3 = None
        
    dataset.prepare_data(ite_mode = ite_mode, treatment2 = treatment2, treatment3 = treatment3)
    
    cf_dl = dataset.test_cf_dataloader()
    train_dl = dataset.train_dataloader()
    
        
    #Xtrain = train_dl.dataset.dataset.X
    #Ytrain = train_dl.dataset.dataset.Y
    #Ttrain = train_dl.dataset.dataset.T
    
    Xtrain = []
    Ytrain = []
    Ttrain = []
    for i,b in enumerate(train_dl):
        X, Y, _, _, T = b
        Xtrain.append(X)
        Ytrain.append(Y)
        Ttrain.append(T)
    Xtrain = torch.cat(Xtrain)
    Ytrain = torch.cat(Ytrain)
    Ttrain = torch.cat(Ttrain)

    #XCF = cf_dl.dataset.X
    #YoCF = cf_dl.dataset.Y_o
    #ToCF = cf_dl.dataset.T_o
    #YnewCF = cf_dl.dataset.Y_new
    #TnewCF = cf_dl.dataset.T_new
    
    XCF = []
    YoCF = []
    ToCF = []
    YnewCF = []
    TnewCF = []
    Ynew2CF = []
    Tnew2CF = []
    for i,b in enumerate(cf_dl):
        if ite_mode:
            Xcf, Yo, _, To, Tnew,Ynew, Tnew2, Ynew2 = b
            Ynew2CF.append(Ynew2)
            Tnew2CF.append(Tnew2)
        else:
            Xcf, Yo, _, To, Tnew,Ynew = b
            
        XCF.append(Xcf)
        YoCF.append(Yo)
        ToCF.append(To)
        YnewCF.append(Ynew)
        TnewCF.append(Tnew)
        
    XCF = torch.cat(XCF)
    YoCF = torch.cat(YoCF)
    ToCF = torch.cat(ToCF)
    YnewCF = torch.cat(YnewCF)
    TnewCF = torch.cat(TnewCF) 
    Ynew2CF = torch.cat(Ynew2CF)
    Tnew2CF = torch.cat(Tnew2CF) 
    
    mses = []
    
    for idx in tqdm.tqdm(range(XCF.shape[0])):
        xcf = XCF[idx]
        ycf = YnewCF[idx]
        tcf = TnewCF[idx]

        Xtrain_ = Xtrain[(Ttrain-tcf)**(2)<0.01]
        Ytrain_ = Ytrain[(Ttrain-tcf)**(2)<0.01]

        sc_mod = SC(Xtrain_.shape[0])
        mod, loss_history = train(Xtrain_,xcf,sc_mod)
        ypred = pred(mod,Ytrain_)
        
        if ite_mode:
            ycf2 = Ynew2CF[idx]
            tcf2 = Tnew2CF[idx]
            
            Xtrain2_ = Xtrain[(Ttrain-tcf2)**(2)<0.01]
            Ytrain2_ = Ytrain[(Ttrain-tcf2)**(2)<0.01]

            sc_mod2 = SC(Xtrain2_.shape[0])
            mod2, loss_history2 = train(Xtrain2_,xcf,sc_mod2)
            ypred2 = pred(mod2,Ytrain2_)
            
            mse = ((ypred2-ypred)-(ycf2-ycf)).pow(2).mean()
            mses.append(mse.detach().cpu())
        else:
            mse = (ypred-ycf).pow(2).mean()
            mses.append(mse.detach().cpu())
        
    mse = np.array(mses).mean()
    return mse

## Experiments

In [20]:
config_MNIST = [{"sweep_id":["cchpo2kd"],
  "model_cls":CFGAN,
  "data_cls":MNISTDataModule,
  "model_name":"Synthetic Controls",
  "data_name":"MNIST",
  "config_name":f"Synthetic Controls MNIST"}]

config_CV = [{"sweep_id":["qw11zu2e"],
  "model_cls":CFGAN,
  "data_cls":SimpleTrajDataModule,
    "fold_name": "random_seed",
  "model_name":"CFGAN",
  "data_name":"CV",
    "ite_mode" : True,
    "treatment2":0.5,
    "treatment3":0.8,
    "config_name" : "Synthetic Controls CV" }]

config_Traj = [{"sweep_id":["yavzrkz7"],
  "model_cls":CFGAN,
  "data_cls": SimpleTrajDataModule,
    "fold_name": "random_seed",
  "model_name":"CFGAN",
  "data_name":"SimpleTraj",
    "config_name" : "Synthetic Controls MNIST" }]

#configs = config_MNIST + config_CV + config_Traj
configs = config_CV

In [21]:
fold_name = "random_seed"
df = pd.DataFrame()
api = wandb.Api()

for config in configs:

    pd_dict = {"Model":config["model_name"],"Data":config["data_name"], "Name":config["config_name"]}
    sweep_names = [sweep_id for sweep_id in config["sweep_id"]]
    
    model_cls = config["model_cls"]
    # Gathering runs from sweeps -----
    sweeps = [api.sweep(f"edebrouwer/counterfactuals/{sweep_name}") for sweep_name in sweep_names]
    sweep_runs = []
    for sweep in sweeps:
        sweep_runs += [r for r in sweep.runs]
        
    best_runs = []
    for fold in [421,422,423,424,425]:

        runs_fold = [r for r in sweep_runs if (r.config.get(fold_name)==fold) and (r.config.get("data_type")==config["data_name"])]
        if "groups" in config:
            for group_key in config["groups"].keys():
                runs_fold = [r for r in runs_fold if (r.config.get(group_key)==config["groups"][group_key])]
        
        runs_fold_sorted = sorted(runs_fold,key = lambda run: run.summary.get("restored_val_loss"), reverse = False)
        best_runs.append(runs_fold_sorted[0])
        
    mses = []
    for run in best_runs:
        mse = evaluate_sc(run = run, model_cls = model_cls, dataset_cls = config["data_cls"], config = config)
        mses.append(mse)
        
    mses = np.array(mses)
    mse_mu = mses.mean()
    mse_std = mses.std()

    mse_str = "$ " + str(mse_mu.round(3))+ "\pm" +str(mse_std.round(3)) +" $"
    pd_dict["MSE"] = mse_str
    
    df = df.append(pd_dict,ignore_index =True)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:10<00:00, 14.24it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:23<00:00, 11.91it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:22<00:00, 12.08it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:23<00:00, 11.92it/s]
100%|███████████████████████████████████████

In [22]:
print(df)

   Model Data                   Name                MSE
0  CFGAN   CV  Synthetic Controls CV  $ 0.258\pm0.016 $


In [None]:
df