# Knowledge distillation

## 1) UC1: Anomaly Detection


In [None]:
import os
os.chdir('/home/benfenati/code/tle-supervised/')

In [None]:
import torch
import pandas as pd
import numpy as np

from Algorithms.models_audio_mae_evaluate import audioMae_vit_base_evaluate

import torch.nn as nn
import torch.optim as optim

from utils import *
from util.engine_pretrain import evaluate
from plot_anomaly import compute_threshold_accuracy

import datetime

from Datasets.AnomalyDetection_SS335.get_dataset import get_dataset as get_dataset_ss335

In [None]:
device = "cuda:0"
dir = "/home/benfenati/code/Datasets/SHM/AnomalyDetection_SS335/"
window_size = 1190
lr = 0.25e-2
total_epochs = 10

### Teacher model

In [None]:
device = torch.device(device)
teacher = audioMae_vit_base_evaluate(norm_pix_loss=False)
teacher.to(device)
# checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/checkpoints/checkpoint-pretrain_all-200.pth", map_location='cpu')
checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/Results/checkpoints/checkpoint--400.pth", map_location='cpu')
checkpoint_model = checkpoint['model']
msg = teacher.load_state_dict(checkpoint_model, strict=False)

params, size = get_model_info(teacher)
print("params={} | size={:.3f} MB".format(millify(params), size))

### Student model

In [None]:
embed_dim = 384 # 96, 192, 384, 768(original)
decoder_embed_dim = 256 # 64, 128, 256, 512(original)
student = audioMae_vit_base_evaluate(embed_dim=embed_dim, decoder_embed_dim=decoder_embed_dim, norm_pix_loss=False)
student.to(device)

params, size = get_model_info(student)
print("params={} | size={:.3f} MB".format(millify(params), size))

### Training

In [None]:
starting_date = datetime.date(2019,5,22) 
num_days = 7
print("Creating Training Dataset")
dataset = get_dataset_ss335(dir, starting_date, num_days, sensor = 'S6.1.3', time_frequency = "frequency", windowLength = window_size)
sampler_train = torch.utils.data.RandomSampler(dataset)
data_loader_train = torch.utils.data.DataLoader(
    dataset, sampler=sampler_train,
    batch_size=64,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
)
device = torch.device(device)
torch.manual_seed(0)
np.random.seed(0)

optimizer = optim.Adam(student.parameters(), lr=0.001, weight_decay=1e-6)
loss_fn_1 = nn.L1Loss()
loss_fn_2 = nn.L1Loss()
loss_fn_3 = nn.MSELoss()

teacher.eval()

b = 0.5

best_loss = 100000000
best_epoch = 0

for epoch in range(total_epochs):

    student.train()
    train_loss = 0
    counter = 0
    for samples, targets in data_loader_train:
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
    
        
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            loss_student, pred_student, _ = student(samples, mask_ratio=0.8)

        with torch.no_grad() and torch.cuda.amp.autocast():
            teacher.eval()
            loss_teacher, pred_teacher, _ = teacher(samples, mask_ratio=0.8)
        
        
        loss_1 = loss_student
        loss_2 = loss_fn_3(pred_student, pred_teacher)
        
        loss = b*loss_1 + (1-b)*loss_2

        loss.backward()
        optimizer.step()

        train_loss += loss_student.item()
        # train_loss += loss.item()
        counter +=1

### Testing

In [None]:
model_to_evaluate = student
who = "student"

### Creating Testing Dataset for Normal Data
starting_date = datetime.date(2019,5,10)
num_days = 4
print("Creating Testing Dataset -- Normal")
dataset = get_dataset_ss335(dir, starting_date, num_days, sensor = 'S6.1.3', time_frequency = "frequency", windowLength = window_size)
data_loader_test_normal = torch.utils.data.DataLoader(
    dataset, shuffle=False,
    batch_size=1,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
)
losses_normal, _ = evaluate(data_loader_test_normal, model_to_evaluate, device)
df = pd.DataFrame.from_dict(losses_normal)
df.to_csv(f'Results/masked_{window_size}samples_normal_{who}.csv', index = False, header = True)
    
### Creating Testing Dataset for Anomaly Data
starting_date = datetime.date(2019,4,17) 
num_days = 4
print("Creating Testing Dataset -- Anomaly")
dataset = get_dataset_ss335(dir, starting_date, num_days, sensor = 'S6.1.3', time_frequency = "frequency", windowLength = window_size)
data_loader_test_anomaly = torch.utils.data.DataLoader(
    dataset, shuffle=False,
    batch_size=1,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
)
losses_anomaly, _ = evaluate(data_loader_test_anomaly, model_to_evaluate, device)
df = pd.DataFrame.from_dict(losses_anomaly)
df.to_csv(f'Results/masked_{window_size}samples_anomaly_{who}.csv', index = False, header = True)

directory = "/home/benfenati/code/tle-supervised/Results/"
acc_enc = []
sens_enc = []
spec_enc = []

for dim_filtering in [15,30,60,120, 240]:
    print(f"Dim {dim_filtering}")
    print(f"Autoencoder")
    data_normal = pd.read_csv(directory + f"masked_{window_size}samples_normal_{who}.csv")
    data_anomaly = pd.read_csv(directory + f"masked_{window_size}samples_anomaly_{who}.csv")
    spec, sens, acc = compute_threshold_accuracy(data_anomaly.values, data_normal.values, None, min, max, only_acc = 1, dim_filtering = dim_filtering)
    acc_enc.append(acc*100)
    sens_enc.append(sens*100)
    spec_enc.append(spec*100)

## 2) UC2: TLE on Roccaprebalza

In [None]:
import os
import torch
import numpy as np

# from Algorithms.models_audio_mae_regression import audioMae_vit_base_R
from Algorithms.models_audio_mae_regression_modified import audioMae_vit_base_R

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from time import time
from utils import *
import util.misc as misc
from util.misc import interpolate_pos_embed

import datetime
from Datasets.Vehicles_Roccaprebalza.get_dataset import get_dataset as get_dataset_roccaprebalza

from util.engine_pretrain import evaluate_finetune
from vehicles_roccaprebalza_example import compute_accuracy

### Params

In [None]:
device = "cuda:0"
car = "y_camion" # y_car, y_camion
dir = "/home/benfenati/code/Datasets/SHM/Vehicles_Roccaprebalza/"

lr = 0.25e-5
total_epochs = 300

### Teacher model

In [None]:
teacher = audioMae_vit_base_R(norm_pix_loss=True, mask_ratio = 0.2)
teacher.to(device)
checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/Results/checkpoints/checkpoint-pretrainig_all_{car}_roccaprebalza_finetune-500.pth", map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = teacher.state_dict()
msg = teacher.load_state_dict(checkpoint_model, strict=True)

params, size = get_model_info(teacher)
print("params={} | size={:.3f} MB".format(millify(params), size))

### Student model

In [None]:
embed_dim = 96 # 96, 192, 384, 768 (original)
decoder_embed_dim = 512 # 256, 512 (original)
student = audioMae_vit_base_R(embed_dim=embed_dim, decoder_embed_dim=decoder_embed_dim, 
                              norm_pix_loss=True, mask_ratio = 0.2)
student.to(device)
checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/Results/checkpoints/checkpoint-student372-pretrain_all-200.pth", map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = student.state_dict()
for k in ['head.weight', 'head.bias']:
    if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
        print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]
msg = student.load_state_dict(checkpoint_model, strict=False)
interpolate_pos_embed(student, checkpoint_model)

params, size = get_model_info(student)
print("params={} | size={:.3f} MB".format(millify(params), size))

### Training

In [None]:
# Train
dataset_train, dataset_test = get_dataset_roccaprebalza(dir, window_sec_size = 60, shift_sec_size = 2, time_frequency = "frequency", car = car)
sampler_train = torch.utils.data.RandomSampler(dataset_train)
data_loader_train = torch.utils.data.DataLoader(
    dataset_train, sampler=sampler_train,
    batch_size=8,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, shuffle=False,
    batch_size=1,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
    )

torch.manual_seed(0)
np.random.seed(0)

device = torch.device(device)
torch.manual_seed(0)
np.random.seed(0)

optimizer = optim.Adam(student.parameters(), lr=lr, weight_decay=1e-6)
loss_fn_1 = nn.L1Loss()
loss_fn_2 = nn.L1Loss()
loss_fn_3 = nn.MSELoss()

teacher.eval()

b = 0.5
g = 0.6667

best_loss = 100000000
best_epoch = 0

for epoch in range(total_epochs):

    student.train()
    train_loss = 0
    counter = 0
    for samples, targets in data_loader_train:
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
    
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            middle_student, final_student = student(samples)

        with torch.no_grad() and torch.cuda.amp.autocast():
            teacher.eval()
            middle_teacher, final_teacher = teacher(samples)
        
        final_student = final_student.squeeze()
        final_teacher = final_teacher.squeeze()
        # print("student", middle_student.shape, final_student.shape)
        # print("teacher", middle_teacher.shape, final_teacher.shape)
        loss_1 = loss_fn_1(final_student, targets.float())
        loss_2 = loss_fn_2(final_student, final_teacher)


        loss_3 = loss_fn_3(middle_student, middle_teacher)
        
        loss = g*(b*loss_1 + (1-b)*loss_2) + (1-g)*loss_3
        # loss = b*loss_1 + (1-b)*loss_2

        loss.backward()
        optimizer.step()

        train_loss += loss_fn_1(final_student, targets).item()
        # train_loss += loss_1.item()
        counter +=1

### Testing

In [None]:
model_to_evaluate = student

dataset_train, dataset_test = get_dataset_roccaprebalza(dir, window_sec_size = 60, shift_sec_size = 2, time_frequency = "frequency", car = car)
sampler_test = torch.utils.data.RandomSampler(dataset_test)
data_loader_test = torch.utils.data.DataLoader(
    dataset_test, shuffle=False,
    batch_size=1,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
)

y_predicted, y_test = evaluate_finetune(data_loader_test, model_to_evaluate, device)
compute_accuracy(y_test, y_predicted)

## UC3: TLE on Sacertis

In [None]:
import os
import torch
import numpy as np

# from Algorithms.models_audio_mae_regression import audioMae_vit_base_R
from Algorithms.models_audio_mae_regression_modified import audioMae_vit_base_R

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from time import time
from utils import *
import util.misc as misc
from util.misc import interpolate_pos_embed

import datetime
from Datasets.Vehicles_Sacertis.get_dataset import get_dataset as get_dataset_sacertis

from util.engine_pretrain import evaluate_finetune
from vehicles_roccaprebalza_example import compute_accuracy

### Params

In [None]:
device = "cuda:0"
dir = "/home/benfenati/code/Datasets/SHM/Vehicles_Sacertis/"

lr = 0.25e-3
total_epochs = 10

### Teacher model

In [None]:
teacher = audioMae_vit_base_R(norm_pix_loss=True, mask_ratio = 0.2)
teacher.to(device)
checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/Results/checkpoints/checkpoint-pretrainig_all_vehicles_sacertis_finetune-200.pth", map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = teacher.state_dict()
msg = teacher.load_state_dict(checkpoint_model, strict=True)

params, size = get_model_info(teacher)
print("N. params = {}; Size = {:.3f}".format(params, size))

### Student model

In [None]:
embed_dim = 384 # 768 (original)
decoder_embed_dim = 512 # 256, 512 (original)
student = audioMae_vit_base_R(embed_dim=embed_dim, decoder_embed_dim=decoder_embed_dim, 
                              norm_pix_loss=True, mask_ratio = 0.2)
student.to(device)
checkpoint = torch.load(f"/home/benfenati/code/tle-supervised/Results/checkpoints/checkpoint-student-pretrain_all-200.pth", map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = student.state_dict()
for k in ['head.weight', 'head.bias']:
    if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
        print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]
msg = student.load_state_dict(checkpoint_model, strict=False)
interpolate_pos_embed(student, checkpoint_model)

params, size = get_model_info(student)
print("N. params = {}; Size = {:.3f}".format(params, size))

### Training

In [None]:
# Train
dataset_train = get_dataset_sacertis(dir, False, True, False,  sensor = "None", time_frequency = "frequency")
sampler_train = torch.utils.data.RandomSampler(dataset_train)
data_loader_train = torch.utils.data.DataLoader(
    dataset_train, sampler=sampler_train,
    batch_size=128,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True)

print("\nDone!")

device = torch.device(device)
torch.manual_seed(0)
np.random.seed(0)

optimizer = optim.Adam(student.parameters(), lr=lr, weight_decay=1e-6)
loss_fn_1 = nn.L1Loss()
loss_fn_2 = nn.L1Loss()
loss_fn_3 = nn.MSELoss()

teacher.eval()

b = 0.5
g = 0.6667

best_loss = 100000000
best_epoch = 0

print(f"Number of samples in the dataset: {len(data_loader_train)}")
print(f"Size of the first sample in the dataset: {dataset_train[0][0].size()}")

for epoch in range(total_epochs):

    student.train()
    train_loss = 0
    counter = 0
    if counter != 0:
        print(f"Epoch {epoch} - Loss {train_loss/counter}")

    for samples, targets in data_loader_train:
        if counter % 10 == 0 and counter != 0: 
            print(f"Epoch {epoch} - Loss {train_loss/counter}")
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
    
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            middle_student, final_student = student(samples)

        with torch.no_grad() and torch.cuda.amp.autocast():
            teacher.eval()
            middle_teacher, final_teacher = teacher(samples)
        
        final_student = final_student.squeeze()
        final_teacher = final_teacher.squeeze()
        loss_1 = loss_fn_1(final_student, targets.float())
        loss_2 = loss_fn_2(final_student, final_teacher)


        loss_3 = loss_fn_3(middle_student, middle_teacher)
        
        loss = g*(b*loss_1 + (1-b)*loss_2) + (1-g)*loss_3
        # loss = b*loss_1 + (1-b)*loss_2

        loss.backward()
        optimizer.step()

        train_loss += loss_fn_1(final_student, targets).item()
        # train_loss += loss_1.item()
        counter +=1

### Testing

In [None]:
dataset = get_dataset_sacertis(dir, False, False, True,  sensor = "None", time_frequency = "frequency")
data_loader_test = torch.utils.data.DataLoader(
    dataset, shuffle=False,
    batch_size=1,
    num_workers=1,
    pin_memory='store_true',
    drop_last=True,
)

In [None]:
model_to_evaluate = student

y_predicted, y_test = evaluate_finetune(data_loader_test, model_to_evaluate, device)
compute_accuracy(y_test, y_predicted)

In [None]:
student_path = "/home/benfenati/code/tle-supervised/Results/checkpoints/checkpoint-student-finetune-sacertis.pth"
torch.save(student, student_path)