# GraphGuard

***Locate and find Classes in Apks with updated Obfuscation Mapping***


Processing Steps:
* Usage of Strings
* Method Signatures (Modifiers, Parameter Types, Number of Parameters...)
* Other methods in same class
* Analyze Method Calls (from and to) via Call Graph (Distance, Offsets, Graph Analysis)

In [None]:
%matplotlib notebook

In [None]:
import unittest
from collections import defaultdict, Counter
from os import path

from androguard.core.analysis.analysis import MethodAnalysis
from androguard.core.bytecode import FormatClassToJava
from androguard.misc import AnalyzeAPK
from androguard.session import Save, Session, Load

# Loading Androguard

The following code loads the files and starts Androguard

It should support multiprocessing, however the Pipe communication seems to break when transmitting the processed Androguard Objects. I suspect the Object is simply too big for Pickle to serialize or another component in the transmitting chain.

In [None]:
AG_SESSION_FILE = "./Androguard.ag"
MAX_USAGE_COUNT_STR = 20
MULTIPROCESS_FILES = True

file_paths = (
    "../../../Downloads/com.snapchat.android_10.85.5.74-2067_minAPI19(arm64-v8a)(nodpi)_apkmirror.com.apk",
    "../../../Downloads/com.snapchat.android_10.86.5.61-2069_minAPI19(arm64-v8a)(nodpi)_apkmirror.com.apk"
)

In [None]:
def load_androguard(file_path, force_reload=False, write_session=True):
    # Writing and Loading sessions currently cause a Kernel Disconnect or an EOF Error
    if (not force_reload) and path.exists(AG_SESSION_FILE):
        print("Loading Existing Session")
        s = Load(AG_SESSION_FILE)
    else:
        print("Loading Session from Apk")
        s = Session()
        a, d, dx = AnalyzeAPK(file_path, s)
        if write_session:
            print("Saving Loaded Session to", AG_SESSION_FILE)
            Save(s, AG_SESSION_FILE)
    return a, d, dx

In [None]:
# Multiprocessing not working, probably same issue regarding serialization than when trying to write/load 
# androguard sessions
"""
def multiprocess_files(file_paths):
    parent_conn, child_conn = multiprocessing.Pipe(False)


    def post_result(file_path, conn):
        value = load_androguard(file_path, True, False)
        conn.send((file_path, value))

    ps =  [multiprocessing.Process(target=post_result, args=(f, child_conn)) for f in file_paths]

    def apply_map(f, i):
        for x in i:
            f(x)
    
    assert len(file_paths) == 2
    print("Starting multiprocessing Files")
    
    # Serialization with Pickle requires higher recursion limit
    import sys
    previous_recursion = sys.getrecursionlimit()
    sys.setrecursionlimit(50000)
    
    
    apply_map(multiprocessing.Process.start, ps)

    values = (q.get(), q.get())
    r = tuple(map(lambda x: x[1], sorted(values, key=lambda x: file_paths.index(x[0]))))
    
    
    apply_map(multiprocessing.Process.join, ps)
    
    print("Finished all processes")
    sys.setrecursionlimit(previous_recursion)
    return r

if MULTIPROCESS_FILES:
    (a, d, dx), (a2, d2, dx2) = multiprocess_files(file_paths)
else:
    (a, d, dx), (a2, d2, dx2) = tuple(map(lambda x: load_androguard(x, True, False), file_paths))
"""
a, d, dx = load_androguard(file_paths[0], True, False)

### Utility Functions to work with Androguard and Java Representations

* Converting Parameter types to TypeDescriptor Format
* Strip return type (not used for hooking)
* Method Representation Format

Loaded with Unit Tests

In [None]:
# https://source.android.com/devices/tech/dalvik/dex-format#typedescriptor
type_descriptors = {
    "void": "V",
    "boolean": "Z",
    "byte": "B",
    "short": "S",
    "char": "C",
    "int": "I",
    "long": "J",
    "float": "F",
    "double": "D"
}

type_ds_reversed = { v : k for k, v in type_descriptors.items() }

def get_as_type_descriptor(arg):
    if arg.endswith("[]"):
        return "[" + get_as_type_descriptor(arg[:-2])
    if arg in type_descriptors:
        return type_descriptors[arg]
    return FormatClassToJava(arg)

In [None]:
def strip_return_descriptor(descriptor):
    return descriptor[1:descriptor.index(")")]

In [None]:
def pretty_format_class(class_name):
    if class_name.startswith("["):
        return pretty_format_class(class_name[1:]) + "[]"
    if class_name in type_ds_reversed:
        return type_ds_reversed[class_name]
    return class_name[1:-1].replace("/", ".")


def get_pretty_params(descriptor):
    return map(pretty_format_class, strip_return_descriptor(descriptor).split(" "))

In [None]:
def get_method_repr(class_name, method_name, param_types):
    return f"{class_name}#{method_name}({param_types})"

def pretty_format_ma(ma):
    return get_method_repr(pretty_format_class(ma.class_name), ma.name, ", ".join(get_pretty_params(str(ma.descriptor))))

In [None]:
tests_1 = (
    ("java.lang.String", "Ljava/lang/String;"),
    ("java.lang.String[]", "[Ljava/lang/String;"),
    ("void", "V"),
    ("int[]", "[I"),
    ("char", "C"), 
    ("java.lang.Object[][]", "[[Ljava/lang/Object;"),
    ("ABC", "LABC;")
)

tests_2 = (
    ("(I)I", "I"), 
    ("(C)Z", "C"),
    ("(Ljava/lang/CharSequence; I)I", "Ljava/lang/CharSequence; I")
)

class TestFunction(unittest.TestCase):
    def test_type_descriptor(self):
        for test, val in tests_1:
            self.assertEqual(get_as_type_descriptor(test), val)
    
    def test_strip_return(self):
        for test, val in tests_2:
            self.assertEqual(strip_return_descriptor(test), val)
    
    def test_pretty_class(self):
        for val, test in tests_1:
            self.assertEqual(pretty_format_class(test), val)


unittest.main(argv=[''], verbosity=2, exit=False)

# Method Declarations

Lightweight Method Declaration for internal representation of a Method / Hook.

Not keeping Androguard Objects in memory to avoid high memory usage.

In [None]:
class MethodDec:
    def __init__(self, class_name, name, *param_types):
        self.name = name
        self.class_name = class_name
        self.param_types = param_types
    
    
    def get_formatted_param_types(self):
        return list(map(get_as_type_descriptor, self.param_types))
    
    
    def param_types_repr(self):
        return " ".join(self.get_formatted_param_types())
    
    
    def get_formatted_class(self):
        return FormatClassToJava(self.class_name)
    
    
    def pretty_format(self):
        return get_method_repr(self.class_name, self.name, ", ".join(self.param_types))
    

    def __repr__(self):
        return f'MethodDec({self.pretty_format()})'
    
    
    def equals_ma(self, ma):
        return self.name == ma.name and \
            self.param_types_repr() == strip_return_descriptor(str(ma.get_descriptor()))
        

### List of Methods

Defining the list of methods to find (obviously requires full class names)

In [None]:
decs_to_find = [
    MethodDec("rD5", "a", "rD5", "qD5"),
    MethodDec("MSg", "j0", "SGd")
]

# Processing

## Strings as Characteristics

Extracting Strings used either in the given methods directly or in the classes the methods define

In [None]:
# Key:   TypeDescriptor Representation of class
# Value: Androguard Class Analysis Object

resolved_classes = { i: dx.get_class_analysis(i)
                    for i in map(lambda x: FormatClassToJava(x.class_name), decs_to_find) }

In [None]:
resolved_methods = []

# Loop through resolved classes
for method_dec, (class_name, class_analysis) in zip(decs_to_find, resolved_classes.items()):
    # Loop through all methods in the resolved class
    for method in class_analysis.get_methods():
        if not method_dec.equals_ma(method):
            continue
        
        # Matching Method Declaration found
        print("Found Class and Method", method_dec)
        resolved_methods.append(method)
        break
    else:
        raise Exception(f"One method was not resolved: {method_dec}")

### Utility functions for working with dx.get_strings()

Filters Strings and xrefs to Strings. Only allow strings with (#xrefs < MAX_USAGE_COUNT_STR) to be used as characteristic to locate classes

In [None]:
def get_filtered_strs(dx):
    """
    Loops through all strings that are referenced less than MAX_USAGE_COUNT_STR times and hence can be 
    used as characteristic for finding methods or classes.
    """
    return ((s, xrefs) 
            for s, xrefs in map(lambda s: (s, s.get_xref_from()), dx.get_strings()) 
            if len(xrefs) <= MAX_USAGE_COUNT_STR)


def get_xrefs_if_usable(s):
    """
    Loops through xrefs of a string only if the number of references does not exceed MAX_USAGE_COUNT_STR.
    """
    xrefs = s.get_xref_from()
    if len(xrefs) > MAX_USAGE_COUNT_STR:
        return
    yield from xrefs

Building Maps of MethodDec and ClassNames associated to lists containing strings used in them

In [None]:
m_strs, c_strs = defaultdict(list), defaultdict(list)

for s, xrefs in get_filtered_strs(dx):
    for x in xrefs:
        c_ref, m_ref = x

        if c_ref.name not in resolved_classes:
            # XReference not in a Class or method that we need to find
            continue

        # Loop through each method and find methods in this class
        for r_m, m_dec in zip(resolved_methods, decs_to_find):
            if r_m.class_name != c_ref.name:
                continue

            # String is used in a class we need to find
            c_strs[c_ref.name].append(s.value)
            
            if m_ref == r_m:
                # String is used in this method
                m_strs[m_dec].append(s.value)

### Count occurrences of strings

Converting list of strings to a Counter object for faster comparisons

In [None]:
m_strs = {k: Counter(v) for k, v in m_strs.items()}
c_strs = {k: Counter(v) for k, v in c_strs.items()}

# If no strings have been found in method, still insert a Counter of 0
for m in decs_to_find:
    if m not in m_strs:
        m_strs[m] = Counter()

In [None]:
def flat_map(f, li):
    """
    Maps values with function f recursively on all Iterables (except Strings)
    Flattened by using recursive Subgenerator Delegation
    """
    from collections.abc import Iterable
    for i in li:
        # str will cause a recursion depth error (Iterator of str returns Iterable str)
        if isinstance(i, Iterable) and not isinstance(i, str):
            yield from flat_map(f, i)
        else:
            yield f(i)

### Searching for Found Strings

Tries to resolve Classes and methods with the strings previously found

* Loading second Apk File
* Find All Strings found previously, build Map of potential matches (ClassName/Method to Counter)
* Filter Potential Matches by comparing both Counter Objects

In [None]:
# TODO: Figure out how to clear memory for a, d, dx in IPython and Jupyter


a2, d2, dx2 = load_androguard(file_paths[1], True, False)

In [None]:
m_strs2, c_strs2 = defaultdict(list), defaultdict(list)

for s in dx2.get_strings():
    for m_dec, m_set in m_strs.items():
        
        c_name = m_dec.class_name
        try:
            c_set = c_strs[FormatClassToJava(c_name)] 
        except KeyError:
            c_set = set()
        
        if s.value in m_set:
            for x in get_xrefs_if_usable(s):
                m_ref = x[1]
                m_strs2[m_ref].append(s.value)
        if s.value in c_set:
            for x in get_xrefs_if_usable(s):
                c_ref = x[0]
                c_strs2[c_ref.name].append(s.value)

In [None]:
m_strs2 = {k: Counter(v) for k, v in m_strs2.items()}
c_strs2 = {k: Counter(v) for k, v in c_strs2.items()}

In [None]:
counters = set()
matching_cs, matching_ms = {}, {}

for k2, c2 in c_strs2.items():
    for k1, c1 in c_strs.items():
        if c1 == c2:
            if str(c1) in counters:
                print("Found Matching Counter again...")
            counters.add(str(c1))
            matching_cs[k1] = k2
            print("Class Pair:", k1, k2, sep="\n\t* ")

counters.clear()
for k2, c2 in m_strs2.items():
    for k1, c1 in m_strs.items():
        if c1 == c2:
            if str(c1) in counters:
                print("Found Matching Counter again...")
            counters.add(str(c1))
            matching_ms[k1] = k2
            print("Method Pair:",
                  k1.pretty_format(),
                  pretty_format_ma(k2),
                  sep="\n\t+ ")

### Fallback if Class was found

In case the class was found, but the method could not be resolved, check each method of the class for the following criteria:

* Matching #xrefs_to
* Matching #xrefs_from
* Matching Code length

All of these checks are currently strict/exact

In [None]:
c_not_found = set(map(lambda x: x.class_name, decs_to_find)) - set(map(pretty_format_class, matching_cs.keys()))
m_not_found = decs_to_find - matching_ms.keys()

for m in m_not_found:
    if m.class_name in c_not_found:
        print("Could not find class of method", m.pretty_format())
        continue
    
    class_name1 = FormatClassToJava(m.class_name)
    class_name2 = matching_cs[class_name]
    
    candidates = []
    for ma1 in dx.get_class_analysis(class_name1).get_methods():
        if not m.equals_ma(ma1):
            continue

        already_found = False

        for ma2 in dx2.get_class_analysis(class_name2).get_methods():
            c_functions = [
                MethodAnalysis.get_length, 
                lambda x: len(x.get_xref_from()), 
                lambda x: len(x.get_xref_to())
            ]
            
            if all((c_fun(ma1) == c_fun(ma2)) for c_fun in c_functions):
                if already_found:
                    print("Another Match detected...")
                already_found = True
                print("Found exact matching criteria: ", pretty_format_ma(ma1), pretty_format_ma(ma2), sep="\n\t* ")
            
        if not already_found:
            print("Could not find any match for function", pretty_format_ma(ma1))
        
        break