In [1]:
import numpy as np
import networkx as nx

In [2]:
def get_validation_folds(df, nfolds=5, random_state=42):
    """
    Function to create validation folds. Split not only by label group, but also by title, image, phash
    """
    np.random.seed(random_state)
    G = nx.Graph()

    for col in ['label_group', 'title', 'image_phash', 'image']:

        agg = df.groupby(col)['posting_id'].agg(list).tolist()
        for p in agg:
            nx.add_path(G, p)

    cc = {}
    for n, c in enumerate(nx.connected_components(G)):
        val = min(c)
        for x in c:
            cc[x] = val

    group_idx = df['posting_id'].map(cc).values
    groups = np.unique(group_idx)
    np.random.shuffle(groups)

    split = np.array_split(groups, nfolds)

    folds = np.zeros(df.shape[0], dtype=np.int32)

    for n, s in enumerate(split):
        folds[np.isin(group_idx, s)] = n

    return folds


In [3]:
import pandas as pd

# DATA_PATH = "../input/shopee-product-matching/"
DATA_PATH = "../input/"

train = pd.read_csv(DATA_PATH + "train.csv")
# train["matches"] = train.label_group.map(
#     train.groupby("label_group").posting_id.agg("unique").to_dict()
# )
# train["matches"] = train["matches"].apply(lambda x: " ".join(x))
# train["image"] = DATA_PATH + "train_images/" + train["image"]

train.head()

Unnamed: 0,posting_id,image,image_phash,title,label_group,matches
0,train_129225211,../input/train_images/0000a68812bc7e98c42888df...,94974f937d4c2433,Paper Bag Victoria Secret,249114794,train_129225211 train_2278313361
1,train_3386243561,../input/train_images/00039780dfc94d01db8676fe...,af3f9460c2838f0f,"Double Tape 3M VHB 12 mm x 4,5 m ORIGINAL / DO...",2937985045,train_3386243561 train_3423213080
2,train_2288590299,../input/train_images/000a190fdd715a2a36faed16...,b94cb00ed3e50f78,Maling TTS Canned Pork Luncheon Meat 397 gr,2395904891,train_2288590299 train_3803689425
3,train_2406599165,../input/train_images/00117e4fc239b1b641ff0834...,8514fc58eafea283,Daster Batik Lengan pendek - Motif Acak / Camp...,4093212188,train_2406599165 train_3342059966
4,train_3369186413,../input/train_images/00136d1cf4edede0203f32f0...,a6f319f924ad708c,Nescafe \xc3\x89clair Latte 220ml,3648931069,train_3369186413 train_921438619


In [7]:
a=get_validation_folds(train)
print(type(a),a[:100],a.shape)

<class 'numpy.ndarray'> [1 1 0 2 4 4 2 4 4 1 0 2 2 1 0 0 4 4 0 3 3 2 0 3 1 4 3 3 2 2 3 0 4 4 3 3 4
 4 3 1 0 2 4 0 4 1 2 0 4 4 4 1 2 0 3 4 4 4 2 2 1 4 0 1 4 1 1 3 4 2 1 3 2 2
 0 1 3 4 2 1 1 2 2 0 1 4 3 2 0 0 1 0 2 3 3 2 2 4 1 1] (34250,)


In [13]:
idx = np.where(a == 0)[0]
train_min = train.loc[idx]
train_min.to_csv(DATA_PATH+"train_min.csv")

[    2    10    14 ... 34241 34247 34248]


KeyError: 2