In [1]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit

In [2]:
data_dir = "/data6/lekevin/hab-master/phytoplankton-db/csv/hab_in_situ_raw_v2_workshop2019.csv"
dest_dir = "/data6/SuryaKrishnan/raw_data_9_class"

In [3]:
def hab_map(value):
    
    hab_classes = ["Akashiwo", "Ceratium falcatiforme or fusus", "Ceratium furca", 
                  "Chattonella", "Cochlodinium", "Gyrodinium",  "Lingulodinium polyedra",
                  "Prorocentrum micans", "Pseudo nitzschia chain"]
    if value in hab_classes:
        return value
    elif value in ["Ceratium falcatiforme fusus pair", "Ceratium falcatiforme fusus single"]:
        return "Ceratium falcatiforme or fusus"
    elif value in ["Ceratium furca pair", "Ceratium furca side", "Ceratium furca single"]:
        return "Ceratium furca"
    elif value in ["Cochlodinium Alexandrium Gonyaulax Gymnodinium chain", "Cochlodinium Alexandrium Gonyaulax Gymnodinium pair"]:
        return "Cochlodinium"
    elif value in ["Lingulodinium"]:
        return "Lingulodinium polyedra"
    elif value in ["Prorocentrum"]:
        return "Prorocentrum micans"
    else:
        return "Other"

In [4]:
data = pd.read_csv(data_dir)

data["transformed_y"] = data["label"].apply(hab_map)

data = data.loc[data["transformed_y"] != "Other"]

print(data["transformed_y"].value_counts())

data_x = data["images"]
data_y = data["transformed_y"]

Prorocentrum micans               3147
Ceratium furca                    2519
Ceratium falcatiforme or fusus    1586
Cochlodinium                      1493
Lingulodinium polyedra            1433
Akashiwo                           777
Gyrodinium                         671
Pseudo nitzschia chain             618
Chattonella                        558
Name: transformed_y, dtype: int64


In [5]:
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)

In [6]:
for train_index, val_index in sss.split(data_x, data_y):
    train_x = data_x.iloc[train_index]
    val_x = data_x.iloc[val_index]
    
    train_y = data_y.iloc[train_index]
    val_y = data_y.iloc[val_index]
    
    train_dict = {"images" : train_x, "label" : train_y}
    val_dict = {"images" : val_x, "label" : val_y}

    train_df = pd.DataFrame(train_dict)
    val_df = pd.DataFrame(val_dict)
    
    train_df.to_csv(os.path.join(dest_dir, "train.csv"), index=False)
    val_df.to_csv(os.path.join(dest_dir, "val.csv"), index=False)

In [7]:
train_df =  pd.read_csv(os.path.join(dest_dir, "train.csv"))

train_df["label"].value_counts()

Prorocentrum micans               2518
Ceratium furca                    2015
Ceratium falcatiforme or fusus    1269
Cochlodinium                      1194
Lingulodinium polyedra            1146
Akashiwo                           622
Gyrodinium                         537
Pseudo nitzschia chain             494
Chattonella                        446
Name: label, dtype: int64

In [8]:
val_df =  pd.read_csv(os.path.join(dest_dir, "val.csv"))

val_df["label"].value_counts()

Prorocentrum micans               629
Ceratium furca                    504
Ceratium falcatiforme or fusus    317
Cochlodinium                      299
Lingulodinium polyedra            287
Akashiwo                          155
Gyrodinium                        134
Pseudo nitzschia chain            124
Chattonella                       112
Name: label, dtype: int64