# Factory Pattern: Step-by-Step Refactoring Guide

## Learning Objectives
By the end of this walkthrough, you'll understand:
1. **How to identify** when you need the Factory Pattern
2. **The step-by-step process** of refactoring to use Factory Pattern
3. **How to handle** different constructor parameters
4. **Why each step** improves the code

---

## Context
You're working on a database connection system. It started simple with just MySQL, but now supports PostgreSQL and MongoDB. The code is becoming difficult to maintain. Let's refactor it systematically.

## Step 0: Understanding the Problem

Let's start with our problematic code and understand what's wrong:

In [None]:
# Our current problematic implementation

class MySQLConnection:
    def __init__(self, host, port, username, password):
        self.host = host
        self.port = port
        self.username = username
        self.password = password
        self.connection_string = f"mysql://{username}:{password}@{host}:{port}"
    
    def connect(self):
        return f"MySQL: Connected to {self.host}:{self.port}"
    
    def execute_query(self, query):
        return f"MySQL: Executing {query}"


class PostgreSQLConnection:
    def __init__(self, host, port, username, password, database):  # Note: Extra parameter
        self.host = host
        self.port = port
        self.username = username
        self.password = password
        self.database = database  # PostgreSQL needs database name
        self.connection_string = f"postgresql://{username}:{password}@{host}:{port}/{database}"
    
    def connect(self):
        return f"PostgreSQL: Connected to {self.database} at {self.host}:{self.port}"
    
    def execute_query(self, query):
        return f"PostgreSQL: Executing {query}"


class MongoDBConnection:
    def __init__(self, host, port, username, password, auth_source="admin"):  # Note: Different parameter
        self.host = host
        self.port = port
        self.username = username
        self.password = password
        self.auth_source = auth_source  # MongoDB specific
        self.connection_string = f"mongodb://{username}:{password}@{host}:{port}/{auth_source}"
    
    def connect(self):
        return f"MongoDB: Connected to {self.host}:{self.port}"
    
    def execute_query(self, query):
        return f"MongoDB: Executing {query}"

### The Problematic Client Code

In [None]:
# Here's how we currently create connections

def create_connection(db_type, host, port, username, password, **kwargs):
    """This function has all the problems we identified"""
    
    if db_type == "mysql":
        return MySQLConnection(host, port, username, password)
    
    elif db_type == "postgresql":
        # PostgreSQL needs an extra 'database' parameter
        if 'database' not in kwargs:
            raise ValueError("PostgreSQL requires 'database' parameter")
        return PostgreSQLConnection(host, port, username, password, kwargs['database'])
    
    elif db_type == "mongodb":
        # MongoDB has optional auth_source parameter
        auth_source = kwargs.get('auth_source', 'admin')
        return MongoDBConnection(host, port, username, password, auth_source)
    
    else:
        raise ValueError(f"Unknown database type: {db_type}")

# Test our problematic code
mysql = create_connection("mysql", "localhost", 3306, "user", "pass")
print(mysql.connect())

postgres = create_connection("postgresql", "localhost", 5432, "user", "pass", database="mydb")
print(postgres.connect())

mongodb = create_connection("mongodb", "localhost", 27017, "user", "pass")
print(mongodb.connect())

## Problem Analysis

### Problem 1: Different Constructor Signatures
```python
MySQL:      __init__(host, port, username, password)
PostgreSQL: __init__(host, port, username, password, database)  # Extra parameter
MongoDB:    __init__(host, port, username, password, auth_source="admin")  # Different parameter
```

### Problem 2: The `create_connection` Function Knows Too Much
- It knows every database class
- It knows each class's specific parameters
- It has complex if/elif logic
- Adding a new database type means modifying this function

### Problem 3: No Common Interface
- Nothing enforces that all connections have `connect()` and `execute_query()`
- Could accidentally create inconsistent interfaces

## Step 1: Find Commonalities

Before refactoring, let's identify what's common across all database connections:

In [None]:
# Let's analyze our classes programmatically

def analyze_class_structure():
    # Create instances
    mysql = MySQLConnection("localhost", 3306, "user", "pass")
    postgres = PostgreSQLConnection("localhost", 5432, "user", "pass", "mydb")
    mongodb = MongoDBConnection("localhost", 27017, "user", "pass")
    
    connections = {
        "MySQL": mysql,
        "PostgreSQL": postgres,
        "MongoDB": mongodb
    }
    
    print("ANALYZING CLASS STRUCTURES\n")
    
    # Analyze methods
    all_methods = {}
    for name, conn in connections.items():
        methods = [m for m in dir(conn) if not m.startswith('_') and callable(getattr(conn, m))]
        all_methods[name] = set(methods)
        print(f"{name} methods: {methods}")
    
    # Find common methods
    common_methods = set.intersection(*all_methods.values())
    print(f"\nCommon methods: {common_methods}")
    
    # Analyze attributes
    print("\nANALYZING ATTRIBUTES\n")
    
    all_attrs = {}
    for name, conn in connections.items():
        attrs = list(vars(conn).keys())
        all_attrs[name] = set(attrs)
        print(f"{name} attributes: {attrs}")
    
    # Find common attributes
    common_attrs = set.intersection(*all_attrs.values())
    print(f"\nCommon attributes: {common_attrs}")
    
    # Find unique attributes
    print("\nUNIQUE ATTRIBUTES:")
    for name, attrs in all_attrs.items():
        unique = attrs - common_attrs
        if unique:
            print(f"{name}: {unique}")

analyze_class_structure()

## Key Insights from Analysis

### Commonalities:
1. **Methods**: All have `connect()` and `execute_query()`
2. **Attributes**: All have `host`, `port`, `username`, `password`, `connection_string`

### Differences:
1. **PostgreSQL**: Has extra `database` attribute
2. **MongoDB**: Has extra `auth_source` attribute
3. **Constructor parameters**: Each has different requirements

### Strategy:
We need to:
1. Create a common interface (abstract base class)
2. Find a way to handle different constructor parameters
3. Move creation logic to a dedicated factory

## Step 2: Create an Abstract Base Class

First, let's define a common interface that all database connections must follow:

In [None]:
from abc import ABC, abstractmethod

class DatabaseConnection(ABC):
    """Abstract base class for all database connections.
    
    This ensures all database connections have the same interface.
    """
    
    @abstractmethod
    def connect(self) -> str:
        """Establish connection to the database"""
        pass
    
    @abstractmethod
    def execute_query(self, query: str) -> str:
        """Execute a query on the database"""
        pass
    
    @abstractmethod
    def disconnect(self) -> str:
        """Close the database connection"""
        pass

print("Created abstract base class")
print("\nThis ensures:")
print("1. All connections MUST implement connect(), execute_query(), and disconnect()")
print("2. We have a common type to work with")
print("3. Can't accidentally forget to implement a method")

## Step 3: Refactor Classes to Use Abstract Base

Now let's update our connection classes to inherit from the abstract base:

In [None]:
class MySQLConnectionV2(DatabaseConnection):
    def __init__(self, host, port, username, password):
        self.host = host
        self.port = port
        self.username = username
        self.password = password
        self.connection_string = f"mysql://{username}:{password}@{host}:{port}"
    
    def connect(self) -> str:
        return f"MySQL: Connected to {self.host}:{self.port}"
    
    def execute_query(self, query: str) -> str:
        return f"MySQL: Executing {query}"
    
    def disconnect(self) -> str:
        return f"MySQL: Disconnected from {self.host}:{self.port}"


class PostgreSQLConnectionV2(DatabaseConnection):
    def __init__(self, host, port, username, password, database):
        self.host = host
        self.port = port
        self.username = username
        self.password = password
        self.database = database
        self.connection_string = f"postgresql://{username}:{password}@{host}:{port}/{database}"
    
    def connect(self) -> str:
        return f"PostgreSQL: Connected to {self.database} at {self.host}:{self.port}"
    
    def execute_query(self, query: str) -> str:
        return f"PostgreSQL: Executing {query}"
    
    def disconnect(self) -> str:
        return f"PostgreSQL: Disconnected from {self.database}"


class MongoDBConnectionV2(DatabaseConnection):
    def __init__(self, host, port, username, password, auth_source="admin"):
        self.host = host
        self.port = port
        self.username = username
        self.password = password
        self.auth_source = auth_source
        self.connection_string = f"mongodb://{username}:{password}@{host}:{port}/{auth_source}"
    
    def connect(self) -> str:
        return f"MongoDB: Connected to {self.host}:{self.port}"
    
    def execute_query(self, query: str) -> str:
        return f"MongoDB: Executing {query}"
    
    def disconnect(self) -> str:
        return f"MongoDB: Disconnected from {self.host}:{self.port}"

print("All classes now inherit from DatabaseConnection")
print("All classes are guaranteed to have the same interface")

## Step 4: Handle Different Parameters - The Key Insight

The big problem is that each database has different constructor parameters. Here are our options:

### Option 1: Use a Configuration Dictionary (Recommended)
Instead of passing individual parameters, use a configuration dictionary!

In [None]:
# Let's refactor to use configuration dictionaries

class MySQLConnectionV3(DatabaseConnection):
    def __init__(self, config: dict):
        """Initialize with a configuration dictionary"""
        self.host = config['host']
        self.port = config['port']
        self.username = config['username']
        self.password = config['password']
        self.connection_string = f"mysql://{self.username}:{self.password}@{self.host}:{self.port}"
    
    def connect(self) -> str:
        return f"MySQL: Connected to {self.host}:{self.port}"
    
    def execute_query(self, query: str) -> str:
        return f"MySQL: Executing {query}"
    
    def disconnect(self) -> str:
        return f"MySQL: Disconnected"


class PostgreSQLConnectionV3(DatabaseConnection):
    def __init__(self, config: dict):
        """Initialize with a configuration dictionary"""
        self.host = config['host']
        self.port = config['port']
        self.username = config['username']
        self.password = config['password']
        self.database = config['database']  # PostgreSQL specific
        self.connection_string = f"postgresql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
    
    def connect(self) -> str:
        return f"PostgreSQL: Connected to {self.database} at {self.host}:{self.port}"
    
    def execute_query(self, query: str) -> str:
        return f"PostgreSQL: Executing {query}"
    
    def disconnect(self) -> str:
        return f"PostgreSQL: Disconnected"


class MongoDBConnectionV3(DatabaseConnection):
    def __init__(self, config: dict):
        """Initialize with a configuration dictionary"""
        self.host = config['host']
        self.port = config['port']
        self.username = config['username']
        self.password = config['password']
        self.auth_source = config.get('auth_source', 'admin')  # MongoDB specific with default
        self.connection_string = f"mongodb://{self.username}:{self.password}@{self.host}:{self.port}/{self.auth_source}"
    
    def connect(self) -> str:
        return f"MongoDB: Connected to {self.host}:{self.port}"
    
    def execute_query(self, query: str) -> str:
        return f"MongoDB: Executing {query}"
    
    def disconnect(self) -> str:
        return f"MongoDB: Disconnected"

print("NOW ALL CLASSES HAVE THE SAME CONSTRUCTOR SIGNATURE")
print("   __init__(self, config: dict)")
print("\nThis solves the different parameters problem!")

In [None]:
# Test our new uniform interface
mysql_config = {
    'host': 'localhost',
    'port': 3306,
    'username': 'root',
    'password': 'password'
}

postgres_config = {
    'host': 'localhost',
    'port': 5432,
    'username': 'postgres',
    'password': 'password',
    'database': 'myapp'  # PostgreSQL specific
}

mongodb_config = {
    'host': 'localhost',
    'port': 27017,
    'username': 'mongo',
    'password': 'password',
    'auth_source': 'admin'  # MongoDB specific
}

# Now we can create them all the same way!
mysql = MySQLConnectionV3(mysql_config)
postgres = PostgreSQLConnectionV3(postgres_config)
mongodb = MongoDBConnectionV3(mongodb_config)

print(mysql.connect())
print(postgres.connect())
print(mongodb.connect())

## Step 5: Create the Factory

Now we can create a factory that handles all the creation logic:

In [None]:
class DatabaseFactory:
    """Factory class for creating database connections.
    
    This is the ONLY place that knows about all database types.
    """
    
    # Registry of available database types
    _connection_classes = {
        'mysql': MySQLConnectionV3,
        'postgresql': PostgreSQLConnectionV3,
        'mongodb': MongoDBConnectionV3
    }
    
    @classmethod
    def create_connection(cls, db_type: str, config: dict) -> DatabaseConnection:
        """Create a database connection based on type.
        
        Args:
            db_type: Type of database ('mysql', 'postgresql', 'mongodb')
            config: Configuration dictionary with connection parameters
            
        Returns:
            DatabaseConnection: The appropriate connection instance
            
        Raises:
            ValueError: If db_type is not supported
        """
        if db_type not in cls._connection_classes:
            available = list(cls._connection_classes.keys())
            raise ValueError(
                f"Unknown database type: {db_type}. "
                f"Available types: {available}"
            )
        
        # Get the appropriate class
        connection_class = cls._connection_classes[db_type]
        
        # Create and return an instance
        return connection_class(config)
    
    @classmethod
    def register_connection_type(cls, db_type: str, connection_class: type):
        """Register a new database connection type.
        
        This makes the factory extensible.
        """
        if not issubclass(connection_class, DatabaseConnection):
            raise TypeError(
                f"{connection_class} must inherit from DatabaseConnection"
            )
        
        cls._connection_classes[db_type] = connection_class
        print(f"Registered new connection type: {db_type}")

print("Factory created")
print("\nBenefits:")
print("1. Single place for all creation logic")
print("2. Easy to add new database types")
print("3. Client code doesn't need to know about specific classes")
print("4. Can validate configurations before creating instances")

## Step 6: Use the Factory

Now let's see how clean our client code becomes:

In [None]:
# BEFORE: Complex if/elif logic
# def create_connection(db_type, host, port, username, password, **kwargs):
#     if db_type == "mysql":
#         return MySQLConnection(host, port, username, password)
#     elif db_type == "postgresql":
#         if 'database' not in kwargs:
#             raise ValueError("PostgreSQL requires 'database' parameter")
#         return PostgreSQLConnection(host, port, username, password, kwargs['database'])
#     elif db_type == "mongodb":
#         auth_source = kwargs.get('auth_source', 'admin')
#         return MongoDBConnection(host, port, username, password, auth_source)
#     else:
#         raise ValueError(f"Unknown database type: {db_type}")

# AFTER: Clean factory usage
def create_connection_v2(db_type: str, config: dict) -> DatabaseConnection:
    """Create a database connection using the factory."""
    return DatabaseFactory.create_connection(db_type, config)

# Even better - just use the factory directly!
mysql = DatabaseFactory.create_connection('mysql', mysql_config)
postgres = DatabaseFactory.create_connection('postgresql', postgres_config)
mongodb = DatabaseFactory.create_connection('mongodb', mongodb_config)

print("Creating connections with the factory:\n")
print(mysql.connect())
print(postgres.connect())
print(mongodb.connect())

print("\nNotice how clean this is - no if/elif chains!")

## Step 7: Adding New Database Types is Now Easy

Let's add support for Redis to see how easy extension has become:

In [None]:
# Create a new Redis connection class
class RedisConnection(DatabaseConnection):
    def __init__(self, config: dict):
        self.host = config['host']
        self.port = config['port']
        self.password = config.get('password', None)  # Optional password
        self.db = config.get('db', 0)  # Redis specific
        
        if self.password:
            self.connection_string = f"redis://:{self.password}@{self.host}:{self.port}/{self.db}"
        else:
            self.connection_string = f"redis://{self.host}:{self.port}/{self.db}"
    
    def connect(self) -> str:
        return f"Redis: Connected to {self.host}:{self.port} (db={self.db})"
    
    def execute_query(self, query: str) -> str:
        return f"Redis: Executing {query}"
    
    def disconnect(self) -> str:
        return f"Redis: Disconnected"

# Register it with the factory
DatabaseFactory.register_connection_type('redis', RedisConnection)

# Now we can use it immediately!
redis_config = {
    'host': 'localhost',
    'port': 6379,
    'db': 1
}

redis = DatabaseFactory.create_connection('redis', redis_config)
print(redis.connect())

print("\nNO CHANGES to existing code were needed")
print("We just added a new class and registered it")

## Step 8: Advanced - Configuration Management

Let's add configuration management to make it even better:

In [None]:
class ConfigurationManager:
    """Manages database configurations for different environments."""
    
    _configurations = {
        'development': {
            'mysql': {
                'host': 'localhost',
                'port': 3306,
                'username': 'dev_user',
                'password': 'dev_pass'
            },
            'postgresql': {
                'host': 'localhost',
                'port': 5432,
                'username': 'dev_user',
                'password': 'dev_pass',
                'database': 'dev_db'
            },
            'mongodb': {
                'host': 'localhost',
                'port': 27017,
                'username': 'dev_user',
                'password': 'dev_pass',
                'auth_source': 'admin'
            },
            'redis': {
                'host': 'localhost',
                'port': 6379,
                'db': 0
            }
        },
        'production': {
            'mysql': {
                'host': 'prod-mysql.company.com',
                'port': 3306,
                'username': 'prod_user',
                'password': 'strong_password_123'
            },
            'postgresql': {
                'host': 'prod-postgres.company.com',
                'port': 5432,
                'username': 'prod_user',
                'password': 'strong_password_456',
                'database': 'prod_db'
            }
            # ... more production configs
        }
    }
    
    @classmethod
    def get_config(cls, environment: str, db_type: str) -> dict:
        """Get configuration for specific environment and database type."""
        if environment not in cls._configurations:
            raise ValueError(f"Unknown environment: {environment}")
        
        env_config = cls._configurations[environment]
        if db_type not in env_config:
            raise ValueError(
                f"No {db_type} configuration for {environment} environment"
            )
        
        return env_config[db_type]


# Now creating connections is even simpler!
def get_database_connection(environment: str, db_type: str) -> DatabaseConnection:
    """Get a database connection for the specified environment."""
    config = ConfigurationManager.get_config(environment, db_type)
    return DatabaseFactory.create_connection(db_type, config)

# Usage is very clean
dev_mysql = get_database_connection('development', 'mysql')
dev_postgres = get_database_connection('development', 'postgresql')
dev_redis = get_database_connection('development', 'redis')

print("Development connections:")
print(dev_mysql.connect())
print(dev_postgres.connect())
print(dev_redis.connect())

print("\nConfiguration is now centralized and clean")

## Summary: The Complete Refactoring Process

### What We Did:

1. **Analyzed the Problem**
   - Different constructor parameters
   - Complex if/elif chains
   - No common interface

2. **Found Commonalities**
   - All have `connect()` and `execute_query()`
   - All have basic connection parameters

3. **Created Abstract Base Class**
   - Defined common interface
   - Enforced consistency

4. **Solved Parameter Problem**
   - Used configuration dictionaries
   - Uniform constructor signature

5. **Built the Factory**
   - Centralized creation logic
   - Made it extensible

6. **Added Configuration Management**
   - Separated config from code
   - Environment-specific settings

### Benefits Achieved:

- **Open/Closed Principle**: Can add new databases without modifying existing code
- **Single Responsibility**: Each class has one job
- **DRY**: No duplicate if/elif chains
- **Testability**: Can easily mock the factory or specific connections
- **Maintainability**: Changes are localized
- **Flexibility**: Easy to add new database types

## Exercise: Apply What You Learned

Now it's your turn! Try adding a new database type:

1. Create a `CassandraConnection` class
2. Register it with the factory
3. Add its configuration
4. Test creating a connection

Remember: You should NOT need to modify any existing code!

In [None]:
# YOUR CODE HERE
# Create CassandraConnection class
# class CassandraConnection(DatabaseConnection):
#     def __init__(self, config: dict):
#         # Your implementation
#         pass
#     
#     def connect(self) -> str:
#         # Your implementation
#         pass
#     
#     def execute_query(self, query: str) -> str:
#         # Your implementation
#         pass
#     
#     def disconnect(self) -> str:
#         # Your implementation
#         pass

# Register it with the factory
# DatabaseFactory.register_connection_type('cassandra', CassandraConnection)

# Test it!
# cassandra_config = {
#     'host': 'localhost',
#     'port': 9042,
#     'username': 'cassandra',
#     'password': 'cassandra',
#     'keyspace': 'my_keyspace'  # Cassandra specific!
# }
# 
# cassandra = DatabaseFactory.create_connection('cassandra', cassandra_config)
# print(cassandra.connect())

## Key Takeaways

1. **Start by analyzing commonalities and differences**
2. **Create abstractions for common behavior**
3. **Use configuration objects to handle parameter differences**
4. **Centralize creation logic in a factory**
5. **Make the factory extensible with registration**
6. **Separate configuration from implementation**

The Factory Pattern is not just about hiding object creation - it's about creating a flexible, maintainable system for managing object creation complexity.