# Feed-forward training

## Imports

In [1]:
import sys
sys.path.append('../')

import thermonets as tn
import torch
import datetime
import numpy as np
import matplotlib.pyplot as plt
import pickle

## Loading input data

In [2]:
#I load the data generated via `/scripts/generate_nrlmsise00_db.py` and print the columns
#note that columns can be (len 16):
#'day', 'month', 'year', 'hour', 'minute', 'second', 'microsecond', 'alt [km]', 'lat [deg]', 'lon [deg]', 'f107A', 'f107', 'ap', 'wind zonal [m/s]', 'wind meridional [m/s]', 'density [kg/m^3]'
#or (len 14):
#'day', 'month', 'year', 'hour', 'minute', 'second', 'microsecond', 'alt [km]', 'lat [deg]', 'lon [deg]', 'f107A', 'f107', 'ap', 'density [kg/m^3]'
db=np.loadtxt('../dbs/nrlmsise00_db.txt',delimiter=',',skiprows=1)
print(f'Shape of database is: {db.shape}')

Shape of database is: (999800, 16)


## Extracting features of interest

In [3]:
#I now construct the day of the year and seconds in day:
years=db[:,2]
months=db[:,1]
days=db[:,0]
hours=db[:,3]
minutes=db[:,4]
seconds=db[:,5]
microseconds=db[:,6]
seconds_in_day=hours*3600+minutes*60+seconds+microseconds/1e6
print('seconds in day min and max:')
print(seconds_in_day.min(), seconds_in_day.max())
doys=np.zeros(db.shape[0])
for i in range(len(db)):
    #date is a string, so I first convert it to datetime:
    date_=datetime.datetime(year=int(years[i]), 
                            month=int(months[i]), 
                            day=int(days[i]),
                            hour=int(hours[i]),
                            minute=int(minutes[i]),
                            second=int(seconds[i]),
                            microsecond=int(microseconds[i]))
    doys[i]=date_.timetuple().tm_yday
print('day of the year min and max:')
print(doys.min(), doys.max())

seconds in day min and max:
0.042381 86396.718623
day of the year min and max:
1.0 365.0


In [4]:
#I extract the altitude:
alt=db[:,7]
#I now extract the longitude and latitude, and convert them to radians:
lat=np.deg2rad(db[:,8])
lon=np.deg2rad(db[:,9])
#now the space weather indices:
f107a=db[:,10]
f107=db[:,11]
ap=db[:,12]
#let's extract the target density as well:
target_density=db[:,-1]

## Normalization

In [5]:
db_normalized=np.zeros((db.shape[0],13))
db_normalized[:,0]=np.sin(lon)
db_normalized[:,1]=np.cos(lon)
db_normalized[:,2]=np.sin(lat)
db_normalized[:,3]=np.sin(2*np.pi*seconds_in_day/86400.)
db_normalized[:,4]=np.cos(2*np.pi*seconds_in_day/86400.)
db_normalized[:,5]=np.sin(2*np.pi*doys/365.25)
db_normalized[:,6]=np.cos(2*np.pi*doys/365.25)
db_normalized[:,7]=tn.normalize_min_max(alt, 150., 650.)
db_normalized[:,8]=tn.normalize_min_max(f107, 60., 290.)
db_normalized[:,9]=tn.normalize_min_max(f107a, 50., 190.)
db_normalized[:,10]=tn.normalize_min_max(ap, 0., 140.)
#I add the non-normalized density & altitude columns (useful to extract during training):
db_normalized[:,11]= alt
db_normalized[:,12]= target_density

In [6]:
#cross check that the max is <=1 and the min is >=-1
print(f"maximum and minimum of all the normalized data: {db_normalized[:,:11].max()}, {db_normalized[:,:11].min()}")
print(f"maximum and minimum of target density: {target_density.max()}, {target_density.min()}")

maximum and minimum of all the normalized data: 1.0, -1.0
maximum and minimum of target density: 1.712652587471857e-09, 1.9853483043363358e-15


## NN Training

In [7]:
torch_data = torch.tensor(db_normalized,
                          dtype=torch.float32)

In [8]:
# NN hyperparameters
device = torch.device('cpu')
batch_size = 4096
model_path = None #pass a path to a model in case you want to continue training from a file
lr = 0.00001
epochs = 100

In [9]:
# Dataloader creation
dataloader = torch.utils.data.DataLoader(torch_data, 
                                         batch_size=batch_size, 
                                         shuffle=True)


In [10]:
#NN creation
model = tn.ffnn(input_dim=db_normalized.shape[1]-2,
                        hidden_layer_dims=[32, 32],
                        output_dim=12,
                        mid_activation=torch.nn.Tanh(),
                        last_activation=torch.nn.Tanh()).to(device)

if model_path is not None:
    model.load_state_dict(torch.load(model_path,
                                     map_location=device.type))

In [11]:
#NN training
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.MSELoss()

In [12]:
print(f'Total number of model parameters: {sum(p.numel() for p in model.parameters())}')

Total number of model parameters: 1836


We load the global fit (see notebook: `rho_global_fit.ipynb`: this will be the baseline from which we ask the NN to learn corrections)

In [13]:
with open('../global_fits/global_fit_nrlmsise00_180.0-1000.0-4.txt','rb') as f:
    best_global_fit=torch.from_numpy(pickle.load(f)).to(device)

In [14]:
# Training loop
ratio_losses=[]
rmse_per_minibatch_nn=[]
mape_per_minibatch_nn=[]
rmse_per_minibatch_fit=[]
mape_per_minibatch_fit=[]
for epoch in range(epochs):
    #model.train(True)  # Set model to training mode
    total_rmse = 0.0
    total_mape = 0.0
    for batch_idx,el in enumerate(dataloader):
        minibatch=el[:,:-2].to(device)
        altitude=el[:,-2].to(device)
        rho_target=el[:,-1].to(device)
        delta_params = model(minibatch).to(device)

        optimizer.zero_grad()  # Clear accumulated gradients
        #now I construct the inputs for the compute_approximated_density function as corrections from the global fit:
        params = best_global_fit*(1+delta_params)
        rho_nn=tn.rho_approximation(h=altitude,
                                                params=params,
                                                backend='torch')
        rho_fit=tn.rho_approximation(h=altitude,
                                             params=best_global_fit,
                                             backend='torch')

        loss = criterion(torch.log10(rho_nn), torch.log10(rho_target))
        loss.backward()
        #I also compute the global fit loss:
        loss_fit =  torch.nn.MSELoss()(torch.log10(rho_fit).squeeze(), torch.log10(rho_target).squeeze())
        #I update the weights:
        optimizer.step()
        #let's store the losses for the NN:
        rmse_per_minibatch_nn.append(loss.item())
        mape_per_minibatch_nn.append(tn.mean_absolute_percentage_error(rho_nn, rho_target).item())
        total_rmse+=rmse_per_minibatch_nn[-1]
        total_mape+=mape_per_minibatch_nn[-1]
        #now the same but for the global fit:
        rmse_per_minibatch_fit.append(loss_fit.item())
        mape_per_minibatch_fit.append(tn.mean_absolute_percentage_error(rho_fit, rho_target).item())

        #ratio of the loss between the NN and the fit (the lower, the more the NN is doing better than a global fit)
        ratio_losses.append(loss.item()/loss_fit.item())
        #I only save the best model:
        if batch_idx>1:
            if rmse_per_minibatch_nn[-1]<min(rmse_per_minibatch_nn[:-1]):    
                #updating torch best model:
                torch.save(model.state_dict(), f'best_model.pyt')
                best_loss=loss.item()
                #print(f'Saving model - current best loss: {best_loss}\n')
        else:
            best_loss=loss.item()
        #I print every 10 minibatches:
        if batch_idx%10:    
            print(f'minibatch: {batch_idx}/{len(dataloader)}, ratio: {ratio_losses[-1]:.10f}, best loss till now: {best_loss:.10f}, loss RMSE (log10) & MAPE -----  NN: {loss.item():.10f}, {mape_per_minibatch_nn[-1]:.7f}; fit: {loss_fit.item():.10f}, {mape_per_minibatch_fit[-1]:.7f}', end='\r')
    #I also print at the end of the epoch
    print(f'End of epoch {epoch + 1}/{epochs}, average RMSE (log10) loss: {total_rmse / len(dataloader)}, average MAPE: {total_mape / len(dataloader)}, ')


End of epoch 1/100, average RMSE (log10) loss: 0.24944231060086464, average MAPE: 67.14630390089386, ----  NN: 0.2057335377, 59.3295326; fit: 0.0732013658, 53.8712158
End of epoch 2/100, average RMSE (log10) loss: 0.16142051974121405, average MAPE: 62.56833642453564, ----  NN: 0.1398187578, 62.8487778; fit: 0.0752398595, 60.1677132
End of epoch 3/100, average RMSE (log10) loss: 0.11007382480465636, average MAPE: 57.672035108293805, ---  NN: 0.0852707401, 51.8967285; fit: 0.0626801699, 53.4257774
End of epoch 4/100, average RMSE (log10) loss: 0.07967869681971414, average MAPE: 52.274841464295676, ---  NN: 0.0666145235, 50.1263390; fit: 0.0787915364, 64.1585922
End of epoch 5/100, average RMSE (log10) loss: 0.06082096440451486, average MAPE: 46.652806464993226, ---  NN: 0.0606885962, 47.8941994; fit: 0.0831792653, 60.4612007
End of epoch 6/100, average RMSE (log10) loss: 0.04841321878591362, average MAPE: 41.44194877780214, ----  NN: 0.0461282916, 38.7175751; fit: 0.0808902904, 59.964263