In [1]:
import shutil
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as iio

from sklearn.preprocessing import LabelEncoder
from streaming import StreamingDataset, MDSWriter

from streaming.base.util import clean_stale_shared_memory

from tqdm import tqdm
from typing import Callable, Any

from dotenv import load_dotenv
load_dotenv();

In [2]:
IMAGENETTE = Path.home() / "datasets" / "imagenette"
TRAIN_DIR = IMAGENETTE / "train"
VAL_DIR = IMAGENETTE / "val"
SHARDS = IMAGENETTE / "shards"

In [13]:
def prepare_labels(path_to_noisy_labels: Path | None = None, path_to_labels_csv: Path | None = None) -> pd.DataFrame:
    if path_to_labels_csv:
        return pd.read_csv(path_to_labels_csv, index_col=0)

    assert (path_to_noisy_labels is Path) and path_to_noisy_labels.is_file(), "bruh?"
    df = pd.read_csv(path_to_noisy_labels)
    df["path"] = df["path"].apply(lambda x: IMAGENETTE/x)
    df = df[["path", "noisy_labels_0", "is_valid"]]
    df.columns = ["path", "label", "is_valid"]
    return df 

def prepare_splits(df: pd.DataFrame, is_valid: bool) -> pd.DataFrame:
    df = df[df["is_valid"] == is_valid]
    df = df.drop(columns = "is_valid")
    df = df.reset_index(drop = True)
    return df

In [21]:
def label_from_path(path: Path) -> str:
    return path.parent.stem

def reset_dir(dir_path: Path) -> None:
    if dir_path.is_dir():
        shutil.rmtree(dir_path)
    dir_path.mkdir(parents = True, exist_ok = True)

labels = prepare_labels(IMAGENETTE/"noisy_imagenette.csv")
val = prepare_splits(labels, True)
train = prepare_splits(labels, False)

class_names = sorted(labels.label.unique())
label_encoder = LabelEncoder().fit(class_names)

In [27]:
df = val 
local_shards = SHARDS / "val"

dtypes = {"image": "bytes", "label": "int"}
reset_dir(local_shards)
with MDSWriter(out = local_shards.as_posix(), columns = dtypes) as out: 
    for idx, example in tqdm(df.iterrows(), total=len(df)):
        try:
            image = iio.imread(example.path, extension=".jpg")
            image_bytes = iio.imwrite("<bytes>", image, extension=".jpg")
        except:
            print(idx)

        label = example.label 
        label_int = int(label_encoder.transform([label])[0]) #type: ignore

        sample = {
            "image": image_bytes,
            "label": label_int
        }
        out.write(sample)

100%|██████████| 3925/3925 [00:15<00:00, 245.79it/s]
