In [None]:
import unittest
from unittest.mock import Mock, patch, AsyncMock, MagicMock
import datetime as dt
from freezegun import freeze_time
import sys

# Create mock modules before attempting to import the actual modules
mock_modules = {
    'ibm_db': MagicMock(),
    'pyspark': MagicMock(),
    'pyspark.sql': MagicMock(),
    'pyjudo': MagicMock(),
    'pyjudo.common': MagicMock(),
    'pyjudo.common.secdb': MagicMock(),
    'pyjudo.common.constants': MagicMock(),
    'pyjudo.database': MagicMock(),
    'pyjudo.database.db2': MagicMock()
}

# Add mocks to sys.modules
for mod_name, mock in mock_modules.items():
    sys.modules[mod_name] = mock

# Import your actual modules after setting up the mocks
from workspace.ai.libs.ingestion.vespa.ingest import (
    backfill_ingestion_earnings,
    daily_ingestion_earnings,
    get_s3_credentials,
    pull_ingest_update
)
from workspace.ai.libs.ingestion.vespa.vespa_manager import VespaManager

class TestIngest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        """Set up test fixtures for the entire test class"""
        # Create class-level patches
        cls.mock_patches = [
            patch('workspace.ai.libs.ingestion.vespa.ingest.asyncio'),
            patch('workspace.ai.libs.ingestion.vespa.ingest.requests'),
            patch('workspace.ai.libs.ingestion.vespa.ingest.get_spark_context'),
            patch('workspace.ai.libs.ingestion.vespa.ingest.load_earnings_call_data_vespa'),
            patch('workspace.ai.libs.ingestion.vespa.ingest.get_payload')
        ]
        
        # Start all patches and store the mocks
        cls.mocks = {p.attribute: p.start() for p in cls.mock_patches}
        
        # Set up any default return values
        cls.mocks['get_spark_context'].return_value = MagicMock()
        cls.mocks['load_earnings_call_data_vespa'].return_value = MagicMock()
        cls.mocks['get_payload'].return_value = [{"nativeId": "test123"}]

    def setUp(self):
        """Set up test fixtures"""
        self.schema_id = "test_schema"
        self.s3_bucket_id = "test_bucket"

    @classmethod
    def tearDownClass(cls):
        """Clean up after all tests"""
        for p in cls.mock_patches:
            p.stop()

    @patch('workspace.ai.libs.ingestion.vespa.ingest.my_dict')
    def test_backfill_ingestion_earnings(self, mock_my_dict):
        # Setup
        mock_my_dict.items.return_value = [
            ("test_dir", ["2025-01-01", "2025-01-02"])
        ]
        self.mocks['asyncio'].run = MagicMock()
        
        # Execute
        backfill_ingestion_earnings(self.schema_id, self.s3_bucket_id)
        
        # Assert
        self.mocks['asyncio'].run.assert_called()
        self.assertEqual(self.mocks['asyncio'].run.call_count, 1)

    @freeze_time("2025-01-31")
    @patch('workspace.ai.libs.ingestion.vespa.ingest.job_date_if_none')
    def test_daily_ingestion_earnings(self, mock_job_date):
        # Setup
        mock_job_date.return_value = "20250131"
        self.mocks['asyncio'].run = MagicMock()
        
        # Execute
        daily_ingestion_earnings(self.schema_id, self.s3_bucket_id)
        
        # Assert
        self.mocks['asyncio'].run.assert_called_once()
        self.assertIn("2025_ECT/20250131", str(self.mocks['asyncio'].run.call_args))

class TestVespaManager(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        """Set up class-level patches"""
        cls.mock_patches = [
            patch('workspace.ai.libs.ingestion.vespa.vespa_manager.aiohttp.ClientSession'),
            patch('workspace.ai.libs.ingestion.vespa.vespa_manager.S3'),
            patch('workspace.ai.libs.ingestion.vespa.vespa_manager.create_headers'),
            patch('workspace.ai.libs.ingestion.vespa.vespa_manager.create_url')
        ]
        cls.mocks = {p.attribute: p.start() for p in cls.mock_patches}

    def setUp(self):
        self.vespa_manager = VespaManager(
            schema_id="test_schema",
            env="test_env",
            s3_bucket_id="test_bucket",
            creds={"access_key": "test_key", "secret_key": "test_secret"},
            logging_dir="test_logs"
        )
        self.sample_payload = {
            "nativeId": "test123",
            "fields": {
                "tickers_s": ["AAPL", "GOOGL"],
                "year_s": "2025",
                "quarter_s": "Q1",
                "event_time_s": "2025-01-15T10:00:00Z"
            }
        }
        # Setup default mock behaviors
        self.mock_session = AsyncMock()
        self.mock_session.post.return_value.__aenter__.return_value.json = AsyncMock(
            return_value={"successDocs": [{"documentId": "doc123"}]}
        )
        self.mocks['ClientSession'].return_value = self.mock_session

    @classmethod
    def tearDownClass(cls):
        for p in cls.mock_patches:
            p.stop()

    @patch('workspace.ai.libs.ingestion.vespa.vespa_manager.AsyncLimiter')
    async def test_ingest_in_vespa(self, mock_limiter):
        with patch.object(self.vespa_manager, 'get_ingestion_tracker') as mock_tracker, \
             patch.object(self.vespa_manager, 'get_ingested_native_id_list') as mock_id_list, \
             patch.object(self.vespa_manager, 'get_failed_payloads') as mock_failed:
            
            # Setup
            mock_tracker.return_value = {}
            mock_id_list.return_value = []
            mock_failed.return_value = {}
            
            # Execute
            await self.vespa_manager.ingest_in_vespa([self.sample_payload])
            
            # Assert
            mock_tracker.assert_called_once()
            mock_id_list.assert_called_once()

    def test_update_failed_payloads(self):
        with patch.object(self.vespa_manager, 'get_failed_payloads') as mock_failed, \
             patch.object(self.vespa_manager.s3, 'write_file') as mock_write:
            
            # Setup
            mock_failed.return_value = {}
            
            # Execute
            self.vespa_manager.update_failed_payloads(
                "test123",
                self.sample_payload,
                "test error"
            )
            
            # Assert
            mock_write.assert_called_once()
            self.assertIn("test error", str(mock_write.call_args))

if __name__ == '__main__':
    unittest.main()