# OOP 5: Subclasses and Inheritance in Python

In this notebook, we'll explore subclasses and inheritance in Python. We'll look at how to define subclasses, how inheritance works, and how to override methods and use the super() function.

## Table of Contents

1. [Introduction to Inheritance](#1)
2. [Defining Subclasses](#2)
3. [Overriding Methods](#3)
4. [Using the super() Function](#4)
5. [Step-by-Step Example](#5)
6. [Exercise: Extending a Data Analysis Class](#6)

---
## 1. Introduction to Inheritance <a id="1"></a>

Inheritance is a fundamental concept in object-oriented programming that allows a class (subclass) to inherit attributes and methods from another class (superclass). This helps to promote code reuse and organize code into a hierarchical structure.

In this example, `Dog` and `Cat` inherit from the `Animal` class and override the `speak()` method.

In [5]:
class Animal:
    def __init__(self, name):
        self.name = name
    
    def speak(self):
        raise NotImplementedError("Subclass must implement abstract method")

class Dog(Animal):
    def speak(self):
        return f"{self.name} says Woof!"

class Cat(Animal):
    def speak(self):
        return f"{self.name} says Meow!"

my_dog = Dog("Frank")
print(my_dog.speak())

my_cat = Cat("Bob")
print(my_cat.speak())

Frank says Woof!
Bob says Meow!


In [6]:
Animal.__bases__

(object,)

---
## 2. Defining Subclasses <a id="2"></a>

To define a subclass, use the following syntax:

```python
class SubclassName(SuperclassName):
    # subclass body
```

In this example, `Car` is a subclass of `Vehicle` and adds an additional attribute `doors`.

In [None]:
class Vehicle:
    def __init__(self, brand, model):
        self._brand = brand
        self._model = model
    
    def info(self):
        return f"Vehicle: {self.brand} {self.model}"

class Car(Vehicle):
    def __init__(self, brand, model, doors):
        self._brand = brand
        self._model = model
        self.doors = doors

my_new_car = Car(brand="Ferrari", model="LaFerrari", doors=2)
my_new_car.info()

In [None]:
class Vehicles:
    def __init__(self, brand, model):
        self.brand = brand
        self.model = model
    
    def info(self):
        return f"Vehicle: {self.brand} {self.model}"

class Car(Vehicles):
    def __init__(self, brand, model, doors):
        super().__init__(brand, model)
        self.doors = doors

my_new_car = Car(brand="Ferrari", model="LaFerrari", doors=2)
my_new_car.info()

---
## 3. Overriding Methods <a id="3"></a>
A subclass can override methods of its superclass to provide specific implementations.

In this example, the `info()` method is overridden in the `Car` class to provide more specific information.

In [None]:
class Vehicle:
    def __init__(self, brand, model):
        self.brand = brand
        self.model = model
    
    def info(self):
        return f"Vehicle: {self.brand} {self.model}"

class Car(Vehicle):
    def __init__(self, brand, model, doors):
        super().__init__(brand, model)
        self.doors = doors
    
    def info(self):
        return f"Car: {self.brand} {self.model} with {self.doors} doors"
    
my_new_car = Car(brand="Ferrari", model="LaFerrari", doors=2)
my_new_car.info()

---
## 4. Using the super() Function <a id="4"></a>

The `super()` function allows you to call methods from the superclass in the subclass. This is useful for extending or modifying the behavior of inherited methods.


In this example, `super().info()` is used to call the `info()` method from the `Vehicle` class, and additional information is appended in the `Car` class.

In [None]:
class Vehicle:
    def __init__(self, brand, model):
        self.brand = brand
        self.model = model
    
    def info(self):
        return f"Vehicle: {self.brand} {self.model}"

class Car(Vehicle):
    def __init__(self, brand, model, doors):
        super().__init__(brand, model)
        self.doors = doors
    
    def info(self):
        return super().info() + f" with {self.doors} doors"
    
my_new_car = Car(brand="Ferrari", model="LaFerrari", doors=2)
my_new_car.info()

---
## 5. Step-by-Step Example <a id="5"></a>

Let's create a comprehensive example by implementing a class hierarchy for data analysis tasks.

### 5.1. DataAnalysis Class

In [None]:
import pandas as pd

class DataAnalysis:
    """
    Performs basic data analysis operations on a given dataset.

    Attributes:
    data (dict): A dictionary of data where keys are column names and values are lists of data.
    _df (DataFrame): A pandas DataFrame created from the data dictionary.
    """

    def __init__(self, data):
        """
        Initializes the DataAnalysis object with a dictionary of data.

        Parameters:
        data (dict): A dictionary of data where keys are column names and values are lists of data.
        """
        self.data = data
        self._df = pd.DataFrame(data)
    
    def summary(self):
        """
        Returns a summary of the data, including the number of rows and columns.

        Returns:
        str: A string summarizing the number of rows and columns in the dataset.
        """
        return f"Dataset with {len(self._df)} rows and {len(self._df.columns)} columns"
    
    def mean(self, column):
        """
        Returns the mean of a specified column.

        Parameters:
        column (str): The column for which to calculate the mean.

        Returns:
        float: The mean of the specified column.
        """
        return self._df[column].mean()
    
    def median(self, column):
        """
        Returns the median of a specified column.

        Parameters:
        column (str): The column for which to calculate the median.

        Returns:
        float: The median of the specified column.
        """
        return self._df[column].median()


### 5.2. AdvancedDataAnalysis Class

In [None]:
class AdvancedDataAnalysis(DataAnalysis):
    """
    Performs advanced data analysis operations on a given dataset.

    Inherits from the DataAnalysis class.

    Attributes:
    data (dict): A dictionary of data where keys are column names and values are lists of data.
    _df (DataFrame): A pandas DataFrame created from the data dictionary.
    """
    
    def __init__(self, data):
        """
        Initializes the AdvancedDataAnalysis object with a dictionary of data.

        Parameters:
        data (dict): A dictionary of data where keys are column names and values are lists of data.
        """
        super().__init__(data)
    
    def variance(self, column):
        """
        Returns the variance of a specified column.

        Parameters:
        column (str): The column for which to calculate the variance.

        Returns:
        float: The variance of the specified column.
        """
        return self._df[column].var()
    
    def standard_deviation(self, column):
        """
        Returns the standard deviation of a specified column.

        Parameters:
        column (str): The column for which to calculate the standard deviation.

        Returns:
        float: The standard deviation of the specified column.
        """
        return self._df[column].std()

### 5.3. Testing

In [None]:
# Usage
data = {
    'age': [25, 30, 35, 40, 45],
    'salary': [50000, 60000, 70000, 80000, 90000]
}
analysis = DataAnalysis(data)
advanced_analysis = AdvancedDataAnalysis(data)

print(advanced_analysis.summary())  # Inherited method
print(advanced_analysis.mean('age'))  # Inherited method
print(advanced_analysis.median('salary'))  # Inherited method
print(advanced_analysis.variance('age'))  # New method
print(advanced_analysis.standard_deviation('salary'))  # New method

---
## 6. Exercise: Extending a Data Analysis Class <a id="6"></a>

In this exercise, you will extend the `DataAnalysis` class with the `StatisticalAnalysis` class that includes additional statistical methods. The `StatisticalAnalysis` class should include the following methods:

- `mode(self, column)`: Returns the mode of a specified column.
- `range(self, column)`: Returns the range (difference between max and min) of a specified column.
- `interquartile_range(self, column)`: Returns the interquartile range (IQR) of a specified column.

### 6.1. DataAnalysis Class Definition

In [None]:
import pandas as pd

class DataAnalysis:
    """
    Performs basic data analysis operations on a given dataset.

    Attributes:
    data (dict): A dictionary of data where keys are column names and values are lists of data.
    _df (DataFrame): A pandas DataFrame created from the data dictionary.
    """

    def __init__(self, data):
        """
        Initializes the DataAnalysis object with a dictionary of data.

        Parameters:
        data (dict): A dictionary of data where keys are column names and values are lists of data.
        """
        self.data = data
        self._df = pd.DataFrame(data)
    
    def summary(self):
        """
        Returns a summary of the data, including the number of rows and columns.

        Returns:
        str: A string summarizing the number of rows and columns in the dataset.
        """
        return f"Dataset with {len(self._df)} rows and {len(self._df.columns)} columns"
    
    def mean(self, column):
        """
        Returns the mean of a specified column.

        Parameters:
        column (str): The column for which to calculate the mean.

        Returns:
        float: The mean of the specified column.
        """
        return self._df[column].mean()
    
    def median(self, column):
        """
        Returns the median of a specified column.

        Parameters:
        column (str): The column for which to calculate the median.

        Returns:
        float: The median of the specified column.
        """
        return self._df[column].median()

### 6.2. StatisticalAnalysis Class Definition

In [None]:
# Extend the DataAnalysis class with the StatisticalAnalysis class


><details>
><summary>Do you need some help?</summary>
>
>Tips:
>
>- Make sure to follow best practices for defining subclasses and using inheritance.
>- Test each method to ensure it behaves as expected.
>- Use pandas' built-in methods to simplify operations where possible.
>
> Here is a working solution:
>
>
>```python
>
>
>class StatisticalAnalysis(DataAnalysis):
>    """
>    Performs advanced statistical analysis operations on a given dataset.
>
>    Inherits from the DataAnalysis class.
>
>    Attributes:
>    data (dict): A dictionary of data where keys are column names and values are lists of data.
>    _df (DataFrame): A pandas DataFrame created from the data dictionary.
>    """
>    
>    def __init__(self, data):
>        """
>        Initializes the StatisticalAnalysis object with a dictionary of data.
>
>        Parameters:
>        data (dict): A dictionary of data where keys are column names and values are lists of data.
>        """
>        super().__init__(data)
>    
>    def mode(self, column):
>        """
>        Returns the mode of a specified column.
>
>        Parameters:
>        column (str): The column for which to calculate the mode.
>
>        Returns:
>        list: The mode(s) of the specified column.
>        """
>        return self._df[column].mode().tolist()
>    
>    def range(self, column):
>        """
>        Returns the range (difference between max and min) of a specified column.
>
>        Parameters:
>        column (str): The column for which to calculate the range.
>
>        Returns:
>        float: The range of the specified column.
>        """
>        return self._df[column].max() - self._df[column].min()
>    
>    def interquartile_range(self, column):
>        """
>        Returns the interquartile range (IQR) of a specified column.
>
>        Parameters:
>        column (str): The column for which to calculate the IQR.
>
>        Returns:
>        float: The interquartile range of the specified column.
>        """
>        q1 = self._df[column].quantile(0.25)
>        q3 = self._df[column].quantile(0.75)
>        return q3 - q1
>```

Try now if your code worked as expected. Run the following cell:

In [None]:
# Test your implementation with the example usage provided
data = {
    'age': [25, 30, 35, 40, 45, 30, 35],
    'salary': [50000, 60000, 70000, 80000, 90000, 60000, 70000]
    }
stat_analysis = StatisticalAnalysis(data)

print(stat_analysis.mode('age'))
print(stat_analysis.range('salary'))
print(stat_analysis.interquartile_range('age'))
print(stat_analysis.summary())
