# Feed-forward training

## Imports

In [3]:
import sys
sys.path.append('../')
import thermonets
import torch
import datetime
import numpy as np
import matplotlib.pyplot as plt
import pickle


## Loading input data

In [4]:
#I load the data generated via `/scripts/generate_nrlmsise00_db.py` and print the columns
#note that columns can be (len 16):
#'day', 'month', 'year', 'hour', 'minute', 'second', 'microsecond', 'alt [km]', 'lat [deg]', 'lon [deg]', 'f107A', 'f107', 'ap', 'wind zonal [m/s]', 'wind meridional [m/s]', 'density [kg/m^3]'
#or (len 14):
#'day', 'month', 'year', 'hour', 'minute', 'second', 'microsecond', 'alt [km]', 'lat [deg]', 'lon [deg]', 'f107A', 'f107', 'ap', 'density [kg/m^3]'
db=np.loadtxt('../dbs/nrlmsise00_db.txt',delimiter=',',skiprows=1)
print(f'Shape of database is: {db.shape}')

Shape of database is: (999800, 16)


## Extracting features of interest

In [5]:
#I now construct the day of the year and seconds in day:
years=db[:,2]
months=db[:,1]
days=db[:,0]
hours=db[:,3]
minutes=db[:,4]
seconds=db[:,5]
microseconds=db[:,6]
seconds_in_day=hours*3600+minutes*60+seconds+microseconds/1e6
print('seconds in day min and max:')
print(seconds_in_day.max(), seconds_in_day.min())
doys=np.zeros(db.shape[0])
for i in range(len(db)):
    #date is a string, so I first convert it to datetime:
    date_=datetime.datetime(year=int(years[i]), 
                            month=int(months[i]), 
                            day=int(days[i]),
                            hour=int(hours[i]),
                            minute=int(minutes[i]),
                            second=int(seconds[i]),
                            microsecond=int(microseconds[i]))
    doys[i]=date_.timetuple().tm_yday
print('day of the year min and max:')
print(doys.max(), doys.min())

seconds in day min and max:
86381.805516 5.198408
day of the year min and max:
365.0 1.0


In [6]:
#I extract the altitude:
alt=db[:,7]
#I now extract the longitude and latitude, and convert them to radians:
lat=np.deg2rad(db[:,8])
lon=np.deg2rad(db[:,9])
#now the space weather indices:
f107a=db[:,10]
f107=db[:,11]
ap=db[:,12]
#let's extract the target density as well:
target_density=db[:,-1]

## Normalization

In [7]:
#this function normalizes the data to the range [-1,1]
def normalize_min_max(data,min_val,max_val):
    normalized_data = (2 * (data - min_val) / (max_val - min_val)) - 1
    return normalized_data
def unnormalize_min_max(data,min_val,max_val):
    unnormalized_data = 1/2 * (data + 1) * (max_val - min_val) + min_val
    return unnormalized_data
#verify: unnormalize_min_max(normalize_min_max(alt,alt.min(),alt.max()),alt.min(),alt.max())==alt

In [8]:
db_normalized=np.zeros((db.shape[0],13))
db_normalized[:,0]=np.sin(lon)
db_normalized[:,1]=np.cos(lon)
db_normalized[:,2]=np.sin(lat)
db_normalized[:,3]=np.sin(2*np.pi*seconds_in_day/86400.)
db_normalized[:,4]=np.cos(2*np.pi*seconds_in_day/86400.)
db_normalized[:,5]=np.sin(2*np.pi*doys/365.25)
db_normalized[:,6]=np.cos(2*np.pi*doys/365.25)
db_normalized[:,7]=normalize_min_max(alt, 150., 650.)
db_normalized[:,8]=normalize_min_max(f107, 60., 290.)
db_normalized[:,9]=normalize_min_max(f107a, 50., 190.)
db_normalized[:,10]=normalize_min_max(ap, 0., 140.)
#I add the non-normalized density & altitude columns (useful to extract during training):
db_normalized[:,11]=alt
db_normalized[:,12] = target_density

In [9]:
#cross check that the max is <=1 and the min is >=-1
print(f"maximum and minimum of all the normalized data: {db_normalized.max()}, {db_normalized.min()}")
print(f"maximum and minimum of target density: {target_density.max()}, {target_density.min()}")

maximum and minimum of all the normalized data: 630.957344480193, -1.0
maximum and minimum of target density: 1.720909357141654e-09, 2.1621841115988526e-15


## NN Training

In [10]:
torch_data = torch.tensor(db_normalized,
                          dtype=torch.float32)

In [11]:
# NN hyperparameters
device = torch.device('cpu')
batch_size = 4096
model_path = None #pass a path to a model in case you want to continue training
lr = 0.00001
epochs = 100

In [12]:
# Dataloader creation
dataloader = torch.utils.data.DataLoader(torch_data, 
                                         batch_size=batch_size, 
                                         shuffle=True)


In [14]:
#NN creation
model = thermonets.ffnn(input_dim=db_normalized.shape[1]-2,
                        hidden_layer_dims=[32, 32],
                        output_dim=12,
                        mid_activation=torch.nn.Tanh(),
                        last_activation=torch.nn.Tanh()).to(device)

if model_path is not None:
    model.load_state_dict(torch.load(model_path,
                                     map_location=device.type))

In [15]:
#NN training
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.MSELoss()

In [16]:
print(f'Total number of model parameters: {sum(p.numel() for p in model.parameters())}')

Total number of model parameters: 1836


We load the global fit (see notebook: `rho_global_fit.ipynb`: this will be the baseline from which we ask the NN to learn corrections)

In [17]:
with open('/Users/ga00693/Develop/thermonets/global_fits/global_fit_nrlmsise00_180.0-1000.0-4.txt','rb') as f:
    best_global_fit=torch.from_numpy(pickle.load(f)).to(device)

In [18]:
def mean_absolute_percentage_error(y_pred, y_true):
    """
    Compute the mean absolute percentage error (MAPE) between true and predicted values.
    
    Args:
        y_true (`torch.tensor`): True values.
        y_pred (`torch.tensor`): Predicted values.
        
    Returns:
        `torch.tensor`: Mean absolute percentage error.
    """
    return torch.mean(torch.abs((y_true - y_pred) / y_true)) * 100

In [20]:
# Training loop
ratio_losses=[]
rmse_per_minibatch_nn=[]
mape_per_minibatch_nn=[]
rmse_per_minibatch_fit=[]
mape_per_minibatch_fit=[]
for epoch in range(epochs):
    model.train(True)  # Set model to training mode
    total_rmse = 0.0
    total_mape = 0.0
    k=0
    for batch_idx,el in enumerate(dataloader):
        minibatch=el[:,:-2]
        altitude=el[:,-2]
        optimizer.zero_grad()  # Clear accumulated gradients    
        minibatch=minibatch.to(device)
        params = model(minibatch).to(device)

        k+=1

        minibatch=minibatch.to(device)
        optimizer.zero_grad()  # Clear accumulated gradients
        delta_params = model(minibatch).to(device)
        #now I construct the inputs for the compute_approximated_density function as corrections from the global fit:
        params = best_global_fit*(1+delta_params)
        rho_nn=thermonets.rho_approximation(h=altitude,
                                                params=params,
                                                backend='torch')
        rho_fit=torch.from_numpy(thermonets.rho_approximation(h=altitude.numpy(),
                                                                params=best_global_fit.numpy()))
        rho_target=el[:,-1].to(device)

        loss = criterion(torch.log10(rho_nn), torch.log10(rho_target))
        loss.backward()
        #I also compute the global fit loss:
        loss_fit =  torch.nn.MSELoss()(torch.log10(rho_fit).squeeze(), torch.log10(rho_target).squeeze())
        #I update the weights:
        optimizer.step()
        #let's store the losses for the NN:
        rmse_per_minibatch_nn.append(loss.item())
        mape_per_minibatch_nn.append(mean_absolute_percentage_error(rho_nn, rho_target).item())
        total_rmse+=rmse_per_minibatch_nn[-1]
        total_mape+=mape_per_minibatch_nn[-1]
        #now the same but for the global fit:
        rmse_per_minibatch_fit.append(loss_fit.item())
        mape_per_minibatch_fit.append(mean_absolute_percentage_error(rho_fit, rho_target).item())

        #ratio of the loss between the NN and the fit (the lower, the more the NN is doing better than a global fit)
        ratio_losses.append(loss.item()/loss_fit.item())
        #I only save the best model:
        if k>1:
            if rmse_per_minibatch_nn[-1]<min(rmse_per_minibatch_nn[:-1]):    
                #updating torch best model:
                torch.save(model.state_dict(), f'best_model.pth')
                best_loss=loss.item()
                #print(f'Saving model - current best loss: {best_loss}\n')
        else:
            best_loss=loss.item()
        #I print every 10 minibatches:
        if k%10:    
            print(f'minibatch: {k}/{len(dataloader)}, ratio: {ratio_losses[-1]:.10f}, best loss till now: {best_loss:.10f}, loss RMSE (log10) & MAPE -----  NN: {loss.item():.10f}, {mape_per_minibatch_nn[-1]:.7f}; fit: {loss_fit.item():.10f}, {mape_per_minibatch_fit[-1]:.7f}', end='\r')
    #I also print at the end of the epoch
    print(f'End of epoch {epoch + 1}/{epochs}, average RMSE (log10) loss: {total_rmse / len(dataloader)}, average MAPE: {total_mape / len(dataloader)}, ')


End of epoch 1/100, average RMSE (log10) loss: 0.1348125112604122, average MAPE: 72.86179965272241, -----  NN: 0.1247022524, 62.6941757; fit: 0.0751954949, 48.6998353
End of epoch 2/100, average RMSE (log10) loss: 0.10745000820987079, average MAPE: 63.015891826396086, ---  NN: 0.0902358145, 58.7182884; fit: 0.0723802327, 49.0179714
End of epoch 3/100, average RMSE (log10) loss: 0.08764881166268368, average MAPE: 56.7876057994609, -----  NN: 0.0800742283, 49.7547569; fit: 0.0771722338, 46.0925415
End of epoch 4/100, average RMSE (log10) loss: 0.07243711650371551, average MAPE: 51.60860265615035, ----  NN: 0.0787589476, 57.5506401; fit: 0.0879144889, 60.0736066
End of epoch 5/100, average RMSE (log10) loss: 0.060235488125864343, average MAPE: 46.72686671042929, ---  NN: 0.0522562303, 47.3934593; fit: 0.0651796062, 52.0998396
End of epoch 6/100, average RMSE (log10) loss: 0.05052070438253636, average MAPE: 42.270780913683836, ---  NN: 0.0589060634, 45.0034676; fit: 0.0983349249, 56.795762