In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from imet.dataset import DATA_ROOT
from imet.utils import mean_df, binarize_prediction
from imet.make_submission import get_classes

In [61]:
ZOO_ROOT = Path('zoo')
threshold = 0.10

In [81]:
train_df = pd.read_csv(DATA_ROOT / 'train.csv')

In [82]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 109237 entries, 0 to 109236
Data columns (total 2 columns):
id               109237 non-null object
attribute_ids    109237 non-null object
dtypes: object(2)
memory usage: 1.7+ MB


In [83]:
model = 'se_resnext50_32x4d'

In [101]:
df = None
for fold in range(5):
    tmp_df = pd.read_hdf(ZOO_ROOT / f'model_{model}_fold_{fold}' / 'val.h5', index_col='id')
    if df is None:
        df = tmp_df
    else:
        df = df.append(tmp_df)

In [102]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 109237 entries, 10041eb49b297c08 to fffdae8164c9cfff
Columns: 1103 entries, 0 to 1102
dtypes: float32(1103)
memory usage: 460.5+ MB


In [103]:
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1093,1094,1095,1096,1097,1098,1099,1100,1101,1102
10041eb49b297c08,1.723427e-10,1.704131e-08,1.454852e-09,5.83591e-12,9.421575e-09,1.872938e-09,5.125733e-10,5.414047e-13,3.235597e-11,3.278363e-10,...,0.011062,5.909484e-07,0.000786,7.322479e-05,0.002841,0.0001,0.001887,2.247792e-08,2.991703e-08,2.68025e-05
1007057734dba6df,1.455984e-10,4.148521e-07,1.922179e-10,6.008871e-10,1.873635e-09,6.875585e-10,9.077812e-10,3.69594e-10,3.435334e-09,2.566108e-07,...,0.026436,3.095358e-07,0.00207,5.612769e-05,0.005297,0.000116,0.000678,1.204266e-06,1.692224e-08,1.100435e-07
100a58282c6584bf,2.503627e-11,5.098246e-08,4.157186e-11,1.213419e-13,4.982188e-11,1.634059e-10,4.242457e-11,4.214602e-13,3.827386e-11,5.455774e-11,...,0.000564,5.482831e-07,0.013106,0.0004580785,0.002063,1.7e-05,0.002073,7.649877e-08,2.538395e-07,5.30874e-07
100b45b7c4020f5d,1.589207e-09,5.966633e-05,7.230917e-05,3.190559e-10,7.631556e-08,3.042795e-06,1.421423e-08,5.688424e-08,1.241836e-06,8.693595e-07,...,0.00059,2.509743e-07,2.6e-05,0.000819398,0.000122,0.000641,0.003939,1.450188e-05,0.0003727955,0.000130341
100e1e65a6d7850e,1.662035e-12,2.997671e-09,2.298329e-10,3.267391e-14,3.370566e-14,2.734924e-13,6.473043e-15,4.096836e-15,1.670104e-11,4.519342e-10,...,0.001607,1.528915e-07,3e-06,2.354377e-08,6e-06,6e-06,6e-05,1.586041e-10,6.811095e-08,2.727214e-10


In [104]:
df.rename_axis(['id'], inplace=True)

In [88]:
df[:] = binarize_prediction(df.values, threshold=0.25)
df = df.apply(get_classes, axis=1)
df.name = 'attribute_ids'

In [89]:
df.to_csv('pseudo_train.csv', header=True)

In [105]:
df

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,1093,1094,1095,1096,1097,1098,1099,1100,1101,1102
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
10041eb49b297c08,1.723427e-10,1.704131e-08,1.454852e-09,5.835910e-12,9.421575e-09,1.872938e-09,5.125733e-10,5.414047e-13,3.235597e-11,3.278363e-10,...,1.106209e-02,5.909484e-07,7.862974e-04,7.322479e-05,2.841242e-03,1.004343e-04,1.886506e-03,2.247792e-08,2.991703e-08,2.680250e-05
1007057734dba6df,1.455984e-10,4.148521e-07,1.922179e-10,6.008871e-10,1.873635e-09,6.875585e-10,9.077812e-10,3.695940e-10,3.435334e-09,2.566108e-07,...,2.643636e-02,3.095358e-07,2.069922e-03,5.612769e-05,5.296701e-03,1.160702e-04,6.779028e-04,1.204266e-06,1.692224e-08,1.100435e-07
100a58282c6584bf,2.503627e-11,5.098246e-08,4.157186e-11,1.213419e-13,4.982188e-11,1.634059e-10,4.242457e-11,4.214602e-13,3.827386e-11,5.455774e-11,...,5.644767e-04,5.482831e-07,1.310554e-02,4.580785e-04,2.062830e-03,1.667958e-05,2.073072e-03,7.649877e-08,2.538395e-07,5.308740e-07
100b45b7c4020f5d,1.589207e-09,5.966633e-05,7.230917e-05,3.190559e-10,7.631556e-08,3.042795e-06,1.421423e-08,5.688424e-08,1.241836e-06,8.693595e-07,...,5.898916e-04,2.509743e-07,2.644986e-05,8.193980e-04,1.219429e-04,6.411059e-04,3.939425e-03,1.450188e-05,3.727955e-04,1.303410e-04
100e1e65a6d7850e,1.662035e-12,2.997671e-09,2.298329e-10,3.267391e-14,3.370566e-14,2.734924e-13,6.473043e-15,4.096836e-15,1.670104e-11,4.519342e-10,...,1.607454e-03,1.528915e-07,3.048929e-06,2.354377e-08,5.848477e-06,5.961645e-06,6.035583e-05,1.586041e-10,6.811095e-08,2.727214e-10
101534933c122e23,1.306989e-09,1.937224e-07,2.520219e-10,2.695739e-12,8.352945e-09,2.477446e-08,1.733528e-08,1.282597e-10,3.272811e-09,8.531349e-09,...,2.756706e-03,3.261326e-07,1.346484e-05,6.816162e-04,5.103823e-05,1.166278e-05,9.545936e-05,1.562538e-04,1.027488e-05,4.649299e-07
1015ddcd27215ca6,1.608890e-08,4.574718e-05,1.364199e-05,4.684250e-10,2.310416e-06,1.642629e-06,4.902663e-08,2.264966e-09,3.356564e-05,3.775984e-06,...,6.139348e-05,7.240926e-08,6.525640e-06,1.044897e-05,2.896699e-05,7.345367e-03,8.156691e-04,8.880237e-06,3.251761e-05,1.162731e-04
101695e8cefdc9c4,8.273845e-10,1.070131e-06,2.743984e-10,2.115489e-12,7.414654e-13,2.959939e-11,3.547805e-12,8.357928e-12,4.977346e-11,4.320344e-09,...,6.667275e-04,3.747841e-06,3.547445e-04,1.097397e-05,4.458516e-06,3.981258e-06,9.409302e-05,5.823569e-07,2.141962e-07,2.939831e-09
1016dc1b22073bb3,9.223753e-10,2.533087e-05,2.166093e-10,1.303950e-09,1.040513e-05,7.454384e-06,1.454086e-05,2.412928e-06,2.002835e-06,1.411544e-07,...,1.081820e-04,4.880739e-07,1.122141e-03,6.575591e-05,1.223102e-03,1.447054e-04,2.127226e-03,1.381132e-04,5.491436e-07,7.716962e-06
101eb6f95ea25ab0,2.020162e-13,8.720494e-07,6.626315e-11,1.263933e-13,8.243809e-14,1.206380e-14,9.099496e-15,2.683244e-14,1.558644e-12,4.579922e-11,...,5.026710e-08,6.015639e-12,1.092218e-09,7.798313e-07,2.325759e-08,1.631642e-05,1.044889e-05,7.977787e-12,5.897929e-07,8.536344e-08


In [90]:
df = pd.read_csv('pseudo_train.csv')

In [91]:
pseudo_df = pd.merge(train_df, df, on='id')

In [92]:
pseudo_df

Unnamed: 0,id,attribute_ids_x,attribute_ids_y
0,1000483014d91860,147 420 813 1093 952 616,147 813 952
1,1000fe2e667721fe,501 156 734 51 813 776 573 616,51 519 616 734 813
2,1001614cb89646ee,776 483 1046 690,483 1046
3,10041eb49b297c08,51 698 671 492 813 1092 616,51 616 813 1092
4,100501c227f8beea,405 1092 404 896 492 903 1093 13,13 813 903 1092
5,10050ed12fbad46d,189 953 800 279 378 721 774 1051,189 279 378 800
6,100543a032517972,369 188 1034,188 1034
7,1006665c0aad488,1059 194 1034 1053 557 179 1010 335 253,1010
8,1007057734dba6df,189 70 1012 542 993 906 541 813 1092,189 259 541 813 1092
9,1008abd71f3ed5bc,70 1046 676 794 111 813 1092 776,70 111 776 1046


In [93]:
def merge_attributes(row):
    row['attribute_ids'] = ' '.join(set(row['attribute_ids'].split()))
    return row

In [94]:
pseudo_df['attribute_ids'] = pseudo_df['attribute_ids_x'] + ' ' + pseudo_df['attribute_ids_y']

In [95]:
pseudo_df = pseudo_df.apply(merge_attributes, axis=1)

In [96]:
pseudo_df.drop(['attribute_ids_x', 'attribute_ids_y'], inplace=True, axis=1)

In [97]:
pseudo_df

Unnamed: 0,id,attribute_ids
0,1000483014d91860,147 616 952 1093 813 420
1,1000fe2e667721fe,573 616 519 813 51 156 734 501 776
2,1001614cb89646ee,483 776 690 1046
3,10041eb49b297c08,616 1092 698 813 492 51 671
4,100501c227f8beea,404 1092 1093 896 492 903 405 813 13
5,10050ed12fbad46d,800 721 1051 774 953 279 189 378
6,100543a032517972,188 369 1034
7,1006665c0aad488,179 1053 253 1059 335 194 557 1010 1034
8,1007057734dba6df,993 542 1092 259 541 70 1012 813 906 189
9,1008abd71f3ed5bc,1046 1092 794 70 776 813 676 111


In [98]:
pseudo_df.to_csv('pseudo_train_0.25.csv', index=None)

In [None]:
pd.read_csv('pseudo_train_0.25.csv')