In [35]:
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
pd.set_option('display.max_rows', 1000) 
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
import warnings
from sklearn.metrics import mean_absolute_error
warnings.filterwarnings("ignore")

In [23]:
sperm_data = pd.read_csv(r"all_features.csv")
sperm_data = sperm_data.replace("-", np.nan)
sperm_data = sperm_data.dropna()
blas_rate = sperm_data["BLAST_D8"].copy()
blas_rate = pd.to_numeric(blas_rate)
sperm_data = sperm_data.drop(columns=['BLAST_D8'])

In [28]:
correlated_features = set()
correlation_matrix = sperm_data.corr()
for i in range(len(correlation_matrix.columns)):
    for j in range(i):
        if abs(correlation_matrix.iloc[i, j]) > 0.8:
            colname = correlation_matrix.columns[i]
            correlated_features.add(colname)
sperm_data.drop(labels=correlated_features, axis=1, inplace=True)
sperm_data = sperm_data.apply(pd.to_numeric)

In [29]:
# FIVE DATASETS TO TEST ON
# - all non-correlated features
# - ss_cells
# - ss_motility
# - ss_pct
# - pca (done on training set, then applied on testing)
ss_cells = sperm_data[["AI", "PI", "ALTO", "FRAG_CRO"]].copy()
ss_motility = sperm_data[["VAP", "VCL", "ALH", "BCF", "STR"]].copy()
ss_pct = sperm_data[["MOTILE_PCT", "MEDIUM_PCT", "SLOW_PCT", "STATIC_PCT"]].copy()

In [72]:
def accuracy(correct, predicted, classes):
    accuracy=[]
    elements = []
    for label in classes:
        accuracy.append(100*np.mean([(x == y)  for x, y in zip(correct,predicted) if x==label]))  
        elements.append(len([(x == y)  for x, y in zip(correct,predicted) if x==label]))
    accuracy.append(100*np.mean([(x == y)  for x, y in zip(correct,predicted)] ) )
    elements.append(len([(x == y)  for x, y in zip(correct,predicted)]))
    accuracy = np.nan_to_num(accuracy) 
    elements = np.nan_to_num(elements)
    return accuracy, elements

def error_calculator(correct, predicted, base):
    error = mean_absolute_error(output_test, predicted)
    near_batches = np.zeros((1,5))
    for i in range(5):
        near_batches[:,i] = sum([(abs(x-y)/base == i + 1) for x, y in zip(correct, predicted)])
    near_batches = 100*near_batches/len(predicted)
    return error, near_batches
    

In [100]:
def split_and_classify_continuous(train_data, test_data, train_labels, test_labels, base, group_labels):
    clf = DecisionTreeClassifier(min_samples_leaf = 6, criterion = "entropy")
    tree_class = clf.fit(train_data, train_labels)
    results = clf.predict(test_data)
    acc, element_num = accuracy(test_labels, results, group_labels)
    error, batches = error_calculator(test_labels, results, base)
    return acc, element_num, error, batches

In [109]:
def sandc_average_accuracy(data, labels, base, iterations):
    group_labels = [n for n in range(1, 60) if n % base == 0]
    group_labels.insert(0,0)
    train, test, output_train, output_test = train_test_split(data, labels, test_size=0.20)
    mean_acc, mean_elements, mean_error, mean_batches = split_and_classify_continuous(train, test, output_train, output_test, 10, group_labels)
    for it in range(iterations-1):
        train, test, output_train, output_test = train_test_split(data, labels, test_size=0.20)
        new_acc, new_elements, new_error, new_batches = split_and_classify_continuous(train, test, output_train, output_test, 10, group_labels)
        mean_acc, mean_elements, mean_error, mean_batches = mean_acc + new_acc, mean_elements + new_elements, mean_error + new_error, mean_batches + new_batches
    mean_acc, mean_elements, mean_error, mean_batches =  mean_acc/iterations, mean_elements/iterations, mean_error/iterations, mean_batches/iterations
    print("acc", mean_acc[-1])
    print("elements", mean_elements[-1])
    print("error", mean_error)
    print("batches", mean_batches)
    return mean_acc, mean_elements, mean_error, mean_batches

In [71]:
split_and_classify_continuous(train, test, output_train, output_test, 10)

[[0.22413793 0.29310345 0.03448276 0.         0.        ]]


(array([ 0.        , 45.83333333, 44.44444444, 33.33333333, 57.14285714,
         0.        , 44.82758621]), array([ 0, 24, 18,  9,  7,  0, 58]))

In [111]:
round_bases = [20,10,5,1]
for base in round_bases :
    blasrate_base = base * round(blas_rate / base)
    print("ROUNDING TO ", base)
    sandc_average_accuracy(sperm_data, blasrate_base, base, 20)

ROUNDING TO  20
acc 63.70689655172415
elements 58.0
error 12.965517241379313
batches [[ 0.         28.44827586  0.          7.84482759  0.        ]]
ROUNDING TO  10
acc 51.29310344827587
elements 58.0
error 10.956896551724137
batches [[25.25862069 17.5862069   5.43103448  0.43103448  0.        ]]
ROUNDING TO  5
acc 34.74137931034483
elements 58.0
error 11.159482758620687
batches [[17.67241379  7.67241379  3.87931034  0.34482759  0.        ]]
ROUNDING TO  1
acc 14.310344827586206
elements 58.0
error 11.131896551724138
batches [[4.56896552 1.46551724 0.60344828 0.         0.        ]]


In [106]:

test

[0, 20, 40]

In [90]:
sandc_average_accuracy(sperm_data, blasrate_10, 10, 20)

[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
[ 0. 10. 20. 30. 40. 50.]
acc [15.83333333 60.94243873 52.12991491 44.12232539 40.03986291  0.
 49.39655172]
elements [ 1.4  18.7  17.25 12.7   7.1   0.85 58.  ]
error 11.21551724137931
batches [[24.65517241 17.32758621  7.4137931   1.20689655  0.        ]]


(array([15.83333333, 60.94243873, 52.12991491, 44.12232539, 40.03986291,
         0.        , 49.39655172]),
 array([ 1.4 , 18.7 , 17.25, 12.7 ,  7.1 ,  0.85, 58.  ]),
 11.21551724137931,
 array([[24.65517241, 17.32758621,  7.4137931 ,  1.20689655,  0.        ]]))