In [2]:
import seaborn as sns
import pandas as pd
import missingno
import numpy as np
import matplotlib.pyplot as plt
df = sns.load_dataset('titanic')

In [3]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 15 columns):
 #   Column       Non-Null Count  Dtype   
---  ------       --------------  -----   
 0   survived     891 non-null    int64   
 1   pclass       891 non-null    int64   
 2   sex          891 non-null    object  
 3   age          714 non-null    float64 
 4   sibsp        891 non-null    int64   
 5   parch        891 non-null    int64   
 6   fare         891 non-null    float64 
 7   embarked     889 non-null    object  
 8   class        891 non-null    category
 9   who          891 non-null    object  
 10  adult_male   891 non-null    bool    
 11  deck         203 non-null    category
 12  embark_town  889 non-null    object  
 13  alive        891 non-null    object  
 14  alone        891 non-null    bool    
dtypes: bool(2), category(2), float64(2), int64(4), object(5)
memory usage: 80.7+ KB


## 1. Identify numeric datatype and fill NaN with mean values (with an optional groupby)

In [4]:
df['age'].isnull().value_counts()

Unnamed: 0_level_0,count
age,Unnamed: 1_level_1
False,714
True,177


In [51]:
def numeric_fillna(df,groupcols=None,stat_func='mean'):
  """
  Fill NaN in `df` with aggregate statistic values (mean by default) with an optional groupby based on columns specified in `groupcols`.
  df: Dataframe containing all columns
  groupcols: List(columns) or None
  stat_func: 'mean', 'median', 'min', 'max', 'std', 'var', 'sum', 'mode' ('mode' is not recommended when the most frequent value could be NaN)
  """
  if df is None or not isinstance(df, pd.DataFrame) or stat_func not in ['mean', 'median', 'min', 'max', 'std', 'var', 'sum', 'mode']:
    print("Please provide a valid input")
    return
  df_copy = df.copy()
  cols = df.select_dtypes(include=['float64', 'int64']).columns
  #print(cols)
  try:
    if groupcols:
      df_copy[cols] = df_copy.groupby(groupcols)[cols].transform(lambda x:x.fillna(getattr(x,stat_func)(),inplace=False))
    else:
      df_copy[cols] = df_copy[cols].transform(lambda x:x.fillna(getattr(x,stat_func)(),inplace=False))
  except Exception as e:
    print(e)
  else:
    return df_copy

In [47]:
df['age'].mean(), df['age'][df['sex']=='male'].mean(), df['age'].median(), df['age'][df['sex']=='male'].median()

(np.float64(29.69911764705882), np.float64(30.72664459161148), 28.0, 29.0)

In [48]:
df_1 = numeric_fillna(df, stat_func='median')
display(df['age'][:10],df_1['age'][:10])

Unnamed: 0,age
0,22.0
1,38.0
2,26.0
3,35.0
4,35.0
5,
6,54.0
7,2.0
8,27.0
9,14.0


Unnamed: 0,age
0,22.0
1,38.0
2,26.0
3,35.0
4,35.0
5,28.0
6,54.0
7,2.0
8,27.0
9,14.0


In [49]:
df_2 = numeric_fillna(df,groupcols=['sex'],stat_func='median')
display(df_2['age'][:10])

Unnamed: 0,age
0,22.0
1,38.0
2,26.0
3,35.0
4,35.0
5,29.0
6,54.0
7,2.0
8,27.0
9,14.0


In [15]:
display(df['age'].isnull().value_counts())
display(df_1['age'].isnull().value_counts())
display(df_2['age'].isnull().value_counts())

Unnamed: 0_level_0,count
age,Unnamed: 1_level_1
False,714
True,177


Unnamed: 0_level_0,count
age,Unnamed: 1_level_1
False,891


Unnamed: 0_level_0,count
age,Unnamed: 1_level_1
False,891
