# Data drift notebook

In [32]:
import copy
import os

import torch
import torchdrift
from omegaconf import OmegaConf

from FishEye.data.data_module import FishDataModule
from FishEye.models.model import FishNN

if "notebooks" in os.getcwd():
    os.chdir("..")

In [33]:
model = FishNN.load_from_checkpoint("models/epoch=379-step=4180.ckpt", cfg=OmegaConf.load("config/config.yaml"))
model.to("cuda")

FishNN(
  (accuracy): MulticlassAccuracy()
  (feature_extractor): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (classifier): Sequential(
    (0): Linear(in_features=124416, out_features=256, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): Linear(in_features=256, out_features=9, bias=True)
    (3): Dropout(p=0.2, inplace=False)
  )
  (criterion): CrossEntropyLoss()
)

In [34]:
feature_extractor = copy.deepcopy(model)
feature_extractor.classifier = torch.nn.Identity()
drift_detector = torchdrift.detectors.KernelMMDDriftDetector()

data_module_aug = FishDataModule(augment=True)
data_module = FishDataModule(augment=False)

data_module_aug.setup(stage="fit")
data_module.setup(stage="fit")

In [35]:
torchdrift.utils.fit(data_module.train_dataloader(), feature_extractor, drift_detector)

100%|██████████| 11/11 [00:02<00:00,  4.98it/s]


In [36]:
drift_detection_model = torch.nn.Sequential(
    feature_extractor,
    drift_detector
)

In [43]:
# get batch from augmented data
batch = next(iter(data_module.train_dataloader()))
data, target = batch

features = feature_extractor(data.to("cuda")[0].unsqueeze(0))
score = drift_detector(features)
p_val = drift_detector.compute_p_value(features)
score, p_val

(tensor(inf, device='cuda:0'), tensor(0., device='cuda:0'))