In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import numpy as np
import glob
import torch
import torch.nn as nn
from patchify import patchify, unpatchify
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import trange


from csfm.dfc import dfc as f_c_network
from det_head.CNOModule import CNO
# from det_head.upernetconvnext import uperconvnext
# import segmentation_models_pytorch as smp



from config_dfc_unet_eval import config
from plotting_tools import *
import gc
import pickle

In [3]:

tex_fonts = {
    # Use LaTeX to write all text
    "text.usetex": False,
    "font.family": "sans-serif",
    # "font.sans-serif": "Helvetica",
    # Use 10pt font in plots, to match 10pt font in document
    "axes.labelsize": 9,
    "font.size": 9,
    # Make the legend/label fonts a little smaller
    "legend.fontsize": 8,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8
}

plt.rcParams.update(tex_fonts)


In [None]:
# some utils, will use these from the utils python files later
def load_checkpoint(checkpoint, dfc, optimizer: None):
    print("==> Loading checkpoint")
    dfc.load_state_dict(checkpoint["state_dict"])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer"])
        for param_group in optimizer.param_groups:
            lr = param_group["lr"]
        print("Loading Optimizer. Current learning rate = ", lr)

def count_params(model) -> float:
        """Calculate model parameters memory usage in MB."""
        model_size = 0
        for param in model.parameters():
            model_size += param.nelement() * param.element_size()
        return model_size / (1024 ** 3)  # Convert bytes to MB
        
def gkern(l=5, sig=1.):
    """\
    creates gaussian kernel with side length `l` and a sigma of `sig`
    """
    ax = np.linspace(-(l - 1) / 2., (l - 1) / 2., l)
    gauss = np.exp(-0.5 * np.square(ax) / np.square(sig))
    kernel = np.outer(gauss, gauss)
    return kernel

def mask_maker(X, Y, Z, R, N_x = 512, N_y = 512, kernel_size = 5):
    dx = 3e-6 # for 5120
    dz = 1e-3
    dr = 1e-6
    
    X = np.round(np.array(X)).astype(int)
    Y = np.round(np.array(Y)).astype(int)
    
    Z = (np.array(Z))
    R = (np.array(R))
    
    gk = gkern(l = kernel_size, sig = 1)
    mask  = np.zeros((3, N_y + kernel_size + 2, N_x + kernel_size + 2))
    
    for (x,y,z,r) in zip(X,Y,Z,R):


        mask[0,(kernel_size//2+y-kernel_size//2):(kernel_size//2+y+kernel_size//2+1),(kernel_size//2+x-kernel_size//2):(kernel_size//2+x+kernel_size//2+1)] = gk # Use this for synthetic 
        mask[1,(kernel_size//2+y-kernel_size//2):(kernel_size//2+y+kernel_size//2+1),(kernel_size//2+x-kernel_size//2):(kernel_size//2+x+kernel_size//2+1)] = z
        mask[2,(kernel_size//2+y-kernel_size//2):(kernel_size//2+y+kernel_size//2+1),(kernel_size//2+x-kernel_size//2):(kernel_size//2+x+kernel_size//2+1)] = r


    return mask[:,1+kernel_size//2:N_y+1+kernel_size//2,1+kernel_size//2:N_x+1+kernel_size//2]

def peak_local_max(input, threshold_abs=1, min_distance=1):
    '''
    Returns a binary map where maxima positions are true.

        Parameters:
            input (pytorch tensor): image-like pytorch tensor of dimension [batch_size, channels, width, height], where each image will be processed individually
            threshold_abs (float): local maxima below will be dropped
            min_distance (int): min distance (in pixels) between two maxima detections.
        Returns
            pytorch tensor of same shape as input
    '''
    max_filtered=nn.functional.max_pool2d(input, kernel_size=2*min_distance+1, stride=1, padding=min_distance)
    maxima = torch.eq(max_filtered, input)
    return maxima * (input >= threshold_abs)

epsilon = 2.2204e-16
def precision(hits, fp):
    if hits == 0 and fp == 0:
        return 1.0
    else:
        return hits/(hits+fp+epsilon)

def recall(hits, fn):
    if hits == 0 and fn == 0:
        return 1.0
    
    else:
        return hits/(hits+fn+epsilon)

def f1(precision, recall):
    return 2*(precision*recall)/(precision+recall+epsilon)


def get_grid(shape, device):
        batchsize, size_x, size_y = shape[0], shape[2], shape[3]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        # print(gridx.shape)
        gridx = gridx.reshape(1, 1, size_x, 1).repeat([batchsize, 1, 1, size_y])
        # print(gridx.shape)
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, 1, size_y).repeat([batchsize, 1, size_x, 1])
        return torch.cat((gridx, gridy), dim=1).to(device)

def make_smaller_patches(input: torch.tensor, patch_size: int = 128, step: int = 128, device: str = 'cuda'):
    input = input.cpu().detach().numpy()[0,0]
    input = patchify(input, patch_size=patch_size, step=step)
    intermediate_shape = input.shape
    input = np.reshape(input, (input.shape[0]*input.shape[1],patch_size, patch_size))
    input = torch.from_numpy(input).unsqueeze(1).to(device)
    return input, intermediate_shape

def patch_batch(input: torch.tensor, im_shape: tuple = (256,256), intermediate_shape:tuple = (2,2,256,256), device: str = 'cuda'):
    input = input.cpu().detach().numpy()
    input = np.reshape(input, intermediate_shape)
    input = unpatchify(input, imsize = im_shape)
    input = torch.from_numpy(input).unsqueeze(0)
    return input 


def make_img_crops(img_files, crops_per_hologram, crop_size, step_size, W, edge_crop_dist, ds_factor):
    num_files = len(img_files)
    inp = np.zeros((crops_per_hologram*num_files,crop_size,crop_size)) # storing them in a numpy array 
    i = 0
    xy_slices = (slice(edge_crop_dist, W-edge_crop_dist), slice(edge_crop_dist, W-edge_crop_dist)) 

    # print("Now downsampling and cropping...")
    avgpool = nn.AvgPool2d(ds_factor,ds_factor)

    for f in img_files:
        holo = avgpool(torch.from_numpy(np.float32(np.array(Image.open(f)))).unsqueeze(0)).squeeze(0).numpy()[xy_slices]
        # print("Downsampled hologram shape", holo.shape)
        holo_patches = patchify(holo, patch_size=crop_size, step=step_size)

        holo_patches = np.reshape(holo_patches, (holo_patches.shape[0]*holo_patches.shape[1], crop_size, crop_size))
        inp[i:i+holo_patches.shape[0]] = holo_patches
        # print("Crops obtained", holo_patches.shape) 
        i += holo_patches.shape[0]

    inp = torch.from_numpy(inp).unsqueeze(1).float() # converting to torch and float32
    # for i in range(inp.shape[0]):
    #     inp[i] = torch.clip((inp[i]-torch.mean(inp[i]))*2, 0, 255)
    # print("Final test data shape", inp.shape)
    return inp


def make_msk_crops(m_files, crops_per_hologram, crop_size, step_size, W, edge_crop_dist, ds_factor, gkern_size, dist_from_crop_corner):
    # Borrowing from above 
    num_files = len(m_files)

    # Desired shape of the target labels, adjust accordingly 
    target = torch.zeros((num_files*crops_per_hologram,3,crop_size,crop_size))

    # some known paramters of the ground truth
    dr = 3 #um, pixel size
    Z = [50.0, 99.0, 167.0, 192.0, 75.0] # z values of the ground truth 
    # domain of ground truth
    xmin = 0
    ymin = 0
    xmax = 5120
    ymax = 5120


    # adjust this for matching. At downsampled levek kernel_size = 7 would men 3*4*3 x 3*4*3 µm2 area around the ground truth center will be checked for matching. 
    kernel_size = gkern_size

    i = 0
    for file_num in range(len(m_files)):

        data = np.genfromtxt(m_files[file_num])
        x = data[:,0]
        y = data[:,1]
        d = data[:,2]*dr
        z = np.ones(x.shape[0])*Z[file_num]


        # throw away values if there are any out of the 5120 domain 
        xx = x[(x>xmin)*(y>ymin)*(x<xmax)*(y<ymax)]
        yy = y[(x>xmin)*(y>ymin)*(x<xmax)*(y<ymax)]
        zz = z[(x>xmin)*(y>ymin)*(x<xmax)*(y<ymax)]
        dd = d[(x>xmin)*(y>ymin)*(x<xmax)*(y<ymax)]

        
        xy_slices = (slice(3), slice(edge_crop_dist, W-edge_crop_dist), slice(edge_crop_dist,W-edge_crop_dist))
        # converts sparse coordinates into images for matching and slices out the the earlier cropped edges
        tgt = mask_maker(xx/ds_factor, yy/ds_factor, zz, dd, xmax//ds_factor, ymax//ds_factor, kernel_size = kernel_size)[xy_slices]
        # print(f"For hologram {W}x{W} at Z = {Z[file_num]}mm, ground truth mask of shape {tgt.shape} is made.")
        
        # Below we crop the big masks
        mask_4ds = tgt
        xy_masks = patchify(np.flipud(mask_4ds[0]), crop_size, step_size)
        z_masks = patchify(np.flipud(mask_4ds[1]), crop_size, step_size)
        r_masks = patchify(np.flipud(mask_4ds[2]), crop_size, step_size)

        xy_masks = xy_masks.reshape((xy_masks.shape[0]*xy_masks.shape[1], crop_size, crop_size))
        z_masks = z_masks.reshape((z_masks.shape[0]*z_masks.shape[1], crop_size, crop_size))
        r_masks = r_masks.reshape((r_masks.shape[0]*r_masks.shape[1], crop_size, crop_size))
        masks_4ds = np.concatenate((xy_masks[:,np.newaxis,:,:],z_masks[:,np.newaxis,:,:],r_masks[:,np.newaxis,:,:]), axis = 1)
        tgt = masks_4ds

        tgt = torch.from_numpy(tgt)
        # print(f"Cropped the mask into {list(tgt.shape)}!")
        target[i:i+tgt.shape[0]] = tgt
        i += tgt.shape[0]
        
    tgt = target
    # print("Final test target mask shape", tgt.shape)
    dist_from_corner = dist_from_crop_corner #128, 64
    slice_holo = (slice(None), slice(None), slice(dist_from_corner, crop_size-dist_from_corner), slice(dist_from_corner, crop_size-dist_from_corner))
    return tgt[slice_holo]

def error_analysis(tgt, store_xy2, store_z2, store_r2, best_cutoff, min_distance, kernel_size, file_indices, min_r, max_r, ez_allowed):
    # note the best_cutoff and min_distance from above 
    threshold = best_cutoff
    min_distance = min_distance

    # storing the predicted z and size values for error analysis 
    z_predictions = []
    r_predictions = []
    z_detected = []
    r_detected = []
    
    # similar stuff as above, extratcing locations and sizes at best_cutoff from intersections and storing them for error calculation later
    for i in range(len(store_xy2[file_indices])):
        true_mask = ((tgt[i,0] > gkern(kernel_size).min()).float()).unsqueeze(0)
        peak_finding_mask = peak_local_max((store_xy2[i]*true_mask), threshold_abs=threshold, min_distance=kernel_size//2+1)
        
        z_values_predicted = store_z2[i][peak_finding_mask]
        z_predictions.append(z_values_predicted)
        z_values_detected = (tgt[i,1]).unsqueeze(0)[peak_finding_mask]
        z_detected.append(z_values_detected)

        r_values_predicted = store_r2[i][peak_finding_mask]
        r_predictions.append(r_values_predicted)
        r_values_detected = (tgt[i,2]).unsqueeze(0)[peak_finding_mask]
        r_detected.append(r_values_detected)



    # this extracts the z and size values from ground truth
    z_true = []
    r_true = []
    for i in range(len(store_xy2[file_indices])):
        peak_finding_mask = peak_local_max(tgt[i,0].unsqueeze(0), threshold_abs=0.6, min_distance=1).squeeze(0)
        z_values_true = tgt[i,1][peak_finding_mask]
        z_true.append(z_values_true)
        r_values_true = tgt[i,2][peak_finding_mask]
        r_true.append(r_values_true)

    d_total  = []
    z_total = []
    for r,z in zip(r_true, z_true):
        for rr, zz in zip(r, z):
            if rr > max_r or rr < min_r:
                continue
            d_total.append(rr)
            z_total.append(zz)
    z_total = np.array(z_total)
    d_total = np.array(d_total)

    # above the data structure is crop wise. Here we make them independent of crop. Storing error values per detected/matched particle.
    Z_true = []
    Z_pred = []
    R_true = []
    R_pred = []
    ez_det = []
    er_det = [] 
    

    for zpred, rpred, ztrue, rtrue  in zip(z_predictions, r_predictions, z_detected, r_detected):

        for zt, zp, rt, rp, ez, er in zip(ztrue, zpred ,rtrue, rpred, np.abs(zpred-ztrue), np.abs(rpred-rtrue)):
        
        # we don't calculate error for particles out of size range and ez_allowed 
            if rt > max_r or rt < min_r or ez > ez_allowed:
                
                continue

            Z_true.append(zt.cpu())
            Z_pred.append(zp.cpu())
            R_true.append(rt.cpu())
            R_pred.append(rp.cpu())
            ez_det.append(ez.cpu())
            er_det.append(er.cpu())
            
    # convert to numpy for convenience
    z_det = torch.asarray(Z_true).numpy()
    d_det = torch.asarray(R_true).numpy()
    d_pred = torch.asarray(R_pred).numpy()
    z_pred = torch.asarray(Z_pred).numpy()
    ez_det = torch.asarray(ez_det).numpy()
    ed_det = torch.asarray(er_det).numpy()
    
    return ez_det, ed_det, z_det, z_pred, d_det, d_pred, d_total, z_total
    

def _eval_prec_rec_f1_(tgt, store_xy2, store_z2, store_r2, gkern_size, num_files, num_crops_per_holo, 
                       cutoff_range, min_distance, max_r, min_r, max_z, min_z, ez_allowed, error_calc):
    kernel_size = gkern_size
    

    # dictionaries for storing per cutoff value the below
    false_postives = {} 
    false_negatives = {}
    interesections = {}

    hit_box_size_param = gkern(l = kernel_size, sig = 1).min() #predictions are checked within the hitbox 
     
    num_ez_outlier = 0 # counts false positives due to ez<ez_allowed criteria 

    # middle_crop = (slice(None), slice(None), slice(), slice())

    gt_particles = peak_local_max(tgt[:,0].unsqueeze(1), 0.8, 2) # converting gauss maps to binary mask, considering only peaks 

    # dictionaries for storing per cutoff value the below
    pr = {} # precision 
    rc = {} # recall
    f1s = {} # f1 score 

    # if want to evaluate all five holograms
    num_samples = num_files*num_crops_per_holo  
    file_index = 0 

    # num_samples = 1*num_crops_per_holo # if want to evaluate one hologram 
    # file_index = 0 # 0,1,2,3,4 corresponds to z = 50mm, 99mm, 167mm, 192mm, 75mm 

    # DO NOT CHANGE
    start_file_index = 0 + file_index*9 
    file_indices = slice(start_file_index,start_file_index+num_samples)
    i = start_file_index

    for cutoff in trange(1,cutoff_range+1):

        threshold = cutoff/cutoff_range # brings the cutoff between 0 and 1

        # initialize keys 
        if threshold not in false_negatives:
            false_postives[threshold] = 0  
            false_negatives[threshold] = 0
            interesections[threshold] = 0
            rc[threshold] = 0
            pr[threshold] = 0
            f1s[threshold] = 0

        # for tracking precision, recall and f1 per crop
        tmp_pr = 0
        tmp_rc = 0
        tmp_f1 = 0

        for j, sample in enumerate(store_xy2[file_indices]):
            
            # extract precited particle locations, notice store_xy map above. 
            predicted_particles = peak_local_max(sample, threshold_abs=threshold, min_distance=min_distance) 
            # find intersections
            hits = peak_local_max((sample*predicted_particles)*((tgt[i,0]>=hit_box_size_param).float()).unsqueeze(0), threshold_abs=threshold, min_distance=kernel_size//2)
            # check for outliers with respect to ground truth. Remove particles below min_r and above max_r
            gt_sizes_out_of_domain = tgt[i,2][gt_particles[i,0]]
            gt_sizes_out_of_domain = ((gt_sizes_out_of_domain > max_r) + (gt_sizes_out_of_domain < min_r)).sum()
            # check for outliers with respect to interesections. Remove particles below min_r and above max_r
            hits_sizes_out_of_domain = tgt[i,2][hits.squeeze(0)]
            hits_sizes_out_of_domain = ((hits_sizes_out_of_domain > max_r) + (hits_sizes_out_of_domain < min_r)).sum()
            # detected z_values at intersections
            gt_z_detected = tgt[i,1][hits.squeeze(0)]
            # predicted z_values at intersections 
            pred_z_detected = store_z2[i][hits]
            # count the particles which match in xy but are terribly predicted in z
            num_ez_outlier = (torch.abs(pred_z_detected-gt_z_detected) > ez_allowed).sum()

            # binary masks to scalars, removing out-of-domain particles and false postives with respect to z from intersections
            hits = hits.sum()-hits_sizes_out_of_domain-num_ez_outlier
            # same for false positives
            fp = (predicted_particles.sum()-hits_sizes_out_of_domain) - (hits)
            # same for false negatives
            fn =(gt_particles[i,0].sum()-gt_sizes_out_of_domain) - (hits+num_ez_outlier) # add num_ez_outlier here because they are false positves and added there.
            # checks for overcounting, number of matched particles shouldn't be greater than the ground truth. 
            
    
            if hits>hits+fn:
                # print((gt_particles[i,0].sum()-gt_sizes_out_of_domain), hits)
                # print(f"Overcounting by {hits-hits-fn}")
                fn = 0
                fp += hits - (hits+fn)
                hits -= hits - (hits+fn)


            # accumulate for calculating precision, recall and f1 score for all five holograms together. 
            false_postives[threshold] += fp
            false_negatives[threshold] += fn
            interesections[threshold] += hits

            #this calculates the same but per crop and later we average
            tmp_pr += precision(hits, fp)
            tmp_rc += recall(hits, fn)
            tmp_f1 += f1(tmp_pr, tmp_rc)
        
            i += 1
        i = start_file_index
        
        # precision, recall and f1 per crop per cutoff
        pr[threshold] = tmp_pr/num_samples
        rc[threshold] = tmp_rc/num_samples
        f1s[threshold] = tmp_f1/num_samples


    # code below calculates the precision, recall and f1 for all five holograms together 
    P = []
    R = []
    f1_score = []

    # best points in terms of f1 score 
    best_cutoff = 0 # cutoff at best f1
    best_cutoff_index = 0
    best_f1 = 0
    R_star = 0 # Recall at best f1
    P_star = 0 # Precision at best f1

    for index, cutoff in enumerate(interesections):
        # if for very high cutoff (0.99), you're still getting false positives (by accident), have to remove them this way to 
        # make the precision-recall curve not look weird.
        if cutoff > 0.99:
            if false_negatives[cutoff] > 0:
                false_postives[cutoff] = 0

        prec = precision(interesections[cutoff], false_postives[cutoff])
        rec = recall(interesections[cutoff], false_negatives[cutoff])
        if prec == 0 and rec == 0:
            print(f"{prec} precision and {rec} recall at {cutoff}. Either manually remove such cases, or find the bug!")    
        F1 = f1(prec, rec)
        
        # block for tracking the best f1
        if F1 > best_f1:
            best_f1 = F1
            best_cutoff = cutoff 
            best_cutoff_index = index
            R_star = rec
            P_star = prec

        P.append(float(prec))
        R.append(float(rec))
        f1_score.append(float(F1))

    # print('(Cutoff of at max F1, Precision at max F1, Recall at max F1, Max. F1):', best_cutoff, float(P_star), 
    #       float(R_star), float(best_f1))
        
    if error_calc:
        _, _, z_det, z_pred, d_det, d_pred, d_total, z_total = error_analysis(tgt, store_xy2, store_z2, store_r2, best_cutoff, min_distance, kernel_size, file_indices, min_r, max_r, ez_allowed)
        print("Putting everyhting into a dictionary...")
        prediction_dict = {
            "precision": P,
            "recall": R,
            "F1": f1_score,
            "best_F1":  float(best_f1),
            "precision_at_best_f1": float(P_star),
            "recall_at_best_f1": float(R_star),
            "z_detected": z_det,
            "z_predicted": z_pred,
            "d_detected": d_det,
            "d_predicted": d_pred,
            "z_all_in_gt": z_total,
            "d_all_in_gt":d_total
        }

        # return P, R, f1_score,  float(P_star), float(R_star), float(best_f1), ez, ed, z_det, z_pred, d_det, d_pred, d_total, z_total
        return prediction_dict
    
    return P, R, f1_score, float(P_star), float(R_star), float(best_f1)


def _get_dfcs(config):
    device = config["test"]["device"] 

    dfc, unet = None, None

    if config["test"]["dfc"]:


        in_channels = config["dfc"]["fourier_part"]["in_channels"]
        hidden_channels = config["dfc"]["fourier_part"]["hidden_channels"]
        n_modes = config["dfc"]["fourier_part"]["n_modes"]
        fourier_interpolation = config["dfc"]["fourier_part"]["fourier_interpolation"]
        bias = config["dfc"]["fourier_part"]["bias"]
        skip = config["dfc"]["fourier_part"]["skip"]
        dilate_fourier_kernel_fac = config["dfc"]["fourier_part"]["dilate_fourier_kernel_fac"]
        lifting_channels = config["dfc"]["fourier_part"]["lifting_channels"]
        projection_channels = config["dfc"]["fourier_part"]["projection_channels"]
        n_layers = config["dfc"]["fourier_part"]["n_layers"]
        decomposition = config["dfc"]["fourier_part"]["factorization"]
        implementation = config["dfc"]["fourier_part"]["implementation"]
        rank = config["dfc"]["fourier_part"]["rank"]
        mem_checkpoint = config["dfc"]["fourier_part"]["mem_checkpoint"]
        separable_fourier_layers = config["dfc"]["fourier_part"]["separable_fourier_layers"]
        batch_norm = config["dfc"]["fourier_part"]["batch_norm"]
        fno_block_precision = config["dfc"]["fourier_part"]["fourier_block_precision"]


        kernel_size = config["dfc"]["dilated_cnn_part"]["kernel_size"]
        padding = config["dfc"]["dilated_cnn_part"]["padding"]
        dilations = config["dfc"]["dilated_cnn_part"]["dilations"]
        dfc = f_c_network(in_channels=in_channels, width=hidden_channels, n_modes=(n_modes,n_modes), fourier_interpolate=fourier_interpolation, bias = bias,
        spectral_dilation_fac=dilate_fourier_kernel_fac, decomposition=decomposition, rank = rank, implementation=implementation, 
                    separable_fourier_layers=separable_fourier_layers, mem_checkpoint=mem_checkpoint, skip = skip, batch_norm=batch_norm,
                    fno_block_precision=fno_block_precision, 
                    lifting_channels=lifting_channels, 
                    projection_channels=projection_channels, kernel_size=kernel_size, padding=padding,
                    dilations=dilations, num_layers=n_layers,).to(device)


        load_checkpoint(torch.load(config["dfc"]["data"]["LOAD_CHECKPOINT_DIR"]), dfc, None)
        
    if config["test"]["unet"]:
        in_channels = config["unet"]["build"]["in_channels"]
        in_size = config["unet"]["build"]["in_size"]
        out_channels = config["unet"]["build"]["out_channels"]
        latent_lift_proj_dim = config["unet"]["build"]["latent_lift_proj_dim"]
        n_layers = config["unet"]["build"]["n_layers"]
        n_res = config["unet"]["build"]["n_res"]
        activations = config["unet"]["build"]['activations']

        unet = CNO(in_dim = in_channels, in_size = in_size, N_layers = n_layers, out_dim = out_channels,
     activation = activations, N_res=n_res, latent_lift_proj_dim=latent_lift_proj_dim).to(device = device)
        

        # unet = uperconvnext(in_channels = in_channels, out_channels = out_channels).to(device = device)
       
        # unet = smp.DeepLabV3(encoder_name = 'resnet101' ,in_channels = in_channels, classes = out_channels, activation = None).to(device = device)

        
        load_checkpoint(torch.load(config["unet"]["data"]["LOAD_CHECKPOINT_DIR"]), unet, None)
        
    return dfc, unet
    
def _infer(inp, dfc, unet, config, device, dist_from_corner, crop_size,):
    predictions = []

    # local crops if different fov-trained unet is used
    # dist_from_corner = 128
    # crop_size = 384
    slice_holo = (slice(None), slice(None), slice(dist_from_corner, crop_size-dist_from_corner), slice(dist_from_corner, crop_size-dist_from_corner))
    std_scale_factor_128to384 = 0.6904758593396808
    std_diffccde_level_factor = 9200/12000
    # put the dfc in eval mode 
    if unet is not None and dfc is None:
        unet.eval()
        unet_mean = config["unet"]["evaluation"]["mean"]
        unet_std = config["unet"]["evaluation"]["std"]
        # storing predictions for analysing later 
        with torch.no_grad():
            for i in trange(inp.shape[0]):
    
                img = ((inp[i].unsqueeze(0)-unet_mean[0])/unet_std[0]).to(device = device)
                
                prediction = torch.sigmoid(unet(img))[slice_holo] # adjust accordingly 
                predictions.append(prediction.cpu())
    

    if unet is not None and dfc is not None:
        unet.eval()
        dfc.eval()

        unet_mean = config["unet"]["evaluation"]["mean"]
        unet_std = config["unet"]["evaluation"]["std"]
        dfc_mean = config["dfc"]["evaluation"]["mean"]
        dfc_std = config["dfc"]["evaluation"]["std"]

        
        store_wholo = []

        with torch.no_grad():
            for i in trange(inp.shape[0]):    
                                    
                img = ((inp[i].unsqueeze(0)-dfc_mean[0])/dfc_std[0]).to(device = device)
                # img = ((inp[i].unsqueeze(0)-torch.mean(inp[i]))/torch.std(inp[i])).to(device = device)
                
                grid = get_grid(img.shape, device = device)
                img = torch.concat((img, grid), dim = 1)

                prediction_holo = dfc(img) # have to place before continue statment otherwise mismatch with gt
                store_wholo.append(prediction_holo[0].cpu())

                img = torch.clip((inp[i].unsqueeze(0).to(device = device) - 128)*std_diffccde_level_factor+128, 0, 255)
                prediction_holo = torch.concat(((img - unet_mean[0])/(unet_std[0]/std_scale_factor_128to384), (prediction_holo - unet_mean[1])/unet_std[1]), dim = 1)
                
                prediction = torch.sigmoid(unet(prediction_holo))[slice_holo]
                # print(prediction.max(), prediction.min())
                
                # plt.imshow(prediction[0,0].cpu().detach().numpy())
                predictions.append(prediction.cpu())
            

    store_xy2, store_z2, store_r2 = [], [], [],
    print("Separating xy, z and size channels in different lists for further analysis...")
    for pred in predictions:
        store_xy2.append(pred[:,0])
        store_z2.append(pred[:,1]*200)
        store_r2.append(pred[:,2]*100)
    if dfc is not None:
        return store_wholo, store_xy2, store_z2, store_r2
    else:
        return  store_xy2, store_z2, store_r2


def _get_throughput(dfc, unet, optimal_batch_size, inp_shape, device):
        
    repetitions=100
    total_time = 0
    timings = np.zeros((repetitions))

    if unet is not None and dfc is None:
        unet.eval()
        dummy_input = torch.randn(optimal_batch_size, 1, *inp_shape, dtype=torch.float).to(device)
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                for i in trange(repetitions):
                    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
                    starter.record()
                    _ = torch.sigmoid(unet(dummy_input))  
                    ender.record()
                    torch.cuda.synchronize()
                    curr_time = starter.elapsed_time(ender)/1000
                    total_time += curr_time
                    timings[i] = curr_time


    if unet is not None and dfc is not None:
        unet.eval()
        dfc.eval()

        unet_mean = config["unet"]["evaluation"]["mean"]
        unet_std = config["unet"]["evaluation"]["std"]
        dfc_mean = config["dfc"]["evaluation"]["mean"]
        dfc_std = config["dfc"]["evaluation"]["std"]

        dummy_input = torch.randn(optimal_batch_size, 1, *inp_shape, dtype=torch.float).to(device)
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                for i in trange(repetitions):    
                    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
                    starter.record()
                    prediction_holo = dfc(torch.concat((dummy_input, get_grid(dummy_input.shape, device = device)), dim = 1)) # have to place before continue statment otherwise mismatch with gt
                    prediction_holo = torch.concat(((dummy_input*dfc_std[0]+dfc_mean[0] - unet_mean[0])/unet_std[0],
                                                    (prediction_holo - unet_mean[1])/unet_std[1]), dim = 1)
                    _ = torch.sigmoid(unet(prediction_holo))
                    ender.record()
                    torch.cuda.synchronize()
                    curr_time = starter.elapsed_time(ender)/1000
                    total_time += curr_time
                    timings[i] = curr_time

    # total_time = np.sum(timings)
    # total_time_std = np.std(timings)
    Throughput = (repetitions*optimal_batch_size)/total_time
    mean_time_per_batch = np.mean(timings)
    std_timer_per_batch = np.std(timings)
    # print("Final Throughput:",Throughput)
    return Throughput, [mean_time_per_batch,std_timer_per_batch]

def _get_inference_times(dfc, unet, inp_shape, device):
    
    if unet is not None and dfc is None:
        unet.eval()
        dummy_input = torch.randn(1, 1, *inp_shape, dtype=torch.float).to(device)

        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
        repetitions = 300
        timings=np.zeros((repetitions,1))
        #GPU-WARM-UP
        for _ in range(10):
            _ = torch.sigmoid(unet(prediction_holo))
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                for i in trange(repetitions):
                    starter.record()
                    _ = torch.sigmoid(unet(dummy_input))  
                    ender.record()
                    torch.cuda.synchronize()
                    curr_time = starter.elapsed_time(ender)/1000
                    timings[i] = curr_time


    if unet is not None and dfc is not None:
        unet.eval()
        dfc.eval()

        unet_mean = config["unet"]["evaluation"]["mean"]
        unet_std = config["unet"]["evaluation"]["std"]
        dfc_mean = config["dfc"]["evaluation"]["mean"]
        dfc_std = config["dfc"]["evaluation"]["std"]

        dummy_input = torch.randn(1, 1, *inp_shape, dtype=torch.float).to(device)
        
        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
        repetitions = 300
        timings=np.zeros((repetitions,1))
        
        #warm_up
        for _ in range(10):
            prediction_holo = dfc(torch.concat((dummy_input, get_grid(dummy_input.shape, device = device)), dim = 1)) # have to place before continue statment otherwise mismatch with gt
            prediction_holo = torch.concat(((dummy_input*dfc_std[0]+dfc_mean[0] - unet_mean[0])/unet_std[0],
                                                 (prediction_holo - unet_mean[1])/unet_std[1]), dim = 1)
            _ = torch.sigmoid(unet(prediction_holo))

        with torch.cuda.amp.autocast():
            with torch.no_grad():
                for i in trange(repetitions):    
                    starter.record()
                    prediction_holo = dfc(torch.concat((dummy_input, get_grid(dummy_input.shape, device = device)), dim = 1)) # have to place before continue statment otherwise mismatch with gt
                    prediction_holo = torch.concat(((dummy_input*dfc_std[0]+dfc_mean[0] - unet_mean[0])/unet_std[0],
                                                    (prediction_holo - unet_mean[1])/unet_std[1]), dim = 1)
                    _ = torch.sigmoid(unet(prediction_holo))
                    ender.record()
                    torch.cuda.synchronize()
                    curr_time = starter.elapsed_time(ender)/1000
                    timings[i] = curr_time


    # print(f"Inference_time: {np.mean(timings)} $\pm$ {np.std(timings)}",)
    return f"{np.mean(timings):.4f} $\pm$ {np.std(timings):.4f}"

def _get_memory_stats(dfc, unet, inp_shape, device):
    if unet is not None and dfc is None:
        unet.eval()
        with torch.cuda.amp.autocast():
            dummy_input = torch.randn(1, 1, *inp_shape, dtype=torch.float).to(device)
            torch.cuda.reset_peak_memory_stats()
            preditiction = torch.sigmoid(unet(prediction_holo))
            unet_act_mem = torch.cuda.max_memory_allocated() / 1024**3
            # print(f"Memory due to activations: {torch.cuda.max_memory_allocated() / 1024**3} GB")        
        return [(count_params(unet)*4/(1024**3), unet_act_mem)]
        
    if unet is not None and dfc is not None:
        unet.eval()
        dfc.eval()

        unet_mean = config["unet"]["evaluation"]["mean"]
        unet_std = config["unet"]["evaluation"]["std"]
        dfc_mean = config["dfc"]["evaluation"]["mean"]
        dfc_std = config["dfc"]["evaluation"]["std"]
        with torch.cuda.amp.autocast():
            dummy_input = torch.randn(1, 1, *inp_shape, dtype=torch.float).to(device)
            dummy_input_with_xychannels = torch.concat((dummy_input, get_grid(dummy_input.shape, device = device)), dim = 1)
            torch.cuda.reset_peak_memory_stats()
            prediction_holo = dfc(dummy_input_with_xychannels) # have to place before continue statment otherwise mismatch with gt
            dfc_act_mem = torch.cuda.max_memory_allocated() / 1024**3
            prediction_holo = torch.concat(((dummy_input*dfc_std[0]+dfc_mean[0] - unet_mean[0])/unet_std[0],
                                                    (prediction_holo - unet_mean[1])/unet_std[1]), dim = 1)
            torch.cuda.reset_peak_memory_stats()
            preditiction = torch.sigmoid(unet(prediction_holo))
            unet_act_mem = torch.cuda.max_memory_allocated() / 1024**3
        
        # print(f"Memory occupied by the model: {(count_params(unet)+count_params(dfc))*4/(1024**3)} GB")

        return [(count_params(unet)*4/(1024**3), unet_act_mem-count_params(unet)*4/(1024**3)), 
                 (count_params(dfc)*4/(1024**3), dfc_act_mem-count_params(dfc)*4/(1024**3))]


def get_forward_backward_memory_max_batch_size(model, shape):
    x = torch.randn(shape).cuda()
    if isinstance(model, f_c_network):
        x = torch.concat((x, get_grid(x.shape, device=x.device)), dim = 1)
    # dummy inits 
    loss = nn.MSELoss()
    scaler = torch.cuda.amp.GradScaler()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
    # with torch.no_grad():
    #     print(model(x).cpu().detach().numpy().shape)
    torch.cuda.reset_peak_memory_stats()
    init_mem = torch.cuda.max_memory_allocated()
    with torch.cuda.amp.autocast(enabled=True):
        # print(x.dtype)
        y = model(x)
        torch.cuda.synchronize()
        forw_mem = torch.cuda.max_memory_allocated()
        # print(y.dtype)
        l = loss(x[:,0,:,:].unsqueeze(1), y)   
        
    scaler.scale(l).backward() # ?
            # optimizer.step()
    scaler.step(optimizer) # ?
    scaler.update()        # ?
    torch.cuda.synchronize()
    back_mem = torch.cuda.max_memory_allocated()
    


    if init_mem is None:    
        init_mem = 0
    print(f"Batch size:{x.shape[0]}", ("Forward:%.2f"%((forw_mem) / 1024**3), "Backward:%.2f"%((back_mem) / 1024**3), "Init:%.2f"%(init_mem/1024**3)), "%.2f" % (count_params(model)* 32/8/1024**3))
    print("\n")
    
    return (forw_mem) / 1024**3, (back_mem) / 1024**3, (init_mem/1024**3)


def _get_fig_axes(nrows, ncols, width, fraction, sharex = 'col', sharey = 'row', layout = 'constrained'):
    fig, axes = plt.subplots(nrows,ncols,figsize = set_size(width=width, fraction = fraction, subplots = (nrows,ncols)), sharex=sharex, sharey = sharey,layout = layout)
    for ax in axes.flat:
        ax.set_xlabel('X Label')
        ax.set_ylabel('Y Label')

        ax.xaxis.set_tick_params(width=0.5)
        ax.yaxis.set_tick_params(width=0.5)

        ax.spines['bottom'].set_linewidth(0.5) # Line width for the bottom axis
        ax.spines['left'].set_linewidth(0.5)

    return fig, axes


def _plot_(P, R, ez, ed, z_det, d_det, d_total, z_total):
    color_fc = colors[0]
    Z = [50, 75, 99, 167, 192,]
    d_bins = [5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 13.5, 16.5, 23.5, 56.5, 83.5]
    center_points_for_d_bins = [int((d_bins[i]+d_bins[i+1])/2) for i in range(len(d_bins)-1)]


    fig, axes = _get_fig_axes(3, 5, 'iccv', 1,)
    
    for ax in axes.flat:
        ax.set_xlabel(r'Diameter$_{GT}$ [µm]')


    sns.despine()

    showfliers = False
    showcaps = True
    whis = True

    errors_d_d = {}
    errors_z_d = {}
        
    for j,z in enumerate(Z):
        z_mask = z_det == z
        get_bin_idx = np.digitize(d_det[z_mask], bins = d_bins)
       
        if z not in errors_d_d:
            errors_d_d[z] = {}
            errors_z_d[z] = {}
        for i, size in enumerate(center_points_for_d_bins):
            if size not in errors_d_d:
                errors_d_d[z][size] = []
                errors_z_d[z][size] = []
                
            
            errors_d_d[z][size].append([ed[z_mask][get_bin_idx == i+1]])
            errors_z_d[z][size].append([ez[z_mask][get_bin_idx == i+1]])
            
            
        
        fc_ed_d = [np.array(errors_d_d[z][key][0][0]) for key in errors_d_d[z]]
        fc_ez_d = [np.array(errors_z_d[z][key][0][0]) for key in errors_z_d[z]]
        fc_d_det = d_det[z_mask]
        # return errors_d_d, d_det_z
        # print(fc_d_det)
        tmp_error_xticklabels = range(len(fc_ed_d))


        for i in range(axes.shape[0]):

            ax1, ax2, ax3 = [*axes[:,j]]

            hist, bin_centers = np.histogram(d_total[z_total == z], d_bins)
            tmp_range = range(1,len(hist)+1)
            ax1.bar(tmp_range, hist, color = 'white', edgecolor = 'black', label = 'GT', linewidth = 0.8)
            hist, bin_centers = np.histogram(fc_d_det, d_bins)
            tmp_range = range(1,len(hist)+1)
            ax1.bar(tmp_range, hist, alpha = 0.3, label = 'NOA', color = color_fc)
            ax1.set_yscale("linear")



            ax1.set_xticks([i for i,j in enumerate(hist)])
            xticklabels = [f"{d_bins[i]-0.5}-{d_bins[i+1]-0.5}" for i in range(len(d_bins)-1)]
            ax1.set_xticklabels(xticklabels, rotation = 45)
            ax1.set_ylabel(r"$\#$particles")


            fc_bplot_ed_d = ax2.boxplot(fc_ed_d, label=tmp_error_xticklabels, patch_artist=True, showfliers=showfliers, showcaps = showcaps, whis = whis)
            for patch in fc_bplot_ed_d['boxes']:
                patch.set_edgecolor(color_fc)
                patch.set_facecolor("none")

            for patch in fc_bplot_ed_d['fliers']:
                patch.set(markeredgecolor=color_fc, markersize = 2)


            for patch in fc_bplot_ed_d['whiskers']:
                patch.set_color(color_fc)

            for patch in fc_bplot_ed_d['medians']:
                patch.set_color(color_fc)

            for cap in fc_bplot_ed_d['caps']:
                cap.set_color(color_fc)

        

            ax2.set_xticks([i+1 for i,j in enumerate(fc_ed_d)])
            xticklabels = [f"{d_bins[i]}-{d_bins[i+1]}" for i in range(len(d_bins)-1)]
            ax2.set_xticklabels(xticklabels, rotation = 45,)
            ax2.set_ylabel(r"err$_{diam}$[µm]")
            # ax2.set_yscale('log')


            # ax2.get_yaxis().set_major_formatter(plticker.ScalarFormatter())
            ax2.yaxis.set_minor_locator(plticker.MultipleLocator(5))

            fc_bplot_ez_d = ax3.boxplot(fc_ez_d, label=tmp_error_xticklabels, patch_artist=True, showfliers=showfliers, showcaps = showcaps, whis = whis)
            for patch in fc_bplot_ez_d['boxes']:
                patch.set_edgecolor(color_fc)
                patch.set_facecolor("none")

            for patch in fc_bplot_ez_d['fliers']:
                patch.set(markeredgecolor=color_fc, markersize = 2)

            for patch in fc_bplot_ez_d['whiskers']:
                patch.set_color(color_fc)

            for patch in fc_bplot_ez_d['medians']:
                patch.set_color(color_fc)

            for cap in fc_bplot_ez_d['caps']:
                cap.set_color(color_fc)


            ax3.set_xticks([i+1 for i,j in enumerate(fc_ez_d)])
            xticklabels = [f"{int(d_bins[i]-0.5)}-{int(d_bins[i+1]-0.5)}" for i in range(len(d_bins)-1)]
            ax3.set_xticklabels(xticklabels, rotation = 60,)
            ax3.set_xlabel(r"diameter$_{GT}$[µm]")
            ax3.set_ylabel(r"err$_z$[mm]")
            # ax3.set_yscale('log')


            # ax3.get_yaxis().set_major_formatter(plticker.ScalarFormatter())
            # ax3.yaxis.set_major_locator(plticker.MultipleLocator(0.5))
            ax3.yaxis.set_minor_locator(plticker.MultipleLocator(5))

    for i in range(axes.shape[0]):
        for j in range(axes.shape[1]):
            if i!=axes.shape[0]-1:
                axes[i,j].set_xlabel("")
                axes[i,j].xaxis.set_visible(False)
                axes[i,j].spines['bottom'].set_visible(False)
            if j!=0:
                axes[i,j].set_ylabel("")
                axes[i,j].yaxis.set_visible(False)
                axes[i,j].spines['left'].set_visible(False)
            if i ==0:
                axes[i,j].set_yscale('linear')

    return errors_d_d, errors_z_d

In [None]:
# DO NOT CHANGE
img_files = glob.glob("/CloudTarget_holograms/") # these are master filtered holograms 


######### DO NOT CHANGE ###########
m_files = ["/CloudTarget_labels/"]
###################################



num_files = len(img_files)
assert num_files == len(m_files), print("Image and Mask lists are not of equal length!") 
print("Number of files:", num_files)


ds_factor = config["dfc"]["evaluation"]["ds_factor"] # if you downsample for inference 

W = config["dfc"]["evaluation"]["HOLO_SIZE"] # recorded holograms are square and of this dimension
H = config["dfc"]["evaluation"]["HOLO_SIZE"]
crop_size = config["dfc"]["evaluation"]["crop_size"] # adjust accordingly (this number should be NOT downsampled dimension)
step_size = config['dfc']["evaluation"]["step_size"]

# applying the downsampling effect below
W = W//ds_factor 
H = H//ds_factor
crop_size = crop_size//ds_factor
# crop_size = 256
step_size = step_size//ds_factor
mask_crop_size = crop_size # (xy,z,d) mask; adjust the size accordingly, this is temporary mask_crop_size here though 
# mask_crop_size = 128
assert W-crop_size >= 0, print("Asked crop size is bigger than the image itself")

# crop_size and original hologram size (H,W) might not be exactly divisible, so we crop some edges. Adjust accordingly. 
dist_to_cut_from_edge = int(config["dfc"]["evaluation"]["dist_to_cut_from_edge"]//ds_factor)
holo_size = W - dist_to_cut_from_edge*2 # after downsampling 
print(f"Hologram size after cutting given edge distance to crop:", holo_size*ds_factor)
anti_step_size = crop_size - step_size # 0 if no overlapping
num_crops_per_holo = ((W-crop_size)//step_size+1)**2 # formula: crop_size*n - (crop_size-step_size)*(n-1) = holo_size, n is the number of crops you will get of size crop_size = num_holos
dist_to_cut_from_each_edge = int((W - (crop_size*np.sqrt(num_crops_per_holo) - (anti_step_size*(np.sqrt(num_crops_per_holo)-1))))//2)
# dist_to_cut_from_each_edge = 0
print(f"Distance being cut from each corner in the downsampled hologram ({W*ds_factor}/{ds_factor}): {dist_to_cut_from_each_edge*ds_factor}")

total_edge_to_crop = int(dist_to_cut_from_edge)
holo_size = W - total_edge_to_crop*2 # after cropping
crops_per_hologram = num_crops_per_holo
print(f"Number of crops per {H}x{H} downsampled hologram: {crops_per_hologram}")

gkern_size = config["unet"]["evaluation"]["gkern_size"] # size of the gaussian blob, controls the strictness of the hitbox

# Evaluate for preicison and recall
cutoff_range = config["unet"]["evaluation"]["cutoff_range"] # will check 100 values between 0 and 1
min_distance = config["unet"]["evaluation"]["min_distance"]  # min distance b/w consecutive peaks in the prediction, can controle this

error_calc = config["unet"]["evaluation"]["error_calc"]
max_r = config["unet"]["evaluation"]["max_r"] # max particle size in the gt, µm
min_r = config["unet"]["evaluation"]["min_r"] # min particle size in the gt 
min_z = config["unet"]["evaluation"]["min_z"] # mm
max_z = config["unet"]["evaluation"]["max_z"]
ez_allowed = config["unet"]["evaluation"]["ez_allowed"] # if error in z greater than 10mm, count the prediction as fasle postive




In [None]:
if __name__ == "__main__":
    print("Getting i/o data...")
    inp = make_img_crops(img_files, crops_per_hologram, crop_size, step_size, W, total_edge_to_crop, ds_factor)
    tgt = make_msk_crops(m_files, crops_per_hologram, crop_size, step_size, W, total_edge_to_crop, ds_factor, gkern_size, anti_step_size//2)
    print(inp.shape)
    print(f"\nTest data input shape is {inp.shape} and target shape is {tgt.shape}")    
    print("\nGetting models...")
    dfc, unet = _get_dfcs(config, )
    print(f"\n Making predictions...")    
    if config["test"]["dfc"]:
        store_wholo, store_xy2, store_z2, store_r2 = _infer(inp, dfc, unet, config, device=config["test"]["device"], dist_from_corner=64, crop_size=crop_size)
        print(inp.shape)
    else:
        store_xy2, store_z2, store_r2 = _infer(inp, dfc, unet, config, device=config["test"]["device"], dist_from_corner=dist_to_cut_from_each_edge, crop_size=crop_size)
    
    print("\nAnalyzing results...")
    if error_calc:
        prediction_dict  = _eval_prec_rec_f1_(tgt, store_xy2, store_z2, store_r2, gkern_size, num_files, num_crops_per_holo,
                                                                                   cutoff_range, min_distance, max_r, min_r, max_z, min_z, ez_allowed, error_calc)
    else:
        P, R, f1_score, P_star, R_star, best_f1  = _eval_prec_rec_f1_(tgt, store_xy2, store_z2, store_r2, gkern_size, num_files, num_crops_per_holo,
                                                                                cutoff_range, min_distance, max_r, min_r, max_z, min_z, ez_allowed, error_calc)

    optimal_batch_size = config["dfc"]["evaluation"]["optimal_batch_size"]

    for batch_size in optimal_batch_size:
        with torch.no_grad():
            torch.cuda.empty_cache()
        try:
            throughput, timings = _get_throughput(dfc, unet, batch_size, inp.shape[2:], device=config["test"]["device"])
            break
        except Exception as e:
            print(f"Exceeding GPU memory for batch size {batch_size}")
            with torch.no_grad():
                torch.cuda.empty_cache()
    opt_batch_size = batch_size


    train_max_batch_size = [32,28,24,20,16,14,12,10,8,6,4,2,1]
    for batch_size in train_max_batch_size:
        with torch.no_grad():
            torch.cuda.empty_cache()
        try:
            if dfc is not None:
                fow_mem, back_mem,_ = get_forward_backward_memory_max_batch_size(dfc, (batch_size, *inp.shape[1:]))
            else: 
                fow_mem, back_mem,_ = get_forward_backward_memory_max_batch_size(unet, (batch_size, *inp.shape[1:]))
            break
        except Exception as e:
            print(f"Exceeding GPU memory for batch size {batch_size}", e)
            with torch.no_grad():
                torch.cuda.empty_cache()
            


    crop_sz = config["dfc"]["training"]["IMG_SIZE"]
    kern_sz = config["dfc"]["fourier_part"]["n_modes"]
    rank = config["dfc"]["fourier_part"]["rank"]

    P_star = prediction_dict["precision_at_best_f1"]
    R_star = prediction_dict["recall_at_best_f1"]
    best_f1 = prediction_dict["best_F1"]
    output_str = f"{crop_sz} & {rank}  & {P_star:.3f} & {R_star:.3f} & {best_f1:.3f} & {count_params(unet)*32/8/1024**3:.2f} & {fow_mem:.2f} & {timings[0]/(opt_batch_size/crops_per_hologram):.4f}$\pm${timings[1]/(opt_batch_size/crops_per_hologram):.4f}"
    print(output_str)


    # # save the above
    print("saving...")
    with open(config["unet"]["data"]["SAVE_FOLDER"]+"flashµ_predictions.pkl", mode = 'wb+') as f:
        pickle.dump(prediction_dict, f)
    

