In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas
import torch
import torchvision.transforms as transforms
from globsML.utils.eval import get_test_metrics
from globsML.utils.draw import plot_source
from globsML.utils.imageloader import load_data, CustomGCDataset
from globsML.utils.training import train_CNN as train
from globsML.utils.training import test_CNN as test
from globsML.models.CNN import CNN
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm, trange

# list of thresholds used to calculate AUC ROC
thresh = np.arange(0,1.02,0.01)

In [3]:
# create random seeds for random splits
np.random.seed(123124)
seeds = list(map(int, np.random.random(10)*21321456))

# parameters for data split
BATCHSIZE = 500
TEST_SIZE = 0.2
EVAL_SIZE = 0.05

# choose method
method = 'forest'

# run experiment for all random splits
final_res = pandas.DataFrame()
for SEED in seeds:
    # load labels
    data_path = '../data/ACS_sources_original.csv'
    data = pandas.read_csv(data_path)
    galaxies_to_test = set([])

    # load image data and create a data set
    available_galaxies = set(data['galaxy'].unique())
    available_galaxies = available_galaxies.difference(set(['VCC538']))
    images, labels, probabilities, galaxies, IDs = load_data(data, available_galaxies)

    # create data splits
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    train_images, test_images, train_labels, test_labels = train_test_split(images, probabilities, test_size=TEST_SIZE, random_state=SEED)
    test_labels = np.array(test_labels>=0.5, dtype=int)
    train_labels = np.array(train_labels>=0.5, dtype=int)
    
    # create model
    if method == 'forest':
        model = RandomForestClassifier(n_estimators = 200, random_state = 42424)
    elif method == '12NN':
        model = KNeighborsClassifier(12)
    elif method == 'NN':
        model = KNeighborsClassifier(1)
    # fit model
    model.fit(np.reshape(train_images, (len(train_images), 2*20*20)), train_labels)
    # get test prediction
    pred = model.predict(np.reshape(test_images, (len(test_images), 2*20*20)))
    probs = model.predict_proba(np.reshape(test_images, (len(test_images), 2*20*20)))[:,1]
    
    # evaluate performance metrics
    stats_gal, stats_all, _, _, _, _, _ = get_test_metrics(list(galaxies_to_test), [], [], test_labels, pred, probs=probs, thresh=thresh)
    final_res = final_res.append(stats_all)
final_res['seed'] = seeds 
final_res.to_csv('All2All--RF-results')

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [12]:
final_res.mean()

  final_res.mean()


TPR             8.380882e-01
FDR             1.003682e-01
FPR             2.634964e-02
AUC(FDR,TPR)    9.324322e-01
AUC(FPR,TPR)    9.808841e-01
# found GCs     3.107900e+03
# total GCs     3.708300e+03
# fake GCs      3.468000e+02
# sources       1.687000e+04
seed            6.988626e+06
dtype: float64

In [14]:
final_res

Unnamed: 0,Galaxy,TPR,FDR,FPR,"AUC(FDR,TPR)","AUC(FPR,TPR)",# found GCs,# total GCs,# fake GCs,# sources,seed
0,ALL,0.830199,0.098389,0.026117,0.930826,0.980207,3134,3775,342,16870,2780441
0,ALL,0.840639,0.10306,0.027099,0.934473,0.982035,3107,3696,357,16870,6488346
0,ALL,0.841767,0.09914,0.026342,0.93462,0.981428,3144,3735,346,16870,12269029
0,ALL,0.836373,0.093341,0.024426,0.935493,0.981136,3118,3728,321,16870,2166833
0,ALL,0.836646,0.100936,0.026142,0.931078,0.97958,3073,3673,345,16870,6621300
0,ALL,0.830698,0.105665,0.027269,0.925465,0.97932,3047,3668,360,16870,13698722
0,ALL,0.841863,0.095432,0.024892,0.937085,0.982309,3109,3693,328,16870,2249589
0,ALL,0.847467,0.106299,0.028811,0.93062,0.980912,3178,3750,378,16870,9815943
0,ALL,0.840272,0.096019,0.024858,0.932335,0.981089,3088,3675,328,16870,625702
0,ALL,0.834959,0.105401,0.027542,0.932328,0.980826,3081,3690,363,16870,13170350
