In [1]:
import torch
from functools import partial

from TRUNCLexer import *
from TRUNCParser import *
from TRUNCListener import *


In [2]:
def ineq_func(self,comp):
    mu = comp.gm.mu[0]
    sigma = comp.gm.sigma[0]
    final_pi = []
    final_mu = []
    final_sigma = []
    for part in product(*[range(len(mean)) for mean in self.aux_means]):
        # for a given combination of components of the auxiliary variables, creates a new component extending comp
        aux_pi = 1
        aux_mean = list(copy(mu))
        aux_sigma = []
        ineq_coeff = np.array(copy(self.coeff))
        ineq_const = self.const
        for p,q in zip(range(len(self.aux_means)), part):
            aux_pi = aux_pi*self.aux_pis[p][q]
            aux_mean.append(self.aux_means[p][q])
            aux_sigma.append(self.aux_covs[p][q])
        aux_mean = np.array(aux_mean)
        aux_sigma = np.diag(aux_sigma)
        aux_cov = np.block([[sigma, np.zeros((len(sigma), len(aux_sigma)))], [np.zeros((len(aux_sigma), len(sigma))), aux_sigma]])
        # substitute deltas
        delta_idx = np.where(np.diag(aux_cov) < delta_tol)[0]
        ineq_const -= np.array(self.coeff)[delta_idx].dot(aux_mean[delta_idx])
        ineq_coeff[delta_idx] = np.zeros(len(delta_idx))
        # if all variables were deltas return
        if np.all(np.array(ineq_coeff) == 0):
            if (self.type == '>' and ineq_const < 0) or (self.type == '>=' and ineq_const <= 0) or (self.type == '<' and ineq_const > 0) or (self.type == '<=' and ineq_const >= 0):
                new_P = 1.
            else:
                new_P = 0.
            new_mu = mu
            new_sigma = sigma
        # else compute truncated distribution
        else:
            # STEP 1: change variables
            norm = np.linalg.norm(ineq_coeff)
            ineq_coeff = np.array(ineq_coeff)/norm
            ineq_const = ineq_const/norm
            A = find_basis(ineq_coeff)           # maybe instead of A a vector can be used to improve scalability (?)
            transl_mu = A.dot(aux_mean)
            transl_sigma = A.dot(aux_cov).dot(A.transpose())
            # STEP 2: finds the indices of the components that needs to be transformed
            transl_alpha = np.zeros(len(transl_mu))
            transl_alpha[0] = 1
            indices = select_indices(transl_alpha, transl_sigma)
            # STEP 3: creates reduced vectors taking into account only the coordinates that need to be transformed
            red_transl_alpha = reduce_indices(transl_alpha, indices)
            red_transl_mu = reduce_indices(transl_mu, indices)
            red_transl_sigma = reduce_indices(transl_sigma, indices) 
            # STEP 4: creates the hyper-rectangle to integrate on
            a = np.ones(len(red_transl_alpha))*(-np.inf)
            b = np.ones(len(red_transl_alpha))*(np.inf)
            if self.type=='>' or self.type=='>=':
                a[0] = ineq_const
            if self.type=='<' or self.type=='<=':
                b[0] = ineq_const   
            # STEP 5: compute moments in the transformed coordinates
            new_P, new_red_transl_mu, new_red_transl_sigma = compute_moments(red_transl_mu, red_transl_sigma, a, b)
            # STEP 6: recreates extended vectors
            new_transl_mu = extend_indices(new_red_transl_mu, transl_mu, indices)
            new_transl_sigma = extend_indices(new_red_transl_sigma, transl_sigma, indices)
            # STEP 7: goes back to older coordinates
            d = len(comp.var_list)
            A_inv = np.linalg.inv(A)
            new_mu = A_inv.dot(new_transl_mu)[:d]
            new_sigma = A_inv.dot(new_transl_sigma).dot(A_inv.transpose())[:d,:d]
            end = time()
        # append new values
        final_pi.append(aux_pi*new_P)
        final_mu.append(new_mu)
        final_sigma.append(new_sigma)
    return GaussianMix(final_pi, final_mu, final_sigma)

In [9]:
class TruncRule(TRUNCListener):
    
    def __init__(self, var_list, data):
        self.var_list = var_list
        self.data = data
        self.type = None
        self.coeff = torch.zeros(len(var_list))
        self.const = torch.tensor(0.)
        self.func = None
        
        self.aux_pis = []
        self.aux_means = []
        self.aux_covs = []
        
    def enterIneq(self, ctx):
        self.type = ctx.inop().getText()
        if not ctx.const().NUM() is None:
            self.const = torch.tensor(float(ctx.const().NUM().getText()))
        elif not ctx.const().idd() is None:
            self.const = ctx.const().idd().getValue(self.data)
                
    
    def enterLexpr(self, ctx):
        self.flag_sign = torch.tensor(1.)

            
    def exitLexpr(self, ctx):
        self.func = partial(ineq_func,self)
        
        
    def enterMonom(self,ctx):
        if ctx.var().gm() is None:
            # monom in the form const? '*' (IDV | idd)
            ID = ctx.var()._getText(self.data)
            if not ctx.const() is None:
                if not ctx.const().NUM() is None:
                    coeff = self.flag_sign*torch.tensor(float(ctx.const().NUM().getText()))
                elif not ctx.const().idd() is None:
                    coeff = self.flag_sign*torch.tensor(ctx.const().idd().getValue(self.data))
            else:
                coeff = self.flag_sign
            idx = self.var_list.index(ID)
            self.coeff[idx] = coeff
        # monom in the form const? '*' gm
        else:
            self.aux_pis.append(torch.tensor(eval(ctx.var().gm().list_()[0].getText())))
            self.aux_means.append(torch.tensor(eval(ctx.var().gm().list_()[1].getText())))
            self.aux_covs.append(torch.pow(torch.tensor(eval(ctx.var().gm().list_()[2].getText())),2))
            if not ctx.const() is None:
                if not ctx.const().NUM() is None:
                    coeff = self.flag_sign*torch.tensor(float(ctx.const().NUM().getText()))
                elif not ctx.const().idd() is None:
                    coeff = self.flag_sign*torch.tensor(ctx.const().idd().getValue(self.data))
            else:
                coeff = self.flag_sign 
            self.coeff = torch.hstack([self.coeff, coeff])            
            
    def enterSub(self, ctx):
        self.flag_sign = torch.tensor(-1.)
        
    def enterSum(self, ctx):
        self.flag_sign = torch.tensor(1.)
        
    def enterEq(self, ctx):
        self.type = ctx.eqop().getText()
        idx = self.var_list.index(ctx.var()._getText(self.data))
        self.coeff[idx] = torch.tensor(1.)
        if not ctx.const() is None:
            if not ctx.const().NUM() is None:
                self.const = torch.tensor(float(ctx.const().NUM().getText()))
            elif not ctx.const().idd() is None:
                self.const = torch.tensor(ctx.const().idd().getValue(self.data))
        self.func = partial(eq_func,self)

In [10]:
trunc = '0.5*gm([0.5, 0.5], [0.,1.], [1.,1.]) > 0'
var_list = ['x', 'y']
data = {}

lexer = TRUNCLexer(InputStream(trunc))
stream = CommonTokenStream(lexer)
parser = TRUNCParser(stream)
tree = parser.trunc()
trunc_rule = TruncRule(var_list, data)
walker = ParseTreeWalker()
walker.walk(trunc_rule, tree) 

In [11]:
trunc_rule.var_list
trunc_rule.data
trunc_rule.type
trunc_rule.coeff 
#trunc_rule.const
#trunc_rule.func
#trunc_rule.aux_pis
#trunc_rule.aux_means
#trunc_rule.aux_covs

tensor([0.0000, 0.0000, 0.5000])