In [2]:
def preprocess_photo(dset):
    dropped = dset.drop(columns=['ra', 'dec', 'type', 'mjd'])
    renamed = dropped.rename(columns={'objType': 'type'})
    del dropped
    cols = renamed.columns
    cols_rearranged = [[cols[1]], [cols[0]], cols[7:12], cols[17:], cols[2:7], cols[12:17]]
    cols_rearranged = pd.Index([item for sublist in cols_rearranged for item in sublist])
    renamed = renamed[cols_rearranged]
    renamed['id'] = renamed.index
    final_cols = [renamed.columns[-1]]
    final_cols.extend(renamed.columns[:-1])
    final_cols = pd.Index(final_cols)
    reordered = renamed[final_cols]
    del renamed
    filtered = reordered[~reordered['type'].isin(['NONLEGACY', 'QA', 'HOT_STD'])]
    del reordered
    final_cols2 = ['id', 'type', 'fiberID']
    final_cols2.extend(filtered.columns[3:])
    final_dset = filtered[final_cols2]
    del filtered
    return final_dset

def get_ratios(dataset, target_col):
    cnt_dict = {type_key: 0 for type_key in dataset[target_col].unique()}
    for t in dataset[target_col]:
        cnt_dict[t] += 1
    total = len(dataset)
    ratio_dict = dict()
    for k in cnt_dict.keys():
        cnt = cnt_dict[k]
        ratio_dict[k] = cnt / total
    return ratio_dict

In [9]:
import pandas as pd
import math
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split

In [5]:
p5 = pd.read_csv('./SDSSPhotoObjDR8Attempt5_etture_0.csv')

In [6]:
p5_df = preprocess_photo(p5)

In [8]:
sample_nums = {'QSO': 50000,
 'STAR_RED_DWARF': 10635,
 'SERENDIPITY_BLUE': 22869,
 'STAR_BHB': 14198,
 'STAR_CATY_VAR': 6832,
 'SERENDIPITY_DISTANT': 4897,
 'GALAXY': 40000,
 'SPECTROPHOTO_STD': 15380,
 'REDDEN_STD': 15355,
 'ROSAT_D': 6539,
 'STAR_WHITE_DWARF': 2272,
 'SERENDIPITY_RED': 1543,
 'STAR_CARBON': 3382,
 'SERENDIPITY_FIRST': 6726,
 'STAR_BROWN_DWARF': 258,
 'STAR_SUB_DWARF': 1149,
 'SKY': 200,
 'SERENDIPITY_MANUAL': 54,
 'STAR_PN': 14}

sample_divs = dict()
for k, v in sample_nums.items():
    train_cnt = math.ceil(v * 0.85)
    test_cnt = v - train_cnt
    sample_divs[k] = {
        'train': train_cnt,
        'test': test_cnt
    }

## Create train dataset

In [10]:
train_df = pd.DataFrame()
test_df = pd.DataFrame()
for k, v in sample_divs.items():
    train_tmp, test_tmp = train_test_split(
        p5_df[p5_df['type'] == k],
        train_size=v['train'], test_size=v['test'],
        random_state=123
    )
    train_df = train_df.append(train_tmp)
    test_df = test_df.append(test_tmp)
    del train_tmp
    del test_tmp

In [11]:
train_df.to_csv('ybigta_sdss_train.csv', index=False)

## Create test dataset

In [12]:
test_count_dict = {type_key: 0 for type_key in test_df['type'].unique()}
for t in test_df['type']:
    test_count_dict[t] += 1

sorted_test_count = sorted([(k, v) for k, v in test_count_dict.items()], key=lambda x: x[1])

validate_df = pd.DataFrame()
for k, v in test_count_dict.items():
    if v < 800:
        to_add = test_df[test_df['type'] == k]
    else:
        to_add = test_df[test_df['type'] == k].sample(n=800, replace=False, random_state=123)
    validate_df = validate_df.append(to_add)
    del to_add

In [39]:
test_df

Unnamed: 0,id,type,fiberID,psfMag_u,psfMag_g,psfMag_r,psfMag_i,psfMag_z,fiberMag_u,fiberMag_g,...,petroMag_u,petroMag_g,petroMag_r,petroMag_i,petroMag_z,modelMag_u,modelMag_g,modelMag_r,modelMag_i,modelMag_z
70379,70379,QSO,182,18.81052,18.32911,17.90038,17.70225,17.65288,19.15704,18.62413,...,18.85941,18.33458,17.93811,17.74766,17.68350,18.80268,18.32140,17.89795,17.69311,17.66096
863365,863365,QSO,207,22.99911,21.23738,19.89757,19.26474,19.88944,23.23582,21.56254,...,22.60821,21.30784,19.90107,19.31741,19.59066,23.08329,21.22320,19.89089,19.25760,19.78871
40381,40381,QSO,496,18.91604,18.69714,18.59921,18.68769,18.63109,19.21263,18.99961,...,18.92110,18.72866,18.63251,18.71346,18.47264,18.91443,18.68466,18.58225,18.67094,18.58222
1322598,1322598,QSO,441,19.22888,18.64079,18.68171,18.52602,18.22359,19.02119,18.28669,...,17.71678,16.83449,16.67583,16.49454,16.33526,17.69108,16.90666,16.79139,16.57578,16.47379
1201659,1201659,QSO,320,19.57746,19.57045,19.23320,19.02790,18.62246,19.87532,19.80541,...,19.15427,18.84896,18.01861,17.66363,17.29827,18.74826,18.57701,17.93407,17.64736,17.20101
1335771,1335771,QSO,57,19.26550,18.64677,18.32241,17.88121,17.75130,19.45834,18.73669,...,18.99858,18.26249,17.82439,17.41028,17.36499,18.91461,18.20422,17.77503,17.34980,17.29262
1564469,1564469,QSO,590,19.45684,19.12321,19.00831,18.98338,19.03146,19.77078,19.39687,...,19.50100,19.11691,19.02535,18.98158,19.04087,19.45539,19.08589,18.98612,18.94333,19.01933
376106,376106,QSO,127,23.59964,19.95862,19.47769,19.35825,19.10950,23.44202,20.27695,...,22.78849,19.92928,19.43834,19.29493,19.30431,23.64885,19.95266,19.47823,19.35339,19.12055
197635,197635,QSO,258,18.00722,16.17424,15.52764,15.29504,15.19117,18.29322,16.49043,...,18.08923,16.23658,15.57622,15.35958,15.26615,18.00967,16.17774,15.51632,15.29850,15.18514
311010,311010,QSO,273,19.18171,18.86552,18.63319,18.68903,18.09233,19.53741,19.22089,...,19.12451,18.86324,18.61445,18.68444,18.13084,19.14034,18.82939,18.59932,18.65805,18.05879


In [16]:
validate_df.head()

Unnamed: 0,id,type,fiberID,psfMag_u,psfMag_g,psfMag_r,psfMag_i,psfMag_z,fiberMag_u,fiberMag_g,...,petroMag_u,petroMag_g,petroMag_r,petroMag_i,petroMag_z,modelMag_u,modelMag_g,modelMag_r,modelMag_i,modelMag_z
22121,22121,QSO,524,20.25184,19.21797,18.81325,18.63427,18.57789,20.55951,19.53614,...,20.31346,19.24276,18.84517,18.65059,18.63834,20.24235,19.21357,18.80395,18.62018,18.57158
28799,28799,QSO,281,19.63154,19.48602,19.19663,18.75605,18.82092,19.9481,19.73573,...,19.6106,19.32184,18.93773,18.49994,18.31607,19.52231,19.29745,18.95885,18.51906,18.56078
1416396,1416396,QSO,166,19.16294,18.94821,18.77507,18.92679,18.89543,19.4714,19.30156,...,19.14995,19.01328,18.78268,18.94584,18.95756,19.13103,18.94504,18.72925,18.89673,18.877
229542,229542,QSO,38,19.28713,19.02645,18.71246,18.3028,18.24522,19.63397,19.34047,...,19.30062,19.02987,18.73833,18.31396,18.30046,19.25876,19.01132,18.70494,18.29297,18.2314
482490,482490,QSO,304,19.43952,19.22812,19.03674,18.79651,18.71855,19.7895,19.51578,...,19.48426,19.23326,19.1245,18.84784,18.83571,19.43872,19.20541,19.04823,18.79752,18.73385


In [34]:
type_dropped = validate_df.drop(columns=['type'])

In [35]:
type_dropped.head()

Unnamed: 0,id,fiberID,psfMag_u,psfMag_g,psfMag_r,psfMag_i,psfMag_z,fiberMag_u,fiberMag_g,fiberMag_r,...,petroMag_u,petroMag_g,petroMag_r,petroMag_i,petroMag_z,modelMag_u,modelMag_g,modelMag_r,modelMag_i,modelMag_z
22121,22121,524,20.25184,19.21797,18.81325,18.63427,18.57789,20.55951,19.53614,19.14778,...,20.31346,19.24276,18.84517,18.65059,18.63834,20.24235,19.21357,18.80395,18.62018,18.57158
28799,28799,281,19.63154,19.48602,19.19663,18.75605,18.82092,19.9481,19.73573,19.3936,...,19.6106,19.32184,18.93773,18.49994,18.31607,19.52231,19.29745,18.95885,18.51906,18.56078
1416396,1416396,166,19.16294,18.94821,18.77507,18.92679,18.89543,19.4714,19.30156,19.08113,...,19.14995,19.01328,18.78268,18.94584,18.95756,19.13103,18.94504,18.72925,18.89673,18.877
229542,229542,38,19.28713,19.02645,18.71246,18.3028,18.24522,19.63397,19.34047,19.07622,...,19.30062,19.02987,18.73833,18.31396,18.30046,19.25876,19.01132,18.70494,18.29297,18.2314
482490,482490,304,19.43952,19.22812,19.03674,18.79651,18.71855,19.7895,19.51578,19.38709,...,19.48426,19.23326,19.1245,18.84784,18.83571,19.43872,19.20541,19.04823,18.79752,18.73385


In [36]:
type_dropped.to_csv('ybigta_sdss_test.csv', index=False)

In [37]:
mytest = pd.read_csv('./ybigta_sdss_test.csv')

In [40]:
test_all_dropped = test_df.drop(columns=['type'])

In [52]:
test_all_dropped.to_csv('ybigta_sdss_test_full.csv', index=True)

In [53]:
mytest2 = pd.read_csv('./ybigta_sdss_test_full.csv')

In [54]:
mytest2.head()

Unnamed: 0.1,Unnamed: 0,id,fiberID,psfMag_u,psfMag_g,psfMag_r,psfMag_i,psfMag_z,fiberMag_u,fiberMag_g,...,petroMag_u,petroMag_g,petroMag_r,petroMag_i,petroMag_z,modelMag_u,modelMag_g,modelMag_r,modelMag_i,modelMag_z
0,70379,70379,182,18.81052,18.32911,17.90038,17.70225,17.65288,19.15704,18.62413,...,18.85941,18.33458,17.93811,17.74766,17.6835,18.80268,18.3214,17.89795,17.69311,17.66096
1,863365,863365,207,22.99911,21.23738,19.89757,19.26474,19.88944,23.23582,21.56254,...,22.60821,21.30784,19.90107,19.31741,19.59066,23.08329,21.2232,19.89089,19.2576,19.78871
2,40381,40381,496,18.91604,18.69714,18.59921,18.68769,18.63109,19.21263,18.99961,...,18.9211,18.72866,18.63251,18.71346,18.47264,18.91443,18.68466,18.58225,18.67094,18.58222
3,1322598,1322598,441,19.22888,18.64079,18.68171,18.52602,18.22359,19.02119,18.28669,...,17.71678,16.83449,16.67583,16.49454,16.33526,17.69108,16.90666,16.79139,16.57578,16.47379
4,1201659,1201659,320,19.57746,19.57045,19.2332,19.0279,18.62246,19.87532,19.80541,...,19.15427,18.84896,18.01861,17.66363,17.29827,18.74826,18.57701,17.93407,17.64736,17.20101


In [55]:
basic_test =  pd.read_csv('./test.csv')

In [56]:
basic_test.head()

Unnamed: 0,id,fiberID,psfMag_u,psfMag_g,psfMag_r,psfMag_i,psfMag_z,fiberMag_u,fiberMag_g,fiberMag_r,...,petroMag_u,petroMag_g,petroMag_r,petroMag_i,petroMag_z,modelMag_u,modelMag_g,modelMag_r,modelMag_i,modelMag_z
0,199991,251,23.817399,22.508963,20.981106,18.517316,17.076079,25.05389,23.167848,21.335901,...,22.246697,22.796239,21.195315,18.584486,17.154284,25.391534,22.499435,21.011918,18.499341,17.091474
1,199992,386,22.806983,21.937111,20.33577,20.000512,19.527369,22.498565,22.186,20.618879,...,21.729831,21.837511,20.196128,19.967204,19.683671,22.475338,21.853442,20.173169,19.796757,19.567372
2,199993,232,21.02425,19.235669,18.304061,17.808608,17.380113,21.205546,19.439533,18.344433,...,20.722629,18.710223,17.611851,17.158519,16.843986,20.579314,18.653338,17.562108,17.120529,16.708748
3,199994,557,20.503424,20.286261,20.197204,20.162419,20.059832,20.976132,20.611498,20.567262,...,20.329269,20.385262,20.129157,20.206574,20.212342,20.479879,20.280943,20.150499,20.206221,20.092909
4,199995,75,24.244851,22.668237,21.239333,19.284777,18.235939,25.68186,22.935289,21.642456,...,22.308298,22.957496,21.285033,19.29912,18.307526,25.48936,22.85729,21.191862,19.237964,18.280368


## Create submission sample

In [14]:
sample_submission = pd.DataFrame(columns=['id', 'STAR_WHITE_DWARF', 'STAR_CATY_VAR', 'STAR_BROWN_DWARF',
       'SERENDIPITY_RED', 'REDDEN_STD', 'STAR_BHB', 'GALAXY',
       'SERENDIPITY_DISTANT', 'QSO', 'SKY', 'STAR_RED_DWARF', 'ROSAT_D',
       'STAR_PN', 'SERENDIPITY_FIRST', 'STAR_CARBON', 'SPECTROPHOTO_STD',
       'STAR_SUB_DWARF', 'SERENDIPITY_MANUAL', 'SERENDIPITY_BLUE'])

sample_submission['id'] = validate_df.index

sample_submission = sample_submission.fillna(0.0)

In [15]:
sample_submission.to_csv('ybigta_sdss_sample_submission.csv', index=False)

In [17]:
validate_df.head()

Unnamed: 0,id,type,fiberID,psfMag_u,psfMag_g,psfMag_r,psfMag_i,psfMag_z,fiberMag_u,fiberMag_g,...,petroMag_u,petroMag_g,petroMag_r,petroMag_i,petroMag_z,modelMag_u,modelMag_g,modelMag_r,modelMag_i,modelMag_z
22121,22121,QSO,524,20.25184,19.21797,18.81325,18.63427,18.57789,20.55951,19.53614,...,20.31346,19.24276,18.84517,18.65059,18.63834,20.24235,19.21357,18.80395,18.62018,18.57158
28799,28799,QSO,281,19.63154,19.48602,19.19663,18.75605,18.82092,19.9481,19.73573,...,19.6106,19.32184,18.93773,18.49994,18.31607,19.52231,19.29745,18.95885,18.51906,18.56078
1416396,1416396,QSO,166,19.16294,18.94821,18.77507,18.92679,18.89543,19.4714,19.30156,...,19.14995,19.01328,18.78268,18.94584,18.95756,19.13103,18.94504,18.72925,18.89673,18.877
229542,229542,QSO,38,19.28713,19.02645,18.71246,18.3028,18.24522,19.63397,19.34047,...,19.30062,19.02987,18.73833,18.31396,18.30046,19.25876,19.01132,18.70494,18.29297,18.2314
482490,482490,QSO,304,19.43952,19.22812,19.03674,18.79651,18.71855,19.7895,19.51578,...,19.48426,19.23326,19.1245,18.84784,18.83571,19.43872,19.20541,19.04823,18.79752,18.73385


In [33]:
import pickle

pickle_out = open("ybigta_validate2.pickle","wb")
pickle.dump(list(validate_df.type), pickle_out)
pickle_out.close()

In [32]:
len(list(validate_df.type))

10062

### Full Test Set

In [48]:
sample_test_submission = pd.DataFrame(columns=['id', 'STAR_WHITE_DWARF', 'STAR_CATY_VAR', 'STAR_BROWN_DWARF',
       'SERENDIPITY_RED', 'REDDEN_STD', 'STAR_BHB', 'GALAXY',
       'SERENDIPITY_DISTANT', 'QSO', 'SKY', 'STAR_RED_DWARF', 'ROSAT_D',
       'STAR_PN', 'SERENDIPITY_FIRST', 'STAR_CARBON', 'SPECTROPHOTO_STD',
       'STAR_SUB_DWARF', 'SERENDIPITY_MANUAL', 'SERENDIPITY_BLUE'])

sample_test_submission['id'] = test_df.index

sample_test_submission = sample_test_submission.fillna(0.0)

In [49]:
sample_test_submission.to_csv('ybigta_sdss_sample_submission_full.csv', index=False)

In [45]:
test_df.type.to_pickle('ybigta_validate_full.pickle')

In [25]:
validate_df.type.to_pickle('ybigta_validate.pickle')

In [26]:
vdf = pd.read_pickle('ybigta_validate.pickle')

In [28]:
vdf

22121                     QSO
28799                     QSO
1416396                   QSO
229542                    QSO
482490                    QSO
1433774                   QSO
1609859                   QSO
952073                    QSO
681644                    QSO
192906                    QSO
16137                     QSO
888078                    QSO
1263328                   QSO
1102813                   QSO
1699042                   QSO
127544                    QSO
374876                    QSO
336590                    QSO
825384                    QSO
1474708                   QSO
702682                    QSO
116475                    QSO
887159                    QSO
576815                    QSO
893980                    QSO
1068044                   QSO
70825                     QSO
4179                      QSO
1137572                   QSO
1267209                   QSO
                  ...        
748590                    SKY
1489874                   SKY
764437    