-
Notifications
You must be signed in to change notification settings - Fork 231
Refactor CITs in oop way #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tofuwen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the awesome work!
The code looks much better!!
Only some nits comment :)
| class CIT(object): | ||
| def __init__(self, data, method='fisherz', **kwargs): | ||
|
|
||
| def CIT(data, method='fisherz', **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, I don't like this kargs design, because users doesn't know what to input?
I think the current way is fine (for backward compatibility), and later, maybe we can change it to
cit = FisherZ(data, args) for all caller?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point! And I am also confused about this...
Now **kwargs is placed here to support some user-defined parameters (just at the algorithm call level, e.g., pc, s.t. users don't need to edit codes inside pc). E.g.,
pc(data, 0.05, kci, est_width='median')if the user wants to use pc+kci with another kernel width (which is not supported in the oldfunc kci, where all parameters are set by default.)pc(data, 0.05, kci, True, 0, -1, cache_path='/my/path/to/cache.json')to save&load citest cache.
Now I use kwargs because the additional arguments that different methods (e.g., FisherZ, KCI) can take are different, and we have to use func CIT as the entrance for all callers (just for backward compatibility).
args won't work because users still don't know what to input (and even the order)? A perfect way would be to declare cit outside the algorithm call:
kci_obj = KCI(data, kernelZ='Polynomial', est_width='median', cache_path='/my/path/to/cache.json')
pc(data, 0.05, kci_obj, True, 0, -1) # only take the algorithm-related parametersBut this is not backward compatible (for user input). So a compromise might be to still use func CIT as entrance, and just put more instructions (on what parameters are allowable) at CIT's comment and CIT, pc's documents?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I totally agree. The final example is the perfect solution we should pursue eventually.
I think the design is fine for now, but maybe later we may want to change the design and make it better.
I think in order to make our package, we have to do some backward incompatible things --- the current input / output for each algorithm is not even consistent, which is bad... So we need to change things anyway.
How about we add a todo here to remind us what the good design looks like and later when we do huge refactor, we can do it. cc @kunwuz if you'd like to share some feedback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ic. Cool!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MarkDana how about add a todo here to remind us later we want to remove the kargs argument?
| assert isinstance(data, np.ndarray), "Input data must be a numpy array." | ||
| self.data = data | ||
| self.data_hash = hash(str(data)) | ||
| self.data_hash = hashlib.md5(str(data).encode('utf-8')).hexdigest() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When data is huge, this will be slow?
And when path is None, we don't need to compute data_hash?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be fast (and used only once). np.ndarray.__str__ only returns a preview:
In [41]: data = np.random.randn(10000, 100)
In [42]: str(data)
Out[42]: '[[-1.39552534 1.37974053 -1.1619043 ... 0.13616104 -0.12120668\n -1.00001339]\n [-0.25197878 -2.00971912 0.63008704 ... -0.97997436 -1.21297862\n 1.42272323]\n [-1.22421999 0.90022162 -1.33748472 ... 1.32908047 -1.37618144\n -0.28312766]\n ...\n [ 1.71461535 0.10882434 0.08604805 ... 1.34678215 -2.30936746\n 0.76045509]\n [ 0.55727436 0.2203048 0.41242777 ... 0.95881301 0.58538315\n 1.26002782]\n [-0.77753666 0.53018912 0.70592259 ... 0.14847539 -0.60861808\n -0.36093896]]'
In [43]: len(str(data))
Out[43]: 491
In [44]: timeit hashlib.md5(str(data).encode('utf-8')).hexdigest()
125 µs ± 392 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool
| ---------- | ||
| X: int, or np.*int* | ||
| Y: int, or np.*int* | ||
| condition_set: Iterable<int | np.*int*> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not force it to be List?
causallearn/utils/cit.py
Outdated
| self.SAVE_CACHE_CYCLE_SECONDS = 30 | ||
| self.last_time_cache_saved = time.time() | ||
| self.pvalue_cache = {'data_hash': self.data_hash} | ||
| if not cache_path is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if not cache_path is None: | |
| if cache_path is not None: |
check https://stackoverflow.com/questions/2710940/python-if-x-is-not-none-or-if-not-x-is-none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, thanks for this!!! Marked!
I always get lost when writing something like this...
| self.save_to_local_cache() | ||
|
|
||
| METHODS_SUPPORTING_MULTIDIM_DATA = ["kci"] | ||
| if condition_set is None: condition_set = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be cleaner to force condition_set is never None?
If the user don't want condition_set, just use []?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I missed this.
Sometimes users might want to input FisherZ(X, Y) to test just unconditional independence.
From the users' end, I feel that this looks better than FisherZ(X, Y, []), or FisherZ(X, Y, ()) ... (which of course also works).
Another reason is that, I'm not sure whether there exist usages like FisherZ(X, Y) in current codes. lol
| return [X], [Y], condition_set, _stringize([X], [Y], condition_set) | ||
|
|
||
| # also to support multi-dimensional unconditional X, Y (usually in kernel-based tests) | ||
| Xs = sorted(set(map(int, X))) if isinstance(X, Iterable) else [int(X)] # sorted for comparison |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not force X to be List type?
Personally I always prefer to make the variable typed --- it can remove lots of potential bug and make the code much cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In all of our constraint-based methods, X and Y are assumed to be integers.
If we force X as List type, all the codes related to cit calls in pc, fci, ... will need to be changed.
Overall, integers X and Y should always be the first-class citizen in CITests. Multi-dim X Y only works in KCI, and not in constaint-based methods, but somewhere else (e.g., GIN, you name it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ic, sounds good.
| var = Xs + Ys + condition_set | ||
| sub_corr_matrix = self.correlation_matrix[np.ix_(var, var)] | ||
| try: | ||
| inv = np.linalg.inv(sub_corr_matrix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, curious, what if exception is thrown here? You didn't catch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes I didn't catch exceptions. But what would be the possible exceptions here? If it's about the type of X, Y, condition_set, a built-in error message seems to be informative enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you "try" then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh you mean which lines? L155?
This is from the original code at Fixed fisherz test (#58).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ha, ic. I missed the later "except".... nvm
causallearn/utils/cit.py
Outdated
| class FisherZ(CIT_Base): | ||
| def __init__(self, data, **kwargs): | ||
| super().__init__(data, **kwargs) | ||
| self.check_cache_method_consistent('fisherz', -1) # -1: no parameters can be specified for fisherz |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, -1 here looks ugly. maybe we can have better design, e.g. use None?
And in the code below, it seems you never write "parameters_hash" to json? Why the assertion not fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeh -1 is ugly, lol. A message string ("NO SPECIFIED PARAMETERS") might also be ok.
"parameters_hash" is written to cache (and json). See check_cache_method_consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, a (const) string is much better.
try to avoid to use any magic number like -1 here in your code. :)
causallearn/utils/cit.py
Outdated
| return np.unique(column, return_inverse=True)[1] | ||
| assert method_name in ['chisq', 'gsq'] | ||
| super().__init__(np.apply_along_axis(_unique, 0, data).astype(np.int64), **kwargs) | ||
| self.check_cache_method_consistent(method_name, -1) # -1: no parameters can be specified for chisq/gsq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
causallearn/utils/cit.py
Outdated
| # result = [[]] | ||
| # for pool in lists: | ||
| # result = [x + [y] for x in result for y in pool] | ||
| # return result | ||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
| # return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this to make the code cleaner
tofuwen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great, I think this PR is ready to be merged! cc @kunwuz
| class CIT(object): | ||
| def __init__(self, data, method='fisherz', **kwargs): | ||
|
|
||
| def CIT(data, method='fisherz', **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MarkDana how about add a todo here to remind us later we want to remove the kargs argument?
| var = Xs + Ys + condition_set | ||
| sub_corr_matrix = self.correlation_matrix[np.ix_(var, var)] | ||
| try: | ||
| inv = np.linalg.inv(sub_corr_matrix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ha, ic. I missed the later "except".... nvm
Updated files:
cit.py: Last time we rewrite all cit functions into oneCITclass, with all methods in one class. This time we further separate each test method into a subclass inherited from a base classCIT_Base.How to use the new class(es):
though in code before,
CITis a class while nowCITis a function API that returns the respective class. So an alternative way of writing code above is:Issues on MVPC's inaccurate fisherz result is solved. It's due to samplesize's change (my fault). Code logic is consistent as before.
Functions for cit's resume-from-break-point is added in the
CIT_Baseclass. I will create a new pr for reference.Test plan:
Same as #46 Rewrite CITests as a class && re-use covariance matrix for fisherz: