In [1]:
import os
from os.path import join as opj

import numpy as np
from omegaconf import OmegaConf
from copy import deepcopy
import torch
from datetime import datetime
from einops import rearrange
from matplotlib import pyplot as plt


from cuts_main import CUTS
from utils.cuts_parts import *
from utils.gumbel_softmax import gumbel_softmax
from utils.misc import plot_causal_matrix, reproduc, prepross_data
from utils.data_interpolate import interp_multivar_data
from utils.load_data import simulate_var_from_links
from utils.logger import MyLogger


ModuleNotFoundError: No module named 'numpy'

In [None]:
opt = OmegaConf.load("./cuts_example.yaml")
device = "cuda"

reproduc(**opt.reproduc)
timestamp = datetime.now().strftime("_%Y_%m%d_%H%M%S_%f")
opt.task_name += timestamp
proj_path = opj(opt.dir_name, opt.task_name)
log = MyLogger(log_dir=proj_path, **opt.log)
log.log_opt(opt)

: 

In [None]:
data, true_cm = simulate_var_from_links(**opt.data.param)


T, N, D = data.shape
print("Data shape: ", data.shape)
data = prepross_data(data)

mask = np.ones_like(data)
np.random.seed(opt.data.pre_sample.random_missing.seed)
p = opt.data.pre_sample.random_missing.missing_prob
missing_var = opt.data.pre_sample.random_missing.missing_var
if isinstance(missing_var, str) and missing_var=="all":
    mask = np.random.choice([0,1], size=mask.shape, p=[p,1-p])
else:
    for var_i in missing_var:
        mask[:,var_i] = np.random.choice([0,1], size=mask[:,var_i].shape, p=[p,1-p])
print(f"Generated random missing with missing_prob: {p:.4f}")

: 

In [None]:
sampled_data = data * mask
interp_data = interp_multivar_data(sampled_data, mask, interp=opt.data.init_fill)

: 

In [None]:
fig = plt.figure(figsize=[10,10])
plt.plot(np.arange(0, 100, 1), data[:100,1,0], label="original", alpha=0.5)
# plt.plot(np.arange(0, 100, 1), interp_data[:100,1,0], label="interp", c="red")
plt.scatter(np.argwhere(mask[:100,1,0]), data[np.argwhere(mask[:100,1,0]),1,0], label="sampled points")
plt.legend()
plt.show()

: 

In [None]:
sub_cg = plot_causal_matrix(true_cm, figsize=[4, 3], vmin=0, vmax=1)
plt.show()

: 

In [None]:
multicad = CUTS(opt.cuts, log, device=device)
multicad.train(interp_data, mask, data, true_cm)


: 

In [None]:
time_prob_mat = torch.sigmoid(multicad.graph).detach().cpu().numpy()
print(np.max(time_prob_mat, axis=2))
causal_graph = (np.max(time_prob_mat, axis=2) > 0.5)
sub_cg = plot_causal_matrix(causal_graph, figsize=[4, 3], vmin=0, vmax=1)
plt.show()

: 

: 