In [None]:
from pathlib import Path
from collections import defaultdict, Counter
from typing import Any
from pprint import pprint
from sklearn.preprocessing import LabelEncoder

import scipy.stats as sp_stats
import pickle as pkl
import json
import numpy as np
import matplotlib.pyplot as plt

In [None]:
DATASETS_PATH = Path('/data/Datasets/usg-kaggle/train/')

In [None]:
zeroes = list((DATASETS_PATH / "0").rglob('regression*'))
ones = list((DATASETS_PATH / "1").rglob("regression*"))
len(zeroes), len(ones)

In [None]:
def get_stats(paths: list) -> dict:
    results = defaultdict(list)
    for path in paths:
        datum = json.loads(path.read_text())
        for key, value in datum.items():
            results[key].append(value)
            
    plt.figure()
    plt.hist(results["mean"], bins=30)
    print(np.median(results["mean"]))
                
    return {
        "sd": np.mean(results["sd"]),
        "min": np.min(results["min"]),
        "max": np.max(results["max"]),
        "mean": np.mean(results["mean"])
    }

def get_mean_stats(paths: list) -> dict:
    means_stats = []
    for path in paths:
        datum = json.loads(path.read_text())
        means_stats.append(datum["mean"])
                
    return {
        "sd": np.std(means_stats),
        "min": np.min(means_stats),
        "max": np.max(means_stats),
        "mean": np.mean(means_stats)
    }

def get_unique_decimals(paths: list) -> set:
    decimals = []
    for path in paths:
        datum = json.loads(path.read_text())
        decimals.append(int(datum["mean"]))
    return set(decimals)

def get_decimals(paths: list) -> list:
    decimals = []
    for path in paths:
        datum = json.loads(path.read_text())
        decimals.append(int(datum["mean"]))
    return decimals


def get_floats(paths: list) -> list:
    decimals = []
    for path in paths:
        datum = json.loads(path.read_text())
        decimals.append(datum["mean"])
    return decimals

def get_unique_floats(paths: list) -> set:
    floats = []
    for path in paths:
        datum = json.loads(path.read_text())
        floats.append(int(datum["mean"] * 10) % 10)
    return set(floats)

In [None]:
get_stats(zeroes), get_stats(ones)

In [None]:
get_mean_stats(zeroes), get_mean_stats(ones)

In [None]:
decimals = list(get_unique_decimals(zeroes + ones))
np.min(decimals), np.max(decimals)

In [None]:
get_unique_floats(zeroes + ones)

In [None]:
def get_threshold_split(ones, zeroes, bins=20):
    together = np.asarray(ones + zeroes)
    true_labels = np.asarray(
        [1] * len(ones) + [0] * len(zeroes)
    )
    thresholds = np.linspace(
        np.min(together) + 1,
        np.max(together) - 1,
        200,
        dtype=np.int
    )
    
    best_thr = -1
    best_entr = np.inf
    print(together)
    
    for thr in thresholds:
        indices_left_split = np.where(together < thr)[0]
        indices_right_split = np.where(together >= thr)[0]
    
        preds = np.zeros_like(true_labels).astype(np.float32)
        preds[indices_left_split] = 0.01
        preds[indices_right_split] = 0.99
        
        
        entropy = -np.sum(true_labels * np.log(preds) + (1 - true_labels) * np.log(1 - preds))

        if entropy < best_entr:
            best_entr = entropy
            best_thr = thr
            
    plt.figure(figsize=(18, 6))
    plt.hist(ones_decims, bins=20)
    plt.hist(zeroes_decims, bins=20)
    plt.axvline(x=best_thr)
    
    return best_thr

In [None]:
ones_decims = get_floats(ones)
zeroes_decims = get_floats(zeroes)
plt.hist(ones_decims, bins=20)
plt.hist(zeroes_decims, bins=20)
print()

In [None]:
best_thr = get_threshold_split(ones_decims, zeroes_decims)
print("Best threshold: " + str(best_thr))

In [None]:
sp_stats.mode(ones_decims), sp_stats.mode(zeroes_decims)

In [None]:
np.median(ones_decims), np.median(zeroes_decims)

In [None]:
np.mean(ones_decims), np.mean(zeroes_decims)

In [None]:
np.max(ones_decims) - np.min(zeroes_decims)

In [None]:
total = np.asarray(ones_decims + zeroes_decims)
plt.hist(total / 4, bins=20)
print()