In [1]:
import os
import shutil
import numpy as np
from sklearn.model_selection import train_test_split

np.random.seed(257103)

In [2]:
# get class names

class_names = os.listdir("images/original")
class_names.sort()

In [3]:
# Count number of instances of each class
n_instances = {}
for name in class_names:
    n_instances[name] = len(os.listdir(f"./images/original/{name}"))
    print(f"{name}: {n_instances[name]}")

Akashiwo: 5
Amphidinium_sp: 210
Asterionellopsis: 1853
Bacillaria: 13
Bidulphia: 31
Cerataulina: 21686
Cerataulina_flagellate: 142
Ceratium: 784
Chaetoceros: 46957
Chaetoceros_didymus: 502
Chaetoceros_didymus_flagellate: 1427
Chaetoceros_flagellate: 149
Chaetoceros_pennate: 970
Chrysochromulina: 515
Ciliate_mix: 15197
Cochlodinium: 14
Corethron: 6386
Coscinodiscus: 711
Cylindrotheca: 20832
DactFragCerataul: 5906
Dactyliosolen: 13020
Delphineis: 401
Dictyocha: 2358
Didinium_sp: 24
Dinobryon: 7543
Dinophysis: 357
Ditylum: 5419
Ditylum_parasite: 0
Emiliania_huxleyi: 148
Ephemera: 633
Eucampia: 2191
Euglena: 764
Euplotes_sp: 25
G_delicatula_parasite: 3398
Gonyaulax: 564
Guinardia_delicatula: 38268
Guinardia_flaccida: 1290
Guinardia_striata: 3234
Gyrodinium: 677
Hemiaulus: 18
Heterocapsa_triquetra: 1413
Karenia: 4
Katodinium_or_Torodinium: 395
Laboea_strobila: 1470
Lauderia: 295
Leegaardiella_ovalis: 245
Leptocylindrus: 125690
Leptocylindrus_mediterraneus: 392
Licmophora: 313
Mesodinium_sp:

In [4]:
def move_to_dir(names, from_dir, to_dir):
    for i, name in enumerate(names):
        new_path = os.path.join(to_dir, f"{i}.png")
        old_path = os.path.join(from_dir, name)
        shutil.copyfile(old_path, new_path)

In [16]:
def sample_to(samples, size):
    samples = resample(samples, replace=(len(samples) < size), n_samples=size)
    return samples

In [7]:
# Split train, test, validation and use sampling scheme 1 for train set
train_dir = "./images/train"
val_dir = "./images/val"
test_dir = "./images/test"
shutil.rmtree(train_dir, ignore_errors=True)
shutil.rmtree(val_dir, ignore_errors=True)
shutil.rmtree(test_dir, ignore_errors=True)
os.makedirs(train_dir)
os.makedirs(val_dir)
os.makedirs(test_dir)

for name in class_names:
    if n_instances[name] > 40:
        ori_path = f"./images/original/{name}"
        instances = os.listdir(ori_path)
        
        train, test = train_test_split(instances, test_size=0.5)
        train, val = train_test_split(train, test_size=0.1)
        
        if len(train) > 5000:
            train = sample_to(train, 5000)
        
        train_dest = os.path.join(train_dir, name)
        val_dest = os.path.join(val_dir, name)
        test_dest = os.path.join(test_dir, name)
        os.makedirs(train_dest)
        os.makedirs(val_dest)
        os.makedirs(test_dest)
        
        move_to_dir(train, ori_path, train_dest)
        move_to_dir(val, ori_path, val_dest)
        move_to_dir(test, ori_path, test_dest)
    else:
        del n_instances[name]

In [None]:
# Split train, test, validation and use sampling scheme 2 for train set
train_dir = "./images/train"
val_dir = "./images/val"
test_dir = "./images/test"
shutil.rmtree(train_dir, ignore_errors=True)
shutil.rmtree(val_dir, ignore_errors=True)
shutil.rmtree(test_dir, ignore_errors=True)
os.makedirs(train_dir)
os.makedirs(val_dir)
os.makedirs(test_dir)

for name in class_names:
    if n_instances[name] > 40:
        ori_path = f"./images/original/{name}"
        instances = os.listdir(ori_path)
        
        train, test = train_test_split(instances, test_size=0.5)
        train, val = train_test_split(train, test_size=0.1)
        
        train = sample_to(train, 1000)
        
        train_dest = os.path.join(train_dir, name)
        val_dest = os.path.join(val_dir, name)
        test_dest = os.path.join(test_dir, name)
        os.makedirs(train_dest)
        os.makedirs(val_dest)
        os.makedirs(test_dest)
        
        move_to_dir(train, ori_path, train_dest)
        move_to_dir(val, ori_path, val_dest)
        move_to_dir(test, ori_path, test_dest)

In [None]:
# Split train, test, validation and do not perform resampling
train_dir = "./images/train"
val_dir = "./images/val"
test_dir = "./images/test"
shutil.rmtree(train_dir, ignore_errors=True)
shutil.rmtree(val_dir, ignore_errors=True)
shutil.rmtree(test_dir, ignore_errors=True)
os.makedirs(train_dir)
os.makedirs(val_dir)
os.makedirs(test_dir)

for name in class_names:
    if n_instances[name] > 40:
        ori_path = f"./images/original/{name}"
        instances = os.listdir(ori_path)
        
        train, test = train_test_split(instances, test_size=0.5)
        train, val = train_test_split(train, test_size=0.1)
        
        train_dest = os.path.join(train_dir, name)
        val_dest = os.path.join(val_dir, name)
        test_dest = os.path.join(test_dir, name)
        os.makedirs(train_dest)
        os.makedirs(val_dest)
        os.makedirs(test_dest)
        
        move_to_dir(train, ori_path, train_dest)
        move_to_dir(val, ori_path, val_dest)
        move_to_dir(test, ori_path, test_dest)