# Evaluating Tikhonet Trained

In this Notebook we are going to evaluate the performance of a [Tikhonet](https://arxiv.org/pdf/1911.00443.pdf) trained.

## Required Libraries and Functions

In [2]:
%matplotlib inline
import sys

# Add library path to PYTHONPATH

lib_path = '/home/ShapeNetL1/'
path_alphatransform = lib_path+'alpha-transform'
path_score = lib_path+'score'
path_SUNet = "/home/SUNet/"

sys.path.insert(0, path_alphatransform)
sys.path.insert(0, path_score)
sys.path.insert(0, path_SUNet)

data_path = '/home/data_dir_optical/'

model_dir = '/home/models/'

import sys
sys.path.insert(0, lib_path)

# Libraries
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import fft
import cadmos_lib as cl
#force cpu-only to make sure video-mem is enough
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import tensorflow as tf
import galsim
from galsim import Image
import galsim.hsm
import pickle



## Load The Comparison Batch

In [2]:
f = open(data_path+"cfht_batch.pkl", "rb")
batch = pickle.load(f)
f.close()
tf.set_random_seed(2)




## Load and Apply Trained Model on Batch

In [3]:
model_name = 'L1_GAMMA0.0078125'

model_g05_name = 'L2_ShapeNet'

model = tf.keras.models.load_model(model_dir+model_name, compile=False)
model_g05 = tf.keras.models.load_model(model_dir+model_g05_name, compile=False)
res = model(np.expand_dims(batch['inputs_tikho'], axis=-1))
res_np = tf.keras.backend.eval(res)[...,0]
res_g05 = model_g05(np.expand_dims(batch['inputs_tikho'], axis=-1))
res_g05_np = tf.keras.backend.eval(res_g05)[...,0]

SUNet = np.load(path_SUNet+"SUNet.npy")

# generate the psfs in the spatial domain
psf_hst = np.fft.ifftshift(np.fft.irfft2(batch['psf_hst'][0]))
psf_tile_cfht = np.array([np.fft.ifftshift(np.fft.irfft2(p)) for p in batch['psf_cfht']])
# make psf tiles
psf_tile_hst = np.repeat(psf_hst[np.newaxis, :, :], batch['psf_hst'].shape[0], axis=0)
# psf_tile_cfht = np.repeat(psf_cfht[np.newaxis, :, :], k_batch*n_batch, axis=0)

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
If using Keras pass *_constraint arguments to layers.


2025-11-30 13:18:36.735692: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1
2025-11-30 13:18:36.785365: E tensorflow/stream_executor/cuda/cuda_driver.cc:318] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2025-11-30 13:18:36.785421: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: fedora
2025-11-30 13:18:36.785430: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: fedora
2025-11-30 13:18:36.785578: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 575.64.5
2025-11-30 13:18:36.785611: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: Not found: could not find kernel module information in driver version file contents: "NVRM version: NVIDIA UNIX Open Kernel Module for x86_64  575.64.05  Release Build  (dvs-builder@U22-A23-13-1)  Fri Jul 18 16

## Processing and Analyzing Results

### Define Error Metrics

In [4]:
im_size = 64
scale = 0.1

def EllipticalGaussian(e1, e2, sig, xc=im_size//2, yc=im_size//2, stamp_size=(im_size,im_size)):
    # compute centered grid
    ranges = np.array([np.arange(i) for i in stamp_size])
    x = np.outer(ranges[0] - xc, np.ones(stamp_size[1]))
    y = np.outer(np.ones(stamp_size[0]),ranges[1] - yc)
    # shift it to match centroid
    xx = (1-e1/2)*x - e2/2*y
    yy = (1+e1/2)*y - e2/2*x
    # compute elliptical gaussian
    return np.exp(-(xx ** 2 + yy ** 2) / (2 * sig ** 2))

def relative_mse(solution, ground_truth):
    relative_mse = ((solution-ground_truth)**2).mean()/ \
                         (ground_truth**2).mean()
    return relative_mse



def get_KSB_ell(image,psf):
    error_flag = True
    #create a galsim version of the data
    image_galsim = Image(image,scale=scale)
    psf_galsim = Image(psf,scale=scale)
    #estimate the moments of the observation image
    ell=galsim.hsm.EstimateShear(image_galsim
                                 ,psf_galsim,shear_est='KSB'
                                 ,guess_centroid=galsim.PositionD(im_size//2,im_size//2)
                                 ,strict=False)
    if ell.error_message != '':
        error_flag = False
    return ell#,error_flag

def get_KSB_g(images,psfs):
    g_list,error_flag_list=[],[]
    for image,psf in zip(images,psfs):
        error_flag = True
        #create a galsim version of the data
        image_galsim = galsim.Image(image,scale=scale)
        # CHECK ADAPTIVE MOMENTS
        psf_galsim = galsim.Image(psf,scale=scale)
        #estimate the moments of the observation image
        shape = galsim.hsm.EstimateShear(image_galsim
                                         ,psf_galsim,shear_est='KSB'
                                         ,guess_centroid=galsim.PositionD(im_size//2,im_size//2)
                                         ,strict=False)
        g = np.array([shape.corrected_g1, shape.corrected_g2])
#        g = np.array([shape.observed_shape.g1, shape.observed_shape.g2])
        if shape.error_message:# or np.linalg.norm(shape.corrected_g1+shape.corrected_g2*1j)>1:
            error_flag = False
        error_flag_list += [error_flag]
        g_list += [g]
    return np.array(g_list).T,np.array(error_flag_list)

def get_moments(images, bool_window=False):
    g_list,error_flag_list=[],[]
    if bool_window:
        window_list = []
        window_flag_list = []
    for image in images:
        error_flag = True
        #create a galsim version of the data
        image_galsim = galsim.Image(image,scale=scale)
        #estimate the moments of the observation image
        shape = galsim.hsm.FindAdaptiveMom(image_galsim
                                         ,guess_centroid=galsim.PositionD(im_size//2,im_size//2)
                                         ,strict=False)
        if bool_window:
            k_sigma = 1.2 #scale up the size of the Gaussian window to make it able to capture more useful signal
            window = EllipticalGaussian(-1.*shape.observed_shape.e1, shape.observed_shape.e2 #convention fix:
                                                                                             #e1 sign swap
                                 ,shape.moments_sigma*k_sigma # convention fix: swap x and y and origin at (0,0)
                                 ,shape.moments_centroid.y-1, shape.moments_centroid.x-1
                                 ,image.shape)
            window_flag = bool(shape.moments_status+1)
        g = np.array([shape.observed_shape.g1, shape.observed_shape.g2])
        if shape.error_message:# or np.linalg.norm(shape.corrected_g1+shape.corrected_g2*1j)>1:
            error_flag = False
        error_flag_list += [error_flag]
        g_list += [g]
        if bool_window:
            window_list += [window]
            window_flag_list += [window_flag]
    output = [np.array(g_list).T,np.array(error_flag_list)]
    if bool_window:
        output += [np.array([window_list])[0],np.array([window_flag_list])[0]]
    return output

def g_to_e(g1,g2):
    shear = galsim.Shear(g1=g1,g2=g2)
    ell = -shear.e1, shear.e2 #reverse the signe of e_1 to get our conventions
    return ell

def MSE(X1,X2,norm=False):
    #Computes the relative MSE
    temp = 1
    if norm:
        temp = np.mean(X2**2)
    return np.mean((X1-X2)**2)/temp

def MSE_obj(obj1,obj2,norm=False):
    return np.array([MSE(o1,o2,norm) for o1,o2 in zip(obj1,obj2)])

### Estimate Adaptive Moments

In [5]:
######################################################
# Check the type
print(type(res_np))

# Check the shape (dimensions of each axis)
print("Shape:", res_np.shape)

# Check the number of dimensions
print("Number of dimensions:", res_np.ndim)

# Check the total number of elements
print("Size:", res_np.size)

# Check the data type of elements
print("Data type:", res_np.dtype)
#$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$#

# Check the type
print(type(SUNet))

# Check the shape (dimensions of each axis)
print("Shape:", SUNet.shape)

# Check the number of dimensions
print("Number of dimensions:", SUNet.ndim)

# Check the total number of elements
print("Size:", SUNet.size)

# Check the data type of elements
print("Data type:", SUNet.dtype)
######################################################

# estimate adaptive moments
mom_g0,_ = get_moments(res_np)
mom_g05,_ = get_moments(res_g05_np)
mom_SUNet,_ = get_moments(SUNet)
mom_hst,_,windows, window_flags = get_moments(batch['targets'],bool_window=True)

# estimate flux
flux_g0 = np.array([gal.sum() for gal in res_np]).T
flux_g05 = np.array([gal.sum() for gal in res_g05_np]).T
flux_SUNet = np.array([gal.sum() for gal in SUNet]).T
flux_true = np.array([gal.sum()  for gal in batch['targets']]).T

<class 'numpy.ndarray'>
Shape: (768, 64, 64)
Number of dimensions: 3
Size: 3145728
Data type: float32
<class 'numpy.ndarray'>
Shape: (768, 64, 64)
Number of dimensions: 3
Size: 3145728
Data type: float32


### Estimate Moments and Absolute Pixel Errors

In [6]:
# compute relative pixel errors
mse_g0 = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], res_np)])
mse_g05 = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], res_g05_np)])
mse_SUNet = np.array([relative_mse(est,true) for true,est in zip(batch['targets'], SUNet)])

# compute winodwed pixel relative errors
mse_g0_w = np.array([relative_mse(est*w,true*w) for true,est,w in zip(batch['targets'], res_np,windows)])
mse_g05_w = np.array([relative_mse(est*w,true*w) for true,est,w in zip(batch['targets'], res_g05_np,windows)])
mse_SUNet_w = np.array([relative_mse(est*w,true*w) for true,est,w in zip(batch['targets'], SUNet,windows)])

# compute adapative moments errors
mom_err_g0 = mom_g0-mom_hst
mom_err_g05 = mom_g05-mom_hst
mom_err_SUNet = mom_SUNet-mom_hst

#compute flux relative errors
flux_err_g0 = np.abs(flux_g0 - flux_true) / flux_true
flux_err_g05 = np.abs(flux_g05 - flux_true) /flux_true
flux_err_SUNet = np.abs(flux_SUNet - flux_true) /flux_true

## Save measurements

In [7]:
flux = [flux_g0, flux_g05, flux_SUNet]
mse = [mse_g0, mse_g05, mse_SUNet]
mse_w = [mse_g0_w, mse_g05_w, mse_SUNet_w]
mom = [mom_g0, mom_g05, mom_SUNet]
measures = [flux, mse, mse_w, mom]
measure_names = ['flux', 'mse', 'mse_w', 'mom']
methods = ['tikhonet', 'tikhonet_sc', 'SUNet']

data = {}

# fill dictionnary
for i, measure in enumerate(measures):
    data[measure_names[i]] = {}
    for j, method in enumerate(methods):
        data[measure_names[i]][method] = measure[j] 

# add remaining keys
data['windows'] = windows
data['window_flags'] = window_flags
data['flux']['true'] = flux_true
data['mom']['true'] = mom_hst
data['mag_auto'] = batch['mag_auto']

# save dictionnary
f = open(data_path+"cfht_data.pkl","wb")
pickle.dump(data,f)
f.close()

### Compute Errors per Bin

In [8]:
label_s0 = r'Sparsity'
label_s1 = r'SCORE'
label_g0 = r'Tikhonet'
label_g05 = r'Tikhonet + MW'

color_g0 = 'green'
color_g05 = 'darkgreen'
color_s0 = 'blue'
color_s1 = 'darkblue'

In [9]:
print(measures)

[[array([ 0.9214536 ,  1.180429  ,  0.82571095, 40.56218   ,  2.420916  ,
       29.70805   ,  0.7486948 , 10.581886  ,  1.0386671 ,  0.83746266,
        0.5231446 ,  0.3701477 ,  0.88515174,  0.7774024 ,  1.7353473 ,
       29.189255  ,  1.7650466 ,  0.97815526,  3.4522104 ,  0.636476  ,
        0.8509892 ,  0.7964127 ,  0.69861054,  1.3720124 ,  0.50760615,
        0.5777789 ,  0.4382362 ,  0.51910436,  9.911483  ,  2.3077738 ,
        0.4064166 ,  0.5494172 ,  4.700004  ,  1.3723104 ,  1.2430788 ,
        1.5572551 ,  0.5849413 ,  0.44633394,  1.8154664 ,  0.6120109 ,
        1.3879569 ,  0.58655185,  8.248551  ,  0.8075146 ,  0.51262385,
        0.81237054,  0.532764  ,  0.73359334,  6.1258745 ,  0.4257208 ,
        0.7761417 ,  0.5152483 ,  6.739272  ,  0.46985424,  0.33432794,
        0.37586674,  0.70795333,  1.3501726 ,  6.749208  ,  0.43246943,
        2.2235112 ,  2.196278  ,  0.7120504 ,  0.6324666 ,  0.68746275,
        0.72098064,  3.6635997 ,  2.5741625 ,  1.2545608 ,  2.