In [1]:
import numpy as np
import random

In [4]:
class Scheme(object):
    def __init__(self, rule, times, cens, feat_means):
        self.rule = rule
        self.times = times
        self.cens = cens
        self.feat_means = feat_means
        self.shape = times.shape[0]
        
    def join(self, sch):
        self.rule = f"({self.rule})|({sch.rule})"
        self.times = np.hstack([self.times, sch.times])
        self.cens = np.hstack([self.cens, sch.cens])
        self.feat_means = (self.feat_means + sch.feat_means)/2
        return self
    
    def copy(self):
        return Scheme(self.rule, self.times, self.cens, self.feat_means)
        
    def get_str_rules(self):
        return "("+"&".join(self.rule)+")"
        
    def get_subschemes(self, min_size=5, top=3):
        ret = {self.get_str_rules(): self.copy()}
        return ret
    
    def visualize(self):
        pass

class FilledSchemeStrategy(object):
    def __init__(self, schemes_list):
        self.schemes_dict = {sch.get_str_rules(): sch for sch in schemes_list}
    
    def join(self, fss):
        for k, v in fss.schemes_dict.items():
            if k in self.schemes_dict.keys():
                self.schemes_dict[k].join(v)
            else:
                self.schemes_dict[k] = v
    
    def join_nearest_leaves(self, sign_thres=0.05, diff_func=random.random):
        def delete_k_from_dict(d, del_k):
            d_ = dict()
            for k in d.keys():
                if not(del_k in k.split("#")):
                    d_[k] = d[k]
            return d_
            
        base = self.schemes_dict
        diff_dict = dict()
        for i1, l1 in enumerate(base.keys()):
            for i2, l2 in enumerate(base.keys()):
                if i2 > i1:
                    diff_dict[l1+'#'+l2] = diff_func()
                    #scrit.logrank_fast(base[l1].times, base[l2].times, base[l1].cens, base[l2].cens)
        while len(base) > 1:
            max_pair_key, max_p_val = max(diff_dict.items(), key=lambda x: x[1])
            print('Максимальное P-value:', max_p_val)
            if max_p_val < sign_thres:
                break
            f_l, s_l = max_pair_key.split('#')
            new_sch_name = f_l + '|' + s_l
            new_sch = base[f_l].copy()
            new_sch.join(base[s_l])
            for k in [f_l, s_l]:
                diff_dict = delete_k_from_dict(diff_dict, k)
                del base[k]
            for k in base.keys():
                diff_dict[new_sch_name+'#'+k] = diff_func()
            base[new_sch_name] = new_sch
            print('Цепочки схем:', f_l, s_l)
            print('Заменяются на:', new_sch_name)
        self.schemes_dict = base

In [3]:
a = Scheme(rule=["(a > 0.5)", "(b <= 0.5)"], 
           times=np.array([1, 2, 3, 4, 5]), 
           cens=np.array([1, 0, 1, 0, 1]), 
           feat_means = np.array([1, 10, 100, 1000]))
b = Scheme(rule=["(a > 0.5)", "(b > 0.5)"], 
           times=np.array([6, 7, 8, 9, 10]), 
           cens=np.array([1, 1, 1, 1, 1]), 
           feat_means = np.array([10, 10, 10, 10]))
c = Scheme(rule=["(a <= 0.5)", "(c <= 1000)"], 
           times=np.array([11, 12, 13, 14, 15]), 
           cens=np.array([1, 1, 1, 1, 1]), 
           feat_means = np.array([1, 1, 1, 1]))
d = Scheme(rule=["(a <= 0.5)", "(c > 1000)"], 
           times=np.array([16, 17, 18, 19, 20]), 
           cens=np.array([0, 1, 0, 0, 1]), 
           feat_means = np.array([100, 100, 100, 100]))

fss1 = FilledSchemeStrategy([a, b])
fss2 = FilledSchemeStrategy([c, d])
fss1.join(fss2)
fss1.join_nearest_leaves(sign_thres=0.5)

{'((a > 0.5)&(b <= 0.5))#((a > 0.5)&(b > 0.5))': 0.8831176566142859, '((a > 0.5)&(b <= 0.5))#((a <= 0.5)&(c <= 1000))': 0.19999369617108564, '((a > 0.5)&(b <= 0.5))#((a <= 0.5)&(c > 1000))': 0.3800999247821, '((a > 0.5)&(b > 0.5))#((a <= 0.5)&(c <= 1000))': 0.40705070905695206, '((a > 0.5)&(b > 0.5))#((a <= 0.5)&(c > 1000))': 0.7170625637112442, '((a <= 0.5)&(c <= 1000))#((a <= 0.5)&(c > 1000))': 0.4758376523947492}
{'((a > 0.5)&(b <= 0.5))': <__main__.Scheme object at 0x000001D3211134F0>, '((a > 0.5)&(b > 0.5))': <__main__.Scheme object at 0x000001D321112B00>, '((a <= 0.5)&(c <= 1000))': <__main__.Scheme object at 0x000001D321111270>, '((a <= 0.5)&(c > 1000))': <__main__.Scheme object at 0x000001D321113430>}
Максимальное P-value: 0.8831176566142859
Цепочки схем: ((a > 0.5)&(b <= 0.5)) ((a > 0.5)&(b > 0.5))
Заменяются на: ((a > 0.5)&(b <= 0.5))|((a > 0.5)&(b > 0.5))
{'((a <= 0.5)&(c <= 1000))#((a <= 0.5)&(c > 1000))': 0.4758376523947492, '((a > 0.5)&(b <= 0.5))|((a > 0.5)&(b > 0.5))#((