In [None]:
from orphics import sehgal, maps
import healpy as hp
from pixell import utils, enmap, curvedsky, enplot, wcsutils
import os
import numpy as np

import matplotlib.pyplot as plt
import lmdb
from cosmikyu import datasets, transforms, config, stats
from cosmikyu import utils as cutils

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
data_dir = config.default_data_dir
sehgal_dir = os.path.join(data_dir, 'sehgal')
def data_path(x):
    return os.path.join(sehgal_dir, x)
SDS_validation = datasets.SehgalDataSet(sehgal_dir, "validation281220_fromcat", transforms=[], dummy_label=False)
data = np.zeros((5, 128, 128*len(SDS_validation)))
compts = ["kappa", "ksz", "tsz", "ir_pts", "rad_pts"]

for i in range(len(SDS_validation)):
    sidx = 128*i
    data[...,sidx: sidx+128] = SDS_validation[i]
    
def sehgal_path(x):
    return os.path.join(sehgal_dir, x)

#enplot.pshow(data[:,:128,:128])

In [None]:
zfact = 1
def log_normalize(data):
    emap = data["emap"] 
    info = data["info"]
    loc = np.where(emap!=0)
    std = np.std(emap[loc])
    std = np.std(emap)

    info["lognorm_std"] = std.copy()
    
    loc = np.where(emap>=0)
    emap[loc] = np.log(emap[loc]/std+1)
    loc = np.where(emap<0)
    emap[loc] = -1*np.log(np.abs(emap[loc]/std)+1)
    data["emap"] = emap
    return data

def meansub(data):
    emap = data["emap"] 
    info = data["info"]
    mean = np.mean(emap)
    info["meansub_mean"] = mean.copy()
    data["emap"] = emap - mean
    return data


def z_normalize(data, zfact = zfact, ignore_zero =False):
    emap = data["emap"] 
    info = data["info"]
    if not ignore_zero:
        std = np.std(emap)
        mean = emap.mean()
    else:
        loc = np.where(emap!=0)
        std = np.std(emap[loc])
        mean = emap[loc].mean()
    info["znorm_mean"] = mean
    info["znorm_std"] = std
    info["znorm_zfact"] = zfact
    data["emap"] = (emap-mean)/(std*zfact)
    return data


def shrink(data):
    emap = data["emap"] 
    info = data["info"]
    factor = np.max(np.abs(np.array([emap.min(), emap.max()])))*1.1
    info["shrink_fact"] = factor
    data["emap"]  = emap/factor
    return data

def minmax(data):
    emap = data["emap"] 
    info = data["info"]
    maxval, minval = emap.max(),emap.min()
    #maxval = np.max(np.abs([maxval, minval]))
    info["minmax_min"] = minval
    info["minmax_max"] = maxval
    valrange = (maxval-minval)
    midval = (maxval+minval)/2
    info["minmax_mean"] = midval
    data["emap"] = (emap-midval)/valrange*2
    return data


freq_idx = 148
ns = {"kappa": lambda x: (z_normalize(meansub(x))),
      "ksz": lambda x: (z_normalize(meansub(x))),
      "ir_pts": lambda x: (z_normalize(log_normalize(x),ignore_zero=True)),
      "rad_pts": lambda x: (z_normalize(log_normalize(x),ignore_zero=True)),
      "tsz": lambda x: (z_normalize(log_normalize(x),ignore_zero=True)),
     }
nbins = 10000

norm_info_validation = {}
compts = ["kappa", "ksz", "tsz", "ir_pts", "rad_pts"]

for i, compt_idx in enumerate(compts[:]):
    #if i < 2: continue
    print(compt_idx)
    storage = {}
    storage["emap"]  = data[i].copy()
    storage["info"]  = {}
    minval, maxval, mean = storage["emap"].min(), storage["emap"].max(), storage["emap"].mean()
    print(minval, maxval, mean)
    MB = stats.FastBINNER(minval, maxval, nbins)
    bin_center, hist = MB.bin(data[i].copy())
    fig = plt.figure(figsize=(10, 5))
    plt.plot(bin_center, hist/np.sum(hist), label=compt_idx)
    plt.legend()
    plt.axvline(x=1, ls="--", color="k")
    plt.axvline(x=-1, ls="--", color="k")
    plt.yscale("log")
    plt.show()
    
    storage = ns[compt_idx](storage)
    norm_info_validation[compt_idx] = storage["info"]
    MB = stats.FastBINNER(-30, 30, nbins)
    print(np.min(storage["emap"]), np.max(storage["emap"]))
    bin_center, hist = MB.bin(storage["emap"])
    fig = plt.figure(figsize=(10, 5))
    plt.plot(bin_center, hist/np.sum(hist), label=compt_idx)
    plt.axvline(x=1, ls="--", color="k")
    plt.axvline(x=-1, ls="--", color="k")
    plt.axhline(y=1e-5, ls="--", color="k")
    plt.legend()
    plt.xlim(-5,5)
    plt.yscale("log")
    plt.show()



In [None]:
for idx in norm_info_validation.keys():  
    print(idx, norm_info_validation[idx])

np.savez(data_path("281220_logz_normalization_info_validation.npz"), **norm_info_validation)

In [None]:
norm_info_file = data_path("281220_logz_normalization_info_validation.npz")
SDN = transforms.SehgalDataNormalizerScaledLogZShrink(norm_info_file)
SDS_test = datasets.SehgalDataSet(sehgal_dir, "test281220_fromcat", transforms=[SDN], dummy_label=False)

nsample = len(SDS_test)
data = np.zeros((5, 128, 128*nsample))
SDS_test
nbins = 10000
for i in range(nsample):
    if i % 5000 == 0: print(i)
    sidx = 128*i
    data[...,sidx: sidx+128] = SDS_test[i]
print(data.min(), data.max(), data.mean())
print("start binning")
MB = stats.FastMultBinner((-15,15), nbins, data.shape[0])
MB.bin(data)
    
ret = MB.get_info()
out = {}
for key in range(5):
    print(key)
    out[SDN.channel_idxes[key]] = ret[key].copy()
ret = out
np.savez(sehgal_path("281220_normalized_histogram_test_{}.npz".format(nbins)), **out)