# Montecarlo Approximation of Electron-Matter-Interaction

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sys
from fastai.vision.all import * 
import fastai
import torch
print("sys.version", sys.version)
print("cuda device name(0)", torch.cuda.get_device_name(0))
print("torch.__version__", torch.__version__)
print("fastai.__version__", fastai.__version__)

sys.version 3.11.0 (main, Oct 24 2022, 18:26:48) [MSC v.1933 64 bit (AMD64)]
cuda device name(0) NVIDIA GeForce RTX 2080
torch.__version__ 2.0.1+cu118
fastai.__version__ 2.7.12


In [5]:
import pathlib
from os import listdir
input_path = Path('/mnt/aetna-cluster-workspace/data/')

def get_items(input_path):
    file_names = listdir(input_path)
    print(input_path)
    print(file_names)
    file_names = [filename for filename in file_names if "_bse.tif" in str( filename ) ]
    return file_names
    
file_names = get_items(input_path)    
print(file_names[0:5])

\mnt\aetna-cluster-workspace\data
[]
[]


## Datensatz laden

In [4]:
def get_hf_0(filename):
    filename_parts = str(filename.stem).split("_")
    y = filename.parent / Path( "_".join(filename_parts[:-1]) + "_entry_hf_0.tif" )
    return str(y)

def get_hf_1(filename):
    filename_parts = str(filename.stem).split("_")
    y = filename.parent / Path( "_".join(filename_parts[:-1]) + "_exit_hf_0.tif" )
    return str(y)

def get_hf_2(filename):
    filename_parts = str(filename.stem).split("_")
    y = filename.parent / Path( "_".join(filename_parts[:-1]) + "_entry_hf_1.tif" )
    return str(y)

def get_hf_3(filename):
    filename_parts = str(filename.stem).split("_")
    y = filename.parent / Path( "_".join(filename_parts[:-1]) + "_exit_hf_1.tif" )
    return str(y)

def get_normal(filename):
    filename_parts = str(filename.stem).split("_")
    y = filename.parent / Path( "_".join(filename_parts[:-1]) + "_normal.tif" )
    return str(y)

def get_bse(filename):
    filename_parts = str(filename.stem).split("_")
    y = filename.parent / Path( "_".join(filename_parts[:-1]) + "_bse.tif" )
    return str(y)

def get_se(filename):
    filename_parts = str(filename.stem).split("_")
    y = filename.parent / Path( "_".join(filename_parts[:-1]) + "_se.tif" )
    return str(y)

for i in range(5):
    print(file_names[i], "/", get_normal(file_names[i]), "/", get_bse(file_names[i]))

IndexError: list index out of range

In [None]:
item_transforms  = [RandomCrop((512,512)),DihedralItem]

datablocks = DataBlock(blocks=(ImageBlock(cls=PILImageBW), ImageBlock(cls=PILImageBW), ImageBlock(cls=PILImageBW), ImageBlock(cls=PILImageBW), ImageBlock, ImageBlock(cls=PILImageBW), ImageBlock(cls=PILImageBW)),
                       n_inp=5,
                       get_items=get_items,
                       getters=[get_hf_0, get_hf_1, get_hf_2, get_hf_3, get_normal, get_bse, get_se],
                       splitter=RandomSplitter(valid_pct=0.2, seed=42),
                       item_tfms=item_transforms)

data_loader = datablocks.dataloaders(input_path, bs=1, num_workers=0 )

@typedispatch
def show_batch(x:tuple, y, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):    
    hf0,hf1,hf2,hf3,normal = x
    
    show_output = False
    bse,se = y
    batch_size = hf0.shape[0]
    
    nrows = min(x[0].shape[0], max_n)
    if show_output:
        ncols = 7
    else:
        ncols = 5
    if ctxs is None: ctxs = get_grid(nrows*ncols, nrows=nrows, ncols=ncols, figsize=figsize)
    
    ctxi = 0
    for i in range(batch_size):
        hf0_image = hf0[i,:,:,:].squeeze(0)
        show_image(hf0_image, title="heightmap (entry 1)", ctx=ctxs[ctxi], cmap="gray", **kwargs)
        
        ctxi = ctxi+1
        
        hf1_image = hf1[i,:,:,:].squeeze(0)
        show_image(hf1_image, title="heightmap (exit 1)", ctx=ctxs[ctxi], cmap="gray", **kwargs)
        ctxi = ctxi+1

        hf2_image = hf2[i,:,:,:].squeeze(0)
        show_image(hf2_image, title="heightmap (entry 2) ", ctx=ctxs[ctxi], cmap="gray", **kwargs)
        ctxi = ctxi+1
        
        hf3_image = hf3[i,:,:,:].squeeze(0)
        show_image(hf3_image, title="heightmap (exit 2) ", ctx=ctxs[ctxi], cmap="gray", **kwargs)
        ctxi = ctxi+1        
        
        normal_image = normal[i,:,:,:].squeeze(0)      
        show_image(normal_image, title="normalmap", ctx=ctxs[ctxi], **kwargs)
        ctxi = ctxi+1
        
        if show_output:
            bse_image = bse[i,:,:,:].squeeze(0)      
            show_image(bse_image, title="BSE", ctx=ctxs[ctxi], **kwargs)
            ctxi = ctxi+1

            se_image = se[i,:,:,:].squeeze(0)      
            show_image(se_image, title="SE", ctx=ctxs[ctxi], **kwargs)
            ctxi = ctxi+1
    plt.savefig("data_representation.png", dpi=300)  
    
data_loader.show_batch( )

# Netzwerk Architektur

In [None]:
def create_inner_model( dataloader, backbone, datalayout, **kwargs ):
    img_size = dataloader.one_batch()[0].shape[-2:]
    
    if datalayout == "exthf_normal" or datalayout == "normal_exthf":
        n_in=7
    if datalayout == "exthf_only" or datalayout == "hf_normal":
        n_in=4
    if datalayout == "hf_only":
        n_in=1
    
    model = create_unet_model(backbone, 2, img_size, n_in=n_in, **kwargs)
    return model

class FIBModel(torch.nn.Module):
    def __init__(self, inner_model, datalayout = "normal_exthf"):
        super().__init__()
        self.inner_model = inner_model
        self.datalayout = datalayout
        
    def forward(self, x_hf_0, x_hf_1, x_hf_2, x_hf_3, x_normal):

        print("normal_map", x_normal.shape, x_normal.dtype, torch.min(x_normal), "-", torch.max(x_normal))
        print("x_hf_0    ", x_hf_0.shape, x_hf_0.dtype, torch.min(x_hf_0), "-", torch.max(x_hf_0))
        print("x_hf_1    ", x_hf_1.shape, x_hf_1.dtype, torch.min(x_hf_1), "-", torch.max(x_hf_1))
        print("x_hf_2    ", x_hf_2.shape, x_hf_2.dtype, torch.min(x_hf_2), "-", torch.max(x_hf_2))
        print("x_hf_3    ", x_hf_3.shape, x_hf_3.dtype, torch.min(x_hf_3), "-", torch.max(x_hf_3))
        
        if self.datalayout == "normal_exthf":
            x = torch.cat( (x_normal, x_hf_0, x_hf_1, x_hf_2, x_hf_3 ), dim=1 )
        elif self.datalayout == "exthf_normal":
            x = torch.cat( (x_hf_0, x_hf_1, x_hf_2, x_hf_3, x_normal ), dim=1 )
        elif self.datalayout == "exthf_only":
            x = torch.cat( (x_hf_0, x_hf_1, x_hf_2, x_hf_3 ), dim=1 )
        elif self.datalayout == "hf_normal":
            x = torch.cat( (x_hf_0, x_normal ), dim=1 )
        elif self.datalayout == "hf_only":
            x = x_hf_0
        output_of_inner_model = self.inner_model(x)
        output = torch.split(output_of_inner_model, 1, dim=1)
        # output = output_of_inner_model
        return output

## Determine Learning Rate

In [None]:
learning_rate=0.001

## Adding noise

In [None]:
from skimage.util import random_noise

def add_noise( img, var ):
    return torch.tensor( random_noise(img, mode='gaussian', mean=0, var=var, clip=True) )

def estimate_snr( img ):
    return img.mean() / img.std();

## Loss Function

In [None]:
class CombinedLoss():
    def __init__(self, losses, weights=None, reduction='mean', axis=-1):
        self.losses = losses
        self._reduction = reduction
        self.axis = axis
        self.weights = weights
        if weights is None:
            self.weights = []
            for _ in losses:
                self.weights.append(1.0)
        
    def __call__(self, out, *yb):
        total_loss = 0.0
        # out = torch.split(out, 1, dim=1)
        for i,loss_fct in enumerate( self.losses ):
            loss_fct.reduction = 'none'
            total_loss += loss_fct(out[i], yb[i]) * self.weights[i]
        if self.reduction == "mean":
            total_loss = total_loss.mean()
        elif self.reduction == "sum":
            total_loss = total_loss.sum()
        return total_loss
    
    @property
    def reduction(self) -> str:
        return self._reduction    
    
    @reduction.setter
    def reduction(self, reduction:str):
        self._reduction = reduction  
    
    def decodes(self, x:Tensor) -> Tensor:    
        return x
        # return x.argmax(dim=self.axis)
    
    def activation(self, x:Tensor) -> Tensor:                 
        activation = torch.zeros(x[0].shape)
        for xi in list(x):
            activation += F.softmax(xi, dim=self.axis)    
        return activation  
    
l1          = CombinedLoss([ L1LossFlat(), L1LossFlat()] )    
l1_weighted = CombinedLoss([ L1LossFlat(), L1LossFlat()], [1.0, 1.691]  )
l2          = CombinedLoss([ MSELossFlat(), MSELossFlat()])
l2_weighted = CombinedLoss([ MSELossFlat(), MSELossFlat()], [1.0, 1.691] )

## Metrics 

In [None]:
from fastai import metrics

def total_mse(inp, *targ):
    total = 0.0
    for i in range(len(inp)):
        total += mse(inp[i], targ[i])
    return total

def total_l1(inp, *targ):
    total = 0.0
    for i in range(len(inp)):
        total += mae(inp[i], targ[i])
    return total

# Evaluation

In [None]:
import matplotlib as mp
cpts = [0.0, 127.0/255.0, 1.0]
colors = [(cpts[0], (0, 0, 1.0)), (cpts[1], (0, 1.0, 0.0)), (cpts[2], (1.0, 0, 0))]
cmap_name = 'my_list'
colormap = mp.colors.LinearSegmentedColormap.from_list(cmap_name, colors)

In [None]:
def filler(depth = 0):
    result = ""
    for i in range(depth):
        result = result + "  "
    return result

def print_tuple(x, depth = 0):
    if isinstance(x,fastai.torch_core.TensorBase):
        print(filler(depth), type(x), x.shape)
    elif isinstance(x,torch.Tensor):
        print(filler(depth), type(x), x.shape)
    elif type(x) is tuple:
        print(filler(depth), type(x), len(x) )
        for child_x in list(x):
            print_tuple( child_x, depth+1 )
    else:
        print(filler(depth), type(x))

class ConstantFunc():
    "Returns a function that returns `o`"
    def __init__(self, o): self.o = o
    def __call__(self, *args, **kwargs): return self.o        
        
class PredictionsFromTupleCallback(Callback):    
    def before_validate(self):
        self.preds = []
        self.targets = []
            
    def after_pred(self, **kwargs:Any)->None:
        se,bse = self.pred
        se = to_detach(se)
        bse = to_detach(bse)
        self.preds.append((se,bse))
        self.targets.append(self.yb)       

def create_loss_image( learner, output_filename, top=True ):
    n_images = 6
    
    interpretation = Interpretation.from_learner( learner )    
    values,indices = interpretation.top_losses(k=n_images, largest=top)

    metrics_string = ""
    for metric in learner.metrics:
        metrics_string = metrics_string + metric.name + ": " + "{:.3f}".format(metric.value.item() ) + "\n"
    
    tmp_data_loader = learner.dls[1].new( get_idxs = ConstantFunc( indices ), bs=1 )
    cb = PredictionsFromTupleCallback()
    ctx_mgrs = learner.validation_context(cbs=[cb])
    with ContextManagers(ctx_mgrs):
        learner._do_epoch_validate(dl=tmp_data_loader)

    all_predictions = cb.preds
    all_targets     = cb.targets

    figure      = plt.figure( constrained_layout=True, figsize=(16,12*n_images) )
    figure.suptitle(metrics_string, fontsize=16 )
    
    sub_figures = figure.subfigures(nrows=n_images, ncols=1)
    
    # fig, axs = plt.subplots(2*4, 3, figsize=(16,48))

    for i,(idx,loss_value) in enumerate(zip(indices,values)):
        filename = str( learner.dls.valid_ds.items[idx] )
        
        sub_figures[i].suptitle( filename + "\nLoss:{:.3f}".format(loss_value) )
        axs = sub_figures[i].subplots(nrows=2, ncols=3)        
        
        # hf_preds,bse_preds = preds
        se_pred,bse_pred = all_predictions[i]
        se_pred  = torch.squeeze( se_pred, 0 )
        bse_pred = torch.squeeze( bse_pred, 0 )

        se_target,bse_target = all_targets[i]
        se_target  = torch.squeeze(  se_target, 0 )
        bse_target = torch.squeeze( bse_target, 0 )

        se_resid  = abs(se_pred.cpu() - se_target.cpu())
        bse_resid = abs(bse_pred.cpu() - bse_target.cpu())

        show_image( ax=axs[0,0], im=se_pred,        title="SE Prediction ", vmin=0, vmax=1,  cmap="gray")
        show_image( ax=axs[1,0], im=bse_pred,       title="BSE Prediction", vmin=0, vmax=1,  cmap="gray")

        show_image( ax=axs[0,1], im=se_target,      title="SE Target", vmin=0, vmax=1,       cmap="gray")
        show_image( ax=axs[1,1], im=bse_target,     title="BSE Target", vmin=0, vmax=1,      cmap="gray")

        show_image( ax=axs[0,2], im=se_resid,       title="SE Residual", vmin=0, vmax=0.20,  cmap="coolwarm")
        show_image( ax=axs[1,2], im=bse_resid,      title="BSE Residual", vmin=0, vmax=0.20, cmap="coolwarm")

    plt.savefig(output_filename, dpi=150)    

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import time

def get_predictions( learner ):
    callback = PredictionsFromTupleCallback()
    ctx_mgrs = learner.validation_context(cbs=[callback])
    with ContextManagers(ctx_mgrs):
        learner._do_epoch_validate(dl=learner.dls.valid)
    return callback.preds,callback.targets

def evaluate_predictions( learner ):
    print(learner)
    
    print("performing predictions")
    start = time.perf_counter()
    predictions,targets = get_predictions( learner )
    stop = time.perf_counter()
    print("done performing predictions of ", len(predictions), f"items in {stop-start:0.4f}")

    n_images = 0
    
    bse_std = 0.0
    se_std  = 0.0
       
    ssim_bse = 0.0
    ssim_se  = 0.0
    
    psnr_bse = 0.0
    psnr_se  = 0.0
        
    for (bse_pred,se_pred),(bse_target,se_target) in zip(predictions,targets) :
        bse_pred   = bse_pred.cpu().to(dtype=torch.float32)
        se_pred    = se_pred.cpu().to(dtype=torch.float32)
        bse_target = bse_target.cpu().to(dtype=torch.float32)
        se_target  = se_target.cpu().to(dtype=torch.float32)
        
        bse_pred   = bse_pred * 255.0
        se_pred    = se_pred * 255.0
        bse_target = bse_target * 255.0
        se_target  = se_target * 255.0
        
        bse_residual = bse_pred - bse_target
        se_residual  = se_pred - se_target
       
        for i in range(bse_residual.shape[0]):
            n_images = n_images + 1
            bse_std = bse_std + torch.std(bse_residual[i,0,:,:])            
            se_std = se_std + torch.std(se_residual[i,0,:,:])
            
            bse_a = bse_pred[i,0,:,:].squeeze(0).squeeze(0).cpu().detach().numpy()
            bse_b = bse_target[i,0,:,:].squeeze(0).squeeze(0).cpu().detach().numpy()            
            ssim_bse = ssim_bse + ssim(bse_a, bse_b, data_range=255)
            psnr_bse = psnr_bse + psnr(bse_a, bse_b, data_range=255)

            se_a = se_pred[i,0,:,:].squeeze(0).squeeze(0).cpu().detach().numpy()
            se_b = se_target[i,0,:,:].squeeze(0).squeeze(0).cpu().detach().numpy()            
            ssim_se = ssim_se + ssim(se_a, se_b, data_range=255)
            psnr_se = psnr_se + psnr(se_a, se_b, data_range=255)
            
    print( "bse std",  "{:.2f}".format(bse_std / float(n_images) ) )
    print( "se  std ", "{:.2f}".format(se_std / float(n_images) ) )
    
    print( "bse ssim",  "{:.2f}".format(ssim_bse / float(n_images) ) )
    print( "se  ssim ", "{:.2f}".format(ssim_se / float(n_images) ) )
    
    print( "bse psnr",  "{:.2f}".format(psnr_bse / float(n_images) ) )
    print( "se  psnr ", "{:.2f}".format(psnr_se / float(n_images) ) )
            
# evaluate_predictions( learner )            

## Learner

In [None]:
all_losses =  [ (l1, "l1"), (l2, "l2"), (l1_weighted, "l1_weighted"), (l2_weighted, "l2_weighted") ]
all_backbones = [ (resnet152, "resnet152"), (resnet101, "resnet101"), (resnet50,"resnet50"), (resnet34,"resnet34") ]
all_datalayout = ["exthf_normal", "normal_exthf", "exthf_only", "hf_normal", "hf_only" ]

experiments = []
experiments.append( ( (l2, "l2"), (resnet152, "resnet152"), "exthf_normal") )

# for loss in all_losses:
#    experiments.append( (loss, (resnet152, "resnet152"), "exthf_normal") )#
#for backbone in all_backbones:
#    experiments.append(  ((l1, "l1"), backbone, "exthf_normal") )
#for layout in all_datalayout:
#    experiments.append(  ((l1, "l1"), (resnet152, "resnet152"), layout ))

In [None]:
train    = True
evaluate = True
test     = False

n_epochs = [19,30,50]

model_dir = "/mnt/aetna-cluster-workspace/models/"
    
for (loss_func, lossname),(backbone, backbone_name),datalayout in experiments:
    model = FIBModel( create_inner_model( data_loader, backbone, datalayout ), datalayout )
    learner = Learner( data_loader, model, model_dir=model_dir, loss_func=loss_func, metrics=[total_mse, total_l1] )

    if train:
        learner.fit( 1, lr=learning_rate )
        epochs = 1

        for step in n_epochs:
            epochs = epochs + step
            weight_name = datalayout + "_" + backbone_name + "_" + lossname + "_" + str(epochs)
            print(weight_name)
            learner.fit( step, lr=learning_rate, cbs=[ShowGraphCallback(), CSVLogger(filename=weight_name+".csv")] )    
            learner.save( weight_name )
    else:
        epochs = 100
        weight_name = datalayout + "_" + backbone_name + "_" + lossname + "_" + str(epochs)
        print("loading weights", weight_name)
        learner.load(weight_name)

    if evaluate:
        evaluate_predictions( learner )
        create_loss_image( learner, weight_name + "_worst6.png", top=True)
        create_loss_image( learner, weight_name + "_bestp6.png", top=False)
        
    if test:
        test_input_path = Path('./test_data')
        test_files = get_items( test_input_path )
        test_dataloader = learner.dls.test_dl( test_files )

        cb = PredictionsFromTupleCallback()
        ctx_mgrs = learner.validation_context(cbs=[cb])
        with ContextManagers(ctx_mgrs):
            learner._do_epoch_validate(dl=test_dataloader)
        all_preds = cb.preds

        # all_preds = learner.get_preds( dl=test_dataloader, with_decoded=False, with_input=True )            

## Testing 

In [None]:
n_rows = 8

figure   = plt.figure  ( constrained_layout=True )
fig, axs = plt.subplots( nrows=n_rows, ncols=2, figsize=(16,8*n_rows) )        

for i,preds in enumerate(all_preds[0:n_rows]):
    se,bse = preds
    show_image( ax=axs[i,0], im=se[0,:,:],  title="SE", cmap="gray")    
    show_image( ax=axs[i,1], im=bse[0,:,:], title="BSE",  cmap="gray")