|
|
@@ -16,84 +16,104 @@ |
|
|
# program. If not, go to http://www.gnu.org/licenses/gpl.html
|
|
|
# or write to the Free Software Foundation, Inc.,
|
|
|
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
|
|
-import logging
|
|
|
+from collections import namedtuple
|
|
|
from threading import Lock
|
|
|
+from enum import Enum
|
|
|
+import logging
|
|
|
import sqlite3
|
|
|
|
|
|
-class Rule:
|
|
|
+
|
|
|
+Rule = namedtuple('Rule', ('app_path',
|
|
|
+ 'verdict',
|
|
|
+ 'address',
|
|
|
+ 'port',
|
|
|
+ 'proto'))
|
|
|
+
|
|
|
+
|
|
|
+class RuleVerdict(Enum):
|
|
|
+
|
|
|
ACCEPT = 0
|
|
|
- DROP = 1
|
|
|
+ DROP = 1
|
|
|
+
|
|
|
+
|
|
|
+class RuleSaveOption(Enum):
|
|
|
|
|
|
ONCE = 0
|
|
|
UNTIL_QUIT = 1
|
|
|
FOREVER = 2
|
|
|
|
|
|
- def __init__( self, app_path=None, verdict=ACCEPT, address=None, port=None, proto=None ):
|
|
|
- self.app_path = app_path
|
|
|
- self.verdict = verdict
|
|
|
- self.address = address
|
|
|
- self.port = port
|
|
|
- self.proto = proto
|
|
|
|
|
|
- def matches( self, c ):
|
|
|
- if self.app_path != c.app_path:
|
|
|
- return False
|
|
|
+def matches(rule, conn):
|
|
|
+ if rule.app_path != conn.app_path:
|
|
|
+ return False
|
|
|
|
|
|
- elif self.address is not None and self.address != c.dst_addr:
|
|
|
- return False
|
|
|
+ elif rule.address is not None and rule.address != conn.dst_addr:
|
|
|
+ return False
|
|
|
|
|
|
- elif self.port is not None and self.port != c.dst_port:
|
|
|
- return False
|
|
|
+ elif rule.port is not None and rule.port != conn.dst_port:
|
|
|
+ return False
|
|
|
|
|
|
- elif self.proto is not None and self.proto != c.proto:
|
|
|
- return False
|
|
|
+ elif rule.proto is not None and rule.proto != conn.proto:
|
|
|
+ return False
|
|
|
+
|
|
|
+ else:
|
|
|
+ return True
|
|
|
|
|
|
- else:
|
|
|
- return True
|
|
|
|
|
|
class Rules:
|
|
|
def __init__(self, database):
|
|
|
self.mutex = Lock()
|
|
|
- self.db = RulesDB(database)
|
|
|
- self.rules = self.db.load_rules()
|
|
|
+ db = self.db = RulesDB(database)
|
|
|
+ self.rules = {}
|
|
|
+
|
|
|
+ with db._lock:
|
|
|
+ for r in db._load_rules():
|
|
|
+ self._add_rule(r)
|
|
|
|
|
|
- def get_verdict( self, connection ):
|
|
|
+ def get_verdict(self, connection):
|
|
|
with self.mutex:
|
|
|
- for r in self.rules:
|
|
|
- if r.matches(connection):
|
|
|
+ for r in self.rules.get(connection.app_path, []):
|
|
|
+ if matches(r, connection):
|
|
|
return r.verdict
|
|
|
|
|
|
return None
|
|
|
|
|
|
- def _remove_rules_for_path( self, path, remove_from_db=False ):
|
|
|
- for rule in self.rules:
|
|
|
- if rule.app_path == path:
|
|
|
- self.rules.remove(rule)
|
|
|
+ def _remove_rules_for_path(self, path, remove_from_db=False):
|
|
|
+ try:
|
|
|
+ del self.rules[path]
|
|
|
+ except KeyError:
|
|
|
+ pass
|
|
|
|
|
|
if remove_from_db is True:
|
|
|
self.db.remove_all_app_rules(path)
|
|
|
|
|
|
- def add_rule( self, connection, verdict, apply_to_all=False, save_option=Rule.UNTIL_QUIT ):
|
|
|
+ def _add_rule(self, rule):
|
|
|
+ self.rules.setdefault(rule.app_path, set()).add(rule)
|
|
|
+
|
|
|
+ def add_rule(self, connection, verdict, apply_to_all=False,
|
|
|
+ save_option=RuleSaveOption.UNTIL_QUIT.value):
|
|
|
+
|
|
|
with self.mutex:
|
|
|
- logging.debug( "Adding %s rule for '%s' (all=%s)" % (
|
|
|
- "ALLOW" if verdict == Rule.ACCEPT else "DENY",
|
|
|
- connection,
|
|
|
- "true" if apply_to_all == True else "false" ) )
|
|
|
- r = Rule()
|
|
|
- r.verdict = verdict
|
|
|
- r.app_path = connection.app_path
|
|
|
+ logging.debug("Adding %s rule for '%s' (all=%s)",
|
|
|
+ "ALLOW" if RuleVerdict(verdict) == RuleVerdict.ACCEPT else "DENY", # noqa
|
|
|
+ connection,
|
|
|
+ "true" if apply_to_all is True else "false")
|
|
|
|
|
|
if apply_to_all is True:
|
|
|
- self._remove_rules_for_path( r.app_path, (save_option == Rule.FOREVER) )
|
|
|
+ self._remove_rules_for_path(
|
|
|
+ connection.app_path,
|
|
|
+ (RuleSaveOption(save_option) == RuleSaveOption.FOREVER))
|
|
|
|
|
|
- elif apply_to_all is False:
|
|
|
- r.address = connection.dst_addr
|
|
|
- r.port = connection.dst_port
|
|
|
- r.proto = connection.proto
|
|
|
+ r = Rule(
|
|
|
+ connection.app_path,
|
|
|
+ verdict,
|
|
|
+ connection.dst_addr if not apply_to_all else None,
|
|
|
+ connection.dst_port if not apply_to_all else None,
|
|
|
+ connection.proto if not apply_to_all else None)
|
|
|
|
|
|
- self.rules.append(r)
|
|
|
+ self._add_rule(r)
|
|
|
|
|
|
- if save_option == Rule.FOREVER:
|
|
|
+ if RuleSaveOption(save_option) == RuleSaveOption.FOREVER:
|
|
|
self.db.save_rule(r)
|
|
|
|
|
|
|
|
|
@@ -115,18 +135,18 @@ def _create_table(self): |
|
|
c = conn.cursor()
|
|
|
c.execute("CREATE TABLE IF NOT EXISTS rules (app_path TEXT, verdict INTEGER, address TEXT, port INTEGER, proto TEXT, UNIQUE (app_path, verdict, address, port, proto))") # noqa
|
|
|
|
|
|
- def load_rules(self):
|
|
|
- with self._lock:
|
|
|
- conn = self._get_conn()
|
|
|
- c = conn.cursor()
|
|
|
- c.execute("SELECT * FROM rules")
|
|
|
- return [Rule(*item) for item in c.fetchall()]
|
|
|
+ def _load_rules(self):
|
|
|
+ conn = self._get_conn()
|
|
|
+ c = conn.cursor()
|
|
|
+ c.execute("SELECT * FROM rules")
|
|
|
+ for item in c.fetchall():
|
|
|
+ yield Rule(*item)
|
|
|
|
|
|
def save_rule(self, rule):
|
|
|
with self._lock:
|
|
|
conn = self._get_conn()
|
|
|
c = conn.cursor()
|
|
|
- c.execute("INSERT INTO rules VALUES (?, ?, ?, ?, ?)", (rule.app_path, rule.verdict, rule.address, rule.port, rule.proto,)) # noqa
|
|
|
+ c.execute("INSERT INTO rules VALUES (?, ?, ?, ?, ?)", (rule.app_path, rule.verdict.value, rule.address, rule.port, rule.proto,)) # noqa
|
|
|
conn.commit()
|
|
|
|
|
|
def remove_all_app_rules(self, app_path):
|
|
|
|
0 comments on commit
eb3f96c