In [1]:
import pyreadr
result = pyreadr.read_r('DataforR50JSDPlot.Rdata')
print(result.keys())

odict_keys(['Test.df', 'mixprop'])


In [2]:
import numpy as np
df = result["Test.df"]


In [3]:
preds = np.array(df["Trans_h"])
targets = np.array(df["label"])

In [4]:
print(targets)

['0' '0' '0' ... '1' '1' '1']


In [5]:
print(preds)

[0.84791559 0.85341697 0.80439193 ... 0.4250999  0.04426433 0.13619038]


In [6]:
def find_threshold(L,mask,x_frac):
    """
    Calculate c such that x_frac of the array is less than c.

    Parameters
    ----------
    L : Array
        The array where the cutoff is to be found
    mask : Array,
        Mask that returns L[mask] the part of the original array over which it is desired to calculate the threshold.
    x_frac : float
        Of the area that is lass than or equal to c.

    returns c (type=L.dtype)
    """
    max_x = mask.sum()
    x = int(np.round(x_frac * max_x))
    L_sorted = np.sort(L[mask.astype(bool)])
    return L_sorted[x] 

In [7]:
c = find_threshold(preds, (targets=="0"), 0.5)
R50 = 1/((preds[targets=="1"]>c).sum()/(targets=="1").sum())
print(R50)

10.199891377780132


In [8]:
from scipy.stats import entropy
m = np.array(df["mass"])
hist1, bins = np.histogram(m[(targets=="1")&(preds>c)], bins = 50, density = True)
hist2, _ = np.histogram(m[(targets=="1")&(preds<c)], bins = bins, density = True)
JSD = 1/(0.5*(entropy(hist1,0.5*(hist1+hist2)) + entropy(hist2,0.5*(hist1+hist2))))
print(JSD)

7358.641914341908


In [9]:
R50list = np.array(R50)
JSDlist = np.array(JSD)
prop = np.array(result["mixprop"])
OTh = np.array(df["Trans_h"])
h = np.array(df["h"])

In [10]:
for p in prop.flatten():
    preds = (1 - p)*OTh + p*h
    c = find_threshold(preds, (targets=="0"), 0.5)
    R50new = 1/((preds[targets=="1"]>c).sum()/(targets=="1").sum())
    R50list = np.append(R50list, R50new)
    hist1, bins = np.histogram(m[(targets=="1")&(preds>c)], bins = 50, density = True)
    hist2, _ = np.histogram(m[(targets=="1")&(preds<c)], bins = bins, density = True)
    JSDnew = 1/(0.5*(entropy(hist1,0.5*(hist1+hist2)) + entropy(hist2,0.5*(hist1+hist2))))
    JSDlist = np.append(JSDlist, JSDnew)

In [11]:
print(JSDlist)

[7.35864191e+03 7.35864191e+03 7.35864191e+03 7.35864191e+03
 7.36527166e+03 7.36527166e+03 7.36413791e+03 7.36026017e+03
 7.36013534e+03 7.36013534e+03 7.35630006e+03 7.27298549e+03
 7.28696405e+03 7.29529664e+03 7.30425559e+03 7.23567166e+03
 7.26396352e+03 7.25891420e+03 6.90078127e+03 5.62251290e+03
 3.23604998e+03 1.14952927e+03 1.06308806e+03 8.12602290e+02
 6.27506009e+02 4.96866650e+02 3.93532425e+02 3.25898282e+02
 2.71340889e+02 2.29009288e+02 1.94568072e+02 1.80933407e+02
 9.64303827e+01 5.76802785e+01 3.82669792e+01 2.68295413e+01
 1.98773704e+01 1.51677934e+01 1.20141088e+01 1.02174835e+01
 8.99726778e+00 7.91265418e+00 7.11607331e+00 6.47880076e+00
 6.01388385e+00 5.81425242e+00 5.53893943e+00 5.32088115e+00
 4.95698996e+00]


In [12]:
print(R50list)

[10.19989138 10.19989138 10.19989138 10.19989138 10.20002649 10.20002649
 10.19989138 10.19962116 10.19948605 10.19948605 10.20029674 10.20083726
 10.20259437 10.2047578  10.20624569 10.21382713 10.22209832 10.24508369
 10.29177861 10.37903704 10.53942704 10.89833411 10.93237545 11.05162689
 11.16978067 11.29611971 11.427045   11.55652944 11.70177198 11.83650254
 11.96915998 12.0205442  12.67781876 13.38594997 14.10902428 14.82536871
 15.58767561 16.33363031 16.99591657 17.66460197 18.26722338 18.83838137
 19.33216169 19.80401739 20.32573978 20.74856512 21.18818965 21.62374681
 22.01950299]


In [13]:
print(prop.flatten())

[5.96046448e-08 1.19209290e-07 2.38418579e-07 4.76837158e-07
 9.53674316e-07 1.90734863e-06 3.81469727e-06 7.62939453e-06
 1.52587891e-05 3.05175781e-05 6.10351562e-05 1.22070312e-04
 2.44140625e-04 4.88281250e-04 9.76562500e-04 1.95312500e-03
 3.90625000e-03 7.81250000e-03 1.56250000e-02 3.12500000e-02
 6.25000000e-02 6.50000000e-02 7.50000000e-02 8.50000000e-02
 9.50000000e-02 1.05000000e-01 1.15000000e-01 1.25000000e-01
 1.35000000e-01 1.45000000e-01 1.50000000e-01 2.00000000e-01
 2.50000000e-01 3.00000000e-01 3.50000000e-01 4.00000000e-01
 4.50000000e-01 5.00000000e-01 5.50000000e-01 6.00000000e-01
 6.50000000e-01 7.00000000e-01 7.50000000e-01 8.00000000e-01
 8.50000000e-01 9.00000000e-01 9.50000000e-01 1.00000000e+00]


In [14]:
import pandas as pd
prop = np.append(0,prop)
dfR50JSD = pd.DataFrame({'prop':prop, 'R50list':R50list, 'JSDlist':JSDlist})

In [15]:
print(prop)

[0.00000000e+00 5.96046448e-08 1.19209290e-07 2.38418579e-07
 4.76837158e-07 9.53674316e-07 1.90734863e-06 3.81469727e-06
 7.62939453e-06 1.52587891e-05 3.05175781e-05 6.10351562e-05
 1.22070312e-04 2.44140625e-04 4.88281250e-04 9.76562500e-04
 1.95312500e-03 3.90625000e-03 7.81250000e-03 1.56250000e-02
 3.12500000e-02 6.25000000e-02 6.50000000e-02 7.50000000e-02
 8.50000000e-02 9.50000000e-02 1.05000000e-01 1.15000000e-01
 1.25000000e-01 1.35000000e-01 1.45000000e-01 1.50000000e-01
 2.00000000e-01 2.50000000e-01 3.00000000e-01 3.50000000e-01
 4.00000000e-01 4.50000000e-01 5.00000000e-01 5.50000000e-01
 6.00000000e-01 6.50000000e-01 7.00000000e-01 7.50000000e-01
 8.00000000e-01 8.50000000e-01 9.00000000e-01 9.50000000e-01
 1.00000000e+00]


In [16]:
pyreadr.write_rdata("PythonR50JSDPlotData.RData", dfR50JSD, df_name="R50JSDPython")