In [None]:
%load_ext autoreload
%autoreload 2

device = 'cuda'

In [None]:
import torch 
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import torch.nn as nn 

np.random.seed(0)
torch.manual_seed(0)

# Fully Connected Network

In [None]:
from Elemental_Mapping.datasets.Pixel2PixelDataset import Pixel2PixelDataset

images = ['gogo', 'dionisios', 'fanourios', 'minos', 'odigitria']
test_image = 'saintjohn'

band_range = range(0, 4096)

In [None]:
dataset = Pixel2PixelDataset(
    '/home/igeor/MSC-THESIS/data/h5',
    image_names=images, 
    sample_step = 10, 
    device='cuda', 
    band_range=(band_range.start, band_range.stop), 
    target_elems=['S_K','K_K','Ca_K','Cr_K','Mn_K','Fe_K','Cu_K','Zn_K','Sr_K','Au_L','Hg_L','Pb_L'])

# Split dataset into training and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True)

## Model

In [None]:
# Fully Connected Network
from Elemental_Mapping.models.FullyConnectedModel import FullyConnectedModel 
from spec_db import pure_elements
from Elemental_Mapping.models.PriorLayer import PriorLayer 

# set as w the values of keys of pure_elements
w = torch.cat([pure_spectrum.unsqueeze(0) for pure_spectrum in pure_elements.values()], dim=0)
prior_layer = PriorLayer(
    w, s=None, bias=False, apply_sum=True, requres_grad=False, device='cuda') 

fcn = FullyConnectedModel(
    in_features=4096, 
    out_features=12, 
    hidden_dims=[512, 64, 64], 
    prior_layer=prior_layer,
    dropout=0.0
).to(device)

fcn

## Training

In [None]:
from Elemental_Mapping.loss_functions.AdaptiveL1Loss import AdaptiveL1Loss

# Loss Function
train_criterion = AdaptiveL1Loss()
# Adam Optimizer
fcn_optimizer = torch.optim.Adam(fcn.parameters(), lr=1e-3)

In [None]:
n_epochs = 2000
eval_n_epochs = 2
min_val_loss = np.inf
for epoch in range(n_epochs):
    train_loss = fcn.train(train_loader, fcn_optimizer, train_criterion, epochs=1, device='cuda')
    if epoch % eval_n_epochs == 0:
        eval_loss, _ = fcn.eval(val_loader, train_criterion, device='cuda')
    print(f'Epoch: {epoch}, Train Loss: {train_loss} Eval Loss: {eval_loss}')
    if eval_loss < min_val_loss:
        min_val_loss = eval_loss
        torch.save(fcn.state_dict(), f'./results/Elemental_Mapping/new models/fcn_testOdigitria_v3.pt')

## Evaluation

In [None]:
test_dataset = Pixel2PixelDataset(
    '/home/igeor/MSC-THESIS/data/h5',
    image_names=[test_image], 
    sample_step = 1, 
    device='cuda', 
    band_range=(band_range.start, band_range.stop), 
    target_elems=['S_K','K_K','Ca_K','Cr_K','Mn_K','Fe_K','Cu_K','Zn_K','Sr_K','Au_L','Hg_L','Pb_L'])

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
## Open target image file (elemental_maps)
df = pd.read_csv(f'/home/igeor/MSC-THESIS/data/h5/elem_maps/{test_image}.dat' , sep='  ', engine='python')
width, height = df['row'].iloc[-1] + 1, df['column'].iloc[-1] + 1
y_real = np.array(df[test_dataset.target_elems])
y_real = y_real.reshape((width, height, len(test_dataset.target_elems)))

In [None]:
# load fcn state_dict
fcn.load_state_dict(torch.load(f'./results/Elemental_Mapping/models/FCNplus_testSaintJohn.pt'))

_, y_pred = fcn.eval(test_loader, torch.nn.L1Loss(), device='cuda')

fcnplus_y_pred = y_pred.reshape((width, height, len(test_dataset.target_elems))).cpu().detach().numpy()

# 1d Convolutional Network

In [None]:
from Elemental_Mapping.datasets.Pixel2PixelDataset import Pixel2PixelDataset

images = ['gogo', 'dionisios', 'fanourios', 'minos', 'odigitria']
test_image = 'saintjohn'
band_range = range(0, 4096)

In [None]:
dataset = Pixel2PixelDataset(
    '/home/igeor/MSC-THESIS/data/h5',
    image_names=images, 
    sample_step = 10, 
    device='cuda', 
    band_range=(band_range.start, band_range.stop), 
    target_elems=['S_K','K_K','Ca_K','Cr_K','Mn_K','Fe_K','Cu_K','Zn_K','Sr_K','Au_L','Hg_L','Pb_L'])

# Split dataset into training and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True)

## Model

In [None]:
# Fully Connected Network
from Elemental_Mapping.models.Conv1DModel import Conv1DModel 
from spec_db import pure_elements
from Elemental_Mapping.models.PriorLayer import PriorLayer 

# Set as w the values of keys of pure_elements
w = torch.cat([pure_spectrum.unsqueeze(0) for pure_spectrum in pure_elements.values()], dim=0)
# Initialize the PriorLayer
prior_layer = PriorLayer(w, s=None, bias=False, apply_sum=False, requres_grad=False, device='cuda') 

cnn1d = Conv1DModel(in_features=4096, hidden_dims=[64, 64, 64, 64, 128], out_features=12, 
    prior_layer=prior_layer, iis=True, flatten_dims=512, dropout=0.0).to(device)

print(cnn1d.alias)
cnn1d

## Training

In [None]:
from Elemental_Mapping.loss_functions.AdaptiveL1Loss import AdaptiveL1Loss

# Loss Function
train_criterion = AdaptiveL1Loss()
# Adam Optimizer
cnn1d_optimizer = torch.optim.Adam(cnn1d.parameters(), lr=1e-3)

In [None]:
import csv
with open('./results/Elemental_Mapping/CNN1Dplus_testSaintJohn.csv', 'w') as f:
    writer = csv.writer(f)
    writer.writerow(['epoch', 'train_loss', 'val_loss'])


n_epochs = 1500
eval_n_epochs = 5

min_val_loss = np.inf
for epoch in range(0, n_epochs):
    train_loss = cnn1d.train(train_loader, cnn1d_optimizer, train_criterion, epochs=1, device='cuda')
    # Evaluate on validation set every eval_n_epochs
    if epoch % eval_n_epochs == 0:
        eval_loss, _ = cnn1d.eval(val_loader, train_criterion, device='cuda')
        
    # Save model if eval_loss is the lowest so far
    if eval_loss < min_val_loss:
        min_val_loss = eval_loss
        torch.save(cnn1d.state_dict(), f'./results/Elemental_Mapping/{cnn1d.alias}.pt')
    
    # Store the training and validation losses for each epoch in a csv file
    with open('./results/Elemental_Mapping/CNN1Dplus_testSaintJohn.csv', 'a') as f:
        writer = csv.writer(f)
        writer.writerow([epoch, train_loss, eval_loss])

    # Print train and eval loss
    print(f'Epoch: {epoch}, Train Loss: {round(train_loss, 4)} Eval Loss: {round(eval_loss, 4)}')

## Evaluation

In [None]:
band_range = range(0, 4096)

test_dataset = Pixel2PixelDataset(
    '/home/igeor/MSC-THESIS/data/h5',
    image_names = [test_image], 
    sample_step = 1, 
    device='cuda', 
    band_range=(band_range.start, band_range.stop), 
    target_elems=['S_K','K_K','Ca_K','Cr_K','Mn_K','Fe_K','Cu_K','Zn_K','Sr_K','Au_L','Hg_L','Pb_L'])

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
## Open target image file (elemental_maps)
df = pd.read_csv(f'/home/igeor/MSC-THESIS/data/h5/elem_maps/{test_image}.dat' , sep='  ', engine='python')
width, height = df['row'].iloc[-1] + 1, df['column'].iloc[-1] + 1
y_real = np.array(df[test_dataset.target_elems])
y_real = y_real.reshape((width, height, len(test_dataset.target_elems)))

In [None]:
# Load cnn1d state_dict
cnn1d.load_state_dict(torch.load(f'./results/Elemental_Mapping/models/CNN1Dplus_testSaintJohn.pt'))

# Evaluate the model on the test set
_, y_pred = cnn1d.eval(test_loader, torch.nn.L1Loss(), device='cuda')

y_pred = y_pred.reshape((width, height, len(test_dataset.target_elems))).cpu().detach().numpy()

# Evaluation Metrics

### Visualization

In [None]:
# plot the pred and real (12) image in a 2x12 grid
fix, axs = plt.subplots(2, len(test_dataset.target_elems), figsize=(20, 10))
for i in range(len(test_dataset.target_elems)):

    axs[0, i].imshow(y_real[:,:,i])
    axs[0, i].set_title(test_dataset.target_elems[i])
    axs[0, i].set_xticks([]); axs[0, i].set_yticks([])
    axs[0, 0].set_ylabel('GT')

    axs[1, i].imshow(y_pred[:,:,i])
    axs[1, 0].set_ylabel('Pred')
    axs[1, i].set_xticks([]); axs[1, i].set_yticks([])

plt.tight_layout()
# plt.savefig(f'./results/Elemental_Mapping/{cnn1d.alias}.png')

### Z-score

In [None]:
def z_score_eval(y_real, y_pred):
    # Initialize an empty numpy array of shape (w, h, 3)
    out_image = np.zeros((width, height, 3))
    
    # Compute the z-score of y_pred and y_real
    zscore = (np.abs(y_pred[:,:,i] - y_real[:,:,i])) / np.sqrt(y_real[:,:,i] + 1)

    # Find the indices where zscore is between 0 and 1 
    # and set the corresponding pixels to white color
    z0to1 = np.logical_and(zscore >= 0, zscore < 1) 
    num_0to1 = z0to1.sum()
    pxls_x, pxls_y = np.where(z0to1 == True)
    out_image[pxls_x, pxls_y, :] = [1, 1, 1]

    # Find the indices where zscore is between 1 and 2 
    # and set the corresponding pixels to orange color
    z1to2 = np.logical_and(zscore >= 1, zscore < 2)
    num_1to2 = z1to2.sum()
    pxls_x, pxls_y = np.where(z1to2 == True)
    out_image[pxls_x, pxls_y, :] = [1, 0.5, 0]

    # Find the indices where zscore is between 2 and 3
    # and set the corresponding pixels to red color
    z2to3 = np.logical_and(zscore >= 2, zscore < 3)
    num_2to3 = z2to3.sum()
    pxls_x, pxls_y = np.where(z2to3 == True)
    out_image[pxls_x, pxls_y, :] = [1, 0, 0]

    # Find the indices where zscore is greater than 3
    # and set the corresponding pixels to black color
    z3toInf = zscore >= 3
    num_3toInf = z3toInf.sum()
    pxls_x, pxls_y = np.where(z3toInf == True)
    out_image[pxls_x, pxls_y, :] = [0, 0, 0]

    return out_image, num_0to1, num_1to2, num_2to3, num_3toInf

In [None]:
# z-score of y_pred and y_real
zscore_per_elem = { elem: None for elem in test_dataset.target_elems }

fig, axes = plt.subplots(2, 12, figsize=(20, 10))
for i in range(len(test_dataset.target_elems)):
    axes[0, i].imshow(y_real[:,:,i])
    axes[0, i].set_title(test_dataset.target_elems[i])
    axes[0, i].set_xticks([]); axes[0, i].set_yticks([])
    axes[0, 0].set_ylabel('GT')

for i in range(len(test_dataset.target_elems)):
    out_image, num_0to1, num_1to2, num_2to3, num_3toInf = z_score_eval(y_real, y_pred)
    axes[1, i].imshow(out_image)
    axes[1, 0].set_ylabel('Pred')
    axes[1, i].set_xticks([]); axes[1, i].set_yticks([])

print(f'Percentage of pixels with zscore between 0 and 1: {num_0to1 / (width * height)}')
print(f'Percentage of pixels with zscore between 1 and 2: {num_1to2 / (width * height) }')
print(f'Percentage of pixels with zscore between 2 and 3: {num_2to3 / (width * height) }')
print(f'Percentage of pixels with zscore greater than 3: {num_3toInf / (width * height) }')

plt.tight_layout()

### SSIM

In [None]:
from skimage.metrics import structural_similarity as ssim

ssim_per_elem = { elem: 0.0 for elem in test_dataset.target_elems }

for i in range(len(test_dataset.target_elems)):
    if isinstance(y_pred, torch.Tensor): y_pred = y_pred.cpu().detach().numpy()
    
    ssim_score = ssim(y_real[:,:,i], y_pred[:,:,i], data_range=1.0)
    ssim_per_elem[test_dataset.target_elems[i]] = ssim_score

# compute the mean of ssim per element
print(f'Mean SSIM per element: {np.mean(list(ssim_per_elem.values()))}')
ssim_per_elem['total'] = ssim(y_real, y_pred, data_range=1.0)
ssim_per_elem

### Pearson

In [None]:
# find the pearson correlation between the real and predicted y_real and y_pred
from scipy.stats import pearsonr

pearson_per_elem = { elem: 0.0 for elem in test_dataset.target_elems }
for i in range(len(test_dataset.target_elems)):
    pearson_score = pearsonr(y_real[:,:,i].flatten(), y_pred[:,:,i].flatten())[0]
    if np.isnan(pearson_score): pearson_score = 0.0
    pearson_per_elem[test_dataset.target_elems[i]] = pearson_score

# compute the mean of pearson per element
print(f'Mean Pearson per element: {np.mean(list(pearson_per_elem.values()))}')
pearson_per_elem['total'] = pearsonr(y_real.flatten(), y_pred.flatten())[0]
pearson_per_elem

### Slope

In [None]:
slope_per_elem = { elem: 0.0 for elem in test_dataset.target_elems }

for i in range(len(test_dataset.target_elems)):
    y_pred_flat = y_pred[:,:,i].flatten()
    y_real_flat = y_real[:,:,i].flatten()
    slope_score = np.mean((y_pred_flat + 1) / (y_real_flat + 1))
    if np.isnan(slope_score): slope_score = 0.0
    slope_per_elem[test_dataset.target_elems[i]] = slope_score 

# compute the mean slope per element
print(f'Mean slope per element: {np.mean(list(slope_per_elem.values()))}')
slope_per_elem