# thermoNET neural differentiable model for JB-08

In this notebook, we train a neural network model to learn the JB-08 empirical model. 

## Imports

In [1]:
import sys
sys.path.append('../')

import thermonets as tn
import torch
import datetime
import numpy as np
import matplotlib.pyplot as plt
import pickle

## Loading input data

In [2]:
#Loads the data generated via `/scripts/generate_nrlmsise00_db.py` and print the columns
#note that columns are (len 22):
# day, month, year, hour, minute, second, microsecond, alt [km], lat [deg], lon [deg], sun ra [deg], sun dec [deg], f107, f107A, s107, s107A, m107, m107A, y107, y107A, dDstdT, density [kg/m^3]
db=np.loadtxt('../dbs/jb08_db.txt',delimiter=',',skiprows=1)
print(f'Shape of database is: {db.shape}')

Shape of database is: (1000000, 22)


In [7]:
db[:,10]

array([320.56114061, 320.56114061, 320.56114061, ..., 335.47795232,
       335.47795232, 335.47795232])

## Extracting features of interest

In [3]:
# Renames some of the db content with readable names
days=db[:,0]
months=db[:,1]
years=db[:,2]
hours=db[:,3]
minutes=db[:,4]
seconds=db[:,5]
microseconds=db[:,6]
alt=db[:,7]
# Geodetic longitude and latitude are converted in radians:
lat=np.deg2rad(db[:,8])
lon=np.deg2rad(db[:,9])
# Solar right ascension and declination are also converted in radians:
sun_ra=np.deg2rad(db[:,10])
sun_dec=np.deg2rad(db[:,11])
# Space weather indices:
f107=db[:,12]
f107a=db[:,13]
s107=db[:,14]
s107a=db[:,15]
m107=db[:,16]
m107a=db[:,17]
y107=db[:,18]
y107a=db[:,19]
dDstdT=db[:,20]
# Atmospheric density as well:
target_density=db[:,-1]

# We need to extract from the db also the doy (Day of Year) and the sid (seconds in day)
seconds_in_day=hours*3600+minutes*60+seconds+microseconds/1e6
print('seconds in day min and max:')
print(seconds_in_day.min(), seconds_in_day.max())
doys=np.zeros(db.shape[0])
for i in range(len(db)):
    #date is a string, so I first convert it to datetime:
    date_=datetime.datetime(year=int(years[i]), 
                            month=int(months[i]), 
                            day=int(days[i]),
                            hour=int(hours[i]),
                            minute=int(minutes[i]),
                            second=int(seconds[i]),
                            microsecond=int(microseconds[i]))
    doys[i]=date_.timetuple().tm_yday
print('day of the year min and max:')
print(doys.min(), doys.max())

seconds in day min and max:
5.873279 86381.406287
day of the year min and max:
1.0 365.0


## Normalization

In [8]:
db_processed=np.zeros((db.shape[0],22))
db_processed[:,0]=np.sin(lon)
db_processed[:,1]=np.cos(lon)
db_processed[:,2]=tn.normalize_min_max(lat,-np.pi/2,np.pi/2)
db_processed[:,3]=np.sin(2*np.pi*seconds_in_day/86400.)
db_processed[:,4]=np.cos(2*np.pi*seconds_in_day/86400.)
db_processed[:,5]=np.sin(2*np.pi*doys/365.25)
db_processed[:,6]=np.cos(2*np.pi*doys/365.25)
db_processed[:,7]=np.sin(sun_ra)
db_processed[:,8]=np.cos(sun_ra)
db_processed[:,9]=tn.normalize_min_max(sun_dec, np.deg2rad(-23.45), np.deg2rad(23.45))
db_processed[:,10]=tn.normalize_min_max(f107, 60., 266.)
db_processed[:,11]=tn.normalize_min_max(f107a, 60., 170.)
db_processed[:,12]=tn.normalize_min_max(s107, 50., 190.)
db_processed[:,13]=tn.normalize_min_max(s107a, 50., 170.)
db_processed[:,14]=tn.normalize_min_max(m107, 50., 190.)
db_processed[:,15]=tn.normalize_min_max(m107a, 50., 160.)
db_processed[:,16]=tn.normalize_min_max(y107, 50., 180.)
db_processed[:,17]=tn.normalize_min_max(y107a, 50., 170.)
db_processed[:,18]=tn.normalize_min_max(dDstdT, 0., 390.)
db_processed[:,19]=tn.normalize_min_max(alt, 170., 1010.)

#Add the non-normalized density & altitude columns (useful to extract during training):
db_processed[:,20]= alt
db_processed[:,21]= target_density

# Cross check that the max is <=1 and the min is >=-1
print(f"maximum and minimum of all the normalized data: {db_processed[:,:19].max()}, {db_processed[:,:19].min()}")
print(f"maximum and minimum of target density: {target_density.max()}, {target_density.min()}")

maximum and minimum of all the normalized data: 1.0, -1.0
maximum and minimum of target density: 7.462357839360134e-10, 5.853858909739958e-16


## NN Training

In [9]:
torch_data = torch.tensor(db_processed, dtype=torch.float32)

In [21]:
# NN hyperparameters
device = torch.device('cpu')
minibatch_size = 512
model_path = None #pass a path to a model in case you want to continue training from a file
lr = 0.001
epochs = 300

In [12]:
#NN creation
model = tn.ffnn(input_dim=db_processed.shape[1]-3,
                        hidden_layer_dims=[32, 32],
                        output_dim=12,
                        mid_activation=torch.nn.Tanh(),
                        last_activation=torch.nn.Tanh()).to(device)

if model_path is not None:
    model.load_state_dict(torch.load(model_path,
                                     map_location=device.type))

In [22]:
# Here we set the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr,amsgrad=True)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25,50,75,100,125,150,175,200,225,230,240,250,260,270], gamma=0.8, verbose=False)

#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25,50,75,100,125,150,175,200,225,230,240,250,260,270], gamma=0.8, verbose=False)
#criterion = tn.MAPE()
criterion = tn.MSE_LOG10()

# And the dataloader
dataloader = torch.utils.data.DataLoader(torch_data, 
                                         batch_size=minibatch_size, 
                                         shuffle=True)


In [14]:
print(f'Total number of model parameters: {sum(p.numel() for p in model.parameters())}')

Total number of model parameters: 2092


We load the global fit (see notebook: `rho_global_fit.ipynb`: this will be the baseline from which we ask the NN to learn corrections)

In [15]:
with open('../global_fits/global_fit_jb08_180.0-1000.0-4.txt','rb') as f:
    best_global_fit=torch.from_numpy(pickle.load(f)).to(device)

In [23]:
# Training loop
ratio_losses=[]
loss_plot = []
mse_per_minibatch_nn=[]
mape_per_minibatch_nn=[]
mse_per_minibatch_fit=[]
mape_per_minibatch_fit=[]
best_loss_total = np.inf
best_loss = np.inf
for epoch in range(epochs):
    for batch_idx,el in enumerate(dataloader):
        minibatch=el[:,:-3].to(device)
        altitude=el[:,-2].to(device)
        rho_target=el[:,-1].to(device)
        delta_params = model(minibatch).to(device)

        #Constructs the inputs for the compute_approximated_density function as corrections from the global fit:
        params = best_global_fit*(1+delta_params)
        rho_nn=tn.rho_approximation(h=altitude,
                                                params=params,
                                                backend='torch')
        rho_fit=tn.rho_approximation(h=altitude,
                                             params=best_global_fit,
                                             backend='torch')

        loss = criterion(rho_nn, rho_target)

        #Computes the global fit loss:
        loss_fit =  criterion(rho_fit, rho_target)

        # Zeroes the gradient 
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its
        # parameters
        optimizer.step()

        #We compute the logged quantities
        mse_per_minibatch_nn.append(loss.item())
        mape_per_minibatch_nn.append(tn.mean_absolute_percentage_error(rho_nn, rho_target).item())
        
        #Now the same but for the global fit:
        mse_per_minibatch_fit.append(loss_fit.item())
        mape_per_minibatch_fit.append(tn.mean_absolute_percentage_error(rho_fit, rho_target).item())

        #Ratio of the loss between the NN and the fit (the lower, the more the NN is doing better than a global fit)
        ratio_losses.append(loss.item()/loss_fit.item())
        
        #Save the best model (this is wrong and should be done on the dataset):
        if loss.item()<best_loss:    
            best_loss=loss.item()

        #Print every 10 minibatches:
        if batch_idx%10:    
            print(f'minibatch: {batch_idx}/{len(dataloader)}, ratio: {ratio_losses[-1]:.4e}, best minibatch loss till now: {best_loss:.4e}, loss & MAPE -----  NN: {loss.item():.10f}, {mape_per_minibatch_nn[-1]:.7f}; fit: {loss_fit.item():.10f}, {mape_per_minibatch_fit[-1]:.7f}', end='\r')
    
    # We compute, at the end of the epoch and thus on the whole dataset, the losses.
    delta_params = model(torch_data[:,:-3]).to(device)
    params = best_global_fit*(1+delta_params)
    rho_nn_total=tn.rho_approximation(h=torch_data[:, -2],
                                            params=params,
                                            backend='torch')

    # First the nn loss
    loss_total = criterion(rho_nn_total, torch_data[:, -1])
    mape_total = tn.MAPE()(rho_nn_total, torch_data[:, -1])
    loss_plot.append(loss_total.item())

    # Perform a step in LR scheduler to update LR
    scheduler.step()
    
    #Print at the end of the epoch
    curr_lr = scheduler.optimizer.param_groups[0]['lr']
    print(" "*300, end="\r")
    print(f'Epoch {epoch + 1}/{epochs}, lr: {curr_lr:.1e}, loss: {loss_total.item():.3e},  MAPE: {mape_total.item():.3f}')
    
    #updating torch best model:
    if loss_total.item() < best_loss_total:
        torch.save(model.state_dict(), f'../models/jb08_model_xxx.pyt')
        best_loss_total=loss_total.item()

Epoch 1/1000, lr: 2.0e-04, loss: 1.404e-04,  MAPE: 2.055                                                                                                                                                                                                                                                    
Epoch 2/1000, lr: 2.0e-04, loss: 1.391e-04,  MAPE: 2.040                                                                                                                                                                                                                                                    
Epoch 3/1000, lr: 2.0e-04, loss: 1.559e-04,  MAPE: 2.182                                                                                                                                                                                                                                                    
Epoch 4/1000, lr: 2.0e-04, loss: 1.295e-04,  MAPE: 1.982                                         