Predicting species

In [7]:
from pathlib import Path

import numpy as np

from PIL import Image

import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch.utils.data.dataset import Dataset
from torchvision.transforms.v2 import TrivialAugmentWide, ToImage

from nn_training import train_and_eval, select_device


Define helpers

In [None]:
def split_train_val(df_img_catalog, val_proportion:float):
    rng = np.random.default_rng()
    val_indices = rng.choice(len(df_img_catalog), size=int(len(df_img_catalog)*val_proportion), replace=False)
    df_val = df_img_catalog.iloc[np.isin(df_img_catalog.index, val_indices)]
    df_train = df_img_catalog.iloc[np.isin(df_img_catalog.index, val_indices, invert=True)]
    return df_train, df_val


def aggregate_images(df_img_catalog, img_dir_path):
    img_dir_path = Path(img_dir_path)
    img_paths = [img_dir_path.joinpath(i) for i in df_img_catalog["filename"]]
    pil_imgs = [Image.open(i) for i in img_paths]
    to_torch_img = ToImage()
    tensor_imgs = [to_torch_img(i) for i in pil_imgs]
    expanded_tensor_imgs = [torch.unsqueeze(i, dim=0) for i in tensor_imgs]
    aggregated_imgs = torch.cat(expanded_tensor_imgs)
    return aggregated_imgs


class Image_Dataset(Dataset):
    
    def __init__(self):
        
        self.img_catalog_file_paths = {
            "original_train": Path("train_catalog.csv"),
            "train": Path("train_catalog_train.csv"),
            "val": Path("train_catalog_val.csv"),
            "test": Path("test_catalog.csv"),
        }

        self.img_dir_paths = {
            "train":Path("train/train"),
            "test":Path("test/test"),
        }

        self.tensor_imgs_file_paths = {
            "train": Path("train_imgs_train.pt"),
            "val": Path("train_imgs_val.pt"),
            "test": Path("test_imgs.pt"),
        }

        self.set_transform(None)

    def generate(self, validation_split_proportion,):
        df_img_catalog_original_train = pd.read_csv(self.img_catalog_file_paths["original_train"]).drop(columns="id")
        df_img_catalog_train, df_img_catalog_val = split_train_val(df_img_catalog_original_train, validation_split_proportion)
        df_img_catalog_train.to_csv(self.img_catalog_file_paths["train"], index=False)
        df_img_catalog_val.to_csv(self.img_catalog_file_paths["val"], index=False)

        tensor_imgs_train = aggregate_images(df_img_catalog_train, self.img_dir_paths["train"])
        torch.save(tensor_imgs_train, self.tensor_imgs_file_paths["train"])
        tensor_imgs_val = aggregate_images(df_img_catalog_val, self.img_dir_paths["train"])
        torch.save(tensor_imgs_val, self.tensor_imgs_file_paths["val"])
        tensor_imgs_test = aggregate_images(pd.read_csv(self.img_catalog_file_paths["test"]), self.img_dir_paths["test"])
        torch.save(tensor_imgs_test, self.tensor_imgs_file_paths["test"])

    def load(self, split, device):
        self.split = split

        if split == "test":
            self.tensor_imgs = torch.load(self.tensor_imgs_file_paths["test"]).to(device)
        elif split in {"train", "val"}:
            self.tensor_imgs = torch.load(self.tensor_imgs_file_paths[split]).to(device)
            self.tensor_img_labels = torch.from_numpy(pd.read_csv(self.img_catalog_file_paths[split])["label"].to_numpy()).to(device)
            assert len(self.tensor_imgs) == len(self.tensor_img_labels)
        else: raise ValueError

    def set_transform(self, transform):
        self.transform = transform

    def __len__(self):
        return len(self.tensor_imgs)
    
    def __getitem__(self, index):
        if self.split == "test":
            x = self.tensor_imgs[index]
            return x
        elif self.split in {"train", "val"}:
            x = self.tensor_imgs[index]
            y = self.tensor_img_labels[index]
            return x, y 



SyntaxError: invalid syntax. Perhaps you forgot a comma? (1633466478.py, line 55)

Create training and validation splits

In [None]:
df_img_catalog_original_train = pd.read_csv("train.csv").drop(columns="id")

df_img_catalog_train, df_img_catalog_val = split_train_val(df_img_catalog_original_train, val_proportion=0.25)

df_img_catalog_train.to_csv("train_split_catalog_train.csv", index=False)
df_img_catalog_val.to_csv("train_split_catalog_val.csv", index=False)

In [None]:
df_img_catalog_train = pd.read_csv("train_split_catalog_train.csv")
df_img_catalog_train

Unnamed: 0,filename,label
0,MV1012-BC-12_obj00001.jpg,0
1,MV1012-BC-12_obj00003.jpg,2
2,MV1012-BC-12_obj00004.jpg,3
3,MV1012-BC-12_obj00005.jpg,0
4,MV1012-BC-12_obj00008.jpg,4
...,...,...
6485,MV1012-BC-8_obj01909.jpg,0
6486,MV1012-BC-8_obj01910.jpg,3
6487,MV1012-BC-8_obj01911.jpg,3
6488,MV1012-BC-8_obj01913.jpg,9


In [13]:
df_img_catalog_val = pd.read_csv("train_split_catalog_val.csv")
df_img_catalog_val

Unnamed: 0,filename,label
0,MV1012-BC-12_obj00002.jpg,1
1,MV1012-BC-12_obj00006.jpg,1
2,MV1012-BC-12_obj00015.jpg,1
3,MV1012-BC-12_obj00023.jpg,5
4,MV1012-BC-12_obj00024.jpg,3
...,...,...
2158,MV1012-BC-8_obj01890.jpg,16
2159,MV1012-BC-8_obj01891.jpg,19
2160,MV1012-BC-8_obj01897.jpg,21
2161,MV1012-BC-8_obj01904.jpg,1


Aggregate images for easy loading