In [3]:
"""
Description : Train neural network model to predict one time step of M7
Options:

  --signs=<need_extra_signs_for_log_mass>
  --classification=<train_classification_net>
  --scale=<scaler>
  --model=<model_version>
"""

import numpy as np
from utils import standard_transform_x, standard_transform_y, get_model, train_model, create_report, calculate_stats, log_full_norm_transform_x, log_tend_norm_transform_y, create_dataloader, create_test_dataloader
# from models import Softmax_model
from utils import add_nn_arguments_jupyter
import torch.nn as nn 
import torch
import torch.optim as optim

from sklearn.metrics import mean_squared_error, r2_score

import torch
import torch.nn as nn
import torch.nn.functional as F

# KB add for active development in models or utils
# %load_ext autoreload
# %autoreload 2

In [4]:
# define full path 
path_to_data = "/home/kim/data/aerosols/aerosol_emulation_data/"

X_test = np.load(path_to_data + 'X_test.npy')
y_test = np.load(path_to_data + 'y_test.npy')

X_train = np.load(path_to_data + 'X_train.npy')
y_train = np.load(path_to_data + 'y_train.npy')

X_valid = np.load(path_to_data + 'X_val.npy')
y_valid = np.load(path_to_data + 'y_val.npy')

# Select the correct 24 columns
X_test_24 = X_test[:, 8:]
X_train_24 = X_train[:, 8:] 

y_test_24 = y_test[:, :24]
y_train_24 = y_train[:, :24]

y_valid_24 = y_valid[:, :24]
X_valid_24 = X_valid[:, 8:]

# How much has it changes between x (at t = 0)  and y (at t = 1)
y_delta_train_24 = y_train_24 - X_train_24
y_delta_test_24 = y_test_24 - X_test_24
y_delta_valid_24 = y_valid_24 - X_valid_24

# Define column indices for each of the components (24 column version)
so4_indices = [0, 1, 2, 3, 4]
bc_indices = [5, 6, 7, 8]
oc_indices = [9, 10, 11, 12]
du_indices = [13, 14, 15, 16]

# Define aerosol species and their corresponding indices
species_indices = {
    'so4': so4_indices,
    'bc': bc_indices,
    'oc': oc_indices,
    'du': du_indices
}

# What are these indices?!
extra_indices = [17, 18, 19, 20, 21, 22, 23] 

# Define aerosol species and their corresponding indices

### ARGS ###
args = add_nn_arguments_jupyter()
# Overwrite the model name, keep everything else the same
# Have one model for now as each input dim can be different
args.model = 'transition_model'
# args.model_id = 'transition_' + species # save different models
# Run for only 3 epochs for proof of concept
# Took around 2 mins per epoch
args.epochs = 3 
### DIFFERENT DIMS
# Takes a minute
# stats = calculate_stats(X_train, (y_train - X_train), X_test, (y_test - X_test), args)
# y's can be delata and 24, X is raw
stats = calculate_stats(X_train, y_delta_train_24, X_test, y_delta_test_24, args)

# Look at stats
np.set_printoptions(precision = 4, suppress = True, formatter = {'all': lambda x: f'{x:.4f}'})
# stats

In [None]:
class LogSoftmax_model(nn.Module):
    def __init__(self, in_features, out_features, width, depth = 2):
        super(LogSoftmax_model, self).__init__()
        self.fc_in = nn.Linear(in_features = in_features, out_features = width)
        # Create the hidden layers
        self.hidden_layers = nn.ModuleList()
        for i in range(depth - 1):
            self.hidden_layers.append(nn.ReLU())
            self.hidden_layers.append(nn.Linear(in_features = width, out_features = width))
            self.hidden_layers.append(nn.ReLU())
        # Output layer (fc: fully connected)
        self.fc_out = nn.Linear(in_features = width, out_features = out_features + 1)
        # ADD softmax layer: same as probabilities per class (classification)
        self.softmax = nn.Softmax(dim = -1)  # Apply softmax along the output dimension

    def forward(self, x):
        # x_relative = 
        x_log = torch.log(x + 1e-8)  # Apply log transformation to the input
        # Pass through the input layer (fully connected)
        out = self.fc_in(x_log)
        # Pass through hidden layers
        for layer in self.hidden_layers:
            out = layer(out)
        # Final output layer
        out = self.fc_out(out)

        # Split it up
        softmax_input = out[:, :-1]
        scalar = out[:, -1]

        # Apply softmax to the final output for classification
        # calculate in double precision
        softmax_out = self.softmax(softmax_input).double()

        # Make softmax zero_sum
        zero_sum_output = softmax_out - softmax_out.mean(dim = -1, keepdim = True)

        scaled_zero_sum_output = zero_sum_output * scalar.unsqueeze(1)

        # Avoid division by zero by ensuring denominator is never zero
        denominator = scaled_zero_sum_output.clone()
        denominator = torch.where(denominator == 0, torch.tensor(1e-10).to(denominator.device), denominator)

        ### Ensure non-negativivity constraint ###
        safe_beta = torch.where((
            (scaled_zero_sum_output + x) < 0), # In the case of negative values
            (- x / denominator), # scalar candidates, maybe add noise?! works per row
            torch.tensor(float('inf')) # infinity if no violation (so it doesn't get selected)
            ).min(dim = 1)[0].unsqueeze(-1) # select the minimum value over columns (for each row)
            # minimum by which we have to scale it backk. 0 in worst case
        # unsqueeze to make torch.Size([n_batch, 1])

        # In case of no violation (safe_beta == inf), set beta to 1 - no change occurs
        # In case of violation, set beta to the minimum value (zero in worst case)
        # row-wise i.e. batch-wise min selection
        beta = torch.min(torch.ones(safe_beta.shape), safe_beta)

        safe_out = beta * zero_sum_output

        return safe_out

In [19]:
x_train_so4_subset = X_train_24[np.ix_([-4, 0, 1000, 20000, 21000, 400000, -1], so4_indices)]
y_train_so4_subset = y_train_24[np.ix_([-4, 0, 1000, 20000, 21000, 400000, -1], so4_indices)]
y_delta_train_so4_subset = y_delta_train_24[np.ix_([-4, 0, 1000, 20000, 21000, 400000, -1], so4_indices)]

In [17]:
input

tensor([[0.0000e+00, 1.3182e+04, 2.6606e+08, 6.9101e+06, 1.3711e+05],
        [2.4033e+01, 1.3761e+03, 3.6511e+05, 5.0815e+05, 1.5746e+02],
        [3.4661e+01, 1.5090e+03, 3.0621e+05, 4.0179e+05, 6.1046e+01],
        [2.2522e+01, 2.6219e+03, 1.2649e+06, 2.7109e+06, 1.1954e+03],
        [6.1790e+01, 9.6141e+02, 1.9868e+06, 7.6254e+06, 6.5636e+03],
        [2.7293e+06, 2.4266e+03, 2.2450e+08, 7.1790e+08, 6.0305e+07],
        [0.0000e+00, 1.3179e+04, 2.6545e+08, 7.3312e+06, 1.5054e+05]])

In [55]:
model = LogSoftmax_model(
    in_features = x_train_so4_subset.shape[-1], 
    out_features = y_delta_train_so4_subset.shape[-1], 
    width = 128, depth = 2)

input = torch.tensor(x_train_so4_subset, dtype = torch.float64)
# force issues: input = input - input *0.9999
out = model(input)
out

tensor([[0.0995, 0.0528, 0.3892, 0.4178, 0.0407],
        [0.0382, 0.1223, 0.4379, 0.3292, 0.0725],
        [0.0389, 0.1317, 0.4294, 0.3286, 0.0714],
        [0.0315, 0.1037, 0.4659, 0.3349, 0.0640],
        [0.0281, 0.1026, 0.4739, 0.3285, 0.0670],
        [0.0063, 0.0606, 0.5250, 0.3565, 0.0516],
        [0.0991, 0.0527, 0.3896, 0.4179, 0.0407]], grad_fn=<SoftmaxBackward0>)
tensor([[0.6785],
        [0.9608],
        [0.9130],
        [1.1076],
        [1.2154],
        [2.0176],
        [0.6809]], grad_fn=<UnsqueezeBackward0>)
scaled
tensor([[-0.0682, -0.0999,  0.1284,  0.1478, -0.1081],
        [-0.1554, -0.0747,  0.2285,  0.1241, -0.1225],
        [-0.1471, -0.0624,  0.2095,  0.1174, -0.1174],
        [-0.1867, -0.1067,  0.2946,  0.1495, -0.1507],
        [-0.2089, -0.1184,  0.3329,  0.1561, -0.1617],
        [-0.3909, -0.2812,  0.6556,  0.3157, -0.2993],
        [-0.0687, -0.1003,  0.1291,  0.1484, -0.1085]], grad_fn=<MulBackward0>)
safe_beta
torch.Size([7, 1])
tensor([[0.],
    

tensor([[-0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
        [-0.1618, -0.0777,  0.2379,  0.1292, -0.1275],
        [-0.1611, -0.0683,  0.2294,  0.1286, -0.1286],
        [-0.1685, -0.0963,  0.2659,  0.1349, -0.1360],
        [-0.1719, -0.0974,  0.2739,  0.1285, -0.1330],
        [-0.1937, -0.1394,  0.3250,  0.1565, -0.1484],
        [-0.0000, -0.0000,  0.0000,  0.0000, -0.0000]], grad_fn=<MulBackward0>)

In [None]:
# elif args.scale == 'log':
#            y_test = log_tend_norm_transform_y_inv(stats, y_test)
#            X_test = log_full_norm_transform_x_inv(stats, X_test)

def log_transform(x):
    return np.log(np.abs(x)+1e-8) # absolute? What about signs?

def exp_transform(x):
    return np.exp(x)-1e-8

# Tranforms applied in paper
def log_tend_norm_transform_y(stats, x):
    x = log_transform(x)
    x = (x - stats['y_log_eps_mean'])/stats['y_log_eps_std']
    return x

def log_full_norm_transform_x(stats, x):
    x = log_transform(x)
    x = (x - stats['X_log_eps_mean'])/stats['X_log_eps_std']
    return x

# INVERSE TRANSFORM PRED

# elif args.scale == 'log':
#        pred = log_tend_norm_transform_y_inv(stats, pred)

def log_tend_norm_transform_y_inv(stats, x):    
    x = x*stats['y_log_eps_std']+stats['y_log_eps_mean']
    x = exp_transform(x)
    return x

In [118]:
torch.set_printoptions(precision = 6, sci_mode = False)
model = LogSoftmax_model(
    in_features = x_train_so4_subset.shape[-1], 
    out_features = y_delta_train_so4_subset.shape[-1], 
    width = 128, depth = 2)

inp = torch.tensor(x_train_so4_subset, dtype = torch.float32)
inp = inp - inp*0.9999
print("Input")
print(inp)

out = model(inp)
print("Output (delta)")
print(out)

print("Input + Output (delta)")
print("We have a violation if this is negativ.")
# Violation is this is negative
print(inp + out)

Input
tensor([[    0.000000,     1.318359, 26608.000000,   691.000000,    13.718750],
        [    0.002403,     0.137695,    36.531250,    50.812500,     0.015747],
        [    0.003468,     0.150879,    30.625000,    40.187500,     0.006107],
        [    0.002253,     0.262207,   126.500000,   271.250000,     0.119507],
        [    0.006180,     0.096130,   198.750000,   762.500000,     0.656250],
        [  273.000000,     0.242676, 22448.000000, 71808.000000,  6032.000000],
        [    0.000000,     1.318359, 26544.000000,   733.000000,    15.062500]])
Output (delta)
tensor([[-0., 0., -0., -0., 0.],
        [-0., 0., -0., -0., 0.],
        [-0., 0., 0., 0., 0.],
        [-0., 0., -0., -0., 0.],
        [-0., 0., -0., -0., 0.],
        [0., 0., -0., -0., 0.],
        [-0., 0., -0., -0., 0.]], dtype=torch.float64, grad_fn=<MulBackward0>)
Input + Output (delta)
We have a violation if this is negativ.
tensor([[    0.000000,     1.318359, 26608.000000,   691.000000,    13.718750],
 

In [108]:
torch.where(((out + inp) < 0), (-inp / out) + 1e-5, torch.tensor(float('inf')))

tensor([[    0.0000,        inf,        inf,        inf,        inf],
        [    0.0428,        inf,        inf,        inf,        inf],
        [    0.0628,        inf,        inf,        inf,        inf],
        [    0.0283,        inf,        inf,        inf,        inf],
        [    0.0751,        inf,        inf,        inf,        inf],
        [       inf,        inf,        inf,        inf,        inf],
        [    0.0000,        inf,        inf,        inf,        inf]],
       dtype=torch.float64, grad_fn=<WhereBackward0>)

In [109]:
# Input is always non-negative so 
# Beta must be larger than this ratio
-inp/out

tensor([[     0.0000,    -12.1310, -283228.6969,   4177.9132,   -135.3324],
        [     0.0428,     -1.5328,  -1377.5941,    562.5725,     -0.5226],
        [     0.0628,     -1.5061,  -1116.3906,    467.0323,     -0.4467],
        [     0.0283,     -3.0542,  -3340.6744,   2395.6851,     -1.7310],
        [     0.0751,     -1.3428,  -4804.7660,   6363.0303,     -7.3601],
        [  3655.2058,     -1.4863, -1276366.1838, 458349.6959, -119457.9955],
        [     0.0000,    -12.2776, -279312.6114,   4428.3950,   -147.6224]],
       dtype=torch.float64, grad_fn=<DivBackward0>)

In [110]:
safe_beta = torch.where(out < 0, -inp / out, torch.tensor(float('inf')).to(out.device)).min(dim = 1)[0]
print(safe_beta)
beta = torch.min(torch.ones(1), torch.min(safe_beta))
print(beta)
beta.unsqueeze(-1) * out + inp

tensor([    0.0000,     0.0428,     0.0628,     0.0283,     0.0751,  3655.2058,
            0.0000], dtype=torch.float64, grad_fn=<MinBackward0>)
tensor([0.], grad_fn=<MinimumBackward0>)


tensor([[    0.0000,     1.3184, 26608.0000,   691.0000,    13.7188],
        [    0.0024,     0.1377,    36.5312,    50.8125,     0.0157],
        [    0.0035,     0.1509,    30.6250,    40.1875,     0.0061],
        [    0.0023,     0.2622,   126.5000,   271.2500,     0.1195],
        [    0.0062,     0.0961,   198.7500,   762.5000,     0.6562],
        [  273.0000,     0.2427, 22448.0000, 71808.0000,  6032.0000],
        [    0.0000,     1.3184, 26544.0000,   733.0000,    15.0625]],
       dtype=torch.float64, grad_fn=<AddBackward0>)

In [None]:
# Avoid division by zero; only consider elements where out ≠ 0
safe_beta = torch.where(out != 0, -inp / out, torch.tensor(float('inf')).to(out.device))

# Beta should be the smallest positive value ensuring non-negativity
beta = torch.min(torch.ones(1), torch.min(safe_beta))

In [68]:
- inp / out

tensor([[     0.1581,     -5.0879,    340.5997,   -942.7436,     -0.3795],
        [     0.1581,     -5.0899,    340.5916,   -943.3430,     -0.3794],
        [     0.1582,     -5.0914,    340.6124,   -943.4422,     -0.3793],
        [     0.1582,     -5.0914,    340.6124,   -943.4422,     -0.3793],
        [     0.1583,     -5.0925,    340.6102,   -943.4239,     -0.3796],
        [     0.1583,     -5.0925,    340.6102,   -943.4239,     -0.3796],
        [     0.1583,     -5.0925,    340.6101,   -943.4243,     -0.3796]],
       dtype=torch.float64, grad_fn=<DivBackward0>)

In [62]:
# first row would be neg
print(out + inp)
# Select the minimum per row, because this is what we need to set to zero
print(torch.min(out + inp, axis = -1)[0])
violation = torch.min(out + inp, axis = -1)[0]
violation_tiled = torch.tile(violation.unsqueeze(-1), (1, inp.shape[-1]))
violation_tiled - inp

print(torch.relu(out + inp))
# print(torch.nn.functional.softplus(out + inp))

tensor([[-1.2800e-02,  1.6476e-01,  3.6424e+01,  5.0866e+01,  5.7245e-02],
        [-1.2797e-02,  1.6475e-01,  3.6424e+01,  5.0898e+01,  5.7256e-02],
        [-1.2795e-02,  1.6474e-01,  3.6424e+01,  5.0898e+01,  5.7262e-02],
        [-1.2795e-02,  1.6474e-01,  3.6424e+01,  5.0898e+01,  5.7262e-02],
        [-1.2793e-02,  1.6473e-01,  3.6424e+01,  5.0898e+01,  5.7281e-02],
        [-1.2793e-02,  1.6473e-01,  3.6424e+01,  5.0898e+01,  5.7281e-02],
        [-1.2793e-02,  1.6473e-01,  3.6424e+01,  5.0898e+01,  5.7281e-02]],
       dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([-0.0128, -0.0128, -0.0128, -0.0128, -0.0128, -0.0128, -0.0128],
       dtype=torch.float64, grad_fn=<MinBackward0>)
tensor([[ 0.0000,  0.1648, 36.4240, 50.8664,  0.0572],
        [ 0.0000,  0.1647, 36.4240, 50.8976,  0.0573],
        [ 0.0000,  0.1647, 36.4240, 50.8976,  0.0573],
        [ 0.0000,  0.1647, 36.4240, 50.8976,  0.0573],
        [ 0.0000,  0.1647, 36.4240, 50.8976,  0.0573],
        [ 0.0000,  0.16

In [67]:
torch.set_printoptions(precision = 4, sci_mode = False)
(violation_tiled - inp) / out

tensor([[     1.0000,     -5.5609,    340.7190,   -942.9811,     -0.6879],
        [     1.0000,     -5.5630,    340.7109,   -943.5805,     -0.6877],
        [     1.0000,     -5.5645,    340.7317,   -943.6796,     -0.6875],
        [     1.0000,     -5.5645,    340.7317,   -943.6796,     -0.6875],
        [     1.0000,     -5.5657,    340.7294,   -943.6613,     -0.6878],
        [     1.0000,     -5.5657,    340.7294,   -943.6613,     -0.6878],
        [     1.0000,     -5.5657,    340.7294,   -943.6617,     -0.6878]],
       dtype=torch.float64, grad_fn=<DivBackward0>)

In [41]:
- inp / (out + 1e-8)

tensor([[ 4.5226e+00, -9.8981e+01,  1.4223e+05,  1.9712e+05,  4.5718e+01],
        [ 4.5232e+00, -9.8994e+01,  1.4224e+05,  1.9715e+05,  4.5727e+01],
        [ 4.5237e+00, -9.9007e+01,  1.4224e+05,  1.9718e+05,  4.5736e+01],
        [ 4.5243e+00, -9.9019e+01,  1.4224e+05,  1.9721e+05,  4.5745e+01],
        [ 4.5248e+00, -9.9031e+01,  1.4225e+05,  1.9724e+05,  4.5752e+01],
        [ 4.5252e+00, -9.9042e+01,  1.4225e+05,  1.9727e+05,  4.5760e+01],
        [ 4.5257e+00, -9.9053e+01,  1.4225e+05,  1.9730e+05,  4.5767e+01]],
       dtype=torch.float64, grad_fn=<DivBackward0>)

In [None]:
# This sum must be non-negative
(out + inp)
torch.relu()

safeguard_scalar = torch.max(- inp / (out + 1e-8), torch.zeros_like(inp)) 
safeguard_scalar

tensor([[  595.9630, 14112.4749,     0.0000,     0.0000,  1817.9433],
        [  596.0211, 14114.2689,     0.0000,     0.0000,  1818.1510],
        [  596.0771, 14116.0047,     0.0000,     0.0000,  1818.3509],
        [  596.1312, 14117.6609,     0.0000,     0.0000,  1818.5434],
        [  596.1830, 14119.2519,     0.0000,     0.0000,  1818.7272],
        [  596.2323, 14120.7728,     0.0000,     0.0000,  1818.9036],
        [  596.2796, 14122.2215,     0.0000,     0.0000,  1819.0703]],
       dtype=torch.float64, grad_fn=<MaximumBackward0>)

In [13]:
model(torch.tensor(x_train_so4_subset, dtype = torch.float32))

tensor([[2.4033e+01, 1.3761e+03, 3.6511e+05, 5.0815e+05, 1.5746e+02],
        [2.4036e+01, 1.3763e+03, 3.6515e+05, 5.0822e+05, 1.5748e+02],
        [2.4039e+01, 1.3765e+03, 3.6520e+05, 5.0828e+05, 1.5750e+02],
        [2.4042e+01, 1.3766e+03, 3.6524e+05, 5.0834e+05, 1.5752e+02],
        [2.4044e+01, 1.3768e+03, 3.6528e+05, 5.0840e+05, 1.5754e+02],
        [2.4047e+01, 1.3769e+03, 3.6532e+05, 5.0845e+05, 1.5755e+02],
        [2.4049e+01, 1.3771e+03, 3.6536e+05, 5.0850e+05, 1.5757e+02]])