Consider the quadratic equation and its canonical solution:

ax^2 + bx + c = 0

x = (-b +- sqrt(b^2 - 4ac)) / 2a

The part b^2-4*a*c is called the discriminant. Suppose we want to provide an API with two different strategies for calculating the discriminant:

    In OrdinaryDiscriminantStrategy , If the discriminant is negative, we return it as-is. This is OK, since our main API returns Complex  numbers anyway.

    In RealDiscriminantStrategy , if the discriminant is negative, the return value is NaN (not a number). NaN propagates throughout the calculation, so the equation solver gives two NaN values. In Python, you make such a number with float('nan').

Please implement both of these strategies as well as the equation solver itself. With regards to plus-minus in the formula, please return the + result as the first element and - as the second. Note that the solve() method is expected to return complex values.

In [48]:
from abc import ABC, abstractmethod
import cmath


class DiscriminantStrategy(ABC):
    @abstractmethod
    def calculate_discriminant(self, a, b, c):
        raise NotImplementedError()


class OrdinaryDiscriminantStrategy(DiscriminantStrategy):
    def calculate_discriminant(self, a, b, c):
        return b ** 2 - 4 * a * c

class RealDiscriminantStrategy(OrdinaryDiscriminantStrategy):
    def calculate_discriminant(self, a, b, c):
        discriminant = super().calculate_discriminant(a, b, c)

        if discriminant < 0:
            return float("nan")
        
        return discriminant


class QuadraticEquationSolver:
    def __init__(self, strategy: DiscriminantStrategy):
        self.strategy = strategy

    def solve(self, a, b, c):
        """ Returns a pair of complex (!) values """
        discriminant = complex(self.strategy.calculate_discriminant(a, b, c), 0)
        discriminant = cmath.sqrt(discriminant)
        
        sol_1 = (-b + discriminant) / (2 * a)
        sol_2 = (-b - discriminant) / (2 * a)
        
        return sol_1, sol_2

In [50]:
solver = QuadraticEquationSolver(strategy=OrdinaryDiscriminantStrategy())

solver.solve(5, 2, 1)

((-0.2+0.4j), (-0.2-0.4j))