# Dataset weights

Notebook to get the weight of the different classes present in the dataset.

This is computed only in the trainingset.

In [3]:
from sflizard import LizardDataModule

from tqdm import tqdm
import torch



In [6]:
IMG_SIZE=540
TRAIN_DATA_PATH="../data/Lizard_dataset_extraction/data_final_split_train.pkl"
VALID_DATA_PATH="../data/Lizard_dataset_extraction/data_final_split_valid.pkl"
TEST_DATA_PATH="../data/Lizard_dataset_extraction/data_final_split_test.pkl"
SEED=303
BATCH_SIZE=4
N_RAYS = 32

ANNOTATION_TARGET = "stardist_class"

In [7]:
aditional_args={"n_rays":N_RAYS}

# create the datamodule
dm = LizardDataModule(
    train_data_path=TRAIN_DATA_PATH,
    valid_data_path=VALID_DATA_PATH,
    test_data_path=TEST_DATA_PATH,
    annotation_target=ANNOTATION_TARGET,
    batch_size=BATCH_SIZE,
    num_workers = 1,
    input_size=IMG_SIZE,
    seed=SEED,
    aditional_args=aditional_args,
)
dm.setup()

KeyboardInterrupt: 

## class map in train set

In [4]:
classes_representation = {}
for i in range(7):
    classes_representation[i] = 0
train_iter = iter(dm.train_dataloader())
for b in tqdm(range(len(train_iter))):
    image, obj_probabilities, distances, classes = next(train_iter)
    for i in range(len(image)):
        class_map = torch.flatten(classes[i].int())
        occurences = torch.bincount(class_map)
        for o in range(len(occurences)):
            classes_representation[o] += occurences[o]

100%|██████████| 744/744 [21:30<00:00,  1.73s/it]


In [5]:
print(classes_representation)

{0: tensor(730841977), 1: tensor(1233446), 2: tensor(86053998), 3: tensor(15836149), 4: tensor(5157280), 5: tensor(1065106), 6: tensor(27613644)}


In [6]:
total = 0
for c in classes_representation.keys():
    classes_representation[c] = classes_representation[c].item()
    total += classes_representation[c]
for c in classes_representation.keys():
    classes_representation[c] /= total
print(classes_representation)

{0: 0.8421763419196278, 1: 0.0014213456163252062, 2: 0.09916321656931723, 3: 0.01824858239486998, 4: 0.005942925203180082, 5: 0.001227361184860687, 6: 0.03182022711181911}


In [None]:
print(classes_representation)

In [None]:
{0: 0.8421763419196278, 1: 0.0014213456163252062, 2: 0.09916321656931723, 3: 0.01824858239486998, 4: 0.005942925203180082, 5: 0.001227361184860687, 6: 0.03182022711181911}

## HoverNet Graph class balance

In [19]:
import numpy as np
from tqdm import tqdm
import glob
import scipy.io as sio

true_dir = "../data/Lizard_dataset_split/patches/Lizard_Labels_train/"

file_list = glob.glob("%s/*mat" % (true_dir))

classes = {}
for i in range(7):
    classes[i] = 0
for idx, file in tqdm(enumerate(file_list)):
    mat = sio.loadmat(file)
    for i in range(len(mat["classes"])):
        classes[mat["classes"][i][0]] += 1

2976it [10:25,  4.76it/s]


In [21]:
class_bp = classes.copy()

sum_c = sum(classes.values())
for k, v in classes.items():
    if k == 0:
        classes[k] = 0
    else:
        classes[k] = 1/(v/sum_c)

sum_c = sum(classes.values())
for k, v in classes.items():
    classes[k] = v/sum_c

for k, v in classes.items():
    print(f"{k}: {(v/sum_c)}")


0: 0.0
1: 0.0015289419440833998
2: 3.543263598039262e-05
3: 7.946214852753756e-05
4: 0.0002770799022561251
5: 0.0021221063687575435
6: 7.437612919581149e-05


In [22]:
for k, v in classes.items():
    print(f"{k}: {(v)}")

0: 0.0
1: 0.3713368309107073
2: 0.008605586894052789
3: 0.01929911238667816
4: 0.06729488533622548
5: 0.515399722585458
6: 0.018063861886878453


In [23]:
print(sum(classes.values()))

1.0


In [None]:
{
    0: 0.0,
    1: 0.3713368309107073,
    2: 0.008605586894052789,
    3: 0.01929911238667816,
    4: 0.06729488533622548,
    5: 0.515399722585458,
    6: 0.018063861886878453,
}

## Report cell count

In [1]:
import numpy as np
from tqdm import tqdm
import glob
import scipy.io as sio


# splited dataset

classes = {
    "train": {},
    "valid": {},
    "test": {},
}
for c in classes.keys():
    true_dir = f"../data/Lizard_dataset_split/patches/Lizard_Labels_{c}/"
    file_list = glob.glob("%s/*mat" % (true_dir))
    for i in range(7):
        classes[c][i] = 0
    for idx, file in tqdm(enumerate(file_list)):
        mat = sio.loadmat(file)
        for i in range(len(mat["classes"])):
            classes[c][mat["classes"][i][0]] += 1

2941it [03:19, 14.71it/s]
453it [00:30, 15.04it/s]
349it [00:23, 14.96it/s]


In [9]:
for c in classes.keys():
    print(classes[c])
    print("".join([f"{cc} & " for cc in classes[c].values()]))

{0: 0, 1: 15398, 2: 664434, 3: 296275, 4: 84967, 5: 11094, 6: 316535}
0 & 15398 & 664434 & 296275 & 84967 & 11094 & 316535 & 
{0: 0, 1: 3619, 2: 65409, 3: 18573, 4: 4009, 5: 682, 6: 44961}
0 & 3619 & 65409 & 18573 & 4009 & 682 & 44961 & 
{0: 0, 1: 705, 2: 86274, 3: 59580, 4: 13853, 5: 1784, 6: 39161}
0 & 705 & 86274 & 59580 & 13853 & 1784 & 39161 & 


In [4]:
for c in classes.keys():
    print("\n"+c)
    sum_c = sum(classes[c].values())
    print(sum_c)
    for k, v in classes[c].items():
        print(v/sum_c)


train
1388703
0.0
0.011088044023812147
0.47845651662018446
0.2133465543028279
0.061184428923967187
0.007988749214194828
0.2279357069150135

valid
137253
0.0
0.026367365376348786
0.47655788944503946
0.1353194465694739
0.02920883332240461
0.004968925997974543
0.3275775392887587

test
201357
0.0
0.0035012440590592828
0.4284628793635185
0.2958923702677334
0.06879820418460744
0.00885988567569044
0.19448541644939088


In [5]:
print(244563 - 34191)
print(101413 - 9175)
print(28466 - 3605)
print(4824 - 708)
print(3604 - 625)
print(112309 - 14962)

210372
92238
24861
4116
2979
97347


In [15]:
sum_tot = 0
sums = {
    0:0,
    1:0,
    2:0,
    3:0,
    4:0,
    5:0,
    6:0,
}
for c in classes.keys():
    sum_tot += sum(classes[c].values())
    for k, v in classes[c].items():
        sums[k] += v

print(sum_tot)
print(sums)

1727313
{0: 0, 1: 19722, 2: 816117, 3: 374428, 4: 102829, 5: 13560, 6: 400657}


In [16]:
1388703+137253 + 201357

1727313

In [17]:
print(1388703/1727313)
print(137253/1727313)
print(201357/1727313)

0.8039672022383899
0.07946041047569259
0.1165723872859175
