Skip to content

Commit

Permalink
Add CSRF protection MasoniteFramework#27
Browse files Browse the repository at this point in the history
  • Loading branch information
mapeveri committed Mar 4, 2018
1 parent febf6d8 commit cb89002
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 10 deletions.
1 change: 1 addition & 0 deletions config/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
PROVIDERS = [
# Framework Providers
'masonite.providers.AppProvider.AppProvider',
'masonite.providers.CsrfProvider.CsrfProvider',
'masonite.providers.RouteProvider.RouteProvider',
'masonite.providers.ApiProvider.ApiProvider',
'masonite.providers.RedirectionProvider.RedirectionProvider',
Expand Down
29 changes: 29 additions & 0 deletions masonite/auth/Csrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import binascii
import os


class Csrf(object):
"""
Class for csrf protection
"""

def __init__(self, request):
self.request = request

def generate_csrf_token(self):
"""
Generate token for csrf protection
"""

token = binascii.b2a_hex(os.urandom(15))
self.request.cookie('csrftoken', token)

def verify_csrf_token(self, token):
"""
Verify if csrf token is valid
"""

if self.request.get_cookie('csrftoken') == token:
return True
else:
return False
11 changes: 10 additions & 1 deletion masonite/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@ class FileTypeException(Exception):
"""
pass


class RequiredContainerBindingNotFound(Exception):
pass


class MissingContainerBindingNotFound(Exception):
pass
pass


class InvalidCSRFToken(Exception):
"""
For exceptions that return error when verifying the csrf token
"""
pass
36 changes: 36 additions & 0 deletions masonite/middlewares/CsrfMiddleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from masonite.exceptions import InvalidCSRFToken


class CsrfMiddleware(object):
"""
Verify csrf token middleware
"""

exempt = []

def __init__(self, Request, CSRF):
self.request = Request
self.csrf = CSRF

def before(self):
if self.request.is_post():
token = self.request.input('csrftoken')
if (not self.csrf.verify_csrf_token(token)
and not self.__in_except()):
raise InvalidCSRFToken("Invalid CSRF token.")
else:
self.csrf.generate_token()

def after(self):
pass

def __in_except(self):
"""
Determine if the request has a URI that should pass
through CSRF verification.
"""

if self.request.path in self.exempt:
return True
else:
return False
Empty file.
19 changes: 19 additions & 0 deletions masonite/providers/CsrfProvider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
""" A Csrf Service Provider """
from masonite.provider import ServiceProvider
from masonite.auth.Csrf import Csrf


class CsrfProvider(ServiceProvider):

wsgi = True

def register(self):
request = self.app.make('Request')
self.app.bind('Request', request)
self.app.bind('CSRF', Csrf(request))

def boot(self, View, ViewClass, Request):
# Share token csrf
token = Request.get_cookie('csrftoken')

ViewClass.share({'csrf_field': "<input type='hidden' name='csrf_token' value='{0}' />".format(token)})
52 changes: 43 additions & 9 deletions masonite/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Request(object):
This is the object passed through to the controllers
as a request paramter
"""

def __init__(self, environ=None):
self.cookies = []
self.url_params = {}
Expand All @@ -42,6 +43,7 @@ def input(self, param):
Returns either the FORM_PARAMS during a POST request
or QUERY_STRING during a GET request
"""

# Post Request Input
if self.is_post():
if isinstance(self.params, str):
Expand Down Expand Up @@ -69,12 +71,18 @@ def is_post(self):
return False

def key(self, key):
""" Sets encryption key """
"""
Sets encryption key
"""

self.encryption_key = key
return self

def all(self):
""" Returns all the params """
"""
Returns all the params
"""

if isinstance(self.params, str):
return parse_qs(self.params)

Expand All @@ -96,7 +104,10 @@ def app(self):
return self.container

def has(self, param):
""" Check if a param exists """
"""
Check if a param exists
"""

if param in self.params:
return True

Expand All @@ -108,6 +119,7 @@ def set_params(self, params):
These parameters are where the developer can retrieve the
/url/@variable:string/ from the url.
"""

self.url_params.update(params)
return self

Expand All @@ -117,6 +129,7 @@ def param(self, parameter):
The "parameter" parameter in this method should be the name of the
@variable passed into the url in web.py
"""

if parameter in self.url_params:
return self.url_params[parameter]
return False
Expand All @@ -125,6 +138,7 @@ def cookie(self, key, value, encrypt=True):
"""
Sets a cookie in the browser
"""

if encrypt:
value = Sign(self.encryption_key).sign(value)
else:
Expand All @@ -138,12 +152,14 @@ def get_cookies(self):
"""
Retrieve all cookies from the browser
"""

return self.cookies

def get_cookie(self, provided_cookie, decrypt=True):
"""
Retrieves a specific cookie from the browser
"""

if 'HTTP_COOKIE' in self.environ:
grab_cookie = cookies.SimpleCookie(self.environ['HTTP_COOKIE'])
if provided_cookie in grab_cookie:
Expand All @@ -163,21 +179,33 @@ def append_cookie(self, key, value):
key, value)

def set_user(self, user_model):
""" Loads the user into the class """
"""
Loads the user into the class
"""

self.user_model = user_model
return self

def user(self):
""" Retreives the user model """
"""
Retreives the user model
"""

return self.user_model

def redirect(self, route):
""" Redirect the user based on the route specified """
"""
Redirect the user based on the route specified
"""

self.redirect_url = route
return self

def redirectTo(self, route):
""" Redirect to a named route """
"""
Redirect to a named route
"""

self.redirect_route = route
return self

Expand All @@ -186,7 +214,10 @@ def reset_redirections(self):
self.redirect_route = False

def back(self, input_parameter='back'):
""" Go to a named route with the back parameter """
"""
Go to a named route with the back parameter
"""

self.redirectTo(self.input(input_parameter))
return self

Expand Down Expand Up @@ -236,7 +267,10 @@ def has_subdomain(self):
return False

def send(self, params):
""" With """
"""
With
"""

self.set_params(params)
return self

Expand Down

0 comments on commit cb89002

Please sign in to comment.