Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add directory checks on worker start up #230

Merged
merged 2 commits into from Aug 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 41 additions & 2 deletions turbinia/client.py
Expand Up @@ -20,6 +20,8 @@
from datetime import timedelta
import json
import logging
import os
import stat
import time

from turbinia import config
Expand All @@ -29,7 +31,6 @@

config.LoadConfig()
if config.TASK_MANAGER == 'PSQ':
# TODO(aarontp): Selectively load dependencies based on configured backends
import psq

from google.cloud import exceptions
Expand All @@ -44,6 +45,35 @@
logger.setup()


def check_directory(directory):
"""Checks directory to make sure it exists and is writable.

Args:
directory (string): Path to directory

Raises:
TurbiniaException: When directory cannot be created or used.
"""
if os.path.exists(directory) and not os.path.isdir(directory):
raise TurbiniaException(
'File {0:s} exists, but is not a directory'.format(directory))

if not os.path.exists(directory):
try:
os.makedirs(directory)
except OSError:
raise TurbiniaException(
'Can not create Directory {0:s}'.format(directory))

if not os.access(directory, os.W_OK):
try:
mode = os.stat(directory)[0]
os.chmod(directory, mode | stat.S_IWUSR)
except OSError:
raise TurbiniaException(
'Can not add write permissions to {0:s}'.format(directory))


class TurbiniaClient(object):
"""Client class for Turbinia.

Expand Down Expand Up @@ -289,6 +319,7 @@ def send_request(self, request):
"""
self.task_manager.kombu.send_request(request)

# pylint: disable=arguments-differ
def get_task_data(self,
instance,
_,
Expand Down Expand Up @@ -347,6 +378,8 @@ class TurbiniaCeleryWorker(TurbiniaClient):
def __init__(self, *args, **kwargs):
"""Initialization for Celery worker."""
super(TurbiniaCeleryWorker, self).__init__(*args, **kwargs)
check_directory(config.MOUNT_DIR_PREFIX)
check_directory(config.OUTPUT_DIR)
self.worker = self.task_manager.celery.app

def start(self):
Expand All @@ -362,9 +395,12 @@ class TurbiniaPsqWorker(object):
Attributes:
worker (psq.Worker): PSQ Worker object
psq (psq.Queue): A Task queue object

Raises:
TurbiniaException: When errors occur
"""

def __init__(self, *args, **kwargs):
def __init__(self, *_, **__):
"""Initialization for PSQ Worker."""
config.LoadConfig()
psq_publisher = pubsub.PublisherClient()
Expand All @@ -382,6 +418,9 @@ def __init__(self, *args, **kwargs):
log.error(msg)
raise TurbiniaException(msg)

check_directory(config.MOUNT_DIR_PREFIX)
check_directory(config.OUTPUT_DIR)

log.info('Starting PSQ listener on queue {0:s}'.format(self.psq.name))
self.worker = psq.Worker(queue=self.psq)

Expand Down
35 changes: 34 additions & 1 deletion turbinia/client_test.py
Expand Up @@ -17,6 +17,10 @@
from __future__ import unicode_literals

import unittest
import os
import shutil
import stat
import tempfile

import mock

Expand Down Expand Up @@ -63,7 +67,8 @@ def testTurbiniaClientGetTaskDataNoResults(self, _, __, mock_cloud_function):
@mock.patch('turbinia.client.GoogleCloudFunction.ExecuteFunction')
@mock.patch('turbinia.client.task_manager.PSQTaskManager._backend_setup')
@mock.patch('turbinia.state_manager.get_state_manager')
def testTurbiniaClientGetTaskDataInvalidJson(self, _, __, mock_cloud_function):
def testTurbiniaClientGetTaskDataInvalidJson(
self, _, __, mock_cloud_function):
"""Test for exception after bad json results from cloud functions."""
mock_cloud_function.return_value = {'result': None}
client = TurbiniaClient()
Expand All @@ -85,10 +90,38 @@ def testTurbiniaServerInit(self, _, __):
class TestTurbiniaPsqWorker(unittest.TestCase):
"""Test Turbinia PSQ Worker class."""

def setUp(self):
self.tmp_dir = tempfile.mkdtemp(prefix='turbinia-test')
config.LoadConfig()
config.OUTPUT_DIR = self.tmp_dir
config.MOUNT_DIR_PREFIX = self.tmp_dir

def tearDown(self):
if 'turbinia-test' in self.tmp_dir:
shutil.rmtree(self.tmp_dir)

@mock.patch('turbinia.client.pubsub')
@mock.patch('turbinia.client.datastore.Client')
@mock.patch('turbinia.client.psq.Worker')
def testTurbiniaPsqWorkerInit(self, _, __, ___):
"""Basic test for PSQ worker."""
worker = TurbiniaPsqWorker()
self.assertTrue(hasattr(worker, 'worker'))

@mock.patch('turbinia.client.pubsub')
@mock.patch('turbinia.client.datastore.Client')
@mock.patch('turbinia.client.psq.Worker')
def testTurbiniaClientNoDir(self, _, __, ___):
"""Test that OUTPUT_DIR path is created."""
config.OUTPUT_DIR = os.path.join(self.tmp_dir, 'no_such_dir')
TurbiniaPsqWorker()
self.assertTrue(os.path.exists(config.OUTPUT_DIR))

@mock.patch('turbinia.client.pubsub')
@mock.patch('turbinia.client.datastore.Client')
@mock.patch('turbinia.client.psq.Worker')
def testTurbiniaClientIsNonDir(self, _, __, ___):
"""Test that OUTPUT_DIR does not point to an existing non-directory."""
config.OUTPUT_DIR = os.path.join(self.tmp_dir, 'empty_file')
open(config.OUTPUT_DIR, 'a').close()
self.assertRaises(TurbiniaException, TurbiniaPsqWorker)