# Notebook para treinar SpectroVit

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import random
random.seed(5)
import matplotlib.pyplot as plt
import numpy as np
from scipy import signal,stats
import os

In [2]:
from datasets import DatasetSpgramSyntheticData
from models import SpectroViT
from losses import RangeMAELoss
from lr_scheduler import CustomLRScheduler
from save_models import SaveBestModel, SaveCurrentModel
from main_functions_adapted import valid_on_the_fly, run_train_epoch, run_validation
from main import calculate_parameters
from utils import clean_directory

Using cuda:0


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Dados

In [4]:
hop_size = 16
window_size = 256
window = signal.windows.hann(256,sym = True)

In [5]:
dataset_train = DatasetSpgramSyntheticData(path_data='../sample_data.h5',
                           start=0, end=84,
                           augment=True,hop_size=hop_size,window_size=window_size,window=window)
dataset_validation = DatasetSpgramSyntheticData(path_data='../sample_data.h5',
                           start=84, end=108,
                           augment=False,hop_size=hop_size,window_size=window_size, window=window)

### Modelo

In [6]:
spectrovit = SpectroViT().to(device)

### Loss e Optimizer

In [7]:
loss = RangeMAELoss()
optimizer = torch.optim.Adam(spectrovit.parameters(), lr=1e-4)
lr_scheduler = CustomLRScheduler(optimizer,'cosineannealinglr',T_max = 10, eta_min = 1e-6)

### Loop de treino e validação

In [8]:
n_epochs = 30
batch_size_train = 100
batch_size_validation = 6
step_for_safe_saving = 5
step_for_saving_plots = 3
epoch_to_switch_to_cosine = 20
save_dir_path = '../model_hop_16_mfft_256_zp/'
filename = 'model_hop_16_mfft_256_zp'
name_model = 'model_hop_16_mfft_256_zp'
save_best_model = SaveBestModel(dir_model=save_dir_path)
save_current_model = SaveCurrentModel(dir_model=save_dir_path)

In [9]:
dataloader_train = DataLoader(dataset_train, batch_size=batch_size_train, shuffle=True)
dataloader_validation = DataLoader(dataset_validation, batch_size=batch_size_validation, shuffle=True)

In [10]:
train_loss_list = []
val_loss_list = []
val_mean_mse_list = []
val_mean_snr_list = []
val_mean_linewidth_list = []
val_mean_shape_score_list = []
score_challenge_list = []

os.makedirs(save_dir_path, exist_ok=True)
clean_directory(save_dir_path)

for epoch in range(n_epochs):

  calculate_parameters(spectrovit)
  train_loss = run_train_epoch(model=spectrovit, optimizer=optimizer, criterion=loss, loader=dataloader_train, epoch=epoch, device=device)
  val_loss, loader_mean_mse, loader_mean_snr,loader_mean_linewidth,loader_mean_shape_score,score_challenge = run_validation(model=spectrovit, criterion=loss, loader=dataloader_validation, epoch=epoch, device=device)

  train_loss_list.append(train_loss)
  val_loss_list.append(val_loss)
  val_mean_mse_list.append(loader_mean_mse)
  val_mean_snr_list.append(loader_mean_snr)
  val_mean_linewidth_list.append(loader_mean_linewidth)
  val_mean_shape_score_list.append(loader_mean_shape_score)
  score_challenge_list.append(score_challenge)

  if epoch == epoch_to_switch_to_cosine:
    for param_group in optimizer.param_groups:
      param_group['lr'] = 1e-5
  elif epoch > epoch_to_switch_to_cosine:
    lr_scheduler.step()
    print("Current learning rate:",lr_scheduler.scheduler.get_last_lr()[0])

  save_best_model(current_valid_score=score_challenge, model=spectrovit, name_model=name_model)
  if epoch%step_for_saving_plots == 0:
    valid_on_the_fly(model=spectrovit, epoch=epoch, val_dataset=dataset_validation, save_dir_path=save_dir_path, filename=filename, device=device)
  if epoch%step_for_safe_saving == 0:
    save_current_model(current_valid_score=score_challenge, model=spectrovit, name_model=name_model)


Number of parameters: 90473472


Train Loop:   0%|          | 0/168 [00:00<?, ?it/s]

Generating Spectrograms of size:  (177, 124)
Zero padded to shape:  (1, 224, 224)


Train Loop: 100%|██████████| 168/168 [07:56<00:00,  2.84s/it, desc=[epoch: 1], iteration: 167/168, loss: 0.006324598500560526] 
Validation Loop:  25%|██▌       | 1/4 [00:00<00:00,  5.87it/s, desc=[Epoch 1] Loss: 0.0023229687940329313 | MSE:0.0003225 | SNR:54.7633384 | FWHM:0.0764546 | Shape Score:0.9992894]

Generating Spectrograms of size:  (177, 124)
Zero padded to shape:  (1, 224, 224)


Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.85it/s, desc=[Epoch 1] Loss: 0.002425465267151594 | MSE:0.0003272 | SNR:45.1928203 | FWHM:0.0764546 | Shape Score:0.9992136] 


Best validation score: 0.2004191679332423
Saving current model with score: 0.2004191679332423
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:50<00:00,  2.80s/it, desc=[epoch: 2], iteration: 167/168, loss: 0.002232827646020312] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.95it/s, desc=[Epoch 2] Loss: 0.0019516872707754374 | MSE:0.0002430 | SNR:68.5165137 | FWHM:0.0764546 | Shape Score:0.9995124]


Best validation score: 0.20029197339934687
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:39<00:00,  2.74s/it, desc=[epoch: 3], iteration: 167/168, loss: 0.0019538138343098885]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.93it/s, desc=[Epoch 3] Loss: 0.0016194740310311317 | MSE:0.0002463 | SNR:72.7033412 | FWHM:0.0764546 | Shape Score:0.9995846]


Best validation score: 0.20028017443276586
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:38<00:00,  2.73s/it, desc=[epoch: 4], iteration: 167/168, loss: 0.0018821768351786193]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.93it/s, desc=[Epoch 4] Loss: 0.0019656228832900524 | MSE:0.0002278 | SNR:83.6486054 | FWHM:0.0764546 | Shape Score:0.9996398]


Best validation score: 0.2002543519569465
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:39<00:00,  2.73s/it, desc=[epoch: 5], iteration: 167/168, loss: 0.0017930228274226898]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.80it/s, desc=[Epoch 5] Loss: 0.0015298088546842337 | MSE:0.0001632 | SNR:92.7966301 | FWHM:0.0764546 | Shape Score:0.9996548] 


Best validation score: 0.2001996054950914
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:39<00:00,  2.74s/it, desc=[epoch: 6], iteration: 167/168, loss: 0.0015954121987479517]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.95it/s, desc=[Epoch 6] Loss: 0.0022185244597494602 | MSE:0.0001720 | SNR:80.7839642 | FWHM:0.0767732 | Shape Score:0.9995991]


Saving current model with score: 0.20021777607513042
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:41<00:00,  2.75s/it, desc=[epoch: 7], iteration: 167/168, loss: 0.001641230309836655] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.86it/s, desc=[Epoch 7] Loss: 0.001554556773044169 | MSE:0.0001158 | SNR:98.4013284 | FWHM:0.0764546 | Shape Score:0.9997062] 


Best validation score: 0.2001514038629295
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:54<00:00,  2.83s/it, desc=[epoch: 8], iteration: 167/168, loss: 0.0015211868310524594]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.89it/s, desc=[Epoch 8] Loss: 0.0019668324384838343 | MSE:0.0001589 | SNR:95.7624953 | FWHM:0.0764546 | Shape Score:0.9996207]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:41<00:00,  2.75s/it, desc=[epoch: 9], iteration: 167/168, loss: 0.0014348847464480926]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.91it/s, desc=[Epoch 9] Loss: 0.0016418424202129245 | MSE:0.0001549 | SNR:102.5540761 | FWHM:0.0764546 | Shape Score:0.9996380]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:48<00:00,  2.79s/it, desc=[epoch: 10], iteration: 167/168, loss: 0.0013481232544152242]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.73it/s, desc=[Epoch 10] Loss: 0.00161182158626616 | MSE:0.0001508 | SNR:101.5230721 | FWHM:0.0764546 | Shape Score:0.9996747]  


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:50<00:00,  2.80s/it, desc=[epoch: 11], iteration: 167/168, loss: 0.001388516661661145] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.74it/s, desc=[Epoch 11] Loss: 0.0016340480651706457 | MSE:0.0001358 | SNR:97.3150049 | FWHM:0.0764546 | Shape Score:0.9997131]


Saving current model with score: 0.2001660200786386
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:50<00:00,  2.80s/it, desc=[epoch: 12], iteration: 167/168, loss: 0.0012514154881327635]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.59it/s, desc=[Epoch 12] Loss: 0.001786793814972043 | MSE:0.0001800 | SNR:98.1112369 | FWHM:0.0764546 | Shape Score:0.9996066]  


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:50<00:00,  2.80s/it, desc=[epoch: 13], iteration: 167/168, loss: 0.001208553019317887] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.65it/s, desc=[Epoch 13] Loss: 0.0014215062838047743 | MSE:0.0001191 | SNR:101.0621094 | FWHM:0.0764546 | Shape Score:0.9997006]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:50<00:00,  2.80s/it, desc=[epoch: 14], iteration: 167/168, loss: 0.0011831444831581653]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.71it/s, desc=[Epoch 14] Loss: 0.0017299854662269354 | MSE:0.0001660 | SNR:102.4969227 | FWHM:0.0764546 | Shape Score:0.9996785]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:49<00:00,  2.79s/it, desc=[epoch: 15], iteration: 167/168, loss: 0.0011164128712456052]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.81it/s, desc=[Epoch 15] Loss: 0.0015442497096955776 | MSE:0.0001379 | SNR:110.3973476 | FWHM:0.0764546 | Shape Score:0.9996766]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:49<00:00,  2.80s/it, desc=[epoch: 16], iteration: 167/168, loss: 0.0010388516096836178]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.85it/s, desc=[Epoch 16] Loss: 0.0015826085582375526 | MSE:0.0001457 | SNR:130.9270116 | FWHM:0.0767732 | Shape Score:0.9996764]


Saving current model with score: 0.20018129252111255
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:49<00:00,  2.79s/it, desc=[epoch: 17], iteration: 167/168, loss: 0.0010756599000333587]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.73it/s, desc=[Epoch 17] Loss: 0.0014768302207812667 | MSE:0.0001633 | SNR:113.9074810 | FWHM:0.0761361 | Shape Score:0.9996178]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:47<00:00,  2.78s/it, desc=[epoch: 18], iteration: 167/168, loss: 0.0010171099170942657]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.80it/s, desc=[Epoch 18] Loss: 0.0016066262032836676 | MSE:0.0001837 | SNR:116.9164966 | FWHM:0.0767732 | Shape Score:0.9997012]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:47<00:00,  2.78s/it, desc=[epoch: 19], iteration: 167/168, loss: 0.0009316800133092329]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.76it/s, desc=[Epoch 19] Loss: 0.001707621500827372 | MSE:0.0001804 | SNR:129.8125695 | FWHM:0.0764546 | Shape Score:0.9996652] 


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:48<00:00,  2.79s/it, desc=[epoch: 20], iteration: 167/168, loss: 0.000962719668418036] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.76it/s, desc=[Epoch 20] Loss: 0.0016310039209201932 | MSE:0.0001539 | SNR:104.3622648 | FWHM:0.0767732 | Shape Score:0.9996591]


Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:49<00:00,  2.79s/it, desc=[epoch: 21], iteration: 167/168, loss: 0.0009252577732522262]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.70it/s, desc=[Epoch 21] Loss: 0.0017582259606570005 | MSE:0.0002132 | SNR:124.1670175 | FWHM:0.0770917 | Shape Score:0.9996129]


Saving current model with score: 0.2002479894956464
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:50<00:00,  2.80s/it, desc=[epoch: 22], iteration: 167/168, loss: 0.0005921933739695565]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.77it/s, desc=[Epoch 22] Loss: 0.001503751496784389 | MSE:0.0001344 | SNR:228.0499068 | FWHM:0.0764546 | Shape Score:0.9997082] 


Current learning rate: 9.779754323328192e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:52<00:00,  2.81s/it, desc=[epoch: 23], iteration: 167/168, loss: 0.0005552757684199605]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.84it/s, desc=[Epoch 23] Loss: 0.0015997396549209952 | MSE:0.0001583 | SNR:247.7476226 | FWHM:0.0764546 | Shape Score:0.9996748]


Current learning rate: 9.140576474687263e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:50<00:00,  2.80s/it, desc=[epoch: 24], iteration: 167/168, loss: 0.0005385574816803759]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.83it/s, desc=[Epoch 24] Loss: 0.001525221741758287 | MSE:0.0001451 | SNR:231.3695546 | FWHM:0.0761361 | Shape Score:0.9996993] 


Current learning rate: 8.145033635316128e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:52<00:00,  2.82s/it, desc=[epoch: 25], iteration: 167/168, loss: 0.0005237202298996548]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.78it/s, desc=[Epoch 25] Loss: 0.0017171984072774649 | MSE:0.0001673 | SNR:243.2953928 | FWHM:0.0764546 | Shape Score:0.9997338]


Current learning rate: 6.890576474687263e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:50<00:00,  2.80s/it, desc=[epoch: 26], iteration: 167/168, loss: 0.0005203105805170102]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.76it/s, desc=[Epoch 26] Loss: 0.001721328473649919 | MSE:0.0001313 | SNR:243.3908711 | FWHM:0.0764546 | Shape Score:0.9997013] 


Current learning rate: 5.5e-06
Saving current model with score: 0.2001648162180429
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:51<00:00,  2.80s/it, desc=[epoch: 27], iteration: 167/168, loss: 0.0005057805529282806]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.74it/s, desc=[Epoch 27] Loss: 0.0016715137753635645 | MSE:0.0001448 | SNR:240.4372527 | FWHM:0.0764546 | Shape Score:0.9997027]


Current learning rate: 4.109423525312737e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:49<00:00,  2.80s/it, desc=[epoch: 28], iteration: 167/168, loss: 0.0005030192320797747]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.84it/s, desc=[Epoch 28] Loss: 0.001740835839882493 | MSE:0.0001807 | SNR:240.6246098 | FWHM:0.0767732 | Shape Score:0.9996482] 


Current learning rate: 2.8549663646838717e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:49<00:00,  2.79s/it, desc=[epoch: 29], iteration: 167/168, loss: 0.0004915804027058627]
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.65it/s, desc=[Epoch 29] Loss: 0.001420316519215703 | MSE:0.0001389 | SNR:247.8191147 | FWHM:0.0764546 | Shape Score:0.9996878] 


Current learning rate: 1.859423525312737e-06
Number of parameters: 90473472


Train Loop: 100%|██████████| 168/168 [07:43<00:00,  2.76s/it, desc=[epoch: 30], iteration: 167/168, loss: 0.0004816648120357145] 
Validation Loop: 100%|██████████| 4/4 [00:00<00:00,  5.78it/s, desc=[Epoch 30] Loss: 0.0016862963093444705 | MSE:0.0001537 | SNR:246.9064289 | FWHM:0.0764546 | Shape Score:0.9997241]

Current learning rate: 1.220245676671809e-06





In [11]:
np.savetxt(save_dir_path+'train_loss_list.txt', np.array(train_loss_list), delimiter='\n')
np.savetxt(save_dir_path+'val_loss_list.txt', np.array(val_loss_list), delimiter='\n')
np.savetxt(save_dir_path+'val_mse_list.txt', np.array(val_mean_mse_list), delimiter='\n')
np.savetxt(save_dir_path+'val_snr_list.txt', np.array(val_mean_snr_list), delimiter='\n')
np.savetxt(save_dir_path+'val_linewidth_list.txt', np.array(val_mean_linewidth_list), delimiter='\n')
np.savetxt(save_dir_path+'val_mean_shape_score_list.txt', np.array(val_mean_shape_score_list), delimiter='\n')
np.savetxt(save_dir_path+'score_challenge_list.txt', np.array(score_challenge_list), delimiter='\n')