In [None]:
from lib import *
from TAILMIL import *
import os
import re
from sklearn.metrics import roc_auc_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import average_precision_score

preprocessed_path = 'SMD/preprocessed_set/'
interpretation_path = 'SMD/interpretation_label/'

x_train_path = [preprocessed_path + t for t in sorted(os.listdir(preprocessed_path)) if 'train' in t]
x_test_path = [preprocessed_path + t for t in sorted(os.listdir(preprocessed_path)) if 'test.' in t]
y_test_path = [preprocessed_path + t for t in sorted(os.listdir(preprocessed_path)) if 'test_' in t]
y_test_label = [interpretation_path + t for t in sorted(os.listdir(interpretation_path)) if 'machine' in t]

In [None]:
for ii in tqdm_notebook(range(11)):
    f = open(x_train_path[ii], "rb")
    x_train = pickle.load(f)
    f.close()

    f = open(x_test_path[ii], "rb")
    x_test = pickle.load(f)
    f.close()

    f = open(y_test_path[ii], "rb")
    y_test = pickle.load(f).reshape((-1))
    f.close()

    f = open(y_test_label[ii], 'r')     # mode = 부분은 생략해도 됨
    lines= f.read()
    f.close()

    print('-------------------------------------------------------------')
    print(x_train_path[ii].split('/')[-1][:-4])

    y_interpretated_label = np.zeros(x_train.shape)

    start_indexes = []
    end_indexes = []
    i_labels = []

    for info in lines.split('\n')[:-1]:
        pattern = r'(\d+)-(\d+):([\d,]+)'
        match = re.match(pattern, info)
        if match:
            start_index = int(match.group(1))
            end_index = int(match.group(2))
            columns_to_label = list(map(int, match.group(3).split(',')))

            # 해당 인덱스와 컬럼에 1로 라벨링합니다.
            y_interpretated_label[start_index:end_index + 1, columns_to_label] = 1
            start_indexes.append(start_index)
            end_indexes.append(end_index)
            i_labels.append(columns_to_label)

    x_train, scaler = normalize_data(x_train, scaler=None)
    x_test, _ = normalize_data(x_test, scaler=scaler)

    # print('Data Size')
    # print(x_train.shape, x_test.shape, y_test.shape, y_interpretated_label.shape)

    n_features = x_train.shape[1]
    window_size, target_dims = 12, x_train.shape[1]
    out_dim = 1
    batch_size, val_split, shuffle_dataset = 128, 0.2, True

    train_dataset = SlidingWindowDataset(x_train, window_size, target_dims)
    test_dataset = SlidingWindowDataset(x_test, window_size, target_dims)

    train_loader, val_loader, test_loader = create_data_loaders(
        train_dataset, batch_size, val_split, shuffle_dataset, test_dataset=test_dataset
    )

    device = torch.device("cuda:0")

    model = MTAD_GAT_RECON(
        n_features,
        window_size,
        n_features,
        kernel_size=7,
        use_gatv2=True,
        feat_gat_embed_dim=None,
        time_gat_embed_dim=None,
        gru_n_layers=1,
        gru_hid_dim=300,
        recon_n_layers=1,
        recon_hid_dim=300,
        dropout=0.3,
        alpha=0.2
    ).to(device)

    save_path = 'Model/' + x_train_path[ii].split('/')[-1][:-4]+'.p'
    model.load_state_dict(torch.load(save_path))

    x_train_new = []
    y_train = []

    model.eval()
    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)
        y_hat = model(x)

        x_train_new.append(x.detach().cpu().numpy())
        y_train.append((y_hat - y).detach().cpu().numpy()**2)

    y_train = np.concatenate(y_train, axis=0)
    y_train = y_train.sum(2).sum(1)
    percentile_95 = np.percentile(y_train, 50)

    # y_train[y_train<percentile_95] = 0

    x_train_new = np.concatenate(x_train_new, axis=0)

    x_valid_new = []
    y_valid = []

    for x, y in val_loader:
        x = x.to(device)
        y = y.to(device)
        y_hat = model(x)

        x_valid_new.append(x.detach().cpu().numpy())
        y_valid.append((y_hat - y).detach().cpu().numpy()**2)

    y_valid = np.concatenate(y_valid, axis=0)
    y_valid = y_valid.sum(2).sum(1)
    # y_valid[y_valid<percentile_95] = 0
    x_valid_new = np.concatenate(x_valid_new, axis=0)

    trainset = torch.utils.data.TensorDataset(
        torch.FloatTensor(x_train_new), 
        torch.FloatTensor(y_train), 
    )
    train_loader = torch.utils.data.DataLoader(
        trainset,
        batch_size = 1020,
        shuffle = True
    )

    validset = torch.utils.data.TensorDataset(
        torch.FloatTensor(x_valid_new), 
        torch.FloatTensor(y_valid),
    )
    valid_loader = torch.utils.data.DataLoader(
        validset,
        batch_size = 512,
        shuffle = False
    )

    device = torch.device("cuda:0")
    max_grad_norm = 0.9

    milnet = TAILMIL(window=12, dim=38, sub_window=1).to(device)
    optimizer = torch.optim.AdamW(milnet.parameters(), 1e-2)
    loss_fn = torch.nn.MSELoss()

    best_loss_path2 = f"{x_train_path[ii].split('/')[-1][:-4]}_TAILMIL.p"

    best_mse = 999

    train_auc = []
    valid_auc = []

    for e in tqdm_notebook(range(1, int(20)+1)):
        train_output = []
        train_label = []

        valid_output = []
        valid_label = []

        milnet.train()
        for batch_id, (x, label) in enumerate(train_loader):
            optimizer.zero_grad()
            label = label.float().to(device)
            x = x.transpose(1, 2).to(device)
            out, _, _ = milnet(x)
            loss = loss_fn(out, label.reshape(out.shape))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(milnet.parameters(), max_grad_norm)
            optimizer.step()

            temp_out = out.detach().cpu().numpy()
            temp_label = label.detach().cpu().numpy()
            train_output += list(temp_out)
            train_label += list(temp_label)

    #     if e%10 ==0 or e ==1:
    #         print(f'-----------------------------------------------{e} Train END--------------------------------------------')
    #         print(f'ROC-AUC score: {roc_auc_score(np.array(train_label)>0.2, np.array(train_output).reshape(-1))}')
    #         print(f'MSE score: {mean_squared_error(np.array(train_label), np.array(train_output).reshape(-1))}')

        train_auc.append(roc_auc_score(np.array(train_label)>percentile_95, np.argmax(train_output, axis=1)))

        milnet.eval()
        for batch_id, (x, label) in enumerate(valid_loader):   
            label = label.float().to(device)
            x = x.transpose(1, 2).to(device)
            out, _, _ = milnet(x)

            temp_out = out.detach().cpu().numpy()
            temp_label = label.detach().cpu().numpy()
            valid_output += list(temp_out)
            valid_label += list(temp_label)

    #     if e%10 == 0 or e ==1:
    #         print(f'-----------------------------------------------{e} Validation--------------------------------------------')
    #         print(f'ROC-AUC score: {roc_auc_score(np.array(valid_label)>0.2, np.array(valid_output).reshape(-1))}')
    #         print(f'MSE score: {mean_squared_error(np.array(valid_label), np.array(valid_output).reshape(-1))}')

        valid_auc.append(roc_auc_score(np.array(valid_label)>percentile_95, np.array(valid_output).reshape(-1)))

    #     if best_f1 < roc_auc_score(np.array(valid_label)>0.2, np.array(valid_output).reshape(-1)):
    #         best_f1 = roc_auc_score(np.array(valid_label)>0.2, np.array(valid_output).reshape(-1))
    #         torch.save(calmilnet.state_dict(), best_loss_path1)

        if best_mse > mean_squared_error(np.array(valid_label), np.array(valid_output).reshape(-1)):
            best_f1 = mean_squared_error(np.array(valid_label), np.array(valid_output).reshape(-1))
            torch.save(milnet.state_dict(), best_loss_path2)

    milnet.load_state_dict(torch.load(best_loss_path2))

    sliding_window_label = []

    for i in range(len(y_test)-12):
        if sum(y_test[i:i+12]) > 0:
            sliding_window_label.append(1)
        else:
            sliding_window_label.append(0)

    sliding_window_label = np.array(sliding_window_label)

    recons = []
    recons2 = []
    preds = []

    milnet.eval()
    for x, y in test_loader:
        x = x.transpose(1, 2).to(device)
        y = y.to(device)
        y_hat, y_hat2, y_hat3 = milnet(x)

        recons.append(y_hat.detach().cpu().numpy())
        recons2.append(y_hat2[:, -1].detach().cpu().numpy())
        preds.append(y_hat3.detach().cpu().numpy())

    recons = np.concatenate(recons, axis=0)
    recons2 = np.concatenate(recons2, axis=0)
    recons2 = recons2.reshape(-1)
    preds = np.concatenate(preds, axis=0)

    print('Slinding Window Performance')
    print(f'AUROC: {roc_auc_score(sliding_window_label, recons)}')
    print(f'AUPR: {average_precision_score(sliding_window_label, recons)}')

    print('One Step Performance')
    print(f'AUROC: {roc_auc_score(y_test[12:], recons2)}')
    print(f'AUPR: {average_precision_score(y_test[12:], recons2)}')

    corrects = []
    for i in range(len(start_indexes)):
        temp_preds = preds.reshape(-1, 38)[start_indexes[i]:end_indexes[i]].argmax(1)
        for j in range(len(temp_preds)):
            if temp_preds[j] in i_labels[i]:
                corrects.append(1)
            else:
                corrects.append(0)

    print('Interpretation Performance')
    print(f'ACC: {sum(corrects) / len(corrects)}')
    print('-------------------------------------------------------------')