# 追踪信息流

如果只是不正确，却没有触发crash该怎么办？

In [None]:
import ast
import pdb
import string
import re

## 在内存中构建一个不靠谱的数据库

In [None]:
INVENTORY = """\
1997,van,Ford,E350
2000,car,Mercury,Cougar
1999,car,Chevy,Venture\
"""

VEHICLES = INVENTORY.split('\n')

In [None]:
class SQLException(Exception):
    pass

In [None]:
class DB:
    ###### 数据库初始化 #######
    def __init__(self, db={}):
        # 初始化数据库
        self.db = dict(db)


    ####### 表结构 ##########
    def create_table(self, table, defs):
        # 创建表
        self.db[table] = (defs, [])

    def table(self, t_name):
        # 检索表，返回表结构
        if t_name in self.db:
            return self.db[t_name]
        raise SQLException('Table (%s) was not found' % repr(t_name))

    def column(self, table_decl, c_name):
        # 返回table_decl表，c_name列的定义
        if c_name in table_decl: 
            return table_decl[c_name]
        raise SQLException('Column (%s) was not found' % repr(c_name))

    
    ######## 表的增删改查 ##########
    def sql(self, query):
        methods = [('select ', self.do_select), 
                   ('update ', self.do_update),
                   ('insert into ', self.do_insert),
                   ('delete from', self.do_delete)]
        for key, method in methods:
            if query.startswith(key):
                return method(query[len(key):])
        raise SQLException('Unknown SQL (%s)' % query)


    def do_select(self, query):
        FROM, WHERE = ' from ', ' where '
        table_start = query.find(FROM)
        if table_start < 0:
            raise SQLException('no table specified')

        where_start = query.find(WHERE)
        select = query[:table_start]

        if where_start >= 0:
            t_name = query[table_start + len(FROM):where_start]
            where = query[where_start + len(WHERE):]
        else:
            t_name = query[table_start + len(FROM):]
            where = ''
        _, table = self.table(t_name) # 前面的defs用不着

        if where:
            selected = self.expression_clause(table, "(%s)" % where)
            selected_rows = [hm for i, data, hm in selected if data]
        else:
            selected_rows = table

        rows = self.expression_clause(selected_rows, "(%s)" % select)
        return [data for i, data, hm in rows]
    
    def expression_clause(self, table, statement):
        selected = []
        for i, hm in enumerate(table):
            selected.append((i, self.my_eval(statement, {}, hm), hm))

        return selected
        
    def my_eval(self, statement, g, l):
        try:
            return eval(statement, g, l)
        except:
            raise SQLException('Invalid WHERE (%s)' % repr(statement))
    

    def do_insert(self, query):
        VALUES = ' values '
        table_end = query.find('(')
        t_name = query[:table_end].strip()
        names_end = query.find(')')
        decls, table = self.table(t_name)
        names = [i.strip() for i in query[table_end + 1:names_end].split(',')]

        # verify columns exist
        for k in names:
            self.column(decls, k)

        values_start = query.find(VALUES)

        if values_start < 0:
            raise SQLException('Invalid INSERT (%s)' % repr(query))

        values = [
            i.strip() for i in query[values_start + len(VALUES) + 1:-1].split(',')
        ]

        if len(names) != len(values):
            raise SQLException(
                'names(%s) != values(%s)' % (repr(names), repr(values)))

        # dict lookups happen in C code, so we cant use that
        # 数据库中的每一项用键值对的方式存储，非常臃肿，但便于select查找。
        kvs = {}
        for k,v in zip(names, values):
            for key,kval in decls.items():
                if k == key:
                    kvs[key] = self.convert(kval, v)
        table.append(kvs)

    def convert(self, cast, value):
        try:
            return cast(ast.literal_eval(value))
        except:
            raise SQLException('Invalid Conversion %s(%s)' % (cast, value))
    

    def do_update(self, query):
        SET, WHERE = ' set ', ' where '
        table_end = query.find(SET)

        if table_end < 0:
            raise SQLException('Invalid UPDATE (%s)' % repr(query))

        set_end = table_end + 5
        t_name = query[:table_end]
        decls, table = self.table(t_name)
        names_end = query.find(WHERE)

        if names_end >= 0:
            names = query[set_end:names_end]
            where = query[names_end + len(WHERE):]
        else:
            names = query[set_end:]
            where = ''

        sets = [[i.strip() for i in name.split('=')]
                for name in names.split(',')]

        # verify columns exist
        for k, v in sets:
            self.column(decls, k)

        if where:
            selected = self.expression_clause(table, "(%s)" % where)
            updated = [hm for i, d, hm in selected if d]
        else:
            updated = table

        for hm in updated:
            for k, v in sets:
                # we can not do dict lookups because it is implemetned in C.
                for key, kval in decls.items():
                    if key == k:
                        hm[key] = self.convert(kval, v)

        return "%d records were updated" % len(updated)


    def do_delete(self, query):
        WHERE = ' where '
        table_end = query.find(WHERE)
        if table_end < 0:
            raise SQLException('Invalid DELETE (%s)' % query)
        t_name = query[:table_end].strip()
        _, table = self.table(t_name)
        where = query[table_end + len(WHERE):]
        selected = self.expression_clause(table, "%s" % where)
        deleted = [i for i, d, hm in selected if d]
        for i in sorted(deleted, reverse=True):
            del table[i]
        return "%d records were deleted" % len(deleted)

In [None]:
# 测试下db class
# pdb.set_trace()

def sample_db():
    db = DB()
    inventory_def = {'year': int, 'kind': str, 'company': str, 'model': str}
    db.create_table('inventory', inventory_def)
    return db

db = sample_db()
db.table('inventory')
db.sql('insert into inventory (year, kind, company, model) values (1997, "van", "Ford", "E350")')
db.sql('select year from inventory where year == 1997')
db.sql('update inventory set year = 1998 where year == 1997')
db.sql('delete from inventory where company == "Ford"')

In [None]:
# 使用我们的数据，生成一张表
db = DB()
inventory_def = {'year': int, 'kind': str, 'company': str, 'model': str}
db.create_table('inventory', inventory_def)

def update_inventory(sqldb, vehicle):
    inventory_def = sqldb.db['inventory'][0]
    k, v = zip(*inventory_def.items())
    val = [repr(cast(val)) for cast, val in zip(v, vehicle.split(','))]
    sqldb.sql('insert into inventory (%s) values (%s)' % (','.join(k),
                                                          ','.join(val)))

for V in VEHICLES:
    update_inventory(db, V)

db.db

## Fuzzing SQL

In [None]:
EXPR_GRAMMAR = {
    "<start>": ["<expr>"],
    "<expr>": ["<bexpr>", "<aexpr>", "(<expr>)", "<term>"],
    "<bexpr>": [
        "<aexpr><lt><aexpr>",
        "<aexpr><gt><aexpr>",
        "<expr>==<expr>",
        "<expr>!=<expr>",
    ],
    "<aexpr>": [
        "<aexpr>+<aexpr>", "<aexpr>-<aexpr>", "<aexpr>*<aexpr>",
        "<aexpr>/<aexpr>", "<word>(<exprs>)", "<expr>"
    ],
    "<exprs>": ["<expr>,<exprs>", "<expr>"],
    "<lt>": ["<"],
    "<gt>": [">"],
    "<term>": ["<number>", "<word>"],
    "<number>": ["<integer>.<integer>", "<integer>", "-<number>"],
    "<integer>": ["<digit><integer>", "<digit>"],
    "<word>": ["<word><letter>", "<word><digit>", "<letter>"],
    "<digit>":
    list(string.digits),
    "<letter>":
    list(string.ascii_letters + '_:.')
}


INVENTORY_GRAMMAR = dict(
    EXPR_GRAMMAR, **{
        '<start>': ['<query>'],
        '<query>': [
            'select <exprs> from <table>',
            'select <exprs> from <table> where <bexpr>',
            'insert into <table> (<names>) values (<literals>)',
            'update <table> set <assignments> where <bexpr>',
            'delete from <table> where <bexpr>',
        ],
        '<table>': ['<word>'],
        '<names>': ['<column>,<names>', '<column>'],
        '<column>': ['<word>'],
        '<literals>': ['<literal>', '<literal>,<literals>'],
        '<literal>': ['<number>', "'<chars>'"],
        '<assignments>': ['<kvp>,<assignments>', '<kvp>'],
        '<kvp>': ['<column>=<value>'],
        '<value>': ['<word>'],
        '<chars>': ['<char>', '<char><chars>'],
        '<char>':
        [i for i in string.printable if i not in "<>'\"\t\n\r\x0b\x0c\x00"
         ] + ['<lt>', '<gt>'],
    })

In [None]:
INVENTORY_GRAMMAR_F = dict(INVENTORY_GRAMMAR, **{'<table>': ['inventory']})
from fuzzingbook.fuzzingbook_utils.GrammarFuzzer import GrammarFuzzer

gf = GrammarFuzzer(INVENTORY_GRAMMAR_F)
for _ in range(10):
    query = gf.fuzz()
    print(repr(query))
    try:
        res = db.sql(query)
        print(repr(res))
    except SQLException as e:
        print("> ", e)
        pass
    except:
        traceback.print_exc()
        break
    print()

In [None]:
# crash 并不是唯一的错误指标。
# 起始这也不是问题。选出来的内容，本来就是可以进行处理。
db.sql('select year - 1900 if year < 2000 else year - 2000 from inventory')

In [None]:
# 重现上面的错误
# 类似于systeem这种错误
years = [1997,2000,1999]
ans = []
for year in years:
    ans.append( eval('year - 1900 if year < 2000 else year - 2000',{},{"year":year}) )
print(ans)

One method that allows such differentiation is that of dynamic taint analysis. The idea is to identify the functions that accept user input as sources that taint any string that comes in through them, and those functions that perform dangerous operations as sinks. Finally we bless certain functions as taint sanitizers. The idea is that an input from the source should never reach the sink without undergoing sanitization first. This allows us to use a stronger oracle than simply checking for crashes.（大概意思是：用户输入的数据被进行污点标记。通过它们计算得到的数据，也同样会被标记。当这些污点数据到达sanks的时候，可能会造成安全问题。可以通过无害处理，避免污点数据到达sanks。。而如果不进行无害处理，也无法到达sanks，则说明是安全的）


污点分析可以抽象成一个三元组<sources,sinks,sanitizers>的形式,其中,source 即污点源,代表直接引入不受信任的数据或者机密数据到系统中;sink即污点汇聚点,代表直接产生安全敏感操作(违反数据完整性)或者泄露隐私数据到外界(违反数据保密性);sanitizer即无害处理,代表通过数据加密或者移除危害操作等手段使数据传播不再对软件系统的信息安全产生危害.污点分析就是分析程序中由污点源引入的数据是否能够不经无害处理,而直接传播到污点汇聚点.如果不能,说明系统是信息流安全的;否则,说明系统产生了隐私数据泄露或危险数据操作等安全问题。


[简单理解污点分析技术](https://www.k0rz3n.com/2019/03/01/%E7%AE%80%E5%8D%95%E7%90%86%E8%A7%A3%E6%B1%A1%E7%82%B9%E5%88%86%E6%9E%90%E6%8A%80%E6%9C%AF/#0X03-%E6%B1%A1%E7%82%B9%E5%88%86%E6%9E%90%E5%9C%A8%E5%AE%9E%E9%99%85%E5%BA%94%E7%94%A8%E4%B8%AD%E7%9A%84%E5%85%B3%E9%94%AE%E6%8A%80%E6%9C%AF)

## 字符串污点标记

In [None]:
#############    包装str类。我查了很长时间，没有弄明白。
# new方法创建实例。init方法初始化实例。
# 这里的new方法中的value参数，大概是str类需要的字符串变量，用以创建对象。
# 但是，我没找见str类的__new__方法介绍。
class tstr(str):
    def __new__(cls, value, *args, **kw):
        return str.__new__(cls, value)

    def __init__(self, value, taint=None, **kwargs):
        self.taint = taint
    
    def __repr__(self):
        return tstr(str.__repr__(self), taint=self.taint)
    
    def __str__(self):
        return str.__str__(self)
    
    def __radd__(self, s):
        return self.create(s + str(self))
    
    # 附带消除污点和检查是否污点设置的方法
    def clear_taint(self):
        self.taint = None
        return self

    def has_taint(self):
        return self.taint is not None
    
    # 给已经存在的字符串添加污点标记，返回带污点标记的新字符串
    def create(self, s):
        return tstr(s, taint=self.taint)

## 将字符串处理函数的结果，添加污点标记
def make_str_wrapper(fun):
    def proxy(self, *args, **kwargs):
        res = fun(self, *args, **kwargs)
        return self.create(res)
    return proxy


def informationflow_init_1():
    for name in ['__format__', '__mod__', '__rmod__', '__getitem__', '__add__', '__mul__', '__rmul__',
                 'capitalize', 'casefold', 'center', 'encode',
                 'expandtabs', 'format', 'format_map', 'join', 'ljust', 'lower', 'lstrip', 'replace',
                 'rjust', 'rstrip', 'strip', 'swapcase', 'title', 'translate', 'upper']:
        fun = getattr(str, name)
        setattr(tstr, name, make_str_wrapper(fun))

# 设置setattr的其他方法
# 这个咋不放在__init__方法中呢？
informationflow_init_1()


# INITIALIZER_LIST = [informationflow_init_1]
# def initialize():
#     for fn in INITIALIZER_LIST:
#         fn()

In [None]:
# 由此，经过hello运算的字符串，都会被污点标记
thello = tstr('hello', taint='LOW')

print(thello[0].taint)
print(thello[1:3].taint)
print((tstr('foo', taint='HIGH') + 'bar').taint)
print(('foo' + tstr('bar', taint='HIGH')).taint)
# print((thello += ', world').taint)
print((thello * 5).taint)
print(('hw %s' % thello).taint)
print((tstr('hello %s', taint='HIGH') % 'world').taint)

## 跟踪不信任的输入

In [None]:
class TrustedDB(DB):
    def sql(self, s):
        # 标记为TRUSTED的输入，才能被sql执行
        assert isinstance(s, tstr), "Need a tainted string"
        assert s.taint == 'TRUSTED', "Need a string with trusted taint"
        return super().sql(s)

In [None]:
# 使用之前创建的数据库db进行初始化
bdb = TrustedDB(db.db)

In [None]:
# 执行失败
# bdb.sql(tstr("select year from INVENTORY"))
bdb.sql(tstr("select year from inventory",taint="TRUSTED"))

所以我们需要进行消毒处理。将不信任的输入，转换成信任的输入。

这个转换过程，要求可以判断合法性。

In [None]:
def sanitize(user_input):
    assert isinstance(user_input, tstr)
    if re.match(
            r'^select +[-a-zA-Z0-9_, ()]+ from +[-a-zA-Z0-9_, ()]+$', user_input):
        return tstr(user_input, taint='TRUSTED')
    else:
        return tstr('', taint='UNTRUSTED')

In [None]:
good_user_input = tstr("select year,model from inventory", taint='UNTRUSTED')
sanitized_input = sanitize(good_user_input)
print(sanitized_input)
print(sanitized_input.taint)
bdb.sql(sanitized_input)

## Taint Aware Fuzzing

污点导向的模糊测试。这些模糊测试采用语法生成输入，生成的输入可能是危险的输入。核心思想是生成的输入导致不可信的执行，我们可以留意这些输入。

In [None]:
class Tainted(Exception):
    def __init__(self, v):
        self.v = v

    def __str__(self):
        return 'Tainted[%s]' % self.v


class TaintedDB(DB):
    def my_eval(self, statement, g, l):
        if statement.taint != 'TRUSTED':
            raise Tainted(statement)
        try:
            return eval(statement, g, l)
        except:
            raise SQLException('Invalid SQL (%s)' % repr(statement))

# tdb = TaintedDB()
# tdb.db = db.db
tdb = TaintedDB(db.db)

import traceback
for _ in range(10):
    query = gf.fuzz()
    print(repr(query))
    try:
        res = tdb.sql(tstr(query, taint='UNTRUSTED'))
        print(repr(res))
    except SQLException as e:
        # pass
        print(">> ", e)
    except Tainted as e:
        print("> ", e)
    except:
        traceback.print_exc()
        break
    print()

污点标记可以有下面功能：避免隐私数据泄露(对于一些隐私关键数据进行标记。如果检查出存在该标记，则说明可能隐私泄露，阻止该操作)。

但是，当两个不同标记的字符串相遇的时候，该如何处理？（用优先级处理？）

更详细的是，对应给定的字符串，每一个字符都知道其来源。

## 追踪字符起源

让我们引入一个类ostr，它和tstr一样，为每个字符串携带一个污点，另外为每个字符携带一个表示其来源的origin。它是一个在特定范围内的连续数字(默认情况下，从零开始)，表示它在特定原点内的位置。

我复制了一些代码过来。要实现给一个类，添加附加信息，似乎并不是一件容易的事情。

In [None]:
class ostr(str):
    DEFAULT_ORIGIN = 0

    def __new__(cls, value, *args, **kw):
        return str.__new__(cls, value)

    def __init__(self, value, taint=None, origin=None, **kwargs):
        self.taint = taint

        if origin is None:
            origin = ostr.DEFAULT_ORIGIN
        if isinstance(origin, int):
            self.origin = list(range(origin, origin + len(self)))
        else:
            self.origin = origin
        assert len(self.origin) == len(self)


    # def create(self, s):
    #     return ostr(s, taint=self.taint, origin=self.origin)
    

    UNKNOWN_ORIGIN = -1
    def __repr__(self):
        # handle escaped chars
        origin = [ostr.UNKNOWN_ORIGIN]
        for s, o in zip(str(self), self.origin):
            # len(repr('\n'))==4==2(两个引号)+2(\\n)
            # len(repr('a'))==3==2(两个引号)+1(a)
            origin.extend([o] * (len(repr(s)) - 2))
        origin.append(ostr.UNKNOWN_ORIGIN)
        return ostr(str.__repr__(self), taint=self.taint, origin=origin)


    def __str__(self):
        return str.__str__(self)
    

    def clear_taint(self):
        self.taint = None
        return self

    def has_taint(self):
        return self.taint is not None


    def clear_origin(self):
        self.origin = [self.UNKNOWN_ORIGIN] * len(self)
        return self

    def has_origin(self):
        return any(origin != self.UNKNOWN_ORIGIN for origin in self.origin)


    def create(self, res, origin=None):
        return ostr(res, taint=self.taint, origin=origin)


    ######### 重现字符串的一些功能 ###############
    def __getitem__(self, key):
        # 通过下标或者slice获取ostr的指定内容
        res = super().__getitem__(key)
        if isinstance(key, int):
            key = len(self) + key if key < 0 else key
            return self.create(res, [self.origin[key]])
        elif isinstance(key, slice):
            return self.create(res, self.origin[key])
        else:
            assert False
    
    def __iter__(self):
        # [n:m]这样的方式获取切片
        return ostr_iterator(self)

    class ostr_iterator():
        def __init__(self, ostr):
            self._ostr = ostr
            self._str_idx = 0

        def __next__(self):
            if self._str_idx == len(self._ostr):
                raise StopIteration
            # calls ostr getitem should be ostr
            c = self._ostr[self._str_idx]
            assert isinstance(c, ostr)
            self._str_idx += 1
            return c

    def __add__(self, other):
        if isinstance(other, ostr):
            return self.create(str.__add__(self, other),
                               (self.origin + other.origin))
        else:
            return self.create(str.__add__(self, other),
                               (self.origin + [self.UNKNOWN_ORIGIN for i in other]))

    def __radd__(self, other):
        origin = other.origin if isinstance(other, ostr) else [
            self.UNKNOWN_ORIGIN for i in other]
        return self.create(str.__add__(other, self), (origin + self.origin))


    class TaintException(Exception):
        pass

    def x(self, i=0):
        # 根据origin从该字符串中提取内容
        if not self.origin:
            raise origin.TaintException('Invalid request idx')
        if isinstance(i, int):
            return [self[p]
                    for p in [k for k, j in enumerate(self.origin) if j == i]]
        elif isinstance(i, slice):
            r = range(i.start or 0, i.stop or len(self), i.step or 1)
            return [self[p]
                    for p in [k for k, j in enumerate(self.origin) if j in r]]

    def replace(self, a, b, n=None):
        old_origin = self.origin
        b_origin = b.origin if isinstance(
            b, ostr) else [self.UNKNOWN_ORIGIN] * len(b)
        mystr = str(self)
        i = 0
        while True:
            if n and i >= n:
                break
            idx = mystr.find(a)
            if idx == -1:
                break
            last = idx + len(a)
            mystr = mystr.replace(a, b, 1)
            partA, partB = old_origin[0:idx], old_origin[last:]
            old_origin = partA + b_origin + partB
            i += 1
        return self.create(mystr, old_origin)


    def replace(self, a, b, n=None):
        #  replaces a portion of the string with another.
        old_origin = self.origin
        b_origin = b.origin if isinstance(
            b, ostr) else [self.UNKNOWN_ORIGIN] * len(b)
        mystr = str(self)
        i = 0
        while True:
            if n and i >= n:
                break
            idx = mystr.find(a)
            if idx == -1:
                break
            last = idx + len(a)
            mystr = mystr.replace(a, b, 1)
            partA, partB = old_origin[0:idx], old_origin[last:]
            old_origin = partA + b_origin + partB
            i += 1
        return self.create(mystr, old_origin)


    def _split_helper(self, sep, splitted):
        result_list = []
        last_idx = 0
        first_idx = 0
        sep_len = len(sep)

        for s in splitted:
            last_idx = first_idx + len(s)
            item = self[first_idx:last_idx]
            result_list.append(item)
            first_idx = last_idx + sep_len
        return result_list

    def _split_space(self, splitted):
        result_list = []
        last_idx = 0
        first_idx = 0
        sep_len = 0
        for s in splitted:
            last_idx = first_idx + len(s)
            item = self[first_idx:last_idx]
            result_list.append(item)
            v = str(self[last_idx:])
            sep_len = len(v) - len(v.lstrip(' '))
            first_idx = last_idx + sep_len
        return result_list

    def rsplit(self, sep=None, maxsplit=-1):
        splitted = super().rsplit(sep, maxsplit)
        if not sep:
            return self._split_space(splitted)
        return self._split_helper(sep, splitted)

    def split(self, sep=None, maxsplit=-1):
        splitted = super().split(sep, maxsplit)
        if not sep:
            return self._split_space(splitted)
        return self._split_helper(sep, splitted)


    def strip(self, cl=None):
        return self.lstrip(cl).rstrip(cl)

    def lstrip(self, cl=None):
        res = super().lstrip(cl)
        i = self.find(res)
        return self[i:]

    def rstrip(self, cl=None):
        res = super().rstrip(cl)
        return self[0:len(res)]

    def expandtabs(self, n=8):
        parts = self.split('\t')
        res = super().expandtabs(n)
        all_parts = []
        for i, p in enumerate(parts):
            all_parts.extend(p.origin)
            if i < len(parts) - 1:
                l = len(all_parts) % n
                all_parts.extend([p.origin[-1]] * l)
        return self.create(res, all_parts)
    

    def join(self, iterable):
        mystr = ''
        myorigin = []
        sep_origin = self.origin
        lst = list(iterable)
        for i, s in enumerate(lst):
            sorigin = s.origin if isinstance(s, ostr) else [
                self.UNKNOWN_ORIGIN] * len(s)
            myorigin.extend(sorigin)
            mystr += str(s)
            if i < len(lst) - 1:
                myorigin.extend(sep_origin)
                mystr += str(self)
        res = super().join(iterable)
        assert len(res) == len(mystr)
        return self.create(res, myorigin)
    

    def partition(self, sep):
        partA, sep, partB = super().partition(sep)
        return (self.create(partA, self.origin[0:len(partA)]),
                self.create(sep,
                            self.origin[len(partA):len(partA) + len(sep)]),
                self.create(partB, self.origin[len(partA) + len(sep):]))

    def rpartition(self, sep):
        partA, sep, partB = super().rpartition(sep)
        return (self.create(partA, self.origin[0:len(partA)]),
                self.create(sep,
                            self.origin[len(partA):len(partA) + len(sep)]),
                self.create(partB, self.origin[len(partA) + len(sep):]))
    

    def ljust(self, width, fillchar=' '):
        res = super().ljust(width, fillchar)
        initial = len(res) - len(self)
        if isinstance(fillchar, tstr):
            t = fillchar.x()
        else:
            t = self.UNKNOWN_ORIGIN
        return self.create(res, [t] * initial + self.origin)
    

    def rjust(self, width, fillchar=' '):
        res = super().rjust(width, fillchar)
        final = len(res) - len(self)
        if isinstance(fillchar, tstr):
            t = fillchar.x()
        else:
            t = self.UNKNOWN_ORIGIN
        return self.create(res, self.origin + [t] * final)
    

    def __mod__(self, s):
        # nothing else implemented for the time being
        assert isinstance(s, str)
        s_origin = s.origin if isinstance(
            s, ostr) else [self.UNKNOWN_ORIGIN] * len(s)
        i = self.find('%s')
        assert i >= 0
        res = super().__mod__(s)
        r_origin = self.origin[:]
        r_origin[i:i + 2] = s_origin
        return self.create(res, origin=r_origin)
    
    def __rmod__(self, s):
        # nothing else implemented for the time being
        assert isinstance(s, str)
        r_origin = s.origin if isinstance(
            s, ostr) else [self.UNKNOWN_ORIGIN] * len(s)
        i = s.find('%s')
        assert i >= 0
        res = super().__rmod__(s)
        s_origin = self.origin[:]
        r_origin[i:i + 2] = s_origin
        return self.create(res, origin=r_origin)
    
    def swapcase(self):
        return self.create(str(self).swapcase(), self.origin)

    def upper(self):
        return self.create(str(self).upper(), self.origin)

    def lower(self):
        return self.create(str(self).lower(), self.origin)

    def capitalize(self):
        return self.create(str(self).capitalize(), self.origin)

    def title(self):
        return self.create(str(self).title(), self.origin)
    
    ######等等......
    
# def make_split_wrapper(fun):
#     def proxy(self, *args, **kwargs):
#         lst = fun(self, *args, **kwargs)
#         return [self.create(elem) for elem in lst]
#     return proxy

# for name in ['split', 'rsplit', 'splitlines']:
#     fun = getattr(str, name)
#     setattr(ostr, name, make_split_wrapper(fun))

In [None]:
thello = ostr('hello', taint='HIGH')
assert thello.origin == [0, 1, 2, 3, 4]

tworld = thello.create('world', origin=6)
assert (thello.origin, tworld.origin) == ([0, 1, 2, 3, 4], [6, 7, 8, 9, 10])

## Taint-Directed Fuzzing

In [None]:
import random
from fuzzingbook.fuzzingbook_utils.Grammars import START_SYMBOL
from fuzzingbook.fuzzingbook_utils.GrammarFuzzer import GrammarFuzzer
from fuzzingbook.fuzzingbook_utils.Parser import canonical

In [None]:
class TaintedGrammarFuzzer(GrammarFuzzer):
    def __init__(self,
                 grammar,
                 start_symbol=START_SYMBOL,
                 expansion_switch=1,
                 log=False):
        # expansion_switch为阈值。小于它的时候，随机扩展；大于它的时候，使用最小代价扩展。
        self.tainted_start_symbol = ostr(
            start_symbol, origin=[1] * len(start_symbol))
        self.expansion_switch = expansion_switch
        self.log = log
        self.grammar = grammar
        self.c_grammar = canonical(grammar)
        self.init_tainted_grammar()


    def init_tainted_grammar(self):
        # 给语法中的终结符和非终结符，都使用了int进行tainted标记
        # 因为都是1000,100,10之间的间隔比较啊，如果语法中的表达式没有那么长的话，应该不会冲突。
        # 这里的冲突，指的是，不同的字符，使用了相同的tainted。
        key_increment, alt_increment, token_increment = 1000, 100, 10
        key_origin = key_increment
        self.ct_grammar = {}
        for key, val in self.c_grammar.items():
            key_origin += key_increment
            os = []
            for v in val:
                ts = []
                key_origin += alt_increment
                for t in v:
                    nt = ostr(t, origin=key_origin)
                    key_origin += token_increment
                    ts.append(nt)
                os.append(ts)
            self.ct_grammar[key] = os

        # a use tracking grammar
        self.ctp_grammar = {}
        for key, val in self.ct_grammar.items():
            self.ctp_grammar[key] = [(v, dict(use=0)) for v in val]

    def expansion_cost(self, expansion, seen=set()):
        symbols = [e for e in expansion if e in self.c_grammar]
        if len(symbols) == 0:
            return 1

        if any(s in seen for s in symbols):
            return float('inf')

        return sum(self.symbol_cost(s, seen) for s in symbols) + 1

    def fuzz_tree(self):
        tree = (self.tainted_start_symbol, [])
        nt_leaves = [tree]
        expansion_trials = 0
        while nt_leaves:
            idx = random.randint(0, len(nt_leaves) - 1)
            key, children = nt_leaves[idx]
            expansions = self.ct_grammar[key]
            if expansion_trials < self.expansion_switch:
                expansion = random.choice(expansions)
            else:
                costs = [self.expansion_cost(e) for e in expansions]
                m = min(costs)
                all_min = [i for i, c in enumerate(costs) if c == m]
                expansion = expansions[random.choice(all_min)]

            new_leaves = [(token, []) for token in expansion]
            new_nt_leaves = [e for e in new_leaves if e[0] in self.ct_grammar]
            children[:] = new_leaves # 问题是，如何将children回写到树中位置？？
            nt_leaves[idx:idx + 1] = new_nt_leaves # 当前父节点，被，选中expansion中的非终结符(s)替代
            if self.log:
                print("%-40s" % (key + " -> " + str(expansion)))
            expansion_trials += 1
        return tree

    def fuzz(self):
        self.derivation_tree = self.fuzz_tree()
        return self.tree_to_string(self.derivation_tree)
    

    def tree_to_string(self, tree):
        symbol, children, *_ = tree
        e = ostr('')
        if children:
            return e.join([self.tree_to_string(c) for c in children])
        else:
            return e if symbol in self.c_grammar else symbol
        
    def update_grammar(self, origin, dtree):
        # 根据输出，可以看出，给使用过的语法+1
        # 递归的(if和else)结果，存放在updated_children的list中
        def update_tree(dtree, origin):
            key, children = dtree
            if children:
                updated_children = [update_tree(c, origin) for c in children]
                corigin = set.union(
                    *[o for (key, children, o) in updated_children])
                corigin = corigin.union(set(key.origin))
                return (key, children, corigin)
            else:
                my_origin = set(key.origin).intersection(origin)
                return (key, [], my_origin)

        key, children, oset = update_tree(dtree, set(origin))
        for key, alts in self.ctp_grammar.items():
            for alt, o in alts:
                alt_origins = set([i for token in alt for i in token.origin])
                if alt_origins.intersection(oset):
                    o['use'] += 1

In [None]:
class TrackingDB(TaintedDB):
    def my_eval(self, statement, g, l):
        if statement.origin:
            raise Tainted(statement)
        try:
            return eval(statement, g, l)
        except:
            raise SQLException('Invalid SQL (%s)' % repr(statement))

In [None]:
def tree_type(tree):
    key, children = tree
    return (type(key), key, [tree_type(c) for c in children])

trdb = TrackingDB(db.db)
tgf = TaintedGrammarFuzzer(INVENTORY_GRAMMAR_F)
x = None
for _ in range(10):
    qtree = tgf.fuzz_tree()
    query = tgf.tree_to_string(qtree)
    assert isinstance(query, ostr)
    try:
        print(repr(query))
        res = trdb.sql(query)
        print(repr(res))
    except SQLException as e:
        print(e)
    except Tainted as e:
        print(e)
        origin = e.args[0].origin
        tgf.update_grammar(origin, qtree)
    except:
        traceback.print_exc()
        break
    print()

In [None]:
tgf.ctp_grammar



* String-based and character-based taints allow to dynamically track the information flow from input to the internals of a system and back to the output.

    基于字符串和基于字符的污染允许动态跟踪从输入到系统内部并返回到输出的信息流。

* Checking taints allows to discover untrusted inputs and information leakage at runtime.

    检查污染可以在运行时发现不可信的输入和信息泄漏。

* Data conversions and implicit data flow may strip taint information; the resulting untainted strings should be treated as having the worst possible taint.

    数据转换和隐式数据流可能会剥离污染信息;产生的未受污染的字符串应被视为具有可能的最严重污染。

* Taints can be used in conjunction with fuzzing to provide a more robust indication of incorrect behavior than to simply rely on program crashes.

    污点可以与fuzzing一起使用，以提供不正确行为的更可靠指示，而不是简单地依赖于程序崩溃。