In [13]:
import numpy as np
import os
import pandas as pd

# load pkl file
class_info = np.load("aircraft_osr_splits.pkl",allow_pickle=True)

In [63]:
train_classes = class_info['known_classes']
train_classes_new_index = dict()
for cls_orig, cls_new in zip(np.unique(train_classes),range(len(np.unique(train_classes)))):
    train_classes_new_index[cls_orig] = cls_new

open_set_classes = class_info['unknown_classes']
open_set_classes_dict = {
    "all": open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'],
    "easy": open_set_classes['Easy'],
    "medium": open_set_classes['Medium'],
    "hard": open_set_classes['Hard']
}

In [69]:
np.random.seed(2)

def find_classes(classes_file):

    # read classes file, separating out image IDs and class names
    image_ids = []
    targets = []
    f = open(classes_file, 'r')
    for line in f:
        split_line = line.split(' ')
        image_ids.append(split_line[0])
        targets.append(' '.join(split_line[1:]))
    f.close()

    # index class names
    classes = np.unique(targets)
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    targets = [class_to_idx[c] for c in targets]

    return (image_ids, targets, classes, class_to_idx)

def subsample_dataset(dataset, idxs):
    imgs,targets = dataset

    imgs_sub = [p for i, (p, t) in enumerate(zip(imgs,targets)) if i in idxs]
    targets_sub = [t for i, (p, t) in enumerate(zip(imgs,targets)) if i in idxs]

    return (imgs_sub, targets_sub)

def get_train_val_split(image_ids,targets, val_split=0.2):

    val_dataset = (image_ids,targets)
    train_dataset = (image_ids,targets)

    train_classes = np.unique(targets)

    # Get train/test indices
    train_idxs = []
    val_idxs = []
    for cls in train_classes:

        cls_idxs = np.where(targets == cls)[0]
        print(f"{cls} has {len(cls_idxs)} examples")

        v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
        t_ = [x for x in cls_idxs if x not in v_]
        #print(f"{len(t_)} going in train")

        train_idxs.extend(t_)
        val_idxs.extend(v_)

    # Get training/validation datasets based on selected idxs
    train_dataset = subsample_dataset(train_dataset, train_idxs)
    val_dataset = subsample_dataset(val_dataset, val_idxs)

    return train_dataset, val_dataset

class_type='variant'
split='train'

classes_file = os.path.join('data', 'images_%s_%s.txt' % (class_type, "trainval"))
(image_ids, targets, classes, class_to_idx) = find_classes(classes_file)
datasets = dict()
datasets["train"], datasets["val"] = get_train_val_split(image_ids, targets)


for split in ["train","val","test","ood_easy","ood_medium","ood_hard"]:
    split_fgvc = "test" if "ood" in split else split

    if "ood" in split:
        difficulty = split.split("_")[-1]

    if split_fgvc == "test":
        classes_file = os.path.join('data', 'images_%s_%s.txt' % (class_type, split_fgvc))
        (image_ids, targets, classes, class_to_idx) = find_classes(classes_file)
    elif split_fgvc in ["train","val"]:
        image_ids, targets = datasets[split_fgvc]
    
    res_list = []
    
    for image_id, cls in zip(image_ids, targets):
        #print(image_id, cls)
        idx = cls
        if split in ["train","val","test"] and idx in open_set_classes_dict["all"]:
            continue
        elif "ood" in split and idx not in open_set_classes_dict[difficulty]:
            continue
        res_list.append({
            'image_path': f"fgvc-aircraft-2013b/data/images/{image_id}.jpg",
            'label': train_classes_new_index[idx] if split in ["train","val","test"] else -1
        })

    suffix = "clean" if split in ["train","val","test"] else difficulty
    
    df = pd.DataFrame(res_list)
    print(f"{len(df['label'].unique())} unique classes in {split}")
    assert len(df['image_path'].unique()) == len(df)
    df.to_csv(f"../../benchmark_imglist/fgvc-aircraft/{split.split('_')[0]}_fgvc-{class_type}_{suffix}.txt", sep=" ", header=False, index=False)

0 has 67 examples
1 has 67 examples
2 has 66 examples
3 has 67 examples
4 has 67 examples
5 has 66 examples
6 has 67 examples
7 has 67 examples
8 has 66 examples
9 has 67 examples
10 has 67 examples
11 has 66 examples
12 has 67 examples
13 has 67 examples
14 has 66 examples
15 has 67 examples
16 has 67 examples
17 has 66 examples
18 has 67 examples
19 has 67 examples
20 has 66 examples
21 has 67 examples
22 has 67 examples
23 has 66 examples
24 has 67 examples
25 has 67 examples
26 has 66 examples
27 has 67 examples
28 has 67 examples
29 has 66 examples
30 has 67 examples
31 has 67 examples
32 has 66 examples
33 has 67 examples
34 has 67 examples
35 has 66 examples
36 has 67 examples
37 has 67 examples
38 has 66 examples
39 has 67 examples
40 has 67 examples
41 has 66 examples
42 has 67 examples
43 has 67 examples
44 has 66 examples
45 has 67 examples
46 has 67 examples
47 has 66 examples
48 has 67 examples
49 has 67 examples
50 has 66 examples
51 has 67 examples
52 has 67 examples
53 