# Notebook para treinar SpectroVit

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import random
random.seed(5)
import matplotlib.pyplot as plt
import numpy as np
from scipy import signal,stats
import os

In [2]:
from datasets import DatasetSpgramSyntheticData
from models import SpectroViT
from losses import RangeMAELoss
from lr_scheduler import CustomLRScheduler
from save_models import SaveBestModel, SaveCurrentModel
from main_functions_adapted import valid_on_the_fly, run_train_epoch, run_validation
from main import calculate_parameters
from utils import clean_directory

Using cuda:0


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Dados

In [4]:
dataset_train = DatasetSpgramSyntheticData(path_data='../sample_data.h5',
                           start=0, end=84,
                           augment=True)
dataset_validation = DatasetSpgramSyntheticData(path_data='../sample_data.h5',
                           start=84, end=108,
                           augment=False)

### Modelo

In [5]:
spectrovit = SpectroViT().to(device)

### Loss e Optimizer

In [6]:
loss = RangeMAELoss()
optimizer = torch.optim.Adam(spectrovit.parameters(), lr=1e-4)
lr_scheduler = CustomLRScheduler(optimizer,'cosineannealinglr',T_max = 10, eta_min = 1e-6)

### Loop de treino e validação

In [7]:
n_epochs = 30
batch_size_train = 100
batch_size_validation = 6
step_for_safe_saving = 5
step_for_saving_plots = 3
epoch_to_switch_to_cosine = 20
save_dir_path = '../model_hop_10_mfft_256_zp/'
filename = 'model_hop_10_mfft_256_zp'
name_model = 'model_hop_10_mfft_256_zp'
save_best_model = SaveBestModel(dir_model=save_dir_path)
save_current_model = SaveCurrentModel(dir_model=save_dir_path)

In [8]:
dataloader_train = DataLoader(dataset_train, batch_size=batch_size_train, shuffle=True)
dataloader_validation = DataLoader(dataset_validation, batch_size=batch_size_validation, shuffle=True)

In [14]:
train_loss_list = []
val_loss_list = []
val_mean_mse_list = []
val_mean_snr_list = []
val_mean_linewidth_list = []
val_mean_shape_score_list = []
score_challenge_list = []

os.makedirs(save_dir_path, exist_ok=True)
clean_directory(save_dir_path)

for epoch in range(n_epochs):

  calculate_parameters(spectrovit)
  train_loss = run_train_epoch(model=spectrovit, optimizer=optimizer, criterion=loss, loader=dataloader_train, epoch=epoch, device=device)
  val_loss, loader_mean_mse, loader_mean_snr,loader_mean_linewidth,loader_mean_shape_score,score_challenge = run_validation(model=spectrovit, criterion=loss, loader=dataloader_validation, epoch=epoch, device=device)

  train_loss_list.append(train_loss)
  val_loss_list.append(val_loss)
  val_mean_mse_list.append(loader_mean_mse)
  val_mean_snr_list.append(loader_mean_snr)
  val_mean_linewidth_list.append(loader_mean_linewidth)
  val_mean_shape_score_list.append(loader_mean_shape_score)
  score_challenge_list.append(score_challenge)

  if epoch == epoch_to_switch_to_cosine:
    for param_group in optimizer.param_groups:
      param_group['lr'] = 1e-5
  elif epoch > epoch_to_switch_to_cosine:
    lr_scheduler.step()
    print("Current learning rate:",lr_scheduler.scheduler.get_last_lr()[0])

  save_best_model(current_valid_score=score_challenge, model=spectrovit, name_model=name_model)
  if epoch%step_for_saving_plots == 0:
    valid_on_the_fly(model=spectrovit, epoch=epoch, val_dataset=dataset_validation, save_dir_path=save_dir_path, filename=filename, device=device)
  if epoch%step_for_safe_saving == 0:
    save_current_model(current_valid_score=score_challenge, model=spectrovit, name_model=name_model)


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [10:19<00:00,  3.68s/it, desc=[epoch: 1], iteration: 167/168, loss: 0.005826089858254861] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.37it/s, desc=[Epoch 1] Loss: 0.001976113533601165 | MSE:0.0002497 | SNR:53.0256158 | FWHM:0.0764546 | Shape Score:0.9993597] 


Best validation score: 0.20032790420198343
Saving current model with score: 0.20032790420198343
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [10:14<00:00,  3.66s/it, desc=[epoch: 2], iteration: 167/168, loss: 0.0022913603690020473]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.55it/s, desc=[Epoch 2] Loss: 0.002223386662080884 | MSE:0.0002491 | SNR:67.9903933 | FWHM:0.0764546 | Shape Score:0.9995342] 


Best validation score: 0.20029250845491073
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:51<00:00,  3.52s/it, desc=[epoch: 3], iteration: 167/168, loss: 0.0019718620287243367]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.61it/s, desc=[Epoch 3] Loss: 0.0020960327237844467 | MSE:0.0002258 | SNR:84.1284273 | FWHM:0.0764546 | Shape Score:0.9995630]


Best validation score: 0.200268042027839
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:51<00:00,  3.52s/it, desc=[epoch: 4], iteration: 167/168, loss: 0.001865463082573288] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.55it/s, desc=[Epoch 4] Loss: 0.001526281121186912 | MSE:0.0002334 | SNR:94.4634118 | FWHM:0.0764546 | Shape Score:0.9996500] 


Best validation score: 0.200256711329502
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:48<00:00,  3.51s/it, desc=[epoch: 5], iteration: 167/168, loss: 0.001813507622933858] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.58it/s, desc=[Epoch 5] Loss: 0.0016905679367482662 | MSE:0.0001827 | SNR:85.3246792 | FWHM:0.0764546 | Shape Score:0.9996334]


Best validation score: 0.20021951294825288
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:50<00:00,  3.51s/it, desc=[epoch: 6], iteration: 167/168, loss: 0.0016527930905188744]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.57it/s, desc=[Epoch 6] Loss: 0.0015000698622316122 | MSE:0.0001803 | SNR:93.5357181 | FWHM:0.0764546 | Shape Score:0.9996500]


Best validation score: 0.20021425317142055
Saving current model with score: 0.20021425317142055
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:52<00:00,  3.52s/it, desc=[epoch: 7], iteration: 167/168, loss: 0.0015306184193052883]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.51it/s, desc=[Epoch 7] Loss: 0.0015896391123533249 | MSE:0.0001219 | SNR:89.6081746 | FWHM:0.0764546 | Shape Score:0.9996119]


Best validation score: 0.20017515464416002
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:52<00:00,  3.53s/it, desc=[epoch: 8], iteration: 167/168, loss: 0.0014432348460624261]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.57it/s, desc=[Epoch 8] Loss: 0.0016371747478842735 | MSE:0.0001726 | SNR:89.3625167 | FWHM:0.0764546 | Shape Score:0.9996734]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [10:05<00:00,  3.61s/it, desc=[epoch: 9], iteration: 167/168, loss: 0.0013705024009271125]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.52it/s, desc=[Epoch 9] Loss: 0.0015360750257968903 | MSE:0.0001711 | SNR:100.4066740 | FWHM:0.0764546 | Shape Score:0.9996715]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:52<00:00,  3.53s/it, desc=[epoch: 10], iteration: 167/168, loss: 0.001323714403302542] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.44it/s, desc=[Epoch 10] Loss: 0.0018019663402810693 | MSE:0.0001448 | SNR:102.5727603 | FWHM:0.0764546 | Shape Score:0.9996777]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:53<00:00,  3.53s/it, desc=[epoch: 11], iteration: 167/168, loss: 0.0012253537022929994]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.57it/s, desc=[Epoch 11] Loss: 0.001501705963164568 | MSE:0.0001714 | SNR:101.8063415 | FWHM:0.0764546 | Shape Score:0.9996044] 


Saving current model with score: 0.20021631815627833
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:54<00:00,  3.54s/it, desc=[epoch: 12], iteration: 167/168, loss: 0.001125663096900098] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.57it/s, desc=[Epoch 12] Loss: 0.001673137187026441 | MSE:0.0001887 | SNR:101.1495907 | FWHM:0.0764546 | Shape Score:0.9996643] 


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:53<00:00,  3.53s/it, desc=[epoch: 13], iteration: 167/168, loss: 0.0010867759999763664]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.61it/s, desc=[Epoch 13] Loss: 0.0015657416079193354 | MSE:0.0001525 | SNR:106.8245675 | FWHM:0.0764546 | Shape Score:0.9996505]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:53<00:00,  3.53s/it, desc=[epoch: 14], iteration: 167/168, loss: 0.0010947880516704615]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.51it/s, desc=[Epoch 14] Loss: 0.001512196147814393 | MSE:0.0001549 | SNR:118.3630127 | FWHM:0.0764546 | Shape Score:0.9995715] 


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:54<00:00,  3.54s/it, desc=[epoch: 15], iteration: 167/168, loss: 0.0009538485416366408]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.49it/s, desc=[Epoch 15] Loss: 0.00172244303394109 | MSE:0.0001512 | SNR:107.0996055 | FWHM:0.0767732 | Shape Score:0.9996667]  


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:53<00:00,  3.54s/it, desc=[epoch: 16], iteration: 167/168, loss: 0.0009759911720591065]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.58it/s, desc=[Epoch 16] Loss: 0.0016394301783293486 | MSE:0.0001759 | SNR:109.6433487 | FWHM:0.0764546 | Shape Score:0.9996310]


Saving current model with score: 0.20021453559841887
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:51<00:00,  3.52s/it, desc=[epoch: 17], iteration: 167/168, loss: 0.0009584715610669393]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.58it/s, desc=[Epoch 17] Loss: 0.0016635386273264885 | MSE:0.0002045 | SNR:121.5058246 | FWHM:0.0764546 | Shape Score:0.9996753]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:57<00:00,  3.56s/it, desc=[epoch: 18], iteration: 167/168, loss: 0.0009250808857142969]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.62it/s, desc=[Epoch 18] Loss: 0.0017467804718762636 | MSE:0.0001981 | SNR:112.4663277 | FWHM:0.0764546 | Shape Score:0.9997180]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:52<00:00,  3.53s/it, desc=[epoch: 19], iteration: 167/168, loss: 0.0008691254240277756]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.59it/s, desc=[Epoch 19] Loss: 0.001573382061906159 | MSE:0.0001269 | SNR:138.0411184 | FWHM:0.0764546 | Shape Score:0.9996680] 


Best validation score: 0.2001679307995597
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [10:04<00:00,  3.60s/it, desc=[epoch: 20], iteration: 167/168, loss: 0.0008563438223929898]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.53it/s, desc=[Epoch 20] Loss: 0.0016471856506541371 | MSE:0.0001402 | SNR:128.8503716 | FWHM:0.0764546 | Shape Score:0.9996273]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:52<00:00,  3.53s/it, desc=[epoch: 21], iteration: 167/168, loss: 0.0008566456457455864]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.52it/s, desc=[Epoch 21] Loss: 0.0020410260185599327 | MSE:0.0001798 | SNR:137.4175773 | FWHM:0.0761361 | Shape Score:0.9996563]


Saving current model with score: 0.20021262153674563
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:51<00:00,  3.52s/it, desc=[epoch: 22], iteration: 167/168, loss: 0.0005102292289804955]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.49it/s, desc=[Epoch 22] Loss: 0.001718676183372736 | MSE:0.0001341 | SNR:224.4804793 | FWHM:0.0764546 | Shape Score:0.9996972] 


Current learning rate: 9.779754323328192e-06
Best validation score: 0.20016783758726664
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:54<00:00,  3.54s/it, desc=[epoch: 23], iteration: 167/168, loss: 0.0004770486557390541] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.55it/s, desc=[Epoch 23] Loss: 0.0015971470857039094 | MSE:0.0001279 | SNR:235.2547449 | FWHM:0.0764546 | Shape Score:0.9997194]


Current learning rate: 9.140576474687263e-06
Best validation score: 0.20015842617721596
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:52<00:00,  3.52s/it, desc=[epoch: 24], iteration: 167/168, loss: 0.00046660350355003696]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.58it/s, desc=[Epoch 24] Loss: 0.0016284917946904898 | MSE:0.0001560 | SNR:228.5419261 | FWHM:0.0764546 | Shape Score:0.9996809]


Current learning rate: 8.145033635316128e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:54<00:00,  3.54s/it, desc=[epoch: 25], iteration: 167/168, loss: 0.0004515726892956688] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.45it/s, desc=[Epoch 25] Loss: 0.001555431867018342 | MSE:0.0001572 | SNR:235.5380983 | FWHM:0.0761361 | Shape Score:0.9997140] 


Current learning rate: 6.890576474687263e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:58<00:00,  3.56s/it, desc=[epoch: 26], iteration: 167/168, loss: 0.0004421980125438755] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.39it/s, desc=[Epoch 26] Loss: 0.0015169454272836447 | MSE:0.0001306 | SNR:237.8468931 | FWHM:0.0761361 | Shape Score:0.9997408]


Current learning rate: 5.5e-06
Best validation score: 0.20015633505961714
Saving current model with score: 0.20015633505961714
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:52<00:00,  3.53s/it, desc=[epoch: 27], iteration: 167/168, loss: 0.0004409617667122456] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.54it/s, desc=[Epoch 27] Loss: 0.0017284895293414593 | MSE:0.0001459 | SNR:243.2099027 | FWHM:0.0764546 | Shape Score:0.9996862]


Current learning rate: 4.109423525312737e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:53<00:00,  3.54s/it, desc=[epoch: 28], iteration: 167/168, loss: 0.00042114843814661506]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.54it/s, desc=[Epoch 28] Loss: 0.0016045018564909697 | MSE:0.0001569 | SNR:243.1679889 | FWHM:0.0767732 | Shape Score:0.9997385]


Current learning rate: 2.8549663646838717e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:52<00:00,  3.53s/it, desc=[epoch: 29], iteration: 167/168, loss: 0.0004152638485677363] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.54it/s, desc=[Epoch 29] Loss: 0.001762802479788661 | MSE:0.0001504 | SNR:235.4901030 | FWHM:0.0764546 | Shape Score:0.9997148] 


Current learning rate: 1.859423525312737e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [09:55<00:00,  3.54s/it, desc=[epoch: 30], iteration: 167/168, loss: 0.0004138933613181247] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  4.52it/s, desc=[Epoch 30] Loss: 0.0017332143615931273 | MSE:0.0001707 | SNR:250.1720122 | FWHM:0.0764546 | Shape Score:0.9996456]

Current learning rate: 1.220245676671809e-06





In [15]:
np.savetxt(save_dir_path+'train_loss_list.txt', np.array(train_loss_list), delimiter='\n')
np.savetxt(save_dir_path+'val_loss_list.txt', np.array(val_loss_list), delimiter='\n')
np.savetxt(save_dir_path+'val_mse_list.txt', np.array(val_mean_mse_list), delimiter='\n')
np.savetxt(save_dir_path+'val_snr_list.txt', np.array(val_mean_snr_list), delimiter='\n')
np.savetxt(save_dir_path+'val_linewidth_list.txt', np.array(val_mean_linewidth_list), delimiter='\n')
np.savetxt(save_dir_path+'val_mean_shape_score_list.txt', np.array(val_mean_shape_score_list), delimiter='\n')
np.savetxt(save_dir_path+'score_challenge_list.txt', np.array(score_challenge_list), delimiter='\n')