# Splitting dataframes based on column values

- toc: true 
- badges: true
- comments: true
- categories: [python, pandas]

In [1]:
import pandas as pd
import numpy as np

## The problem

In [2]:
df = pd.DataFrame(data={'case': ['A', 'A', 'A', 'B', 'A', 'A', 'B', 'A','A'],
                        'data': np.random.randn(9)})
df

Unnamed: 0,case,data
0,A,0.684978
1,A,0.000269
2,A,-1.040497
3,B,0.451358
4,A,0.448596
5,A,0.222168
6,B,1.031011
7,A,-2.208787
8,A,-0.440758


You want to split the dataframe every time case equals B and store the resulting dataframes in a list.

## Understanding the cookbook solution

From the [cookbook](https://pandas.pydata.org/pandas-docs/stable/user_guide/cookbook.html#id2):

In [3]:
dfs = list(zip(*df.groupby((1 * (df['case'] == 'B')).cumsum().rolling(window=3, min_periods=1).median())))[-1]
dfs

(  case      data
 0    A  0.684978
 1    A  0.000269
 2    A -1.040497
 3    B  0.451358,
   case      data
 4    A  0.448596
 5    A  0.222168
 6    B  1.031011,
   case      data
 7    A -2.208787
 8    A -0.440758)

This works. But because it's so heavily nested and uses methods like `rolling()` and `median()` not really designed for that purpose, the code is impossible to interpret at a glance. 

Let's break this down into separate pieces.

First, the code creates a grouping variable that changes its value each time *case* equaled *B* on the previous row. The code below shows how it does this.

In [4]:
# Creating grouping variable

a = (df.case == 'B')
b = 1 * (df.case == 'B')
c = 1 * (df.case == 'B').cumsum()
d = 1 * (df.case == 'B').cumsum().rolling(window=3, min_periods=1).median()

a, b, c, d

(0    False
 1    False
 2    False
 3     True
 4    False
 5    False
 6     True
 7    False
 8    False
 Name: case, dtype: bool,
 0    0
 1    0
 2    0
 3    1
 4    0
 5    0
 6    1
 7    0
 8    0
 Name: case, dtype: int64,
 0    0
 1    0
 2    0
 3    1
 4    1
 5    1
 6    2
 7    2
 8    2
 Name: case, dtype: int64,
 0    0.0
 1    0.0
 2    0.0
 3    0.0
 4    1.0
 5    1.0
 6    1.0
 7    2.0
 8    2.0
 Name: case, dtype: float64)

Series *d* above is the argument passed to `groupby()` in the solution. This works, but is a very roundabout way to create such a series. I'll use a different approach below.

Next, the code uses `list()`, `zip()`, and argument expansion to pack the data for each group into a single list of dataframes. Let's look at these one by one.

First, a quick review of how argument expansion works:

In [5]:
def printer(*args, **kwargs):    
    print('Printing args:')
    for arg in args:
        print(arg)
    print('Printing kwargs:')
    for kwarg in kwargs.items():
        print(kwarg)
    
mylist = ['a', 2, 'k', 3]
mydict = {'first': 1, 'second': 2}

printer(*mylist, **mydict)

Printing args:
a
2
k
3
Printing kwargs:
('first', 1)
('second', 2)


Now, `groupby()` stores the grouped data as (label, dataframe) tuples, like so: 

In [6]:
groups = df.groupby('case')
for g in groups:
    print(g)
    print(type(g))

('A',   case      data
0    A  0.684978
1    A  0.000269
2    A -1.040497
4    A  0.448596
5    A  0.222168
7    A -2.208787
8    A -0.440758)
<class 'tuple'>
('B',   case      data
3    B  0.451358
6    B  1.031011)
<class 'tuple'>


So `zip()` is used to separate the group label from the data, and `list()` consumes the iterator created by zip and displays its content.

In [7]:
list(zip(*groups))

[('A', 'B'),
 (  case      data
  0    A  0.684978
  1    A  0.000269
  2    A -1.040497
  4    A  0.448596
  5    A  0.222168
  7    A -2.208787
  8    A -0.440758,
    case      data
  3    B  0.451358
  6    B  1.031011)]

Because we only want the data, we select the last element from the list: 

In [8]:
list(zip(*groups))[-1]

(  case      data
 0    A  0.684978
 1    A  0.000269
 2    A -1.040497
 4    A  0.448596
 5    A  0.222168
 7    A -2.208787
 8    A -0.440758,
   case      data
 3    B  0.451358
 6    B  1.031011)

Now we're basically done. What remains is to use the `list(zip(*groups))` procedure on the more complicated grouping variable, to obtain the original result.

In [9]:
d = 1 * (df.case == 'B').cumsum().rolling(window=3, min_periods=1).median()
groups = df.groupby(d)
list(zip(*groups))[-1]

(  case      data
 0    A  0.684978
 1    A  0.000269
 2    A -1.040497
 3    B  0.451358,
   case      data
 4    A  0.448596
 5    A  0.222168
 6    B  1.031011,
   case      data
 7    A -2.208787
 8    A -0.440758)

## Simplifying the code

I think this can be made much more readable like so:

In [10]:
df

Unnamed: 0,case,data
0,A,0.684978
1,A,0.000269
2,A,-1.040497
3,B,0.451358
4,A,0.448596
5,A,0.222168
6,B,1.031011
7,A,-2.208787
8,A,-0.440758


In [11]:
grouper = df.case.eq('B').cumsum().shift().fillna(0)
grouper

0    0.0
1    0.0
2    0.0
3    0.0
4    1.0
5    1.0
6    1.0
7    2.0
8    2.0
Name: case, dtype: float64

In [12]:
dfs = [df for (g, df) in df.groupby(grouper)]
dfs

[  case      data
 0    A  0.684978
 1    A  0.000269
 2    A -1.040497
 3    B  0.451358,
   case      data
 4    A  0.448596
 5    A  0.222168
 6    B  1.031011,
   case      data
 7    A -2.208787
 8    A -0.440758]

In case the logic of this isn't immediately obvious, the below makes clear what's going on.

In [13]:
dd = df.set_index('case', drop=False)   # Use case as index for clarity
a = dd.case.eq('B')                     # Boolean logic
b = a.cumsum()                          # Create groups
c = b.shift()                           # Shift so B included in previous group
d = c.fillna(0)                         # Replace 0th element emptied by shift
a, b, c, d

(case
 A    False
 A    False
 A    False
 B     True
 A    False
 A    False
 B     True
 A    False
 A    False
 Name: case, dtype: bool,
 case
 A    0
 A    0
 A    0
 B    1
 A    1
 A    1
 B    2
 A    2
 A    2
 Name: case, dtype: int64,
 case
 A    NaN
 A    0.0
 A    0.0
 B    0.0
 A    1.0
 A    1.0
 B    1.0
 A    2.0
 A    2.0
 Name: case, dtype: float64,
 case
 A    0.0
 A    0.0
 A    0.0
 B    0.0
 A    1.0
 A    1.0
 B    1.0
 A    2.0
 A    2.0
 Name: case, dtype: float64)