In [5]:
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions
import mysql.connector
import json
import re
from datetime import datetime
from typing import Dict, List, Optional
from dataclasses import dataclass
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class SMSRecord:
    """Data class for storing structured SMS data"""
    message_id: int
    status: str
    cost: float
    sender: str
    recipient: str
    message: str
    created_date: datetime
    blast_id: Optional[str]

class ReadFromMySQL(beam.DoFn):
    """Reads SMS records from MySQL database with connection handling and error retry"""
    
    def __init__(self, connection_config: Dict[str, str], query: str, batch_size: int = 1000):
        self.connection_config = connection_config
        self.query = query
        self.batch_size = batch_size
        self.records_read = beam.metrics.Metrics.counter('main', 'records_read')

    def setup(self):
        """Initialize database connection with retry mechanism"""
        max_retries = 3
        for attempt in range(max_retries):
            try:
                self.connection = mysql.connector.connect(**self.connection_config)
                self.cursor = self.connection.cursor(buffered=True)
                break
            except mysql.connector.Error as err:
                if attempt == max_retries - 1:
                    raise
                logger.warning(f"Database connection attempt {attempt + 1} failed: {err}")

    def process(self, element) -> List[SMSRecord]:
        try:
            self.cursor.execute(self.query)
            while True:
                rows = self.cursor.fetchmany(self.batch_size)
                if not rows:
                    break
                for row in rows:
                    self.records_read.inc()
                    yield self._convert_to_sms_record(row)
        except Exception as e:
            logger.error(f"Error reading from database: {e}")
            raise

    @staticmethod
    def _parse_datetime(date_str):
        """Parse datetime string from database"""
        if isinstance(date_str, datetime):
            return date_str
        try:
            # Try multiple date formats
            for fmt in ('%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M:%S.%f'):
                try:
                    return datetime.strptime(str(date_str), fmt)
                except ValueError:
                    continue
            raise ValueError(f"Unable to parse date: {date_str}")
        except Exception as e:
            logger.warning(f"Date parsing failed for {date_str}: {e}")
            return datetime.now()  # Fallback to current time

    @staticmethod
    def _convert_to_sms_record(row) -> SMSRecord:
        """Convert database row to SMSRecord"""
        id, message, sender, recipient, _, _, response, created_dt, _, _, status, blast_id, _ = row
        
        # Parse cost from response JSON
        cost = 0.0
        try:
            response_data = json.loads(response)
            cost = float(response_data["Recipients"][0]["cost"].replace("KES ", ""))
        except (json.JSONDecodeError, KeyError, IndexError):
            logger.warning(f"Failed to parse cost for message {id}")

        return SMSRecord(
            message_id=id,
            status=status,
            cost=cost,
            sender=sender,
            recipient=recipient,
            message=message,
            created_date=ReadFromMySQL._parse_datetime(created_dt),
            blast_id=blast_id
        )

class SMSAnalyzer(beam.DoFn):
    """Analyzes SMS data for various metrics"""
    
    def process(self, record: SMSRecord):
        # Basic analysis
        word_count = len(self._tokenize_message(record.message))
        
        # Analysis results
        analysis = {
            'status': record.status,
            'cost': record.cost,
            'word_count': word_count,
            'hour_of_day': record.created_date.hour if record.created_date else 0,
            'blast_id': record.blast_id
        }
        
        yield analysis

    @staticmethod
    def _tokenize_message(message: str) -> List[str]:
        """Tokenize message into words"""
        return re.findall(r'\w+', message.lower())

def format_metrics(metrics):
    """Format metrics for pretty printing"""
    return (f"\nStatus: {metrics['status']}\n"
            f"Total Cost: KES {metrics['total_cost']:.2f}\n"
            f"Message Count: {metrics['message_count']}\n"
            f"Average Cost per Message: KES {metrics['avg_cost']:.2f}\n"
            f"Average Word Count: {metrics['avg_word_count']:.1f}\n"
            f"Most Common Hour: {metrics['most_common_hour']:02d}:00\n"
            f"{'-' * 50}")

def run_pipeline(connection_config: Dict[str, str], query: str):
    """Main pipeline execution function"""
    
    pipeline_options = PipelineOptions()
    pipeline_options.view_as(SetupOptions).save_main_session = True

    with beam.Pipeline(options=pipeline_options) as pipeline:
        # Read and analyze SMS data
        analysis_results = (
            pipeline
            | 'Create Initial' >> beam.Create([None])
            | 'Read SMS Data' >> beam.ParDo(ReadFromMySQL(connection_config, query))
            | 'Analyze SMS' >> beam.ParDo(SMSAnalyzer())
        )

        # Calculate detailed metrics by status
        _ = (
            analysis_results
            | 'Key By Status' >> beam.Map(lambda x: (x['status'], x))
            | 'Group By Status' >> beam.GroupByKey()
            | 'Calculate Status Metrics' >> beam.Map(
                lambda x: {
                    'status': x[0],
                    'total_cost': sum(item['cost'] for item in x[1]),
                    'message_count': len(list(x[1])),
                    'avg_cost': sum(item['cost'] for item in x[1]) / len(list(x[1])),
                    'avg_word_count': sum(item['word_count'] for item in x[1]) / len(list(x[1])),
                    'most_common_hour': max(
                        set(item['hour_of_day'] for item in x[1]),
                        key=lambda h: sum(1 for item in x[1] if item['hour_of_day'] == h)
                    )
                }
            )
            | 'Format Output' >> beam.Map(format_metrics)
            | 'Print Results' >> beam.Map(print)
        )

if __name__ == '__main__':
    connection_config = {
        'host': '127.0.0.1',
        'database': 'defaultdb',
        'user': 'root',
        'password': 'cypher'
    }
    
    query = 'SELECT * FROM smslog LIMIT 20;'
    run_pipeline(connection_config, query)

INFO:root:Missing pipeline option (runner). Executing pipeline using the default runner: DirectRunner.
INFO:apache_beam.runners.worker.statecache:Creating state cache with size 104857600



Status: 
Total Cost: KES 6.80
Message Count: 17
Average Cost per Message: KES 0.40
Average Word Count: 18.4
Most Common Hour: 00:00
--------------------------------------------------

Status: Success
Total Cost: KES 1.40
Message Count: 3
Average Cost per Message: KES 0.47
Average Word Count: 16.3
Most Common Hour: 00:00
--------------------------------------------------
