# Train Test Label Drift

In [1]:
import numpy as np
import pandas as pd

from deepchecks.tabular import Dataset
from deepchecks.tabular.checks import TrainTestLabelDrift
import pprint

## Generate data - Classification label

In [2]:
np.random.seed(42)

train_data = np.concatenate([np.random.randn(1000,2), np.random.choice(a=[1,0], p=[0.5, 0.5], size=(1000, 1))], axis=1)
#Create test_data with drift in label:
test_data = np.concatenate([np.random.randn(1000,2), np.random.choice(a=[1,0], p=[0.35, 0.65], size=(1000, 1))], axis=1) 

df_train = pd.DataFrame(train_data, columns=['col1', 'col2', 'target'])
df_test = pd.DataFrame(test_data, columns=['col1', 'col2', 'target'])

train_dataset = Dataset(df_train, label='target')
test_dataset = Dataset(df_test, label='target')


In [3]:
df_train.head()

Unnamed: 0,col1,col2,target
0,0.496714,-0.138264,1.0
1,0.647689,1.52303,1.0
2,-0.234153,-0.234137,1.0
3,1.579213,0.767435,1.0
4,-0.469474,0.54256,0.0


## Run check

In [4]:
check = TrainTestLabelDrift()
result = check.run(train_dataset=train_dataset, test_dataset=test_dataset)
result

## Generate data - Regression label

In [5]:
train_data = np.concatenate([np.random.randn(1000,2), np.random.randn(1000, 1)], axis=1)
test_data = np.concatenate([np.random.randn(1000,2), np.random.randn(1000, 1)], axis=1)

df_train = pd.DataFrame(train_data, columns=['col1', 'col2', 'target'])
df_test = pd.DataFrame(test_data, columns=['col1', 'col2', 'target'])
#Create drift in test:
df_test['target'] = df_test['target'].astype('float') + abs(np.random.randn(1000)) + np.arange(0, 1, 0.001) * 4

train_dataset = Dataset(df_train, label='target')
test_dataset = Dataset(df_test, label='target')

## Run check

In [6]:
check = TrainTestLabelDrift()
result = check.run(train_dataset=train_dataset, test_dataset=test_dataset)
result

### Add condition

In [7]:
check_cond = TrainTestLabelDrift().add_condition_drift_score_not_greater_than()
check_cond.run(train_dataset=train_dataset, test_dataset=test_dataset)

Status,Condition,More Info
✖,PSI <= 0.2 and Earth Mover's Distance <= 0.1 for label drift,Label's Earth Mover's Distance above threshold: 0.25
