In [None]:
import configs.pyroot_config as pyroot_config
import os
from torch.utils.data import default_collate
import torch
from src.data.mice_dataset import MouseSniffingVideoDatasetMultipleFramesLabeled
from src.data.mice_dataset_factory import create_dataset, trails_split
import torchvision.transforms.v2 as tt
from models.lightning_module import DeepSniff
from models.models import MobileNetV3
from pytorch_lightning.callbacks import *
import random

config_paths = pyroot_config.ConfigPaths()

In [None]:
### TRANSFORMS
loading_transforms=tt.Compose([tt.PILToTensor(),
                        tt.Resize([112, 112] ,antialias=True)])
                        
transforms = tt.Compose([
                        tt.ConvertImageDtype(torch.float16),
                        tt.Normalize(mean=[0.36], std=[0.2])])


transforms_val = tt.Compose([
                            tt.ConvertImageDtype(torch.float16),
                            tt.Normalize(mean=[0.36], std=[0.2])])

### DATASET

data_dir = config_paths.data_processed / 'sniff-training-dataset'

#list all subdirs
trails = [f.path for f in os.scandir(data_dir) if f.is_dir()]
#print(trails)

#train val test 0.9, 0.1, 0.1
trails = random.sample(trails, len(trails))
train_trails = trails[:int(0.8*len(trails))][:2]
val_trails = trails[int(0.8*len(trails)):int(0.9*len(trails))]
test_trails = trails[int(0.9*len(trails)):]

#print len
print(len(train_trails), len(val_trails), len(test_trails))


window_size = 5         # Must be odd number
signal_window_size = 1  # Must be odd number

train_datasets = []
for trail in train_trails:
    train_datasets.append(MouseSniffingVideoDatasetMultipleFramesLabeled(root_dir=trail,
                                                                        video_path='cropped_frames',
                                                                        signal_path='breathing_onsets.txt',
                                                                        window_size=window_size,
                                                                        signal_window_size=signal_window_size,
                                                                        transforms=transforms,
                                                                        loading_transforms=loading_transforms,
                                                                        load_in_memory=True))
train_dataset = torch.utils.data.ConcatDataset(train_datasets)

val_datasets = []
for trail in val_trails:
    val_datasets.append(MouseSniffingVideoDatasetMultipleFramesLabeled(root_dir=trail,
                                                                        video_path='cropped_frames',
                                                                        signal_path='breathing_onsets.txt',
                                                                        window_size=window_size,
                                                                        signal_window_size=signal_window_size,
                                                                        transforms=transforms_val,
                                                                        loading_transforms=loading_transforms,
                                                                        load_in_memory=True))
val_dataset = torch.utils.data.ConcatDataset(val_datasets)

test_datasets = []
for trail in test_trails:
    test_datasets.append(MouseSniffingVideoDatasetMultipleFramesLabeled(root_dir=trail,
                                                                        video_path='cropped_frames',
                                                                        signal_path='breathing_onsets.txt',
                                                                        window_size=window_size,
                                                                        signal_window_size=signal_window_size,
                                                                        transforms=transforms_val,
                                                                        loading_transforms=loading_transforms,
                                                                        load_in_memory=True))
test_dataset = torch.utils.data.ConcatDataset(test_datasets)



def collate_fn(batch):
    batch = default_collate(batch)
    batch[0] = batch[0].squeeze([1,2])
    return batch



## MODEL

#weights models/train/annotated/default_train/-step=840-epoch=14-val_loss=1.6483.ckpt
weights = config_paths.project_root / 'models/train/annotated/default_train/-step=1176-epoch=20-val_loss=0.8802.ckpt'
n_input_channels = window_size
output_dim = signal_window_size

network = MobileNetV3(n_input_channels=n_input_channels, output_dim=output_dim, weights=weights)



model = DeepSniff.load_from_checkpoint(weights)

#subsets of datasetes
batch_size = 256

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=10, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=10, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=10, collate_fn=collate_fn)


sample = next(iter(train_loader))
print(sample[0].shape, sample[1].shape)

print(sample[0])
print(train_dataset[0]) 

In [None]:
import matplotlib.pyplot as plt
#inference train loader


outputs = []
gt = []
with torch.no_grad():
    for sample in val_loader:
        #sample to cuda and as float
        gt.append(sample[1])
        output = model(sample[0].cuda().float())
        outputs.append(output)

    gt = torch.cat(gt, dim=0)
    outputs = torch.cat(outputs, dim=0)
    print(outputs.shape)


In [None]:
#Plt size
plt.figure(figsize=(30, 10))
plt.plot(outputs.squeeze().detach().cpu().numpy())
#plt ground truth
plt.plot(gt.squeeze().detach().cpu().numpy())


In [None]:

# iterate over all models in /workspaces/markoc-haeslerlab/sniff-extraction/models/train/annotated/default_train and plot output
import os
import matplotlib.pyplot as plt
#inference train loader
n_input_channels = window_size
output_dim = signal_window_size


weights_dir = config_paths.project_root / 'models/train/annotated/default_train'
weights_list = [f.path for f in os.scandir(weights_dir) if f.is_file() and f.path.endswith('.ckpt')]
#remove checkpoint called best
weights_list = [x for x in weights_list if 'best.ckpt' not in x]
#sort by epoch
weights_list = sorted(weights_list, key=lambda x: int(x.split("-epoch=")[1].split("-")[0]))

#weights models/train/annotated/default_train/-step=840-epoch=14-val_loss=1.6483.ckpt
for weights in weights_list: 


    network = MobileNetV3(n_input_channels=n_input_channels, output_dim=output_dim, weights=weights)
    model = DeepSniff.load_from_checkpoint(weights)

    #data/processed/sniff-training-dataset/220221_RDP043_7

    dataset = MouseSniffingVideoDatasetMultipleFramesLabeled(root_dir=trail,
                                                                            video_path='cropped_frames',
                                                                            signal_path='breathing_onsets.txt',
                                                                            window_size=window_size,
                                                                            signal_window_size=signal_window_size,
                                                                            transforms=transforms_val,
                                                                            loading_transforms=loading_transforms,
                                                                            load_in_memory=True)

    loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=10, collate_fn=collate_fn)


    outputs = []
    gt = []
    with torch.no_grad():
        for sample in loader:
            #sample to cuda and as float
            gt.append(sample[1])
            #with sigmoide
            output = model(sample[0].cuda().float())
            output = torch.sigmoid(output)
            outputs.append(output)

        gt = torch.cat(gt, dim=0)
        outputs = torch.cat(outputs, dim=0)
        print(outputs.shape)

        #plot
        plt.figure(figsize=(30, 5))
        #dot dash
        plt.plot(outputs.squeeze().detach().cpu().numpy(), linestyle='solid')
        #plt ground truth dashed
        plt.plot(gt.squeeze().detach().cpu().numpy(), linestyle='dashed')
        plt.title(weights)
        plt.show()
        plt.close()

In [None]:

data_dir = config_paths.data_processed / 'sniff-training-dataset'

#list all subdirs
trails = [f.path for f in os.scandir(data_dir) if f.is_dir()]

### TRANSFORMS
loading_transforms=tt.Compose([tt.PILToTensor(),
                        tt.Resize([112, 112] ,antialias=True)])
                        
transforms = tt.Compose([
                        tt.Lambda(lambda x: x.permute(1, 0, 2, 3) ),
                        tt.ConvertImageDtype(torch.float16),
                        tt.Normalize(mean=[0.36], std=[0.2])])


transforms_val = tt.Compose([
                            tt.Lambda(lambda x: x.permute(1, 0, 2, 3) ),
                            tt.ConvertImageDtype(torch.float16),
                            tt.Normalize(mean=[0.36], std=[0.2])])


### DATASET
datasets = []
for trail in trails:
    datasets.append(MouseSniffingVideoDatasetMultipleFramesLabeled(root_dir=trail,
                                                                        video_path='cropped_frames',
                                                                        signal_path='breathing_onsets.txt',
                                                                        window_size=window_size,
                                                                        signal_window_size=signal_window_size,
                                                                        transforms=transforms,
                                                                        loading_transforms=loading_transforms,
                                                                        load_in_memory=True))
train_dataset = torch.utils.data.ConcatDataset(train_datasets)


#models/train/annotated/default_train/-step=1288-epoch=22-val_loss=2.9251.ckpt

weights = config_paths.project_root / 'models/train/annotated/default_train_3d_conv/-step=612-epoch=05-val_loss=0.5442.ckpt'
n_input_channels = window_size
model = DeepSniff.load_from_checkpoint(weights)

for dataset in datasets:
    loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=10, collate_fn=collate_fn)

    outputs = []
    gt = []
    with torch.no_grad():
        for sample in loader:
            #sample to cuda and as float
            gt.append(sample[1])
            output = model(sample[0].cuda().float())
            output = torch.sigmoid(output)
            outputs.append(output)

        gt = torch.cat(gt, dim=0)
        outputs = torch.cat(outputs, dim=0)
        print(outputs.shape)

        #plot
        plt.figure(figsize=(30, 5))
        plt.plot(outputs.squeeze().detach().cpu().numpy())
        plt.plot(gt.squeeze().detach().cpu().numpy(), linestyle='dashed')
        plt.title(dataset.root_dir)
        plt.show()
        plt.close()