# Knowledge distillation

## 1) UC1: Anomaly Detection


In [1]:
import os
os.chdir('/home/benfenati/code/tle-supervised/')

In [2]:
import torch
import pandas as pd
import numpy as np

from Algorithms.models_audio_mae_evaluate import audioMae_vit_base_evaluate

import torch.nn as nn
import torch.optim as optim

from utils import *
from util.engine_pretrain import evaluate
from plot_anomaly import compute_threshold_accuracy

import datetime

from Datasets.AnomalyDetection_SS335.get_dataset import get_dataset as get_dataset_ss335

In [3]:
device = "cuda:3"
dir = "/home/benfenati/code/Datasets/SHM/AnomalyDetection_SS335/"
window_size = 1190
lr = 0.25e-2
total_epochs = 10

### Teacher model

In [4]:
device = torch.device(device)
teacher = audioMae_vit_base_evaluate(norm_pix_loss=False)
teacher.to(device)
# checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/Results/checkpoints/checkpoint-768-512-pretrain_all-200.pth", map_location='cpu')
checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/Results/checkpoints/uc1/checkpoint--400.pth", map_location='cpu')
checkpoint_model = checkpoint['model']
msg = teacher.load_state_dict(checkpoint_model, strict=False)

params, size = get_model_info(teacher)
print("params={} | size={:.3f} MB".format(millify(params), size))

params=36 Million | size=137.772 MB


### Student model

In [None]:
embed_dim = 384 # 96, 192, 384, 768(original)
decoder_embed_dim = 256 # 64, 128, 256, 512(original)
student = audioMae_vit_base_evaluate(embed_dim=embed_dim, decoder_embed_dim=decoder_embed_dim, norm_pix_loss=False)
student.to(device)

params, size = get_model_info(student)
print("params={} | size={:.3f} MB".format(millify(params), size))

### Training

In [None]:
starting_date = datetime.date(2019,5,22) 
num_days = 7
print("Creating Training Dataset")
dataset = get_dataset_ss335(dir, starting_date, num_days, sensor = 'S6.1.3', time_frequency = "frequency", windowLength = window_size)
sampler_train = torch.utils.data.RandomSampler(dataset)
data_loader_train = torch.utils.data.DataLoader(
    dataset, sampler=sampler_train,
    batch_size=64,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
)
device = torch.device(device)
torch.manual_seed(0)
np.random.seed(0)

optimizer = optim.Adam(student.parameters(), lr=0.001, weight_decay=1e-6)
loss_fn_1 = nn.L1Loss()
loss_fn_2 = nn.L1Loss()
loss_fn_3 = nn.MSELoss()

teacher.eval()

b = 0.5

best_loss = 100000000
best_epoch = 0

for epoch in range(total_epochs):

    student.train()
    train_loss = 0
    counter = 0
    for samples, targets in data_loader_train:
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
    
        
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            loss_student, pred_student, _ = student(samples, mask_ratio=0.8)

        with torch.no_grad() and torch.cuda.amp.autocast():
            teacher.eval()
            loss_teacher, pred_teacher, _ = teacher(samples, mask_ratio=0.8)
        
        
        loss_1 = loss_student
        loss_2 = loss_fn_3(pred_student, pred_teacher)
        
        loss = b*loss_1 + (1-b)*loss_2

        loss.backward()
        optimizer.step()

        train_loss += loss_student.item()
        # train_loss += loss.item()
        counter +=1

### Testing

In [5]:
model_to_evaluate = teacher
who = "768-512"

### Creating Testing Dataset for Normal Data
starting_date = datetime.date(2019,5,10)
num_days = 4
print("Creating Testing Dataset -- Normal")
dataset = get_dataset_ss335(dir, starting_date, num_days, sensor = 'S6.1.3', time_frequency = "frequency", windowLength = window_size)
data_loader_test_normal = torch.utils.data.DataLoader(
    dataset, shuffle=False,
    batch_size=1,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
)
losses_normal, _ = evaluate(data_loader_test_normal, model_to_evaluate, device)
df = pd.DataFrame.from_dict(losses_normal)
df.to_csv(f'Results/masked_{window_size}samples_normal_{who}.csv', index = False, header = True)
    
### Creating Testing Dataset for Anomaly Data
starting_date = datetime.date(2019,4,17) 
num_days = 4
print("Creating Testing Dataset -- Anomaly")
dataset = get_dataset_ss335(dir, starting_date, num_days, sensor = 'S6.1.3', time_frequency = "frequency", windowLength = window_size)
data_loader_test_anomaly = torch.utils.data.DataLoader(
    dataset, shuffle=False,
    batch_size=1,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
)
losses_anomaly, _ = evaluate(data_loader_test_anomaly, model_to_evaluate, device)
df = pd.DataFrame.from_dict(losses_anomaly)
df.to_csv(f'Results/masked_{window_size}samples_anomaly_{who}.csv', index = False, header = True)

directory = "/home/benfenati/code/tle-supervised/Results/"
acc_enc = []
sens_enc = []
spec_enc = []

for dim_filtering in [15,30,60,120, 240]:
    print(f"Dim {dim_filtering}")
    print(f"Autoencoder")
    data_normal = pd.read_csv(directory + f"masked_{window_size}samples_normal_{who}.csv")
    data_anomaly = pd.read_csv(directory + f"masked_{window_size}samples_anomaly_{who}.csv")
    spec, sens, acc = compute_threshold_accuracy(data_anomaly.values, data_normal.values, None, min, max, only_acc = 1, dim_filtering = dim_filtering)
    acc_enc.append(acc*100)
    sens_enc.append(sens*100)
    spec_enc.append(spec*100)

Creating Testing Dataset -- Normal
Loading AnomalyDetection dataset Test:  [   0/5279]  eta: 0:36:56  loss: 0.0001 (0.0001)  mae1: 0.0001 (0.0001)  time: 0.4198  data: 0.1256  max mem: 226
Test:  [  10/5279]  eta: 0:03:41  loss: 0.0001 (0.0001)  mae1: 0.0001 (0.0001)  time: 0.0420  data: 0.0115  max mem: 226
Test:  [  20/5279]  eta: 0:02:06  loss: 0.0001 (0.0001)  mae1: 0.0001 (0.0001)  time: 0.0043  data: 0.0001  max mem: 226
Test:  [  30/5279]  eta: 0:01:32  loss: 0.0001 (0.0001)  mae1: 0.0001 (0.0001)  time: 0.0041  data: 0.0001  max mem: 226
Test:  [  40/5279]  eta: 0:01:14  loss: 0.0001 (0.0001)  mae1: 0.0001 (0.0001)  time: 0.0038  data: 0.0001  max mem: 226
Test:  [  50/5279]  eta: 0:01:03  loss: 0.0001 (0.0001)  mae1: 0.0001 (0.0001)  time: 0.0039  data: 0.0001  max mem: 226
Test:  [  60/5279]  eta: 0:00:56  loss: 0.0001 (0.0001)  mae1: 0.0001 (0.0001)  time: 0.0038  data: 0.0001  max mem: 226
Test:  [  70/5279]  eta: 0:00:51  loss: 0.0001 (0.0001)  mae1: 0.0001 (0.0001)  time:

In [5]:
# test PCA
directory = "/home/benfenati/code/tle-supervised/Results/"
window_size = 1190
acc_enc = []
sens_enc = []
spec_enc = []

for dim_filtering in [15,30,60,120, 240]:
    print(f"Dim {dim_filtering}")
    print(f"PCA")
    data_normal = pd.read_csv(directory + f"PCA_{window_size}samples_normal.csv")
    data_anomaly = pd.read_csv(directory + f"PCA_{window_size}samples_anomaly.csv")
    spec, sens, acc = compute_threshold_accuracy(data_anomaly.values, data_normal.values, None, min, max, only_acc = 1, dim_filtering = dim_filtering)
    acc_enc.append(acc*100)
    sens_enc.append(sens*100)
    spec_enc.append(spec*100)

Dim 15
PCA
Sensitivity: 0.011723141159466838 Spcificity: 0.9996211403674938 Accuracy: 0.46497479575873457
Dim 30
PCA
Sensitivity: 0.052995021679781595 Spcificity: 1.0 Accuracy: 0.48748479054406396
Dim 60
PCA
Sensitivity: 0.23847759755901718 Spcificity: 1.0 Accuracy: 0.5878671997218843
Dim 120
PCA
Sensitivity: 0.8644612172795889 Spcificity: 0.9996211403674938 Accuracy: 0.9264731444463759
Dim 240
PCA
Sensitivity: 0.9966275895294684 Spcificity: 0.9969691229399508 Accuracy: 0.9967842864592387


## 2) UC2: TLE on Roccaprebalza

In [8]:
import os
import torch
import numpy as np
import pandas as pd

# from Algorithms.models_audio_mae_regression import audioMae_vit_base_R
from Algorithms.models_audio_mae_regression_modified import audioMae_vit_base_R

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from time import time
from utils import *
import util.misc as misc
from util.misc import interpolate_pos_embed

import datetime
from Datasets.Vehicles_Roccaprebalza.get_dataset import get_dataset as get_dataset_roccaprebalza

from util.engine_pretrain import evaluate_finetune
from vehicles_roccaprebalza_example import compute_accuracy

### Params

In [11]:
device = "cuda:0"
car = "y_car" # y_car, y_camion
dir = "/home/benfenati/code/Datasets/SHM/Vehicles_Roccaprebalza/"

lr = 0.25e-5
total_epochs = 300

### Teacher model

In [12]:
teacher = audioMae_vit_base_R(norm_pix_loss=True, mask_ratio = 0.2)
teacher.to(device)
checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/Results/checkpoints/uc2/checkpoint-768-512-{car}_roccaprebalza_finetune-500.pth", map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = teacher.state_dict()
msg = teacher.load_state_dict(checkpoint_model, strict=True)

params, size = get_model_info(teacher)
print("params={} | size={:.3f} MB".format(millify(params), size))

params=26 Million | size=97.750 MB


### Student model

In [None]:
embed_dim = 96 # 96, 192, 384, 768 (original)
decoder_embed_dim = 512 # 256, 512 (original)
student = audioMae_vit_base_R(embed_dim=embed_dim, decoder_embed_dim=decoder_embed_dim, 
                              norm_pix_loss=True, mask_ratio = 0.2)
student.to(device)
checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/Results/checkpoints/checkpoint-student372-pretrain_all-200.pth", map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = student.state_dict()
for k in ['head.weight', 'head.bias']:
    if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
        print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]
msg = student.load_state_dict(checkpoint_model, strict=False)
interpolate_pos_embed(student, checkpoint_model)

params, size = get_model_info(student)
print("params={} | size={:.3f} MB".format(millify(params), size))

### Training

In [None]:
# Train
dataset_train, dataset_test = get_dataset_roccaprebalza(dir, window_sec_size = 60, shift_sec_size = 2, time_frequency = "frequency", car = car)
sampler_train = torch.utils.data.RandomSampler(dataset_train)
data_loader_train = torch.utils.data.DataLoader(
    dataset_train, sampler=sampler_train,
    batch_size=8,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, shuffle=False,
    batch_size=1,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
    )

torch.manual_seed(0)
np.random.seed(0)

device = torch.device(device)
torch.manual_seed(0)
np.random.seed(0)

optimizer = optim.Adam(student.parameters(), lr=lr, weight_decay=1e-6)
loss_fn_1 = nn.L1Loss()
loss_fn_2 = nn.L1Loss()
loss_fn_3 = nn.MSELoss()

teacher.eval()

b = 0.5
g = 0.6667

best_loss = 100000000
best_epoch = 0

for epoch in range(total_epochs):

    student.train()
    train_loss = 0
    counter = 0
    for samples, targets in data_loader_train:
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
    
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            middle_student, final_student = student(samples)

        with torch.no_grad() and torch.cuda.amp.autocast():
            teacher.eval()
            middle_teacher, final_teacher = teacher(samples)
        
        final_student = final_student.squeeze()
        final_teacher = final_teacher.squeeze()
        # print("student", middle_student.shape, final_student.shape)
        # print("teacher", middle_teacher.shape, final_teacher.shape)
        loss_1 = loss_fn_1(final_student, targets.float())
        loss_2 = loss_fn_2(final_student, final_teacher)


        loss_3 = loss_fn_3(middle_student, middle_teacher)
        
        loss = g*(b*loss_1 + (1-b)*loss_2) + (1-g)*loss_3
        # loss = b*loss_1 + (1-b)*loss_2

        loss.backward()
        optimizer.step()

        train_loss += loss_fn_1(final_student, targets).item()
        # train_loss += loss_1.item()
        counter +=1

### Testing

In [13]:
model_to_evaluate = teacher

dataset_train, dataset_test = get_dataset_roccaprebalza(dir, window_sec_size = 60, shift_sec_size = 2, time_frequency = "frequency", car = car)
sampler_test = torch.utils.data.RandomSampler(dataset_test)
data_loader_test = torch.utils.data.DataLoader(
    dataset_test, shuffle=False,
    batch_size=1,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
)

y_predicted, y_test = evaluate_finetune(data_loader_test, model_to_evaluate, device)
compute_accuracy(y_test, y_predicted)

df = pd.DataFrame({"Y_true": y_test, "Y_predicted": y_predicted})
df.to_csv(f'Results/uc2_results/Roccaprebalza_autoencoder_{car}.csv', index = False, header = True)

Loading Roccaprebalza dataset
Loading Roccaprebalza dataset


  new_labels = pd.concat([new_labels, pd.DataFrame.from_dict(dict)])
  new_labels = pd.concat([new_labels, pd.DataFrame.from_dict(dict)])
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.l1_loss(input, target, reduction=self.reduction)


Test:  [  0/203]  eta: 0:00:15  loss: 0.6460 (0.6460)  mae1: 0.8037 (0.8037)  time: 0.0777  data: 0.0634  max mem: 210
Test:  [ 10/203]  eta: 0:00:03  loss: 0.2699 (0.7279)  mae1: 0.5195 (0.6929)  time: 0.0187  data: 0.0058  max mem: 210
Test:  [ 20/203]  eta: 0:00:02  loss: 0.3479 (0.9286)  mae1: 0.5898 (0.7802)  time: 0.0125  data: 0.0001  max mem: 210
Test:  [ 30/203]  eta: 0:00:02  loss: 0.5684 (0.8692)  mae1: 0.7539 (0.7685)  time: 0.0124  data: 0.0001  max mem: 210
Test:  [ 40/203]  eta: 0:00:02  loss: 0.2927 (0.9799)  mae1: 0.5410 (0.8007)  time: 0.0127  data: 0.0001  max mem: 210
Test:  [ 50/203]  eta: 0:00:02  loss: 0.8826 (1.0583)  mae1: 0.9395 (0.8648)  time: 0.0124  data: 0.0001  max mem: 210
Test:  [ 60/203]  eta: 0:00:01  loss: 0.7656 (1.1106)  mae1: 0.8750 (0.8642)  time: 0.0121  data: 0.0001  max mem: 210
Test:  [ 70/203]  eta: 0:00:01  loss: 0.1682 (1.0148)  mae1: 0.4102 (0.8180)  time: 0.0126  data: 0.0001  max mem: 210
Test:  [ 80/203]  eta: 0:00:01  loss: 0.7932 (1.

## 3) UC3: TLE on Sacertis

In [1]:
import os
import torch
import numpy as np

# from Algorithms.models_audio_mae_regression import audioMae_vit_base_R
from Algorithms.models_audio_mae_regression_modified import audioMae_vit_base_R

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from time import time
from utils import *
import util.misc as misc
from util.misc import interpolate_pos_embed

import datetime
from Datasets.Vehicles_Sacertis.get_dataset import get_dataset as get_dataset_sacertis

from util.engine_pretrain import evaluate_finetune
from vehicles_roccaprebalza_example import compute_accuracy

  from .autonotebook import tqdm as notebook_tqdm


### Params

In [2]:
device = "cuda:0"
dir = "/home/benfenati/code/Datasets/SHM/Vehicles_Sacertis/"

lr = 0.25e-3
total_epochs = 10

### Teacher model

In [3]:
teacher = audioMae_vit_base_R(norm_pix_loss=True, mask_ratio = 0.2)
teacher.to(device)
checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/Results/checkpoints/uc3/checkpoint-pretrainig_all_vehicles_sacertis_finetune-200.pth", map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = teacher.state_dict()
msg = teacher.load_state_dict(checkpoint_model, strict=True)

params, size = get_model_info(teacher)
print("N. params = {}; Size = {:.3f}".format(params, size))

N. params = 25624626; Size = 97.750


### Student model

In [4]:
embed_dim = 384 # 768 (original)
decoder_embed_dim = 512 # 256, 512 (original)
student = audioMae_vit_base_R(embed_dim=embed_dim, decoder_embed_dim=decoder_embed_dim, 
                              norm_pix_loss=True, mask_ratio = 0.2)
student.to(device)
checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/Results/checkpoints/checkpoint-student372-pretrain_all-200.pth", map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = student.state_dict()
for k in ['head.weight', 'head.bias']:
    if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
        print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]
msg = student.load_state_dict(checkpoint_model, strict=False)
interpolate_pos_embed(student, checkpoint_model)

params, size = get_model_info(student)
print("N. params = {}; Size = {:.3f}".format(params, size))

N. params = 9326130; Size = 35.576


### Training

In [5]:
# Train
dataset_train = get_dataset_sacertis(dir, False, True, False,  sensor = "None", time_frequency = "frequency")
sampler_train = torch.utils.data.RandomSampler(dataset_train)
data_loader_train = torch.utils.data.DataLoader(
    dataset_train, sampler=sampler_train,
    batch_size=128,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True)

print("\nDone!")

device = torch.device(device)
torch.manual_seed(0)
np.random.seed(0)

optimizer = optim.Adam(student.parameters(), lr=lr, weight_decay=1e-6)
loss_fn_1 = nn.L1Loss()
loss_fn_2 = nn.L1Loss()
loss_fn_3 = nn.MSELoss()

teacher.eval()

b = 0.5
g = 0.6667

best_loss = 100000000
best_epoch = 0

print(f"Number of samples in the dataset: {len(data_loader_train)}")
print(f"Size of the first sample in the dataset: {dataset_train[0][0].size()}")

for epoch in range(total_epochs):

    student.train()
    train_loss = 0
    counter = 0
    if counter != 0:
        print(f"Epoch {epoch} - Loss {train_loss/counter}")

    for samples, targets in data_loader_train:
        if counter % 10 == 0 and counter != 0: 
            print(f"Epoch {epoch} - Loss {train_loss/counter}")
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
    
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            middle_student, final_student = student(samples)

        with torch.no_grad() and torch.cuda.amp.autocast():
            teacher.eval()
            middle_teacher, final_teacher = teacher(samples)
        
        final_student = final_student.squeeze()
        final_teacher = final_teacher.squeeze()
        loss_1 = loss_fn_1(final_student, targets.float())
        loss_2 = loss_fn_2(final_student, final_teacher)


        loss_3 = loss_fn_3(middle_student, middle_teacher)
        
        loss = g*(b*loss_1 + (1-b)*loss_2) + (1-g)*loss_3
        # loss = b*loss_1 + (1-b)*loss_2

        loss.backward()
        optimizer.step()

        train_loss += loss_fn_1(final_student, targets).item()
        # train_loss += loss_1.item()
        counter +=1

Loading Sacertis dataset 
Done!
Number of samples in the dataset: 1562
Size of the first sample in the dataset: torch.Size([1, 100, 100])
Epoch 0 - Loss 3.44638671875
Epoch 0 - Loss 2.4655517578125
Epoch 0 - Loss 2.0519368489583334
Epoch 0 - Loss 1.80025634765625
Epoch 0 - Loss 1.650927734375
Epoch 0 - Loss 1.5622802734375
Epoch 0 - Loss 1.4873116629464285
Epoch 0 - Loss 1.432537841796875
Epoch 0 - Loss 1.3876898871527779
Epoch 0 - Loss 1.3570166015625
Epoch 0 - Loss 1.3272238991477272
Epoch 0 - Loss 1.30159912109375
Epoch 0 - Loss 1.2871056189903847
Epoch 0 - Loss 1.2736363002232143
Epoch 0 - Loss 1.260966796875
Epoch 0 - Loss 1.246038818359375
Epoch 0 - Loss 1.2337230009191176
Epoch 0 - Loss 1.2248155381944446
Epoch 0 - Loss 1.2150159333881578
Epoch 0 - Loss 1.20544921875
Epoch 0 - Loss 1.1962053571428573
Epoch 0 - Loss 1.1879283558238636
Epoch 0 - Loss 1.1780082370923912
Epoch 0 - Loss 1.17017822265625
Epoch 0 - Loss 1.16494140625
Epoch 0 - Loss 1.160516826923077
Epoch 0 - Loss 1.15

### Testing

In [6]:
dataset = get_dataset_sacertis(dir, False, False, True,  sensor = "None", time_frequency = "frequency")
data_loader_test = torch.utils.data.DataLoader(
    dataset, shuffle=False,
    batch_size=1,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
)

Loading Sacertis dataset 

In [7]:
model_to_evaluate = student

y_predicted, y_test = evaluate_finetune(data_loader_test, model_to_evaluate, device)
compute_accuracy(y_test, y_predicted)

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.l1_loss(input, target, reduction=self.reduction)


Test:  [    0/50000]  eta: 5:40:18  loss: 0.0005 (0.0005)  mae1: 0.0234 (0.0234)  time: 0.4084  data: 0.3692  max mem: 13952
Test:  [   10/50000]  eta: 0:33:16  loss: 0.0005 (1.0660)  mae1: 0.0234 (0.5520)  time: 0.0399  data: 0.0336  max mem: 13952
Test:  [   20/50000]  eta: 0:18:31  loss: 0.9442 (1.1571)  mae1: 0.9717 (0.7128)  time: 0.0029  data: 0.0001  max mem: 13952
Test:  [   30/50000]  eta: 0:13:09  loss: 0.9537 (1.6352)  mae1: 0.9766 (0.8990)  time: 0.0025  data: 0.0001  max mem: 13952
Test:  [   40/50000]  eta: 0:10:33  loss: 1.0434 (1.7153)  mae1: 1.0215 (0.9713)  time: 0.0026  data: 0.0001  max mem: 13952
Test:  [   50/50000]  eta: 0:08:51  loss: 0.9556 (1.5099)  mae1: 0.9775 (0.9162)  time: 0.0027  data: 0.0001  max mem: 13952
Test:  [   60/50000]  eta: 0:07:44  loss: 0.9556 (1.7781)  mae1: 0.9775 (0.9949)  time: 0.0024  data: 0.0001  max mem: 13952
Test:  [   70/50000]  eta: 0:06:55  loss: 0.9537 (1.6794)  mae1: 0.9766 (0.9541)  time: 0.0024  data: 0.0001  max mem: 13952


In [None]:
student_path = "/home/benfenati/code/tle-supervised/Results/checkpoints/checkpoint-student-finetune-sacertis.pth"
torch.save(student, student_path)