# Dataset weights

Notebook to get the weight of the different classes present in the dataset.

This is computed only in the trainingset.

In [1]:
from sflizard import LizardDataModule

from tqdm import tqdm
import torch



In [2]:
IMG_SIZE=540
TRAIN_DATA_PATH="../data/Lizard_dataset_extraction/data_0.9_split_train.pkl"
VALID_DATA_PATH="../data/Lizard_dataset_extraction/data_0.9_split_valid.pkl"
TEST_DATA_PATH="../data/Lizard_dataset_extraction/data_0.9_split_test.pkl"
SEED=303
BATCH_SIZE=4
N_RAYS = 32

ANNOTATION_TARGET = "stardist_class"

In [3]:
aditional_args={"n_rays":N_RAYS}

# create the datamodule
dm = LizardDataModule(
    train_data_path=TRAIN_DATA_PATH,
    valid_data_path=VALID_DATA_PATH,
    test_data_path=TEST_DATA_PATH,
    annotation_target=ANNOTATION_TARGET,
    batch_size=BATCH_SIZE,
    num_workers = 1,
    input_size=IMG_SIZE,
    seed=SEED,
    aditional_args=aditional_args,
)
dm.setup()

Training with 2976 examples


In [4]:
classes_representation = {}
for i in range(7):
    classes_representation[i] = 0
train_iter = iter(dm.train_dataloader())
for b in tqdm(range(len(train_iter))):
    image, obj_probabilities, distances, classes = next(train_iter)
    for i in range(len(image)):
        class_map = torch.flatten(classes[i].int())
        occurences = torch.bincount(class_map)
        for o in range(len(occurences)):
            classes_representation[o] += occurences[o]

100%|██████████| 744/744 [21:30<00:00,  1.73s/it]


In [5]:
print(classes_representation)

{0: tensor(730841977), 1: tensor(1233446), 2: tensor(86053998), 3: tensor(15836149), 4: tensor(5157280), 5: tensor(1065106), 6: tensor(27613644)}


In [6]:
total = 0
for c in classes_representation.keys():
    classes_representation[c] = classes_representation[c].item()
    total += classes_representation[c]
for c in classes_representation.keys():
    classes_representation[c] /= total
print(classes_representation)

{0: 0.8421763419196278, 1: 0.0014213456163252062, 2: 0.09916321656931723, 3: 0.01824858239486998, 4: 0.005942925203180082, 5: 0.001227361184860687, 6: 0.03182022711181911}


In [None]:
print(classes_representation)

In [None]:
{0: 0.8435234983048621, 1: 0.0015844697497448515, 2: 0.09702835179125052, 3: 0.018770678077839286, 4: 0.005716505874930195, 5: 0.0011799091886332306, 6: 0.03219658701273987}

## HoverNet Graph class balance

In [19]:
import numpy as np
from tqdm import tqdm
import glob
import scipy.io as sio

true_dir = "../data/Lizard_dataset_split/patches/Lizard_Labels_train/"

file_list = glob.glob("%s/*mat" % (true_dir))

classes = {}
for i in range(7):
    classes[i] = 0
for idx, file in tqdm(enumerate(file_list)):
    mat = sio.loadmat(file)
    for i in range(len(mat["classes"])):
        classes[mat["classes"][i][0]] += 1

2976it [10:25,  4.76it/s]


In [21]:
class_bp = classes.copy()

sum_c = sum(classes.values())
for k, v in classes.items():
    if k == 0:
        classes[k] = 0
    else:
        classes[k] = 1/(v/sum_c)

sum_c = sum(classes.values())
for k, v in classes.items():
    classes[k] = v/sum_c

for k, v in classes.items():
    print(f"{k}: {(v/sum_c)}")


0: 0.0
1: 0.0015289419440833998
2: 3.543263598039262e-05
3: 7.946214852753756e-05
4: 0.0002770799022561251
5: 0.0021221063687575435
6: 7.437612919581149e-05


In [22]:
for k, v in classes.items():
    print(f"{k}: {(v)}")

0: 0.0
1: 0.3713368309107073
2: 0.008605586894052789
3: 0.01929911238667816
4: 0.06729488533622548
5: 0.515399722585458
6: 0.018063861886878453


In [23]:
print(sum(classes.values()))

1.0
