# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#|exports
from __future__ import annotations
import pandas as pd
import numpy as np

First, we set up an example dataset to showcase the functions in this module.

In [None]:
df = pd.DataFrame(np.random.rand(8,2), 
                  columns=list('AB'), 
                  index=pd.MultiIndex.from_product(
                      [[1,2],
                       pd.to_datetime(['2010-01','2010-02','2010-02','2010-04']
                                      ).to_period('M')],
                      names=['permno','Mdate']))
df

Unnamed: 0_level_0,Unnamed: 1_level_0,A,B
permno,Mdate,Unnamed: 2_level_1,Unnamed: 3_level_1
1,2010-01,0.555742,0.941069
1,2010-02,0.623244,0.047846
1,2010-02,0.879142,0.531396
1,2010-04,0.160305,0.790609
2,2010-01,0.091188,0.128479
2,2010-02,0.588596,0.320304
2,2010-02,0.389166,0.20489
2,2010-04,0.239609,0.105904


### Robust lagging

Note how `shift` fails when we have (1) panel data, (2) duplicate dates, or (3) gaps in the time-series

In [None]:
df.shift()

Unnamed: 0_level_0,Unnamed: 1_level_0,A,B
permno,Mdate,Unnamed: 2_level_1,Unnamed: 3_level_1
1,2010-01,,
1,2010-02,0.555742,0.941069
1,2010-02,0.623244,0.047846
1,2010-04,0.879142,0.531396
2,2010-01,0.160305,0.790609
2,2010-02,0.091188,0.128479
2,2010-02,0.588596,0.320304
2,2010-04,0.389166,0.20489


In [None]:
#|export
def fast_lag(df: pd.Series|pd.DataFrame, # Index (or level 1 of MultiIndex) must be period date
        n: int=1, # Number of periods to lag based on frequency of df.index; Negative values means lead.
        ) -> pd.Series: # Series with lagged values; Name is taken from `df`, with _lag{n} or _lead{n} added
    """Lag data in 'df' by 'n' periods. 
    ASSUMES DATA SORTED BY DATES AND NO DUPLICATE OR MISSING DATES."""

    if isinstance(df,pd.Series): df = df.to_frame()
    if len(df.columns) > 1: raise ValueError("<df> must have a single column")
    dfl = df.copy()
    old_name = str(df.columns[0])
    new_varname = old_name + f'_lag{n}' if n>=0 else old_name + f'_lead{-n}'
    
    if isinstance(df.index, pd.MultiIndex):
        if f'{df.index.levels[1].dtype}'.startswith('period'):
            (panelvar, timevar) = dfl.index.names
            dfl = dfl.reset_index()
            dfl[['lag_panel','lag_time',new_varname]] = dfl[[panelvar, timevar, old_name]].shift(n)
            dfl[new_varname] = np.where((dfl[panelvar]==dfl['lag_panel']) & (dfl[timevar]==dfl['lag_time']+n),
                                        dfl[new_varname], np.nan)
            dfl = dfl.set_index([panelvar, timevar])
        else:
            raise ValueError('Dimension 1 of multiindex must be period date')
    else:
        if f'{df.index.dtype}'.startswith('period'):
            timevar = dfl.index.name
            dfl = dfl.reset_index()
            dfl[['lag_time',new_varname]] = dfl[[timevar, old_name]].shift(n)
            dfl[new_varname] = np.where((dfl[timevar]==dfl['lag_time']+n),
                                        dfl[new_varname], np.nan)
            dfl = dfl.set_index([timevar])
        else:
            raise ValueError('Index must be period date')
    return dfl[new_varname].squeeze()

In [None]:
#|export
def lag(df: pd.Series|pd.DataFrame, # Index (or level 1 of MultiIndex) must be period date with no missing values.
        n: int=1, # Number of periods to lag based on frequency of df.index; Negative values means lead.
        fast: bool=True, # Assumes data is sorted by date and no duplicate or missing dates
        ) -> pd.Series: # Series with lagged values; Name is taken from `df`, with _lag{n} or _lead{n} added
    """Lag data in 'df' by 'n' periods."""

    if fast: return fast_lag(df,n)

    if isinstance(df,pd.Series): df = df.to_frame()
    if len(df.columns) > 1: raise ValueError("'df' parameter must have a single column")
    dfl = df.copy()
    dfl.columns = [str(df.columns[0]) + f'_lag{n}'] if n>=0 else df.columns + f'_lead{-n}'

    if isinstance(df.index, pd.MultiIndex):
        if f'{df.index.levels[1].dtype}'.startswith('period'):
            dfl.index = dfl.index.set_levels(df.index.levels[1]+n, level=1)
        else:
            raise ValueError('Dimension 1 of multiindex must be period date')
    else:
        if f'{df.index.dtype}'.startswith('period'):
            dfl.index += n
        else:
            raise ValueError('Index must be period date')

    dfl = df.join(dfl).drop(columns=df.columns)
    return dfl.squeeze()

The index of the `df` parameter can not contain missing values.

In [None]:
lag(df['A'])

permno  Mdate  
1       2010-01         NaN
        2010-02    0.555742
        2010-02         NaN
        2010-04         NaN
2       2010-01         NaN
        2010-02    0.091188
        2010-02         NaN
        2010-04         NaN
Name: A_lag1, dtype: float64

In [None]:
lag(df['A'],fast=False)

permno  Mdate  
1       2010-01         NaN
        2010-02    0.555742
        2010-02    0.555742
        2010-04         NaN
2       2010-01         NaN
        2010-02    0.091188
        2010-02    0.091188
        2010-04         NaN
Name: A_lag1, dtype: float64

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()