In [1]:
import pickle

class CollectForestInfo:
    intermediateDict = None
    residualDict = None
    descendant_dict = None
    repCommMotifSeq_dict = None # save tree's common motif seq. list
    treeList = None
    
    def __init__(self, intermidiatePicklePath, residualPicklePath, includePairwiseTree, forceMerge=False):
        
        # read the results from pickle files
        with open(intermidiatePicklePath, 'rb') as handle:
            self.intermediateDict = pickle.load(handle)
        with open(residualPicklePath, 'rb') as handle:
            self.residualDict = pickle.load(handle)
        
        self._setForestOutputs(forceMerge)
        self._setTreeList(includePairwiseTree)
        
        
    # get descendant and motif information from pickle
    def _setForestOutputs(self, forceMerge):

        descendant_dict = dict()
        repCommMotifSeq_dict = dict()
        intermediate_list = sorted(self.intermediateDict.items(), key=lambda x : x[0])

        for item in intermediate_list:
            value = item[1] # get original dict value
            score = value[0]
            clusterName = value[1][0]
            memberSet = value[2]
            commonMotifSeq = value[1][1] # list of common motif seq.

            descendants = set()
            for member in memberSet:
                if forceMerge:
                    descendants.add(member)
                else:
                    if member[0] == "G":
                        for descendant in descendant_dict[member]:
                            descendants.add(descendant)
                    else:
                        descendants.add(member)
            descendant_dict[clusterName] = descendants
            repCommMotifSeq_dict[clusterName] = commonMotifSeq

        self.descendant_dict = descendant_dict
        self.repCommMotifSeq_dict = repCommMotifSeq_dict
    
    
    # get those residual trees which isn't sigular
    # collect their clusterName into notLonerList.
    def _setTreeList(self, includePairwiseTree):

        notLonerList = []

        for key, value in self.residualDict.items():
            clusterName = value[0][0]
            motifsList = value[0][1]
            members = value[1]

            notLoner = False

            if(len(members) > 1):
                if(includePairwiseTree):
                    notLoner = True

                else:   # remove 2-member pairs
                    if( len(members) == 2):
                        for member in members:
                            if member[0] == 'G':
                                notLoner = True
                                break
                    else:
                        notLoner = True

            if(notLoner):
                notLonerList.append((clusterName, members))

        notLonerList = sorted(notLonerList, key=lambda x: int(x[0][1::]), reverse=False)

        self.treeList = notLonerList

#     def getGroupMotif_dict(self): # get motif sequence of each group (not only tree root)
#         return self.groupMotif_dict
        
    def getDescendant_dict(self): # get all descendant list(including root and middle nodes)
        return self.descendant_dict
    
    def getTreeList(self): # get tree root list.
        return self.treeList
    
    def getTreeRootNameList(self):
        nameList = list()
        for treeRoot in self.treeList:
            rootName = treeRoot[0] # treeRoot = (ParentNodeName, {children_Node_Names})
            nameList.append(rootName)
        return nameList
    
    def getTreeRootCount(self): # get how many trees in forest
        return len(self.treeList)
    
    def getForestMembers(self):
        forestMemberSet = set()
        trMember_dict = self.getTreeMembers_dict()
        for rootName in trMember_dict:
            members = trMember_dict[rootName]
            forestMemberSet.update(members)
        return forestMemberSet
    
    def getForestMemberCount(self): # return how many malwares in forest
        return len(self.getForestMembers())
    
    def getTreeMembers_dict(self): # key: treeRootName; val: treeMemberSet
        treeMember_dict = dict()
        rootNames = self.getTreeRootNameList()
        for rootName in rootNames:
            members = self.descendant_dict[rootName] # get Node's all descendants
            treeMember_dict[rootName] = members
        return treeMember_dict
    
    def getTreeMembers(self, rootName): # return members (set) in specific treeRoot
        trMember_dict = self.getTreeMembers_dict()
        return trMember_dict[rootName] # type == set()
    
    def getRepAPISeq_dict(self): # key: treeRootName; val: RepAPISeq <list>
        repAPISeq_dict = dict()
        rootNames = self.getTreeRootNameList()
        for rootName in rootNames:
            repAPISeq = self.getRepAPISeq(rootName) # get Rep API Seq of each root
            repAPISeq_dict[rootName] = repAPISeq # add into dict
        return repAPISeq_dict
            
    def getRepAPISeq(self, rootName): # get Rep API Seq of root
        repMotifList = list()
        commMotifSeq = self.repCommMotifSeq_dict[rootName] # get CMS list
        
        commonAPISeq = [] # merge all motif's APIs
        for motifAPI in commMotifSeq:
            commonAPISeq.extend(motifAPI)
        return commonAPISeq
    
    def getRepMotifCount(self, rootName): # get motif count of root
        commMotifSeq = self.repCommMotifSeq_dict[rootName]
        return len(commMotifSeq)
    
    def getRepMotifSequence(self, rootName): # get motif sequence of root
        return self.repCommMotifSeq_dict[rootName]

In [2]:
# ### unit test

# pkl_dir_path = 'output/RasMMA-test/eggnog_0.8/pickle/'
# interPkl = pkl_dir_path + 'eggnog_0.8_intermediate.pickle'
# resPkl = pkl_dir_path + 'eggnog_0.8_residual.pickle'
# TreeUtil = CollectForestInfo
# testFamilyForest = TreeUtil(interPkl, resPkl, True)

# rootNames = testFamilyForest.getTreeRootNameList()
# for root in rootNames:
#     rootAPISeq = testFamilyForest.getRepAPISeq(root)
#     motifCount = testFamilyForest.getRepMotifCount(root)
#     print(len(rootAPISeq), motifCount)
    
#     motifSeq = testFamilyForest.getRepMotifSequence(root)
#     motifLenList = [len(motif) for motif in motifSeq]
#     print(motifLenList)

159 10
[50, 2, 12, 1, 6, 2, 5, 1, 13, 67]
