In [None]:
import os
import shutil
import warnings
import csv
import yaml
import json
import torch

from PIL import Image
import pandas as pd
from typing import Any, Sequence

from os import PathLike
from torch.utils.data import Dataset

from megadetector.detection.run_detector import load_detector, model_string_to_model_version
from megadetector.detection.run_detector_batch import process_images, write_results_to_file


from sklearn.model_selection import train_test_split

In [1]:
import random
import torch

import numpy as np
from pathlib import Path
from PIL import Image

from torchvision.transforms import v2
from torch.utils.data import DataLoader

from ba_dev.dataset import MammaliaDataSequence, MammaliaDataImage
from ba_dev.transform import ImagePipeline, BatchImagePipeline
from ba_dev.utils import load_config_yaml

paths = load_config_yaml('../path_config.yml')


### Running Tests

In [2]:
path_to_dataset = paths['dataset']
path_labelfiles = paths['test_labels']
path_to_detector_output = paths['md_output']
detector_model='mdv5a'
mode='train'

transform = ImagePipeline(
                path_to_dataset=path_to_dataset,
                pre_ops = [
                    ('to_rgb', {}),
                    ('crop_by_bb', {})
                ],
                transform = None
                )
                
dataset = MammaliaDataImage(
    path_labelfiles=path_labelfiles,
    path_to_dataset=path_to_dataset,
    path_to_detector_output=path_to_detector_output,
    detector_model=detector_model,
    mode=mode,
    transform=transform
)

8 sequences had no detections and will be excluded.
Excluded sequences: [6000161, 6000163, 6000293, 6000530, 6000691, 6000372, 6000953, 6000186]


In [3]:
dataset.ds

Unnamed: 0,seq_id,class_id,label2,SerialNumber,file_path,bbox,conf
0,4007156,0,apodemus_sp,H550HG09194945,sessions/session_04/W2-WK02/IMG_6165.JPG,"[0.0, 0.4687, 0.4111, 0.3671]",0.983
1,4007156,0,apodemus_sp,H550HG09194945,sessions/session_04/W2-WK02/IMG_6164.JPG,"[0.0004, 0.3203, 0.4101, 0.5136]",0.974
2,4007156,0,apodemus_sp,H550HG09194945,sessions/session_04/W2-WK02/IMG_6162.JPG,"[0.0639, 0.5429, 0.1953, 0.3268]",0.972
3,4007156,0,apodemus_sp,H550HG09194945,sessions/session_04/W2-WK02/IMG_6168.JPG,"[0.0, 0.5279, 0.4008, 0.3001]",0.972
4,4007156,0,apodemus_sp,H550HG09194945,sessions/session_04/W2-WK02/IMG_6178.JPG,"[0.1171, 0.4915, 0.5185, 0.2434]",0.970
...,...,...,...,...,...,...,...
930,4010196,3,soricidae,H550HF07158878,sessions/session_04/W4-M8/IMG_0290.JPG,"[0.6191, 0.5618, 0.1914, 0.1171]",0.962
931,4010196,3,soricidae,H550HF07158878,sessions/session_04/W4-M8/IMG_0287.JPG,"[0.5415, 0.5852, 0.2587, 0.1002]",0.951
932,4010196,3,soricidae,H550HF07158878,sessions/session_04/W4-M8/IMG_0286.JPG,"[0.5375, 0.5852, 0.2661, 0.1009]",0.945
933,4010196,3,soricidae,H550HF07158878,sessions/session_04/W4-M8/IMG_0289.JPG,"[0.6816, 0.5839, 0.1118, 0.2154]",0.944


In [3]:
dataset[2]

{'x': <PIL.Image.Image image mode=RGB size=400x502>,
 'y': 0,
 'class_name': 'apodemus_sp',
 'bbox': [0.0639, 0.5429, 0.1953, 0.3268],
 'conf': 0.972,
 'seq_id': 4007156,
 'file_path': PosixPath('sessions/session_04/W2-WK02/IMG_6162.JPG')}

In [5]:
row = dataset[77]

image = pipline(row['file_path'], row['bbox'])

print(image.shape)

torch.Size([3, 224, 224])


In [5]:
batch_pipline = BatchImagePipeline(
                path_to_dataset=path_to_dataset,
                num_workers=4,
                pre_ops = [
                    ('to_rgb', {}),
                    ('crop_by_bb', {}),
                ],
                transform = None
                )

In [7]:
list_of_paths = []
list_of_bboxes = []

samples = [random.randint(0, len(dataset)) for _ in range(100)]

for i in samples:
    row = dataset[i]

    list_of_paths.append(row['file_path'])
    list_of_bboxes.append(row['bbox'])

images = batch_pipline(list_of_paths, list_of_bboxes)

for image in images:
    print(image.shape)

IndexError: list index out of range

### Tests Feature Stats

In [4]:
path_to_dataset = paths['dataset']
path_labelfiles = paths['test_labels']
path_to_detector_output = paths['md_output']
detector_model=None
mode='init'

transform = ImagePipeline(
                pre_ops=[
                    ('to_rgb', {}),
                    ('crop_by_bb', {})
                ],
                transform=v2.Compose([
                                v2.ToImage(),
                                v2.ToDtype(torch.float32, scale=True),
                                ])
                )

dataset = MammaliaDataImage(
    path_labelfiles=path_labelfiles,
    path_to_dataset=path_to_dataset,
    path_to_detector_output=path_to_detector_output,
    detector_model=detector_model,
    mode=mode,
    transform=transform
)

def collate_fn(batch):
    return batch

loader = DataLoader(
    dataset,
    batch_size=10,
    num_workers=1,
    shuffle=False,
    collate_fn=collate_fn
    )

channel_sum = torch.zeros(3)
pixel_count = 0

batches_flat = []

print('Calculating mean...')
for batch in loader:
    for item in batch:
        img = item['x']
        pixel_count += img.shape[1] * img.shape[2]
        for c in range(img.shape[0]):
            channel_sum[c] += img[c].sum()

        batches_flat.append(img.flatten(start_dim=1))

mean = channel_sum / pixel_count

print('Calculating std...')
channel_diff_squared_sum = torch.zeros(3)
for batch in loader:
    for item in batch:
        img = item['x']
        img_centered_squared = (img - mean[:, None, None]) ** 2
        for c in range(img_centered_squared.shape[0]):
            channel_diff_squared_sum[c] += img_centered_squared[c].sum()

std = torch.sqrt(channel_diff_squared_sum / pixel_count)

print("Mean:", mean)
print("Std:", std)

8 sequences had no detections and will be excluded.
Excluded sequences: [6000161, 6000163, 6000293, 6000530, 6000691, 6000372, 6000953, 6000186]


Calculating mean...
Calculating std...
Mean: tensor([0.3245, 0.3008, 0.2264])
Std: tensor([0.2232, 0.2075, 0.1725])


In [20]:
torch.cat(batches_flat, -1).mean(-1)

tensor([0.3198, 0.2960, 0.2227])

In [21]:
torch.cat(batches_flat, -1).std(-1)

tensor([0.2246, 0.2083, 0.1730])