<img src="../img/python-logo-no-text.png"
     style="display:block;margin:auto;width:10%"/>
<br>
<div style="text-align:center; font-size:200%;">
  <b>SOLID: Single Responsibility Principle</b>
</div>
<br/>
<div style="text-align:center;">Dr. Matthias HÃ¶lzl</div>
<br/>
<!-- <div style="text-align:center;">module_280_solid/topic_120_a3_solid_srp</div> -->


# SOLID: Single Responsibility Principle

What do you think of the following code?

In [None]:
from typing import NamedTuple

In [None]:
class Point2D(NamedTuple):
    x: float
    y: float

    def move(self, dx: float, dy: float) -> "Point2D":
        return Point2D(self.x + dx, self.y + dy)

In [None]:
class Figure2d(NamedTuple):
    pivot: Point2D
    sprite: list[int]
    health: float
    health_bar: object  # In reality a widget...
    database: object  # The connection to the DB

    def move(self, dx: float, dy: float) -> None:
        ...

    def is_player_avatar(self) -> bool:
        # Check whether the figure is the player's avatar...
        return True

    def update_gui(self):
        # Update the health bar widget.
        ...

    def save_to_db(self) -> None:
        # Write the figure data to the database.
        ...

In [None]:
class GameEditor:
    def edit_figure(self, figure: Figure2d):
        # Edit the figure...
        figure.save_to_db()

In [None]:
class Game(NamedTuple):
    player_avatar: Figure2d

    def run_game_loop(self):
        # Do lots of things...
        self.player_avatar.update_gui()


## Single-Responsibility Principle

- For each class there should only be a single reason to change
- The name is not quite correct: SRP does not state that each class may only
  have a single responsibility


## Single Responsibility?

<img src="img/book_01.png"
     style="display:block;margin:auto;width:35%"/>



## Violation of SRP

<img src="img/book_02.png"
     style="display:block;margin:auto;width:60%"/>


In [None]:
from dataclasses import dataclass  # noqa: #402

In [None]:
@dataclass
class Book:
    title: str
    author: str
    pages: int

    def print(self):
        # Lots of code that handles the printer
        print("Printing to printer.")

    def save(self):
        # Lots of code that handles the database
        print("Saving to database.")


## Resolution of the SRP violation (version 1)

<img src="img/book_resolution_1.png"
     style="display:block;margin:auto;width:50%"/>


## Resolution of the SRP violation (version 2)

<img src="img/book_resolution_2.png"
     style="display:block;margin:auto;width:80%"/>


## Comparison

<div>
<img src="img/book_resolution_1.png"
     style="float:left;padding:5px;width:40%"/>
<img src="img/book_resolution_2.png"
     style="float:right;padding:5px;width:50%"/>
</div>


## Workshop: Employee

You have the following implementation of a personnel management system that
contains several SRP violations. Implement a version that no longer has them.

<img src="img/employee_01.png"
     style="display:block;margin:auto;width:40%"/>

In [None]:
from enum import IntEnum

In [None]:
class EmployeeType(IntEnum):
    REGULAR = 0
    HOURED = 1
    COMMISSIONED = 2

In [None]:
from dataclasses import dataclass

In [None]:
@dataclass
class Project:
    name: str
    assets: float

In [None]:
from augurdb import AugurDatabase

In [None]:
@dataclass
class EmployeeV0:
    id: int
    name: str
    salary: float
    overtime: int
    employee_type: EmployeeType
    project: Project
    database: AugurDatabase

    def calculate_pay(self) -> float:
        if self.employee_type == EmployeeType.REGULAR:
            return self.salary + 60.0 * self.overtime
        elif self.employee_type == EmployeeType.COMMISSIONED:
            return self.project.assets * 0.1
        elif self.employee_type == EmployeeType.HOURED:
            return 50.0 * self.overtime
        raise ValueError(f"{self.employee_type} is not valid.")

    def report_hours(self) -> int:
        if self.employee_type == EmployeeType.REGULAR:
            return 40 + self.overtime
        elif self.employee_type == EmployeeType.COMMISSIONED:
            # Commissioned employees always work 40 hours
            return 40
        elif self.employee_type == EmployeeType.HOURED:
            # We use overtime for the billed hours
            return self.overtime
        raise ValueError(f"{self.employee_type} is not valid.")

    def print_report(self) -> None:
        print(f"{self.name} worked {self.report_hours()} hours.")

    def save_employee(self) -> None:
        self.database.start_transaction()
        self.database.store_field(self.id, "name", self.name)
        self.database.store_field(self.id, "salary", self.salary)
        self.database.store_field(self.id, "overtime", self.overtime)
        self.database.store_field(self.id, "employee_type", self.employee_type)
        self.database.store_field(self.id, "project", self.project)
        self.database.commit_transaction()

In [None]:
from pprint import pprint

In [None]:
p1 = Project(name="Project 1", assets=10_000.0)
p2: Project = Project(name="Project 2", assets=12_000.0)

In [None]:
db = AugurDatabase()

In [None]:
e1 = EmployeeV0(
    id=123,
    name="Joe Random",
    salary=1000.0,
    overtime=5,
    employee_type=EmployeeType.REGULAR,
    project=p1,
    database=db,
)

In [None]:
e2 = EmployeeV0(
    id=124,
    name="Jane Ransom",
    salary=1500.0,
    overtime=43,
    employee_type=EmployeeType.HOURED,
    project=p1,
    database=db,
)

In [None]:
e3 = EmployeeV0(
    id=125,
    name="Jill Chance",
    salary=2500.0,
    overtime=2,
    employee_type=EmployeeType.COMMISSIONED,
    project=p2,
    database=db,
)

In [None]:
employees = [e1, e2, e3]

In [None]:
for e in employees:
    print("=" * 35)
    print(f"{e.name} has a salary of {e.calculate_pay():.2f}")
    e.print_report()
    e.save_employee()
print("=" * 35)

In [None]:
pprint(db.records)

In [None]:
class PaymentCalculator:
    def calculate_pay(self, employee: "EmployeeV1") -> float:
        employee_type = employee.employee_type
        if employee_type == EmployeeType.REGULAR:
            return employee.salary + 60.0 * employee.overtime
        elif employee_type == EmployeeType.COMMISSIONED:
            return employee.project.assets * 0.1
        elif employee_type == EmployeeType.HOURED:
            return 50.0 * employee.overtime
        raise ValueError(f"{employee_type} is not valid.")

In [None]:
class HourReporter:
    def report_hours(self, employee: "EmployeeV1") -> int:
        employee_type = employee.employee_type
        if employee_type == EmployeeType.REGULAR:
            return 40 + employee.overtime
        elif employee_type == EmployeeType.COMMISSIONED:
            # Commissioned employees always work 40 hours
            return 40
        elif employee_type == EmployeeType.HOURED:
            # We use overtime for the billed hours
            return employee.overtime
        raise ValueError(f"{employee_type} is not valid.")

In [None]:
class ReportPrinter:
    def print_report(self, employee: "EmployeeV1") -> None:
        print(f"{employee.name} worked {employee.report_hours()} hours.")

In [None]:
@dataclass
class EmployeeDao:
    database: AugurDatabase

    def save_employee(self, employee: "EmployeeV1") -> None:
        self.database.start_transaction()
        self.database.store_field(employee.id, "name", employee.name)
        self.database.store_field(employee.id, "salary", employee.salary)
        self.database.store_field(employee.id, "overtime", employee.overtime)
        self.database.store_field(employee.id, "employee_type", employee.employee_type)
        self.database.store_field(employee.id, "project", employee.project)
        self.database.commit_transaction()

In [None]:
@dataclass
class EmployeeV1:
    id: int
    name: str
    salary: float
    overtime: int
    employee_type: EmployeeType
    project: Project
    payment_calculator: PaymentCalculator
    hour_reporter: HourReporter
    report_printer: ReportPrinter
    dao: EmployeeDao

    def calculate_pay(self):
        return self.payment_calculator.calculate_pay(self)

    def report_hours(self):
        return self.hour_reporter.report_hours(self)

    def print_report(self):
        return self.report_printer.print_report(self)

    def save_employee(self):
        return self.dao.save_employee(self)

In [None]:
p1 = Project(name="Project 1", assets=10_000.0)
p2: Project = Project(name="Project 2", assets=12_000.0)

In [None]:
db = AugurDatabase()

In [None]:
default_payment_calculator = PaymentCalculator()
default_hour_reporter = HourReporter()
default_report_printer = ReportPrinter()
default_employee_dao = EmployeeDao(db)

In [None]:
e1 = EmployeeV1(
    id=123,
    name="Joe Random",
    salary=1000.0,
    overtime=5,
    employee_type=EmployeeType.REGULAR,
    project=p1,
    payment_calculator=default_payment_calculator,
    hour_reporter=default_hour_reporter,
    report_printer=default_report_printer,
    dao=default_employee_dao,
)

In [None]:
e2 = EmployeeV1(
    id=124,
    name="Jane Ransom",
    salary=1500.0,
    overtime=43,
    employee_type=EmployeeType.HOURED,
    project=p1,
    payment_calculator=default_payment_calculator,
    hour_reporter=default_hour_reporter,
    report_printer=default_report_printer,
    dao=default_employee_dao,
)

In [None]:
e3 = EmployeeV1(
    id=125,
    name="Jill Chance",
    salary=2500.0,
    overtime=2,
    employee_type=EmployeeType.COMMISSIONED,
    project=p2,
    payment_calculator=default_payment_calculator,
    hour_reporter=default_hour_reporter,
    report_printer=default_report_printer,
    dao=default_employee_dao,
)

In [None]:
employees = [e1, e2, e3]

In [None]:
for e in employees:
    print("=" * 35)
    print(f"{e.name} has a salary of {e.calculate_pay():.2f}")
    e.print_report()
    e.save_employee()
print("=" * 35)

In [None]:
pprint(db.records)