## Python classes for segment tree
* Sum of elements in a range [l,r] in logn time.
* Add a value to each element in the range [l,r] in logn time. Lazy propagation. 
* Initialization takes nlogn time. 
* Top down implementation.
* Support modification of an item in logn time.
* Can easily be modified for other functions such as min or max.


In [35]:
class Node:
    def __init__(self,v,left=None,right=None):
        self.item=v;
        self.left=left;
        self.right=right;
        #lazy propagation
        self.scaleAdd=0;
    def getValue(self):
        return self.item+self.scaleAdd;
        
class SegmentTree:
    def __init__(self,A):
        self.A=A;
        self.N=len(A);
        self.root=self.build(0,self.N-1);

    def build(self,i,j):
        if i==j: return Node(self.A[i]);
        mid=(i+j)//2;
        left=self.build(i,mid);
        right=self.build(mid+1,j);
        
        return Node(left.item+right.item,left,right); #modify based on function

    def rangeSum(self,curRoot,i,j,l,r):
        
        if i>r or j<l: return 0; # modify based on funtion
        if l<=i and j<=r: return curRoot.getValue(); 
        mid=(i+j)//2;
        
        s1=self.rangeSum(curRoot.left,i,mid,l,r);
        s2=self.rangeSum(curRoot.right,mid+1,j,l,r);
        return s1+s2+curRoot.scaleAdd; #function: sum
        # for min: min(s1,s2)+scaleAdd
        
    def getRangeSum(self,l,r):
        return self.rangeSum(self.root,0,self.N-1,l,r);

    def _updateSingle(self,curRoot,i,j,idx,value):
        if i==j and i==idx: 
            curRoot.item=value;
            
            return ;
        
        if i>idx or j<idx: return ; #unnessary
        mid=(i+j)//2;
        if idx<=mid:
            self._updateSingle(curRoot.left,i,mid,idx,value);
        else:
            self._updateSingle(curRoot.right,mid+1,j,idx,value);
        
        curRoot.item=curRoot.left.item+curRoot.right.item; # function: sum
        
    def updateSingle(self,idx,value):
        self._updateSingle(self.root,0,self.N-1,idx,value);
    def _updateRange(self,curRoot,i,j,l,r,v):
        #Lazy propagation 
        if i>r or j<l: return ; # modify based on funtion
        if l<=i and j<=r:
            curRoot.scaleAdd+=v;
            return;
        mid=(i+j)//2;
        self._updateRange(curRoot.left,i,mid,l,r,v);
        self._updateRange(curRoot.right,mid+1,j,l,r,v);
        curRoot.item=curRoot.left.getValue()+curRoot.right.getValue(); # function: sum
        #return;
    def updateRange(self,l,r,v):
        #add v to each element in range [l,r]
        self._updateRange(self.root,0,self.N-1,l,r,v);
        

## Test 

In [36]:
#test
A=[2,3,5,7,1,7,9,5,7,1];
st=SegmentTree(A);

print st.getRangeSum(2,4)
st.updateSingle(3,9); # A[3]=9
st.updateRange(2,5,2); # A[2:5]+=2

print st.getRangeSum(2,4)



13
19


In [37]:
def showUpdatedArray(N):
    for i in xrange(N):
        v= st.getRangeSum(i,i)
        print 'A[',i,']',v;
        
showUpdatedArray(st.N);    

A[ 0 ] 2
A[ 1 ] 3
A[ 2 ] 7
A[ 3 ] 11
A[ 4 ] 3
A[ 5 ] 9
A[ 6 ] 9
A[ 7 ] 5
A[ 8 ] 7
A[ 9 ] 1


## Print whole tree and node values

In [38]:
#print Whole tree 
def traverse(root,i,j):

    print "[",i,j,"]",root.item,root.scaleAdd;
    if i==j: 
        return
    mid=(i+j)//2;
    traverse(root.left,i,mid);
    traverse(root.right,mid+1,j);

    
traverse(st.root,0,st.N-1);    

[ 0 9 ] 55 0
[ 0 4 ] 24 0
[ 0 2 ] 12 0
[ 0 1 ] 5 0
[ 0 0 ] 2 0
[ 1 1 ] 3 0
[ 2 2 ] 5 2
[ 3 4 ] 10 2
[ 3 3 ] 9 0
[ 4 4 ] 1 0
[ 5 9 ] 31 0
[ 5 7 ] 23 0
[ 5 6 ] 18 0
[ 5 5 ] 7 2
[ 6 6 ] 9 0
[ 7 7 ] 5 0
[ 8 9 ] 8 0
[ 8 8 ] 7 0
[ 9 9 ] 1 0
