In [67]:
import os
import json
import random
import pandas as pd
import shutil

In [68]:
!mkdir prepare_data/img_clf_multilabel_lst/train_imgs
!mkdir prepare_data/img_clf_multilabel_lst/valid_imgs
!mkdir prepare_data/img_clf_multilabel_lst/train_annots
!mkdir prepare_data/img_clf_multilabel_lst/valid_annots

In [69]:
gt_df = pd.read_csv(
    "prepare_data/img_clf_multilabel_lst/all_images_gt/clf_labels.csv", header=None
)
gt_df.columns = ["image", "class"]
gt_df.head()

Unnamed: 0,image,class
0,1.jpg,"[human,hair,interior]"
1,2.jpg,"[human,interior]"
2,3.jpg,[interior]
3,4.jpg,"[human,hair,interior]"
4,5.jpg,"[human,interior]"


In [70]:
train_valid_split = 0.7
nimages = gt_df["image"].nunique()
ntrain = int(train_valid_split * nimages)
nvalid = nimages - ntrain
print(nimages, ntrain, nvalid)

11 7 4


In [71]:
gt_df["class"] = gt_df["class"].apply(lambda x: x[1:-1])
gt_df.head()

Unnamed: 0,image,class
0,1.jpg,"human,hair,interior"
1,2.jpg,"human,interior"
2,3.jpg,interior
3,4.jpg,"human,hair,interior"
4,5.jpg,"human,interior"


In [72]:
labels = ["human", "hair", "interior"]
labels_map = {k: v for v, k in enumerate(labels)}
with open("prepare_data/img_clf_multilabel_lst/labels_map.json", "w") as fp:
    json.dump(labels_map, fp)
labels_map

{'human': 0, 'hair': 1, 'interior': 2}

In [73]:
def function(x):
    xl = x.split(",")
    xll = []
    for lbl in labels:
        if lbl in xl:
            xll.append(1)
        else:
            xll.append(0)
    return xll

In [74]:
gt_df["class2"] = gt_df["class"].apply(function)
gt_df.head()

Unnamed: 0,image,class,class2
0,1.jpg,"human,hair,interior","[1, 1, 1]"
1,2.jpg,"human,interior","[1, 0, 1]"
2,3.jpg,interior,"[0, 0, 1]"
3,4.jpg,"human,hair,interior","[1, 1, 1]"
4,5.jpg,"human,interior","[1, 0, 1]"


In [75]:
gt_df[labels] = gt_df["class2"].apply(pd.Series)
gt_df.head()

Unnamed: 0,image,class,class2,human,hair,interior
0,1.jpg,"human,hair,interior","[1, 1, 1]",1,1,1
1,2.jpg,"human,interior","[1, 0, 1]",1,0,1
2,3.jpg,interior,"[0, 0, 1]",0,0,1
3,4.jpg,"human,hair,interior","[1, 1, 1]",1,1,1
4,5.jpg,"human,interior","[1, 0, 1]",1,0,1


In [76]:
gt_df = gt_df.sample(frac=1).reset_index(drop=True)
gt_df.head()

Unnamed: 0,image,class,class2,human,hair,interior
0,5.jpg,"human,interior","[1, 0, 1]",1,0,1
1,3.jpg,interior,"[0, 0, 1]",0,0,1
2,6.jpg,"human,hair,interior","[1, 1, 1]",1,1,1
3,1.jpg,"human,hair,interior","[1, 1, 1]",1,1,1
4,8.jpg,"human,hair,interior","[1, 1, 1]",1,1,1


In [77]:
gt_df["index"] = gt_df.index + 1
gt_df.head()

Unnamed: 0,image,class,class2,human,hair,interior,index
0,5.jpg,"human,interior","[1, 0, 1]",1,0,1,1
1,3.jpg,interior,"[0, 0, 1]",0,0,1,2
2,6.jpg,"human,hair,interior","[1, 1, 1]",1,1,1,3
3,1.jpg,"human,hair,interior","[1, 1, 1]",1,1,1,4
4,8.jpg,"human,hair,interior","[1, 1, 1]",1,1,1,5


In [78]:
sel_cols = ["index"] + labels + ["image"]
print(sel_cols)

['index', 'human', 'hair', 'interior', 'image']


In [79]:
gt_df[sel_cols].head(ntrain).to_csv(
    "prepare_data/img_clf_multilabel_lst/train_annots/train.lst",
    sep="\t",
    index=False,
    header=False,
)
!head -n 5 prepare_data/img_clf_multilabel_lst/train_annots/train.lst

1	1	0	1	5.jpg
2	0	0	1	3.jpg
3	1	1	1	6.jpg
4	1	1	1	1.jpg
5	1	1	1	8.jpg


In [80]:
gt_df[sel_cols].tail(nvalid).to_csv(
    "prepare_data/img_clf_multilabel_lst/valid_annots/valid.lst",
    sep="\t",
    index=False,
    header=False,
)
!head -n 5 prepare_data/img_clf_multilabel_lst/valid_annots/valid.lst

8	1	0	1	2.jpg
9	1	1	1	9.jpg
10	1	1	1	11.jpg
11	0	0	1	7.jpg


In [81]:
train_df = pd.read_csv(
    "prepare_data/img_clf_multilabel_lst/train_annots/train.lst", sep="\t", header=None
)
images = list(train_df[train_df.columns[-1]].values)
for image in images:
    shutil.copy(
        "prepare_data/img_clf_multilabel_lst/all_images_gt/" + image,
        "prepare_data/img_clf_multilabel_lst/train_imgs/",
    )

In [82]:
valid_df = pd.read_csv(
    "prepare_data/img_clf_multilabel_lst/valid_annots/valid.lst", sep="\t", header=None
)
print(valid_df)
images = list(valid_df[valid_df.columns[-1]].values)
for image in images:
    # print image
    print(image)
    shutil.copy(
        "prepare_data/img_clf_multilabel_lst/all_images_gt/" + image,
        "prepare_data/img_clf_multilabel_lst/valid_imgs/",
    )

    0  1  2  3       4
0   8  1  0  1   2.jpg
1   9  1  1  1   9.jpg
2  10  1  1  1  11.jpg
3  11  0  0  1   7.jpg
2.jpg
9.jpg
11.jpg
7.jpg
