In [77]:
import pandas as pd

def order_col(dataframe):
    print(dataframe.sum().sort_values(ascending=False))

def total_count(dataframe, show_count = True, delete=False):
    dataframe['total'] = dataframe.sum(axis=1,numeric_only = True)
    if show_count == True:
        print(dataframe.groupby(by = 'total').count())
    if delete == True:
        print("Before length :", len(dataframe))
        return dataframe.drop(dataframe[dataframe.total == 0].index)
        print("After length :", len(dataframe))
        
directory = '/home/joel/Downloads/'
df_upper = pd.read_csv(directory + 'anno/df_upper.csv')
df_lower = pd.read_csv(directory + 'anno/df_lower.csv')
df_upper = total_count(df_upper, show_count = False, delete = True)
df_lower = total_count(df_lower, show_count = False, delete = True)

def drop_lower(dataframe, min_attributes):
    del dataframe['total']
    dataframe['total'] = dataframe.sum(axis=1)
    return dataframe.drop(dataframe[dataframe.total < min_attributes].index)

def oversample(df_undersampled, df_full, col_name, min_num):
    # Add one-hot rows only. Failing to do this will result in >3000 samples.
    all_extra_rows = df_full[(df_full[col_name]==1) & (df_full['total']==1)] 
    current_count = len(df_undersampled[df_undersampled[col_name] == 1].index)
    num_rows_to_add = min_num - current_count
    if num_rows_to_add < 1:
        return df_undersampled
    all_extra_rows = all_extra_rows[:num_rows_to_add]
    df_undersampled = df_undersampled.append(all_extra_rows)
    print('{}: length {}, oversample appending {} for a final of {}'.format(col_name,current_count,len(all_extra_rows.index),len(df_undersampled.index)))
    return df_undersampled

def undersample(df_oversampled, col_name, target_num):
    oversampled_rows = df_oversampled[df_oversampled[col_name] == 1]
    undersampled = oversampled_rows.sample(n=target_num)
    print('{}: length {}, undersampling for a final of {}'.format(col_name,len(oversampled_rows.index),len(undersampled.index)))
    return undersampled

def balance_dataframe(dataframe, num_samples, dataset):
    full_set = df_upper if (dataset == 'upper') else df_lower
    for column in dataframe:
        if 'total' not in column and 'image_name' not in column:
            if len(dataframe[dataframe[column] == 1]) < num_samples:
                dataframe = oversample(dataframe, full_set, column, num_samples)
            else:
                dataframe = undersample(dataframe, column, num_samples)
    return dataframe

df_upper_trunc = drop_lower(df_upper, 3)
df_lower_trunc = drop_lower(df_lower, 2)

Before length : 139709
Before length : 58963


In [78]:
df_upper_trunc = balance_dataframe(df_upper_trunc, 2500, dataset='upper')

paisley: length 426, oversample appending 171 for a final of 12202
graphic: length 4692, undersampling for a final of 2500
palm: length 34, oversample appending 192 for a final of 2692
zigzag: length 14, oversample appending 45 for a final of 2737
floral: length 947, oversample appending 1553 for a final of 4290
stripe: length 276, oversample appending 2224 for a final of 6514
abstract: length 378, oversample appending 374 for a final of 6888
animal: length 208, oversample appending 159 for a final of 7047
tribal: length 248, oversample appending 275 for a final of 7322
dot: length 157, oversample appending 903 for a final of 8225
knit: length 187, oversample appending 2313 for a final of 10538
denim: length 68, oversample appending 1420 for a final of 11958
leather: length 43, oversample appending 1908 for a final of 13866
lace: length 172, oversample appending 2328 for a final of 16194
pleated: length 80, oversample appending 401 for a final of 16595
fur: length 16, oversample append

In [79]:
df_lower_trunc = balance_dataframe(df_lower_trunc, 2000, dataset='lower')

paisley: length 552, oversample appending 113 for a final of 20387
graphic: length 6208, undersampling for a final of 2000
palm: length 65, oversample appending 29 for a final of 2029
zigzag: length 10, oversample appending 51 for a final of 2080
floral: length 593, oversample appending 1222 for a final of 3302
stripe: length 82, oversample appending 1263 for a final of 4565
abstract: length 201, oversample appending 86 for a final of 4651
animal: length 165, oversample appending 68 for a final of 4719
tribal: length 283, oversample appending 158 for a final of 4877
dot: length 35, oversample appending 571 for a final of 5448
wash: length 21, oversample appending 1368 for a final of 6816
pleated: length 167, oversample appending 1011 for a final of 7827
ripped: length 25, oversample appending 998 for a final of 8825
knit: length 53, oversample appending 1003 for a final of 9828
denim: length 159, oversample appending 1841 for a final of 11669
leather: length 29, oversample appending 10

In [80]:
df_upper_trunc.drop(columns=['abstract','fur','collarless','tribal','pleated','long-sleeve','animal','palm','paisley','zigzag','total'], inplace=True)

In [81]:
df_lower_trunc.drop(columns=['zipper','midi','tribal','high-waist','cuffed','pocket','abstract','animal','button','paisley','palm','zigzag','total'], inplace=True)

In [82]:
total_count(df_upper_trunc)
total_count(df_lower_trunc)
df_upper_trunc = df_upper_trunc[df_upper_trunc['total']!=0]
df_lower_trunc = df_lower_trunc[df_lower_trunc['total']!=0]
del df_upper_trunc['total']
del df_lower_trunc['total']

       image_name  graphic  floral  stripe    dot   knit  denim  leather  \
total                                                                      
0            2578     2578    2578    2578   2578   2578   2578     2578   
1           29824    29824   29824   29824  29824  29824  29824    29824   
2             818      818     818     818    818    818    818      818   
3            1205     1205    1205    1205   1205   1205   1205     1205   
4             326      326     326     326    326    326    326      326   
5              68       68      68      68     68     68     68       68   
6              13       13      13      13     13     13     13       13   
7               2        2       2       2      2      2      2        2   

        lace  chiffon   ...    sheer  cotton  sleeve  sleeveless  collar  \
total                   ...                                                
0       2578     2578   ...     2578    2578    2578        2578    2578   
1      2982

In [83]:
df_upper_trunc.sort_index(inplace=True)
df_lower_trunc.sort_index(inplace=True)

In [84]:
df_upper_trunc.to_csv(directory + 'anno/final_upper.csv', index=False)
df_lower_trunc.to_csv(directory + 'anno/final_lower.csv', index=False)