In [None]:
import pickle
import torch
from data.datasets_pytorch import get_dataloaders
from networks.UNetPytorch import UNet
from tools import set_device
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib.colors import ListedColormap

sns.set_style("white")

# Loading train/validation/test data
with open(f"./data/dataset.pkl", "rb") as file:
    dict_data = pickle.load(file)

# Set device
device = set_device()

# Get Pytorch DataLoaders
dataloaders = get_dataloaders(
    dict_data,
    device=device,
    batch_size_train=16,
    batch_size_validation=2,
    num_workers=0,
)

In [None]:
Nseg = UNet(
    n_channels=1,
    n_output_channels=3,
    initial_channels=16,
    ndepth=5,
    bilinear=False,
    activation="relu",
    dropout_rate=0.1,
    final_activation="softmax",
).to(device)

state_dict = torch.load("./models/Nseg", map_location=device)
Nseg.load_state_dict(state_dict)

In [None]:
color_dict = {0: 'white',
              1: 'black',
              2: 'blue',
            }
cm = ListedColormap(color_dict.values())
cbar_lims = (0,2)
size = 5
Nseg.eval()
with torch.no_grad():
    for input_sigma_C, output_seg_C in dataloaders["test_seg_C"]:
        input_seg_C = input_sigma_C[:,:1]
        preds_Nseg_C = Nseg(input_seg_C.to(device))
        
        input_seg_C = input_seg_C.detach().cpu().numpy()
        preds_Nseg_C = preds_Nseg_C.detach().cpu().numpy()[0]

        segmented = np.zeros(preds_Nseg_C[0].shape)
        segmented += np.array(preds_Nseg_C[1]>0.5)*1
        segmented += np.array(preds_Nseg_C[2]>=0.5)*2

        fig = plt.figure(figsize=(size, size/2), constrained_layout=True)
        subfig = fig.subfigures(nrows=1, ncols=1, hspace=0.)
        axes = subfig.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [1, 1]})


        im_gs = sns.heatmap(input_seg_C[0,0], ax=axes[0], cmap='gray', cbar=False)
        im_seg = sns.heatmap(segmented, ax=axes[1], cmap=cm, cbar=False, vmin=cbar_lims[0], vmax=cbar_lims[1])

        for ax in axes:
            ax.grid(False)
            ax.axis('tight')
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_aspect('equal')


        for image in [im_gs, im_seg]:
            image.invert_yaxis()
            for _, spine in image.spines.items(): 
                spine.set_visible(True) 
                spine.set_linewidth(1) 

        plt.show()

In [None]:
Nsigma = UNet(
    n_channels=4,
    n_output_channels=3,
    initial_channels=32,
    ndepth=5,
    bilinear=False,
    activation="elu",
    dropout_rate=None,
).to(device)

state_dict = torch.load("./models/Nsigma", map_location=device)
Nsigma.load_state_dict(state_dict)

In [None]:
Nsigma.eval()
titles = ['Prediction', 'Ground truth']
components = [r'$\sigma_{xx}$', r'$\sigma_{xy}$', r'$\sigma_{yy}$']
size = 5
with torch.no_grad():
    for input_sigma_C, output_sigma_C in dataloaders['test_sigma_C']:
        preds_sigma_C = Nsigma(input_sigma_C.to(device))

        preds_sigma_C = preds_sigma_C.detach().cpu().numpy()
        output_sigma_C = output_sigma_C.detach().cpu().numpy()
        for i in range(3):
            fig, axes = plt.subplots(1, 2, figsize=(size, size/2))

            im_pred = sns.heatmap(
                preds_sigma_C[0, i],
                cmap='plasma',
                ax=axes[0],
            )
            im_gt = sns.heatmap(
                output_sigma_C[0, i],
                cmap='plasma',
                ax=axes[1],
            )
            
            im_pred.invert_yaxis()
            im_gt.invert_yaxis()

            for c_ax, ax in enumerate(axes):
                ax.axis("tight")
                ax.set_axis_off()
                ax.set_aspect("equal")
                ax.set_title(titles[c_ax]+' '+components[i], fontsize=10)

            plt.tight_layout(pad=0)
            plt.show()