## Neural Network Method for Black Hole Imaging

### Import libraries and modules

In [1]:
import sys
import csv
import os
import numpy as np
import pandas as pd
import random
from numpy.random import randint
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
from PIL import Image

import tensorflow as tf
#from tf.keras.backend.tensorflow_backend import set_session
from keras.callbacks import ModelCheckpoint
from keras.initializers import RandomUniform, Constant
import keras.models
from keras.models import Sequential
import keras.layers
from keras.layers import Layer, Activation, LeakyReLU
from keras.layers import Input, InputLayer, AveragePooling2D, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Sequential, Model, load_model
from keras.layers.normalization import BatchNormalization
from keras.layers.merge import Concatenate, Add
from keras.layers import Dense, Lambda, Reshape
import keras.initializers
import keras.regularizers
import keras.callbacks
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import backend as K
from keras import losses
from keras.datasets import fashion_mnist
from keras.datasets import mnist

# Import eht imaging package
import ehtim as eh 
import ehtim.const_def as ehc
import ehtim.observing.obs_helpers as obsh
from ehtim.observing.obs_helpers import *
# Import helpers from cosense
#import helpers_posci as hp
# Import utilities for computing data terms, losses, and gradients
#from data_term_functions import *
##from models_posci import IsingVisNet, IsingCpAmpNet, IsingMutipleVisNet, IsingMutipleCpAmpNet, IsingVisFeatureNet, IsingCpAmpFeatureNet
#from losses_posci import site_sparsity, energy, Lambda_similarity, Lambda_angle_diff

Using TensorFlow backend.


Welcome to eht-imaging! v  1.1.1


### Define Imaging Variables

In [2]:
'''Define observation parameters.'''
eht_array='EHT2019'
target='sgrA'

nsamp = 10000
npix = 32 
fov_param = 100.0
flux_label = 1
sefd_param = 1

tint_sec = 5    # integration time in seconds
tadv_sec = 600  # advance time between scans
tstart_hr = 0   # GMST time of the start of the observation
tstop_hr = 24   # GMST time of the end
bw_hz = 4e9     # bandwidth in Hz

stabilize_scan_phase = False # if true then add a single phase error for each scan to act similar to adhoc phasing
stabilize_scan_amp = False # if true then add a single gain error at each scan
jones = False # apply jones matrix for including noise in the measurements (including leakage)
inv_jones = False # no not invert the jones matrix
frcal = True # True if you do not include effects of field rotation
dcal = True # True if you do not include the effects of leakage
dterm_offset = 0 # a random offset of the D terms is given at each site with this standard deviation away from 1
dtermp = 0

array = '/Users/Johanna/Desktop/Proximal Gradient Descent/arrays/' + eht_array + '.txt'
eht = eh.array.load_txt(array)

# Define observation field of view
fov = fov_param * eh.RADPERUAS

# define scientific target
if target == 'm87':
    ra = 12.513728717168174
    dec = 12.39112323919932
elif target == 'sgrA':
    ra = 19.414182210498385
    dec = -29.24170032236311

rf = 230e9
mjd = 57853 # day of observation
fwhm = 1.117609542559987e-10

### Define Helper Functions

In [10]:
''' 
    Prepare and return dataset for RML. Fashion MNIST and Digits MNIST images are blurred by 0.1*fwhm.
    
    -----------------------------------------------------------------------------------------------------
    Parameters:
        -dataset: 'fashion' (fashion MNIST, 'mnist' (MNIST digits), or 'bh_data' (simulated black hole images)
        -flux: sum of pixels per image
    ----------------------------------------------------------------------------------------------------------
    
'''
def get_data(dataset='fashion', flux=1):
    xdata = []
    pad_width = 2
    if (dataset == 'fashion' or dataset == 'all'):
        npix = 32
        (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

        xdata = 1.0*x_train[[k%60000 for k in range(int(nsamp))]]
        xdata = np.pad(xdata, ((0,0), (pad_width,pad_width), (pad_width,pad_width)), 'constant')
        xdata = xdata[..., np.newaxis]/255
        
        xdata = xdata.reshape((-1, npix*npix))
        
        # Blur images by 0.1*fwhm
        xdata_blur = []
        for X in xdata:
            pred_blur = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
            pred_blur.imvec = X.flatten()
            pred_blur = pred_blur.blur_circ(fwhm_i=0.1*fwhm, fwhm_pol=0.1*fwhm)
            xdata_blur.append(pred_blur.imvec)
        xdata = xdata_blur
            
    if (dataset == 'mnist' or dataset == 'all'):
        npix = 32
        (x_train_mnist, y_train_mnist), (x_test_mnist, y_test_mnist) = mnist.load_data()
        
        xdata_train = 1.0*x_train_mnist[[k%60000 for k in range(int(nsamp))]]
        xdata_train = np.pad(xdata_train, ((0,0), (pad_width,pad_width), (pad_width,pad_width)), 'constant')  # get to 160x160
        xdata_train = xdata_train[..., np.newaxis]/255
        
        xdata = xdata_train.reshape((-1, npix*npix))
        
        # Blur images by 0.1*fwhm
        xdata_blur = []
        for X in xdata:
            pred_blur = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
            pred_blur.imvec = X.flatten()
            pred_blur = pred_blur.blur_circ(fwhm_i=0.1*fwhm, fwhm_pol=0.1*fwhm)
            xdata_blur.append(pred_blur.imvec)
        xdata = xdata_blur
        
    if (dataset == 'bh_data'):
        bh_sim_data = np.load('/Users/Johanna/Desktop/Neural Network/bh_sim_data.npy', allow_pickle=True).item()
        bh_data = bh_sim_data['image']
        
        # resize images to 32 x 32 and fov = 100
        bh_data = np.array(bh_data)
        bh_data_reshape = []
        for i in range(len(bh_data)):
            bh_img = eh.image.make_empty(160, 160, ra, dec, rf=rf, source='random', mjd=mjd)
            bh_img.imvec = bh_data[i].flatten()
            bh_img_reshape = bh_img.regrid_image(100, 32)
            bh_data_reshape.append(bh_img_reshape.imvec)
        xdata = np.array(bh_data_reshape).reshape((-1, 32*32))
    
    return xdata

'''
    Generate measurement data for image X using specified parameters. 
    
    -----------------------------------------------------------------------------------------------------
    Parameters:
        -X: target image
        -th_noise: If True, include thermal noise in measurements
        -amp_error: If True, include amplitude error in measurements
        -phase_error: If True, include phase error in measurements
        -gainp: Amount of site-wise standard deviation in gain error to include in measurements
        -gain_offset: Amount of fixed gain error to include in measurements
    -----------------------------------------------------------------------------------------------------
'''
def get_measurements(X, th_noise=False, amp_error=False, phase_error=False, gainp=0.1, gain_offset=0.1):
    # Define noise parameters
    add_th_noise = th_noise # False if you *don't* want to add thermal error. If there are no sefds in obs_orig it will use the sigma for each data point
    phasecal = not phase_error # True if you don't want to add atmospheric phase error. if False then it adds random phases to simulate atmosphere
    ampcal = not amp_error # True if you don't want to add atmospheric amplitude error. if False then add random gain errors 

    simim = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
    simim.imvec = X
    
    # generate the discrete Fourier transform matrices for complex visibilities
    obs = simim.observe(eht, tint_sec, tadv_sec, tstart_hr, tstop_hr, bw_hz, add_th_noise=add_th_noise, ampcal=ampcal, phasecal=phasecal, 
                    stabilize_scan_phase=stabilize_scan_phase, stabilize_scan_amp=stabilize_scan_amp,
                    jones=jones,inv_jones=inv_jones,dcal=dcal, frcal=frcal, dterm_offset=dterm_offset, 
                    gainp=gainp, gain_offset=gain_offset)
    obs_data = obs.unpack(['u', 'v', 'vis', 'sigma'])
    
    uv = np.hstack((obs_data['u'].reshape(-1,1), obs_data['v'].reshape(-1,1)))
    
    # Extract forward model (Discrete Fourier Transform matrix)
    F_vis = ftmatrix(simim.psize, simim.xdim, simim.ydim, uv, pulse=simim.pulse)
    vis = obs_data['vis']
    sigma_vis = obs_data['sigma']
    t1 = obs.data['t1']
    t2 = obs.data['t2']
    
    print("Finished computing visibilities...")
    
    # generate the discrete Fourier transform matrices for closure phases
    obs.add_cphase(count='max')
    # Extract forward models for telescopes 1, 2, and 3
    tc1 = obs.cphase['t1']
    tc2 = obs.cphase['t2']
    tc3 = obs.cphase['t3']
    
    cphase = obs.cphase['cphase']
    sigma_cphase = obs.cphase['sigmacp']
    cphase_map = np.zeros((len(obs.cphase['time']), 3))

    zero_symbol = 10000
    for k1 in range(cphase_map.shape[0]):
        for k2 in list(np.where(obs.data['time']==obs.cphase['time'][k1])[0]):
            if obs.data['t1'][k2] == obs.cphase['t1'][k1] and obs.data['t2'][k2] == obs.cphase['t2'][k1]:
                cphase_map[k1, 0] = k2
                if k2 == 0:
                    cphase_map[k1, 0] = zero_symbol
            elif obs.data['t2'][k2] == obs.cphase['t1'][k1] and obs.data['t1'][k2] == obs.cphase['t2'][k1]:
                cphase_map[k1, 0] = -k2
                if k2 == 0:
                    cphase_map[k1, 0] = -zero_symbol
            elif obs.data['t1'][k2] == obs.cphase['t2'][k1] and obs.data['t2'][k2] == obs.cphase['t3'][k1]:
                cphase_map[k1, 1] = k2
                if k2 == 0:
                    cphase_map[k1, 1] = zero_symbol
            elif obs.data['t2'][k2] == obs.cphase['t2'][k1] and obs.data['t1'][k2] == obs.cphase['t3'][k1]:
                cphase_map[k1, 1] = -k2
                if k2 == 0:
                    cphase_map[k1, 1] = -zero_symbol
            elif obs.data['t1'][k2] == obs.cphase['t3'][k1] and obs.data['t2'][k2] == obs.cphase['t1'][k1]:
                cphase_map[k1, 2] = k2
                if k2 == 0:
                    cphase_map[k1, 2] = zero_symbol
            elif obs.data['t2'][k2] == obs.cphase['t3'][k1] and obs.data['t1'][k2] == obs.cphase['t1'][k1]:
                cphase_map[k1, 2] = -k2
                if k2 == 0:
                    cphase_map[k1, 2] = -zero_symbol

    F_cphase = np.zeros((cphase_map.shape[0], npix*npix, 3), dtype=np.complex64)
    cphase_proj = np.zeros((cphase_map.shape[0], F_vis.shape[0]), dtype=np.float32)
    for k in range(cphase_map.shape[0]):
        for j in range(cphase_map.shape[1]):
            if cphase_map[k][j] > 0:
                if int(cphase_map[k][j]) == zero_symbol:
                    cphase_map[k][j] = 0
                F_cphase[k, :, j] = F_vis[int(cphase_map[k][j]), :]
                cphase_proj[k, int(cphase_map[k][j])] = 1
            else:
                if np.abs(int(cphase_map[k][j])) == zero_symbol:
                    cphase_map[k][j] = 0
                F_cphase[k, :, j] = np.conj(F_vis[int(-cphase_map[k][j]), :])
                cphase_proj[k, int(-cphase_map[k][j])] = -1

    clamparr = obs.c_amplitudes(mode='all', count='max',
                                        vtype='vis', ctype='camp', debias=True, snrcut=0.0)
    

    uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1)))
    uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1)))
    uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1)))
    uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1)))
    camp = clamparr['camp']
    sigma_camp = clamparr['sigmaca']
    
    mask = []
    # shape: (4, 2022, npix**2)
    F_camp = (ftmatrix(simim.psize, simim.xdim, simim.ydim, uv1, pulse=simim.pulse, mask=mask),
          ftmatrix(simim.psize, simim.xdim, simim.ydim, uv2, pulse=simim.pulse, mask=mask),
          ftmatrix(simim.psize, simim.xdim, simim.ydim, uv3, pulse=simim.pulse, mask=mask),
          ftmatrix(simim.psize, simim.xdim, simim.ydim, uv4, pulse=simim.pulse, mask=mask)
          )
    
    return obs, vis, cphase, camp, F_vis, F_cphase, F_camp, sigma_vis, sigma_cphase, sigma_camp, t1, t2

# Split into Training and Testing Sets (Complex Visibilities)
def split_data(X, vis):
    split = int(0.9*len(xdata))
    X_train = vis[:split]
    y_train = xdata[:split].reshape((len(X_train), npix, npix, 1))
    X_test = vis[split:]
    y_test = xdata[split:].reshape((len(X_test), npix, npix, 1))
    
    return X_train, y_train, X_test, y_test

# Post Processing
def post_process(Z, img):
    flux, X_max = np.sum(img), np.max(img)

    # Normalize flux to target image
    Z_flux = np.sum(np.abs(Z))
    Z = (flux/Z_flux)*np.abs(Z)

    # Normalize Z between 0 and X_max
    Z = np.maximum(np.zeros(np.shape(Z)), Z)
    Z = np.minimum(X_max*np.ones(np.shape(Z)), Z)
    
    return Z

# Compute normalized cross-correlation between images X and Z
def compute_xcorr(X, Z):
    target_img = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
    target_img.imvec = X.flatten()
    recon_img = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
    recon_img.imvec = Z.flatten()
    xc = target_img.compare_images(recon_img)[0][0] 
    return xc
    
# Compute MAE loss 
from sklearn.metrics import mean_absolute_error
def compute_mae(X, Z):
    mae = mean_absolute_error(X.flatten(), Z.flatten())
    return mae

# Compute SSIM
from skimage.metrics import structural_similarity as ssim
def compute_ssim(X, Z):
    return ssim(X.flatten(), Z.flatten())
    
''' Show results '''
def visualize(result, obs):
    # Unpack values
    [target, pred, uncertainty, error, vis_chisq, cphase_chisq, camp_chisq, mae, ssim] = result
    
    max_color = 0.01
    if dataset == 'fashion':
        max_color = 0.005
    
    # Compute nominally blurred target image
    fwhm = obs.res()
    target_blur = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
    target_blur.imvec = target.flatten()
    target_blur1 = target_blur.blur_circ(fwhm_i=0.3*fwhm, fwhm_pol=0.3*fwhm)
    target_blur2 = target_blur.blur_circ(fwhm_i=0.7*fwhm, fwhm_pol=0.7*fwhm)
    
    # Show nominally blurred target image
    fig, (ax1, ax2, ax3) = plt.subplots(figsize=(13, 3), ncols=3)

    ground_truth = ax1.imshow((target).reshape(npix, npix), vmin=0, vmax=max_color)
    ax1.title.set_text('Ground Truth')
    fig.colorbar(ground_truth, ax=ax1)
    blur1 = ax2.imshow(target_blur1.imvec.reshape(npix, npix), vmin=0, vmax=max_color)
    ax2.title.set_text('0.3 * fwhm Blurred Truth')
    fig.colorbar(blur1, ax=ax2)
    blur2 = ax3.imshow(target_blur2.imvec.reshape(npix, npix), vmin=0, vmax=max_color)
    ax3.title.set_text('0.7 * fwhm Blurred Truth')
    cbar = fig.colorbar(blur2, ax=ax3)
    cbar.minorticks_on()
    plt.show()
    
    # Compute nominally blurred predicted image
    fwhm = obs.res()
    pred_blur = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
    pred_blur.imvec = pred.flatten()
    pred_blur1 = pred_blur.blur_circ(fwhm_i=0.3*fwhm, fwhm_pol=0.3*fwhm)
    pred_blur2 = pred_blur.blur_circ(fwhm_i=0.7*fwhm, fwhm_pol=0.7*fwhm)
    
    # Visualize blurred results
    fig3, (ax7, ax8, ax9) = plt.subplots(figsize=(13, 3), ncols=3)

    pred_img = ax7.imshow(pred.reshape(npix, npix), vmin=0, vmax=max_color)
    ax7.title.set_text('Predicted Image')
    fig3.colorbar(pred_img, ax=ax7)
    blur1_img = ax8.imshow(pred_blur1.imvec.reshape(npix, npix), vmin=0, vmax=max_color)
    ax8.title.set_text('0.3 * fwhm blur')
    fig3.colorbar(blur1_img, ax=ax8)
    blur2_img = ax9.imshow(pred_blur2.imvec.reshape(npix, npix), vmin=0, vmax=max_color)
    ax9.title.set_text('0.7 * fwhm blur')
    fig3.colorbar(blur2_img, ax=ax9)
    cbar.minorticks_on()
    plt.show()
           
    # Visualize uncertainty
    fig2, (ax4, ax5, ax6) = plt.subplots(figsize=(13, 3), ncols=3)

    pred_img = ax4.imshow(pred.reshape(npix, npix), vmin=0, vmax=max_color)
    ax4.title.set_text('Predicted Image')
    fig2.colorbar(pred_img, ax=ax4)
    uncertainty_img = ax5.imshow(uncertainty.reshape(npix, npix), vmin=0, vmax=max_color)
    ax5.title.set_text('Pixel-Wise Uncertainty')
    fig2.colorbar(uncertainty_img, ax=ax5)
    error_img = ax6.imshow(error.reshape(npix, npix), vmin=0, vmax=max_color)
    ax6.title.set_text('Absolute Error')
    fig2.colorbar(error_img, ax=ax6)
    cbar.minorticks_on()
    plt.show()
    
    
    print("Vis Chi^2 = ", round(vis_chisq, 10))
    #print("Cphase Chi^2 = ", round(cphase_chisq, 4))
    #print("Camp Chi^2 = ", round(camp_chisq, 4))
    print("MAE = ", mae)
    print("SSIM = ", ssim)

    print()

''' Get resolution of reconstructed image: blurred target image versus predicted image. '''
def get_res(recon, target, obs, simim1, simim2):
    # Create target and recon image objects
    target_img = simim1
    target_img.imvec = target.flatten()
    recon_img = simim2
    recon_img.imvec = recon.flatten()
        
    # Compute cross-correlation b/w recon and target images
    recon_xc = target_img.compare_images(recon_img)[0][0]
    
    alphas = np.linspace(0, 50, 1000)
    xcorr = []
    for alpha in alphas:
        target_blur = target_img.blur_circ(alpha*eh.RADPERUAS) # Add gaussian blur in micro-arcsecs
        xc = target_blur.compare_images(recon_img)[0][0] # Get normalized cross-correlation
        xcorr.append(xc)
            
    # Get nominal resolution
    nominal_res = obs.res()
    
    # Get image resolution
    recon_alpha = np.array(xcorr).argmax()*50/1000 # Get index of blurring parameter with highest xcorr
    
    # Show Plot
    plt.figure()
    plt.plot(alphas, xcorr, label="XCorr")
    plt.axvline(x=nominal_res/eh.RADPERUAS, label = "Nominal Resolution (uas)", color='m')
    if recon_alpha is not None:
        plt.axvline(x=recon_alpha, label = "Reconstruction Resolution (uas)", color='g')
    plt.ylim(0.6, 1)
    plt.xlim(0, 50)
    plt.xlabel("Blurring Parameter, alpha (uas)")
    plt.ylabel("Normalized XCorr")
    plt.title("XCorr of Blurred Target/Reconstruction vs. Alpha")
    plt.legend(loc='lower left')
    plt.show()
    
    print("Reconstruction XC = ", recon_xc)
    print("Reconstruction Alpha = ", recon_alpha)
    
    return recon_xc

def compute_snr(X, sigma):
    # Define noise parameters
    add_th_noise = False # False if you *don't* want to add thermal error. If there are no sefds in obs_orig it will use the sigma for each data point
    phasecal = True # True if you don't want to add atmospheric phase error. if False then it adds random phases to simulate atmosphere
    ampcal = True # True if you don't want to add atmospheric amplitude error. if False then add random gain errors 

    simim = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
    simim.imvec = X
    
    # generate the discrete Fourier transform matrices for complex visibilities
    obs = simim.observe(eht, tint_sec, tadv_sec, tstart_hr, tstop_hr, bw_hz, add_th_noise=add_th_noise, ampcal=ampcal, phasecal=phasecal, 
                    stabilize_scan_phase=stabilize_scan_phase, stabilize_scan_amp=stabilize_scan_amp,
                    jones=jones,inv_jones=inv_jones,dcal=dcal, frcal=frcal, dterm_offset=dterm_offset)
    obs_data = obs.unpack(['u', 'v', 'vis', 'sigma'])
    
    vis = obs_data['vis']
    
    snr = np.mean(vis**2) / np.mean(sigma**2)
    
    return snr

def chisq_loss(x_true0, x_pred0):
    F = tf.cast(tf.constant(global_F), tf.complex64)
    S = tf.cast(tf.constant(global_S), tf.float32)
    
    # Flatten and normalize image arrays
    x_true0 = tf.cast(tf.reshape(x_true0, [-1, 1024]), tf.complex64)
    x_pred0 = tf.cast(tf.reshape(x_pred0, [-1, 1024]), tf.complex64)
    x_true0 = tf.math.divide(x_true0, [500])#tf.transpose(tf.cast(tf.reduce_sum(x_true0, 1), tf.complex64)))
    x_pred0 = tf.math.divide(x_pred0, [500])#tf.transpose(tf.cast(tf.reduce_sum(x_pred0, 1), tf.complex64)))

    # Compute visibilities
    vis_true = tf.matmul(x_true0, tf.transpose(F))
    vis_pred = tf.matmul(x_pred0, tf.transpose(F))

    # numerator
    num = tf.reduce_sum(tf.square(tf.divide(tf.abs(tf.subtract(vis_pred, vis_true)), S[:,0])))
    
    chisq = tf.divide(num, tf.cast(tf.multiply(2, tf.size(vis_true[0])), tf.float32))
    
    return chisq

def plotall(obs, field1, field2, conj=False, debias=True, tag_bl=False, 
            ang_unit='deg', timetype=False,  axis=False, rangex=False, 
            rangey=False, snrcut=0., color='b', marker='o', 
            markersize=ehc.MARKERSIZE, label=None, grid=True, ebar=True, 
            axislabels=True, legend=False, show=False):
    bllist = [['All', 'All']]
    colors = ehc.SCOLORS[0]

    # unpack data
    alldata = [obs.unpack([field1, field2],
                           conj=conj, ang_unit=ang_unit, debias=debias, timetype=timetype)]

    # X error bars
    if obsh.sigtype(field1):
        allsigx = obs.unpack(obsh.sigtype(field2), conj=conj, ang_unit=ang_unit)
        allsigx = [allsigx[obsh.sigtype(field1)]]
    else:
        allsigx = [None]

    # Y error bars
    if obsh.sigtype(field2):
        allsigy = obs.unpack(obsh.sigtype(field2), conj=conj, ang_unit=ang_unit)
        allsigy = [allsigy[obsh.sigtype(field2)]]
        
    else:
        allsigy = [None]
            
    # make plot(s)
    if axis:
        x = axis
    else:
        fig = plt.figure()
        x = fig.add_subplot(1, 1, 1)

    xmins = []
    xmaxes = []
    ymins = []
    ymaxes = []
    for i in range(len(alldata)):
        data = alldata[i]
        sigy = allsigy[i]
        sigx = allsigx[i]
        color = colors[i]
        bl = bllist[i]

        # Flag out nans (to avoid problems determining plotting limits)
        mask = ~(np.isnan(data[field1]) + np.isnan(data[field2]))

        # Flag out due to snrcut
        if snrcut > 0.:
            sigs = [sigx, sigy]
            for jj, field in enumerate([field1, field2]):
                if field in ehc.FIELDS_AMPS:
                    fmask = data[field] / sigs[jj] > snrcut
                elif field in ehc.FIELDS_PHASE:
                    fmask = sigs[jj] < (180. / np.pi / snrcut)
                elif field in ehc.FIELDS_SNRS:
                    fmask = data[field] > snrcut
                else:
                    fmask = np.ones(mask.shape).astype(bool)
                mask *= fmask

        data = data[mask]
        if sigy is not None:
            sigy = sigy[mask]
        if sigx is not None:
            sigx = sigx[mask]
        if len(data) == 0:
            continue

        xmins.append(np.min(data[field1]))
        xmaxes.append(np.max(data[field1]))
        ymins.append(np.min(data[field2]))
        ymaxes.append(np.max(data[field2]))
        
        tolerance = len(data[field2])
        
        if ebar and (np.any(sigy) or np.any(sigx)):
            print("sigx = ", sigx)
            print("sigy = ", sigy)
            x.errorbar(data[field1], data[field2], xerr=sigx, yerr=sigy, label='',
                        fmt=marker, markersize=markersize, picker=tolerance)
        else:
            x.plot(data[field1], data[field2], marker, markersize=markersize, picker=tolerance)

                
        # Plot the data
        tolerance = len(data[field2])

        x.plot(data[field1], data[field2], marker, markersize=markersize)
        
    # Data ranges
    if not rangex:
        rangex = [np.min(xmins) - 0.2 * np.abs(np.min(xmins)),
                  np.max(xmaxes) + 0.2 * np.abs(np.max(xmaxes))]
        if np.any(np.isnan(np.array(rangex))):
            print("Warning: NaN in data x range: specifying rangex to default")
            rangex = [-100, 100]

    if not rangey:
        rangey = [np.min(ymins) - 0.2 * np.abs(np.min(ymins)),
                  np.max(ymaxes) + 0.2 * np.abs(np.max(ymaxes))]
        if np.any(np.isnan(np.array(rangey))):
            print("Warning: NaN in data y range: specifying rangey to default")
            rangey = [-100, 100]

    x.set_xlim(rangex)
    x.set_ylim(rangey)
    x.set_xlabel(field1)
    x.set_ylabel(field2)
    x.set_title(field1 + " versus " + field2)
    x.grid
    #plt.legend()
    
    #if show:
        #plt.show(block=False)
    
    return x, data[field1], data[field2]
    
''' Predict with Model and Compute Metrics '''
def predict_and_run_test(model, target, ALPHA, th_noise, amp_err, phase_err, gainp=0.1, gain_offset=0.1, th_noise_factor=0, blur_param=0, savefile=None, pred_img=None):
    # Normalize fluxes
    pred_img = pred_img/np.sum(pred_img)
    target /= np.sum(target)
        
    if th_noise_factor == 0:
        th_noise = False
       
    obs, visibility, cphase, camp, F_vis, F_cphase, F_camp, sigma_vis, sigma_cphase, sigma_camp, t1, t2 = get_measurements(target.flatten(), th_noise, amp_err, phase_err, gainp, gain_offset)  # data terms
    
    # Blur target to 0.3*fwhm
    fwhm = obs.res()
    target_blur = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
    target_blur.imvec = target.flatten()
    target_blur = target_blur.blur_circ(fwhm_i=0.3*fwhm, fwhm_pol=0.3*fwhm)
    target = target_blur.imvec
        
    obs, visibility, cphase, camp, F_vis, F_cphase, F_camp, sigma_vis, sigma_cphase, sigma_camp, t1, t2 = get_measurements(target.flatten(), th_noise, amp_err, phase_err, gainp, gain_offset)  # data terms
    
    if savefile is not None:
        obs.save_txt(savefile+'_obs')
    
    # add additional layers  of thermal noise
    if th_noise_factor > 0:
        noise_arr = [th_noise_factor*np.random.normal(0, sigma) for sigma in sigma_vis]
        visibility += noise_arr
        sigma_vis *= th_noise_factor
        print("Adding additional thermal noise.... mean = ", (th_noise_factor + 1)*np.mean(np.abs(noise_arr)))
    
    print("SNR = ", compute_snr(target, sigma_vis))
    
    # Pre-process data for UNet
    #target *= ALPHA
    visibility *= ALPHA
    sigma_vis *= ALPHA

    recon_imgs = []
    num_tests = 16
    for n in range(num_tests):   
        img = model.predict(visibility.reshape((-1, visibility.shape[0])))
        recon_imgs.append(img)

    imgs = np.array(recon_imgs)
    pred = np.mean(imgs, axis=0)

    # Post-process
    visibility /= ALPHA
    imgs /= ALPHA
    pred /= ALPHA

    #pred = post_process(pred, target)
    uncertainty = np.sqrt(2*(np.mean(imgs ** 2, axis=0)) + (np.std(imgs) ** 2))
    std = np.sqrt(np.var(imgs, axis=0))
    error = np.subtract(pred.flatten(), target.flatten()) 

    print("Total Uncertainty: ", np.sum(uncertainty))
    print("Total STD: ", np.sum(std))
    print("Total Error: ", np.sum(error))

    print("PRED: ", np.sum(pred))
    print("TARGET: ", np.sum(target))

    #obs2, visibility, cphase, camp, F_vis, F_cphase, F_camp, sigma_vis, sigma_cphase, sigma_camp = get_measurements(pred.flatten(), th_noise, amp_err, phase_err, gainp, gain_offset)  # data terms
    #plotall(obs2, 'uvdist', 'amp')
    #plotall(obs2, 'uvdist', 'phase')

    # Compute chi^2 values with observations
    pred_img = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
    pred_img.imvec = (pred).flatten()
    vis_chisq = obs.chisq(pred_img,dtype='vis') 
    cphase_chisq = obs.chisq(pred_img,dtype='cphase') 
    camp_chisq = obs.chisq(pred_img, dtype='camp')

    if savefile is not None:
        target_blur.save_fits(savefile+'_img')

    # Compute cross-correlation
    recon_xc = compute_xcorr(target, pred/ALPHA)

    # Visualize results
    result = [target, pred, std, error, vis_chisq, cphase_chisq, camp_chisq, recon_xc]
    visualize(result, obs)

    return vis_chisq, cphase_chisq, camp_chisq, recon_xc

''' Compute Metrics with Pre-Computed Image '''
def run_test(target, pred, fwhm, th_noise=False, amp_err=False, phase_err=False, gainp=0.1, gain_offset=0.1, blur_param=0.0):
    # Normalize fluxes for EHTIM functions
    pred /= np.sum(pred)
    target /= np.sum(target)
    
    print("Flux of Target = ", np.sum(target))
    print("Flux of Pred = ", np.sum(pred))
    
    nonblur_target = target
    
    # Blur target by blur_param
    target_blur = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
    target_blur.imvec = target.flatten()
    target_blur = target_blur.blur_circ(fwhm_i=blur_param*fwhm, fwhm_pol=blur_param*fwhm)
    target = target_blur.imvec
    
    obs, visibility, cphase, camp, F_vis, F_cphase, F_camp, sigma_vis, sigma_cphase, sigma_camp, t1, t2 = get_measurements(target.flatten(), th_noise, amp_err, phase_err, gainp, gain_offset)  # data terms
    
    # Normalize sigma
    sigma_vis /= np.sum(sigma_vis)
    
    # Apply blur to predicted image
    fwhm = obs.res()
    pred_unblur = pred
    pred_blur = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
    pred_blur.imvec = pred.flatten()
    pred_blur = pred_blur.blur_circ(fwhm_i=blur_param*fwhm, fwhm_pol=blur_param*fwhm)
    pred = pred_blur.imvec
    
    obs_pred, visibility_pred, cphase, camp, F_vis2, F_cphase2, F_camp2, sigma_vis2, sigma_cphase, sigma_camp, t1, t2 = get_measurements(pred.flatten(), th_noise, amp_err, phase_err, gainp, gain_offset)  # data terms
    
    plt.figure()
    x1, uv1, amp1 = plotall(obs, 'uvdist', 'amp', show=False, color='green')
    x2, uv2, amp2 = plotall(obs_pred, 'uvdist', 'amp', show=False, color='blue')
    plt.plot(uv1, amp1,'o', color='green', markersize=6, label='True Amp.')
    plt.plot(uv2, amp2,'o', color='blue', markersize=6, label='Pred Amp.')
    plt.legend()
    plt.show()
    
    plt.figure()
    x1, uv1, amp1 = plotall(obs, 'uvdist', 'phase', show=False, color='green')
    x2, uv2, amp2 = plotall(obs_pred, 'uvdist', 'phase', show=False, color='blue')
    plt.plot(uv1, amp1, 'o', color='green', markersize=6, label='True Phase')
    plt.plot(uv2, amp2, 'o', color='blue', markersize=6, label='Pred Phase')
    plt.legend()
    plt.show()
    
    # Compute chi^2 values with observations
    pred_img = eh.image.make_empty(npix, fov, ra, dec, rf=rf, source='random', mjd=mjd)
    pred_img.imvec = pred.flatten()
    vis_chisq = obs.chisq(pred_img,dtype='vis') 
    #vis_chisq = np.mean(np.abs((visibility_pred-visibility)/sigma_vis)**2)/(2*len(visibility))
    cphase_chisq = obs.chisq(pred_img,dtype='cphase') 
    camp_chisq = obs.chisq(pred_img, dtype='camp')

    # Compute cross-correlation, MAE, and error
    recon_xc = compute_xcorr(target, pred)
    recon_mae = compute_mae(target, pred)
    recon_ssim = compute_ssim(target, pred)
    error = np.subtract(pred.flatten(), target.flatten())

    # Visualize results
    result = [nonblur_target, pred_unblur, pred, pred, vis_chisq, cphase_chisq, camp_chisq, recon_mae, recon_ssim]
    visualize(result, obs)

    return vis_chisq, cphase_chisq, camp_chisq, recon_mae, recon_ssim


In [4]:
''' Plot training and validation loss from .txt file (terminal output) '''
from itertools import groupby 
import re  
from scipy.optimize import curve_fit
    
''' Plot training and validation loss from .txt file (terminal output) '''
def plot_loss(filename, metric=''):
    print(filename)
    
    history_train = {'total': [], 'xc': [], 'chisq': []}
    history_valid = {'total': [], 'xc': [], 'chisq': []}
    history_fashion = {'total': [], 'xc': [], 'chisq': []}
    history_mnist = {'total': [], 'xc': [], 'chisq': []}
    history_bh = {'total': [], 'xc': [], 'chisq': []}
    
    text = []
    i = 0
    with open(filename) as f:
        i = 0
        for line in f:
            if 'fashion_data' in line:
                try:
                    [mae, chisq] = re.findall("\\d+\\.*\\d*", line)
                    history_fashion['xc'].append(float(mae))
                    history_fashion['chisq'].append(float(chisq)/(2*1691))
                except: 
                    print("Error line: ", line)
                    continue
            elif 'mnist_data' in line:
                try:
                    [mae, chisq] = re.findall("\\d+\\.*\\d*", line)
                    history_mnist['xc'].append(float(mae))
                    history_mnist['chisq'].append(float(chisq)/(2*1691))
                except: 
                    print("Error line: ", line)
                    continue
            elif 'bh_data' in line:
                try:
                    [mae, chisq] = re.findall("\\d+\\.*\\d*", line)
                    history_bh['xc'].append(float(mae))
                    history_bh['chisq'].append(float(chisq)/(2*1691))
                except:
                    print("Error line: ", line)
                    continue
            elif 'loss' in line:
                #print(line)
                text.append(line.strip()+ " - ")

    split_lists = np.array([line.split(' - ') for line in text])
    split1 = []
    for arr in split_lists:
        for x in arr:
            split1.append(x)
    #split1 = np.array([itm for itm in split1 if itm != []])
    
    for item in split1:
        try :
            val = re.findall("\\d+\\.*\\d*", item)
        except:
            continue
        if len(val) == 0:
            continue
        val = float(val[0])
        if 'val_xc_mnist_loss' in item:
            history_mnist['xc'].append(val)
        elif 'val_pred_vis_mnist_loss' in item:
            history_mnist['chisq'].append(val)
        elif 'val_xc_bh_loss' in item:
            history_bh['xc'].append(val)
        elif 'val_pred_vis_bh_loss' in item:
            history_bh['chisq'].append(val)
        elif 'val_xc_loss' in item:
            history_valid['xc'].append(val)
        elif 'val_pred_vis_loss' in item:
            history_valid['chisq'].append(val)
        elif 'pred_vis_loss' in item:
            history_train['chisq'].append(val)
        elif 'xc_loss' in item:
            history_train['xc'].append(val)
    
    # Fit validation curve
    def func(x, a, b, c):
        return a * x **b + c
    
    cutoff = 2
    xpoints = range(len(history_valid['xc'][cutoff:]))
    ypoints = history_valid['xc'][cutoff:]
    popt, pcov = curve_fit(func, xpoints, np.log10(ypoints))
    
    # Compute total loss
    train_total_loss = [(history_train['xc'][i] + history_train['chisq'][i]) for i in range(len(history_train['xc']))]             
    train_total_loss = [elem for elem in train_total_loss if elem < 10]
    valid_total_loss = [(history_valid['xc'][i] + history_valid['chisq'][i]) for i in range(len(history_valid['xc']))]   
    valid_total_loss = [elem for elem in valid_total_loss if elem < 10]
    fashion_total_loss = [(history_fashion['xc'][i] + history_fashion['chisq'][i]) for i in range(len(history_fashion['xc']))]   
   # fashion_total_loss = [elem for elem in fashion_total_loss if elem < 10]
    mnist_total_loss = [(history_mnist['xc'][i] + history_mnist['chisq'][i]) for i in range(len(history_mnist['xc']))]   
   # mnist_total_loss = [elem for elem in mnist_total_loss if elem < 10]
    bh_total_loss = [(history_bh['xc'][i] + history_bh['chisq'][i]) for i in range(len(history_bh['xc']))]   
   # bh_total_loss = [elem for elem in bh_total_loss if elem < 10]
    
    #history_train['xc'] = [x[0] for x in groupby(history_train['xc'])]
    #history_train['chisq'] = [x[0] for x in groupby(history_train['chisq'])]
    #train_total_loss = [x[0] for x in groupby(train_total_loss)]
    
    # Plot total loss
    plt.figure()
    #plt.plot(np.log10(train_total_loss[cutoff:]), label='Train Loss ')
    #plt.plot(np.log10(valid_total_loss[cutoff:]), label='Valid Loss ')
    #plt.plot(xpoints, func(xpoints, *popt), label='Validation Fit')
    plt.plot(np.log10(fashion_total_loss[cutoff:]), label='Fashion Loss ')
    plt.plot(np.log10(mnist_total_loss[cutoff:]), label='MNIST Loss ')
    plt.plot(np.log10(bh_total_loss[cutoff:]), label='Black Hole Loss ')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Log(Total Loss)')
    plt.title('Log(Total Loss)'+' vs. Epoch')
    plt.show()
    
    # Plot MAE
    plt.figure()
    #plt.plot(np.log10(history_train['xc'][cutoff:]), label='Train MAE ')
    #plt.plot(np.log10(history_valid['xc'][cutoff:]), label='Valid MAE ')
    #plt.plot(xpoints, func(xpoints, *popt), label='Validation Fit')
    plt.plot(np.log10(history_fashion['xc'][:]), label='Fashion MAE '+metric)
    plt.plot(np.log10(history_mnist['xc'][:]), label='MNIST MAE '+metric)
    plt.plot(np.log10(history_bh['xc'][:]), label='BH MAE '+metric)
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Log(MAE)')
    plt.title('Log(MAE)'+' vs. Epoch')
    plt.show()
    
    min_loss = min(history_valid['xc'])
    print("Min Validation MAE = ", min_loss, " at epoch ", list(history_valid['xc']).index(min_loss))
    min_loss = min(history_mnist['xc'])
    print("Min MNIST MAE = ", min_loss, " at epoch ", list(history_mnist['xc']).index(min_loss))
    min_loss = min(history_bh['xc'])
    print("Min BH MAE = ", min_loss, " at epoch ", list(history_bh['xc']).index(min_loss))
    
    xpoints = range(len(history_valid['chisq'][cutoff:]))
    ypoints = history_valid['chisq'][cutoff:]
    popt, pcov = curve_fit(func, xpoints, np.log10(ypoints))
    
    plt.figure()
    #plt.plot(np.log10(history_train['chisq'][cutoff:]), label='Train Chi^2 ')
    #plt.plot(np.log10(history_valid['chisq'][cutoff:]), label='Valid Chi^2 ')
    #plt.plot(xpoints, func(xpoints, *popt), label='Validation Fit')
    plt.plot(np.log10(history_fashion['chisq'][:]), label='Fashion Chi^2 '+metric)
    plt.plot(np.log10(history_mnist['chisq'][:]), label='MNIST Chi^2 '+metric)
    plt.plot(np.log10(history_bh['chisq'][:]), label='BH Chi^2 '+metric)
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Log(Chi^2)')
    plt.title('Log(Chi^2)'+' vs. Epoch')
    plt.show()
    
    min_loss = min(history_valid['chisq'])
    print("Min Validation Chi^2 = ", min_loss, " at epoch ", list(history_valid['chisq']).index(min_loss))
    min_loss = min(history_mnist['chisq'])
    print("Min MNIST Chi^2 = ", min_loss, " at epoch ", list(history_mnist['chisq']).index(min_loss))
    min_loss = min(history_bh['chisq'])
    print("Min BH Chi^2 = ", min_loss, " at epoch ", list(history_bh['chisq']).index(min_loss))
    
    #plt.plot(np.log10(history_mnist['chisq'][1:]), label='MNIST CHI^2 '+metric)

### Load dataset and select target image

In [8]:
# Load dataset
print('Loading dataset...')
dataset = 'bh_data'
xdata = get_data(dataset)

# Select target image at index 0
X = xdata[0]

print('Done.')

Loading dataset...
Done.


### Generate observations from target image

In [9]:
obs, visibility, cphase, camp, F_vis, F_cphase, F_camp, sigma_vis, sigma_cphase, sigma_camp, t1, t2 = get_measurements(X.flatten())  # data terms
fwhm = obs.res()
print("Done.")

Generating empty observation file . . . 
Producing clean visibilities from image with nfft FT . . . 
Adding gain + phase errors to data and applying a priori calibration . . . 
Finished computing visibilities...
Getting bispectra:: type vis, count max, scan 106/106 

Updated self.cphase: no averaging
updated self.cphase: avg_time 0.000000 s

Getting closure amps:: type vis camp , count max, scan 106/106

Done.


In [None]:
# Get test sample from dataset
pred_img_path = '/path/to/predicted/image'
pred = np.load(pred_img_path)
vis_chisq, cphase_chisq, camp_chisq, mae, ssim = run_test(target, pred, fwhm, blur_param=0.5)

print('Visibility Chi^2 = ', vis_chisq)
print('Closure Phase Chi^2 = ', cphase_chisq)
print('MAE = ', mae)
print('SSIM = ', ssim)