In [None]:
import torch
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
class DefaultConfig(object):
    model = 'GTRAE_MVPT_MTL'
    use_gpu = True
    if use_gpu:
        device = 'gpu'
    else:
        device = 'cpu'
    load_model_path = None
    data_name = 'iris'
    batch_size = 128
    num_workers = 0
    max_epoch = 5000
    patience = patience_acc = patience_mae = 1000 
    lr = 0.001
    lr_decay = 0.97
    weight_decay = 1e-5
    train_rate = 0.6
    val_rate = 0.1
    droput = 0.1
    miss_rate = 0.05
    whiten_rate = 0.1 
    id_num = 1 
    TOPK = 5 
    n_hidden = 30 
    use_all_to_train = True 
    model_save_path_acc = 'zz_saved_model/best_model_acc.pth'
    model_save_path_mae = 'zz_saved_model/best_model_mae.pth'
opt = DefaultConfig()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device=", device)
print("opt",opt.num_workers)

In [None]:
import os
import csv
import random
import copy
import numpy as np
import scipy.sparse as sp
import xlwt, xlrd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils import data as torch_data
from sklearn import metrics
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import mean_absolute_error, mean_squared_error, accuracy_score

def calculate_imputation_metrics(imputed_data, target_data):
    mae = mean_absolute_error(target_data.cpu().numpy(), imputed_data.cpu().numpy())
    rmse = np.sqrt(mean_squared_error(target_data.cpu().numpy(), imputed_data.cpu().numpy()))
    mape = np.mean(np.abs((target_data.cpu().numpy() - imputed_data.cpu().numpy()) / (target_data.cpu().numpy()+1e-8))) * 100
    return mae, rmse, mape


def calculate_imputation_metrics_mask(imputed_data, target_data, mask): 
    imputed_data_np = imputed_data.cpu().numpy()
    target_data_np = target_data.cpu().numpy()
    mask_np = mask.cpu().numpy()
    filtered_imputed_data = imputed_data_np[mask_np == 0]
    filtered_target_data = target_data_np[mask_np == 0]
    mae = mean_absolute_error(filtered_target_data, filtered_imputed_data)
    rmse = np.sqrt(mean_squared_error(filtered_target_data, filtered_imputed_data))
    epsilon = 1e-8
    mape = np.mean(np.abs((filtered_target_data - filtered_imputed_data) / (filtered_target_data + epsilon))) * 100
    return mae, rmse, mape

def calculate_accuracy(estimated_label, target_labels):
    acc = accuracy_score(target_labels.cpu().numpy(), estimated_label.cpu().numpy())
    return acc

def writeline_csv(filename, rmse, mae, mape, acc, norm_rmse, norm_mae, norm_mape):
    file_exists = os.path.isfile(filename)
    with open(filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        if not file_exists:
            writer.writerow(['Dataset','Miss Rate', 'RMSE', 'MAE', 'MAPE', 'Acc', 'Norm RMSE', 'Norm MAE', 'Norm MAPE'])
        writer.writerow([opt.data_name, opt.miss_rate, rmse, mae, mape, acc, norm_rmse, norm_mae, norm_mape])

def whiten(tensor_,whiten_rate = 0.05):
    tensor = tensor_.clone()
    attr = tensor.shape[1] // 3  
    mask = tensor[:, :attr]  
    valid_indices = torch.where(mask == 1)
    num_to_whiten = int(whiten_rate * len(valid_indices[0]))
    if num_to_whiten > 0:
        indices_to_whiten = torch.randperm(len(valid_indices[0]))[:num_to_whiten]
        mask[valid_indices[0][indices_to_whiten], valid_indices[1][indices_to_whiten]] = 0
    tensor[:, :attr] = mask
    return tensor

In [None]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import pandas as pd

class BasicModule(torch.nn.Module):
    def __init__(self):
        super(BasicModule, self).__init__()
        self.model_name = str(type(self))
    def load(self, path):
        self.load_state_dict(torch.load(path))
    def save(self, name=None):
        if name is None:
            prefix = './checkpoints/' + self.model_name
        torch.save(self.state_dict(), prefix)
        return name
    
class GTRAE_MVPT_MTL(BasicModule):
    def __init__(self, hidden_dim, num_classes, global_mask_shape, data_class, model_type):
        super(GTRAE_MVPT_MTL, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        input_feature_dim = global_mask_shape[1]
        self.model_name = 'GTRAE_MVPT_MTL' 
        self.dropout_layer = nn.Dropout(p=opt.droput) 
        self.relu = nn.ReLU() 
        self.graph_sage_layer1 = Graphsage(input_feature_dim, hidden_dim) 
        self.graph_sage_layer2 = Graphsage_hidden(hidden_dim, hidden_dim) 
        self.feature_transform = nn.Linear(hidden_dim, input_feature_dim) 
        self.class_transform = nn.Linear(hidden_dim, num_classes) 
        self.global_missing_values = nn.Parameter(torch.rand(global_mask_shape))
        self.data_class = data_class
        if model_type == 'train':
            self.global_norm_miss_data = self.data_class.missing_data_train_norm.clone()
            self.global_mask = self.data_class.mask_train.clone()
            self.global_adj = self.data_class.adj_train.clone()
        elif model_type == 'val':
            self.global_norm_miss_data = self.data_class.missing_data_val_norm.clone()
            self.global_mask = self.data_class.mask_val.clone()
            self.global_adj = self.data_class.adj_val.clone()
        elif model_type == 'test':  
            self.global_norm_miss_data = self.data_class.missing_data_test_norm.clone() # X
            self.global_mask = self.data_class.mask_test.clone() # M
            self.global_adj = self.data_class.adj_test.clone() # A
        elif model_type == 'all':  
            self.global_norm_miss_data = self.data_class.missing_data_all_norm.clone() # X
            self.global_mask = self.data_class.mask_all.clone() # M
            self.global_adj = self.data_class.adj_all.clone() # A
        self.global_norm_miss_data = torch.nan_to_num(self.global_norm_miss_data, nan=0.0)
        aggregated_AX = torch.mm(self.global_adj, self.global_norm_miss_data) 
        aggregated_AAX = aggregated_AX + torch.mm(self.global_adj, aggregated_AX)
        aggregated_AAAX = aggregated_AAX + torch.mm(self.global_adj, aggregated_AAX)
        valid_weights_AX = torch.mm(self.global_adj, self.global_mask) 
        valid_weights_AAX = valid_weights_AX + torch.mm(self.global_adj, valid_weights_AX)
        valid_weights_AAAX = valid_weights_AAX + torch.mm(self.global_adj, valid_weights_AAX)
        self.global_aggregated_values = torch.where(valid_weights_AAAX > 0.2, aggregated_AAAX / (valid_weights_AAAX + 1e-8), torch.zeros_like(valid_weights_AAAX))
    
    def forward(self, norm_missing_data, batch_idx):
        batch_mask = ~torch.isnan(norm_missing_data).to(self.device)
        batch_missing_values = self.global_missing_values[batch_idx]
        input_data = torch.where(batch_mask.bool(), norm_missing_data, batch_missing_values)
        batch_aggregated_values = self.global_aggregated_values[batch_idx]
        hidden_states_1 = self.graph_sage_layer1(input_data, batch_aggregated_values)
        hidden_states_2 = self.graph_sage_layer2(hidden_states_1)
        concatenated_result = torch.tensor([], device=self.device)
        for j, hidden in enumerate(hidden_states_2[:-1]):
            transformed_data_j = self.feature_transform(hidden)
            feature_j = transformed_data_j[:, j].unsqueeze(1) 
            concatenated_result = torch.cat((concatenated_result, feature_j), dim=1) 
        class_logits = self.class_transform(hidden_states_2[-1])
        imputed_data = concatenated_result
        estimated_label = F.log_softmax(class_logits, dim=1)

        return input_data, imputed_data, estimated_label

class Graphsage(nn.Module): 
    def __init__(self, input_feature_dim, output_feature_dim):
        super(Graphsage, self).__init__()
        self.in_features = input_feature_dim
        self.model_name = 'Graphsage'
        self.W = nn.Parameter(torch.zeros(size=(2 * input_feature_dim, output_feature_dim)))
        self.bias = nn.Parameter(torch.zeros(output_feature_dim))
        self.reset_parameters()
        self.bn = nn.BatchNorm1d(output_feature_dim)  
        
    def reset_parameters(self): 
        stdv = 1. / (math.sqrt(self.W.size(1)) + 1e-8)
        self.W.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input_data, aggregated_data):
        hidden_states = []  
        n_attributes = input_data.size(1)  
        concatenated_features = torch.cat([input_data, aggregated_data], dim=1)
        for j in range(n_attributes): 
            modified_features = concatenated_features.clone()
            modified_features[:, j] = 0
            transformed_features = torch.mm(modified_features, self.W) + self.bias
            hidden_states.append(transformed_features)
        hidden_states.append(torch.mm(concatenated_features, self.W) + self.bias)
        return hidden_states

class Graphsage_hidden(nn.Module): 
    def __init__(self, hidden_feature_dim, output_feature_dim):
        super(Graphsage_hidden, self).__init__()
        self.model_name = 'Graphsage_hidden'
        self.W = nn.Parameter(torch.zeros(size=(hidden_feature_dim, output_feature_dim)))
        self.bias = nn.Parameter(torch.zeros(output_feature_dim))
        self.reset_parameters()
        self.bn = nn.BatchNorm1d(output_feature_dim) 

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.W.size(1))
        self.W.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

    def forward(self, hidden_states):
        new_hidden_states = []
        for hidden in hidden_states:
            transformed_hidden = torch.mm(hidden, self.W) + self.bias 
            new_hidden_states.append(transformed_hidden)
        return new_hidden_states

In [None]:
from torch.utils import data as torch_data
import numpy as np

class Dataload_zz(torch_data.Dataset):
    def __init__(self, norm_missing_data, norm_full_data, labels_onehot,adj_matrix):
        self.norm_missing_data = norm_missing_data 
        self.norm_full_data = norm_full_data
        self.labels_onehot = labels_onehot 
        self.adj_matrix = adj_matrix

    def __len__(self):
        return len(self.norm_missing_data)

    def __getitem__(self, idx):
        batch_norm_missing_data_sample = self.norm_missing_data[idx]
        batch_norm_full_data_sample = self.norm_full_data[idx]
        labels_onehot = self.labels_onehot[idx]
        adj_matrix_sample = self.adj_matrix[idx]
        return idx, batch_norm_missing_data_sample, batch_norm_full_data_sample, labels_onehot, adj_matrix_sample

class Dataset_zz:
    def __init__(self, opt):
        self.data_name = opt.data_name
        self.miss_rate = int(100 * opt.miss_rate)
        self.id_num = opt.id_num
        self.TOPK = opt.TOPK
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.load_data()
        self.normalize_data()
        self.encode_labels()
        self.generate_masks()
        self.generate_adjacency_matrices()
        self.define_indices() 
        self.to_tensor() 

    def load_data(self):
        base_path = f'datasets/{self.data_name}/{self.data_name}'
        self.missing_data_train = np.array(pd.read_csv(f'{base_path}_train_RANDOM_{self.miss_rate}%_NUM_{self.id_num}.csv', header=None))[:, :-1]
        self.missing_data_val = np.array(pd.read_csv(f'{base_path}_val_RANDOM_{self.miss_rate}%_NUM_{self.id_num}.csv', header=None))[:, :-1]
        self.missing_data_test = np.array(pd.read_csv(f'{base_path}_test_RANDOM_{self.miss_rate}%_NUM_{self.id_num}.csv', header=None))[:, :-1]
        self.missing_data_all = np.concatenate((self.missing_data_train, self.missing_data_val, self.missing_data_test), axis=0)
        
        self.train_set = np.array(pd.read_csv(f'{base_path}_train_REAL_{self.miss_rate}%_NUM_{self.id_num}.csv', header=None))
        self.val_set = np.array(pd.read_csv(f'{base_path}_val_REAL_{self.miss_rate}%_NUM_{self.id_num}.csv', header=None))
        self.test_set = np.array(pd.read_csv(f'{base_path}_test_REAL_{self.miss_rate}%_NUM_{self.id_num}.csv', header=None))
        self.all_set = np.concatenate((self.train_set, self.val_set, self.test_set), axis=0)

        self.train_data, self.train_label = self.train_set[:, :-1], self.train_set[:, -1]
        self.val_data, self.val_label = self.val_set[:, :-1], self.val_set[:, -1]
        self.test_data, self.test_label = self.test_set[:, :-1], self.test_set[:, -1]
        self.all_data, self.all_label = self.all_set[:, :-1], self.all_set[:, -1]

    def normalize_data(self):
        _, self.all_mean, self.all_std = self.normalization(self.missing_data_all.copy())
        self.missing_data_train_norm = self.normalization_with_mean_std(self.missing_data_train.copy(), self.all_mean, self.all_std)
        self.missing_data_val_norm = self.normalization_with_mean_std(self.missing_data_val.copy(), self.all_mean, self.all_std)
        self.missing_data_test_norm = self.normalization_with_mean_std(self.missing_data_test.copy(), self.all_mean, self.all_std)
        self.missing_data_all_norm = self.normalization_with_mean_std(self.missing_data_all.copy(), self.all_mean, self.all_std)
        self.train_data_norm = self.normalization_with_mean_std(self.train_data.copy(), self.all_mean, self.all_std)
        self.val_data_norm = self.normalization_with_mean_std(self.val_data.copy(), self.all_mean, self.all_std)
        self.test_data_norm = self.normalization_with_mean_std(self.test_data.copy(), self.all_mean, self.all_std)
        self.all_data_norm = self.normalization_with_mean_std(self.all_data.copy(), self.all_mean, self.all_std)
        
    @staticmethod
    def normalization(data): 
        temp = np.array(data)
        mean = np.nanmean(temp, axis=0) 
        std = np.nanstd(temp, axis=0) 
        temp_masked = np.ma.masked_invalid(temp)
        temp = (temp_masked - mean) / (std + 1e-8) 
        temp = temp.filled(np.nan)
        return temp, mean, std

    @staticmethod
    def normalization_with_mean_std(data, mean, std): 
        temp = np.array(data)
        temp_masked = np.ma.masked_invalid(temp) 
        temp = (temp_masked - mean).astype(np.float32) / (std + 1e-8).astype(np.float32) 
        temp = temp.filled(np.nan) 
        return temp

    @staticmethod
    def denormalization_with_mean_std(data, mean, std):
        return data * (std + 1e-8) + mean
    
    def encode_labels(self):
        unique_all_labels = np.unique(np.concatenate((self.train_label, self.val_label, self.test_label)))
        self.classes_dict = {label: np.eye(len(unique_all_labels))[i, :] for i, label in enumerate(unique_all_labels)}
        self.train_label_onehot = self.encode_onehot(self.train_label)
        self.val_label_onehot = self.encode_onehot(self.val_label)
        self.test_label_onehot = self.encode_onehot(self.test_label)
        self.all_label_onehot = self.encode_onehot(self.all_label)

    def encode_onehot(self, labels):
        map_func = np.vectorize(self.classes_dict.get, otypes=[np.ndarray])
        labels_onehot = np.array(list(map_func(labels)), dtype=np.int32)
        return labels_onehot
    
    def generate_masks(self):
        self.mask_train = (~np.isnan(self.missing_data_train)).astype(int)
        self.mask_val = (~np.isnan(self.missing_data_val)).astype(int)
        self.mask_test = (~np.isnan(self.missing_data_test)).astype(int)
        self.mask_all = (~np.isnan(self.missing_data_all)).astype(int)
    
    def generate_adjacency_matrices(self):
        self.adj_train = self.get_adj_euclidean(self.missing_data_train, self.mask_train)
        self.adj_val = self.get_adj_euclidean(self.missing_data_val, self.mask_val)
        self.adj_test = self.get_adj_euclidean(self.missing_data_test, self.mask_test)
        self.adj_all = self.get_adj_euclidean(self.missing_data_all, self.mask_all)
        
    def get_data_var(self, data, data_mask):
        (n_sample, n_attribute) = data.shape
        data_var = np.zeros(n_attribute)
        for i in range(n_attribute):
            valid_data = data[data_mask[:, i] == 1, i]  
            if valid_data.size > 0: 
                data_mean = np.mean(valid_data)  
                var = np.sum((valid_data - data_mean) ** 2) / valid_data.size
                data_var[i] = var  
            else:
                data_var[i] = np.nan  
        return data_var

    def get_dis_euclidean(self, x, y, data_var, x_mask, y_mask):
        valid_mask = x_mask * y_mask
        valid_indices = np.where(valid_mask == 1)[0]
        if valid_indices.size > 0:
            x_valid = x[valid_indices]
            y_valid = y[valid_indices]
            data_var_valid = data_var[valid_indices]
            n_attributes = x.shape[0]  
            distance_squared_sum = (((x_valid - y_valid) ** 2)/ data_var_valid).sum()
            scaled_distance_squared = distance_squared_sum * n_attributes / valid_indices.size
            dis = np.sqrt(scaled_distance_squared)
        else:
            dis = np.inf
        return dis

    def get_adj_euclidean(self, data, matrix_mask):
        n_sample = data.shape[0]
        matrix_adj = np.zeros((n_sample, n_sample))
        data_var = self.get_data_var(data, matrix_mask)
        for i in range(n_sample):
            for j in range(i + 1, n_sample):
                Dij = self.get_dis_euclidean(data[i], data[j], data_var, matrix_mask[i], matrix_mask[j])
                if Dij != 0 and Dij != np.inf: 
                    matrix_adj[i][j] = matrix_adj[j][i] = 1 / Dij
        for i in range(n_sample):
            row = matrix_adj[i]
            if np.count_nonzero(row) > self.TOPK:
                threshold = np.partition(row, -self.TOPK)[-self.TOPK]
                matrix_adj[i][matrix_adj[i] < threshold] = 0
        row_sums = matrix_adj.sum(axis=1) + 1e-6
        matrix_adj = matrix_adj / row_sums[:, np.newaxis]
        return matrix_adj
    
    def define_indices(self):
        train_len = len(self.train_data)
        val_len = len(self.val_data)
        test_len = len(self.test_data)

        self.train_indices = range(0, train_len)
        self.val_indices = range(train_len, train_len + val_len)
        self.test_indices = range(train_len + val_len, train_len + val_len + test_len)

        self.train_indices = list(self.train_indices)
        self.val_indices = list(self.val_indices)
        self.test_indices = list(self.test_indices)
        
    def to_tensor(self):
        self.missing_data_train = torch.from_numpy(self.missing_data_train).float().to(self.device)
        self.missing_data_val = torch.from_numpy(self.missing_data_val).float().to(self.device)
        self.missing_data_test = torch.from_numpy(self.missing_data_test).float().to(self.device)
        self.missing_data_all = torch.from_numpy(self.missing_data_all).float().to(self.device)

        self.missing_data_train_norm = torch.from_numpy(self.missing_data_train_norm).float().to(self.device)
        self.missing_data_val_norm = torch.from_numpy(self.missing_data_val_norm).float().to(self.device)
        self.missing_data_test_norm = torch.from_numpy(self.missing_data_test_norm).float().to(self.device)
        self.missing_data_all_norm = torch.from_numpy(self.missing_data_all_norm).float().to(self.device)

        self.train_data =  torch.from_numpy(self.train_data).float().to(self.device)
        self.val_data =  torch.from_numpy(self.val_data).float().to(self.device)
        self.test_data =  torch.from_numpy(self.test_data).float().to(self.device)
        self.all_data = torch.from_numpy(self.all_data).float().to(self.device)

        self.train_data_norm =  torch.from_numpy(self.train_data_norm).float().to(self.device)
        self.val_data_norm =  torch.from_numpy(self.val_data_norm).float().to(self.device)
        self.test_data_norm =  torch.from_numpy(self.test_data_norm).float().to(self.device)
        self.all_data_norm = torch.from_numpy(self.all_data_norm).float().to(self.device)

        self.train_label = torch.from_numpy(self.train_label).float().to(self.device)
        self.val_label = torch.from_numpy(self.val_label).float().to(self.device)
        self.test_label = torch.from_numpy(self.test_label).float().to(self.device)
        self.all_label = torch.from_numpy(self.all_label).float().to(self.device)

        self.train_label_onehot = torch.from_numpy(self.train_label_onehot).float().to(self.device)
        self.val_label_onehot = torch.from_numpy(self.val_label_onehot).float().to(self.device)
        self.test_label_onehot = torch.from_numpy(self.test_label_onehot).float().to(self.device)
        self.all_label_onehot = torch.from_numpy(self.all_label_onehot).float().to(self.device)

        self.mask_train = torch.from_numpy(self.mask_train).float().to(self.device)
        self.mask_val = torch.from_numpy(self.mask_val).float().to(self.device)
        self.mask_test = torch.from_numpy(self.mask_test).float().to(self.device)
        self.mask_all = torch.from_numpy(self.mask_all).float().to(self.device)
        
        self.adj_train = torch.from_numpy(self.adj_train).float().to(self.device)
        self.adj_val = torch.from_numpy(self.adj_val).float().to(self.device)
        self.adj_test = torch.from_numpy(self.adj_test).float().to(self.device)
        self.adj_all = torch.from_numpy(self.adj_all).float().to(self.device)

        self.all_mean = torch.from_numpy(self.all_mean).float().to(self.device)
        self.all_std = torch.from_numpy(self.all_std).float().to(self.device)

        self.train_indices = torch.tensor(self.train_indices)
        self.val_indices = torch.tensor(self.val_indices)
        self.test_indices = torch.tensor(self.test_indices)

In [None]:
import matplotlib.pyplot as plt
opt.model = 'GTRAE_MVPT_MTL'
for opt.data_name in ['banknote']: 
    for opt.id_num in [2,3,4,5]*3: # 1,2,3,4,5
        for opt.miss_rate in [0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7]: #0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7
            DATA_ZZ = Dataset_zz(opt)
            all_data_zz = Dataload_zz(DATA_ZZ.missing_data_all_norm, DATA_ZZ.all_data_norm, DATA_ZZ.all_label_onehot, DATA_ZZ.adj_all) 
            global_train_idx, global_val_idx, global_test_idx = DATA_ZZ.train_indices, DATA_ZZ.val_indices, DATA_ZZ.test_indices

            n_class = DATA_ZZ.train_label_onehot.shape[1] 
            global_mask_shape = DATA_ZZ.all_data.shape 
            model = GTRAE_MVPT_MTL(opt.n_hidden, n_class, global_mask_shape, DATA_ZZ, 'all').to(device) 
            
            if opt.load_model_path:
                model.load(opt.load_model_path)
            
            all_dataloader = DataLoader(all_data_zz, opt.batch_size, shuffle=True, num_workers=opt.num_workers)
            
            criterion_classicfic = torch.nn.NLLLoss().to(device)
            criterion_imputation = torch.nn.MSELoss().to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
            
            # 初始化存储指标的列表
            train_losses, val_losses, test_losses = [], [], []
            train_maes, train_rmses, train_mapes, train_accs = [], [], [], []
            val_maes, val_rmses, val_mapes, val_accs = [], [], [], []
            test_maes, test_rmses, test_mapes, test_accs = [], [], [], []
            train_mae_accs, val_mae_accs, test_mae_accs = [], [], []
            train_de_maes, train_de_rmses, train_de_mapes = [], [], []
            val_de_maes, val_de_rmses, val_de_mapes = [], [], []
            test_de_maes, test_de_rmses, test_de_mapes = [], [], []

            best_val_loss = float('inf')
            patience, triggered = opt.patience, 0
            for epoch in range(opt.max_epoch):
                epoch_input_data_list = []
                epoch_imputed_data_list = []
                epoch_target_data_list = [] 
                epoch_real_data_list = [] 
                epoch_mask_list = [] 
                epoch_estimated_label_onehot_list = []
                epoch_target_label_onehot_list = []
                epoch_all_idx_list = []

                for batch_data in all_dataloader:
                    batch_all_idx, batch_norm_missing_data, batch_norm_full_data, batch_label_onehot, batch_adj_matrix = batch_data
                    batch_mask = ~torch.isnan(batch_norm_missing_data).to(device) 
                    
                    outputs = model(batch_norm_missing_data, batch_all_idx) 
                    input_data, imputed_data, estimated_label_onehot = outputs 
                    target_data = batch_norm_full_data * batch_mask + input_data * (~batch_mask) 
                    target_label_onehot = batch_label_onehot 
                    target_labels = torch.argmax(target_label_onehot, dim=1) 

                    mask_train_idx = torch.isin(batch_all_idx, global_train_idx)
                    mask_val_idx = torch.isin(batch_all_idx, global_val_idx)
                    mask_test_idx = torch.isin(batch_all_idx, global_test_idx)
                    batch_train_idx = batch_all_idx[mask_train_idx]
                    batch_val_idx = batch_all_idx[mask_val_idx]
                    batch_test_idx = batch_all_idx[mask_test_idx]
                    batch_train_relative_idx = torch.where(mask_train_idx)
                    batch_val_relative_idx = torch.where(mask_val_idx)
                    batch_test_relative_idx = torch.where(mask_test_idx)

                    loss_imputation = criterion_imputation(imputed_data * batch_mask, target_data * batch_mask) 
                    loss_imputation = loss_imputation/(1-opt.miss_rate)
                    loss_classicfic_train = criterion_classicfic(estimated_label_onehot[batch_train_relative_idx], target_labels[batch_train_relative_idx]) 
                    batch_model_loss_train = loss_classicfic_train + loss_imputation 
                    
                    optimizer.zero_grad()
                    batch_model_loss_train.backward(retain_graph=True)
                    optimizer.step()

                    epoch_input_data_list.append(input_data.detach())
                    epoch_imputed_data_list.append(imputed_data.detach())
                    epoch_real_data_list.append(batch_norm_full_data.detach()) 
                    epoch_mask_list.append(batch_mask.detach()) 
                    epoch_target_data_list.append(target_data.detach())
                    epoch_estimated_label_onehot_list.append(estimated_label_onehot.detach())
                    epoch_target_label_onehot_list.append(target_label_onehot.detach())
                    epoch_all_idx_list.append(batch_all_idx.detach())
                ###################################################
                #               val and test                      #
                ###################################################
                epoch_input_data = torch.cat(epoch_input_data_list)
                epoch_imputed_data = torch.cat(epoch_imputed_data_list)
                epoch_target_data = torch.cat(epoch_target_data_list) 
                epoch_real_data = torch.cat(epoch_real_data_list) 
                epoch_mask = torch.cat(epoch_mask_list) 
                epoch_de_input_data = DATA_ZZ.denormalization_with_mean_std(epoch_input_data.clone(), DATA_ZZ.all_mean, DATA_ZZ.all_std)
                epoch_de_imputed_data = DATA_ZZ.denormalization_with_mean_std(epoch_imputed_data.clone(), DATA_ZZ.all_mean, DATA_ZZ.all_std)
                epoch_de_target_data = DATA_ZZ.denormalization_with_mean_std(epoch_target_data.clone(), DATA_ZZ.all_mean, DATA_ZZ.all_std) 
                epoch_de_real_data = DATA_ZZ.denormalization_with_mean_std(epoch_real_data.clone(), DATA_ZZ.all_mean, DATA_ZZ.all_std) 
                epoch_estimated_label_onehot = torch.cat(epoch_estimated_label_onehot_list)
                epoch_target_label_onehot = torch.cat(epoch_target_label_onehot_list)
                epoch_all_idx = torch.cat(epoch_all_idx_list)
                epoch_estimated_labels = torch.argmax(epoch_estimated_label_onehot, dim=1) 
                epoch_target_labels = torch.argmax(epoch_target_label_onehot, dim=1) 

                epoch_mask_train_idx = torch.isin(epoch_all_idx, global_train_idx)
                epoch_mask_val_idx = torch.isin(epoch_all_idx, global_val_idx)
                epoch_mask_test_idx = torch.isin(epoch_all_idx, global_test_idx)
                epoch_batch_train_idx = epoch_all_idx[epoch_mask_train_idx]
                epoch_batch_val_idx = epoch_all_idx[epoch_mask_val_idx]
                epoch_batch_test_idx = epoch_all_idx[epoch_mask_test_idx]
                epoch_train_relative_idx = torch.where(epoch_mask_train_idx)
                epoch_val_relative_idx = torch.where(epoch_mask_val_idx)
                epoch_test_relative_idx = torch.where(epoch_mask_test_idx)

                epoch_loss_imputation = criterion_imputation(epoch_input_data * epoch_mask, epoch_imputed_data * epoch_mask) 
                epoch_loss_imputation = epoch_loss_imputation/(1-opt.miss_rate)
                epoch_loss_classicfic_train = criterion_classicfic(epoch_estimated_label_onehot[epoch_train_relative_idx], epoch_target_labels[epoch_train_relative_idx]) 
                epoch_loss_classicfic_val = criterion_classicfic(epoch_estimated_label_onehot[epoch_val_relative_idx], epoch_target_labels[epoch_val_relative_idx]) 
                epoch_loss_classicfic_test = criterion_classicfic(epoch_estimated_label_onehot[epoch_test_relative_idx], epoch_target_labels[epoch_test_relative_idx]) 
                epoch_model_loss_train = epoch_loss_classicfic_train + epoch_loss_imputation 
                epoch_model_loss_val = epoch_loss_classicfic_val + epoch_loss_imputation 
                epoch_model_loss_test = epoch_loss_classicfic_test + epoch_loss_imputation 
                
                epoch_mae_train, epoch_rmse_train, epoch_mape_train = calculate_imputation_metrics(epoch_imputed_data[epoch_train_relative_idx], epoch_target_data[epoch_train_relative_idx])
                epoch_mae_val, epoch_rmse_val, epoch_mape_val = calculate_imputation_metrics(epoch_imputed_data[epoch_val_relative_idx], epoch_target_data[epoch_val_relative_idx])
                epoch_mae_test, epoch_rmse_test, epoch_mape_test = calculate_imputation_metrics_mask(epoch_imputed_data[epoch_test_relative_idx], epoch_real_data[epoch_test_relative_idx], epoch_mask[epoch_test_relative_idx])
                
                epoch_de_mae_train, epoch_de_rmse_train, epoch_de_mape_train = calculate_imputation_metrics(epoch_de_imputed_data[epoch_train_relative_idx], epoch_de_target_data[epoch_train_relative_idx])
                epoch_de_mae_val, epoch_de_rmse_val, epoch_de_mape_val = calculate_imputation_metrics(epoch_de_imputed_data[epoch_val_relative_idx], epoch_de_target_data[epoch_val_relative_idx])
                epoch_de_mae_test, epoch_de_rmse_test, epoch_de_mape_test = calculate_imputation_metrics_mask(epoch_de_imputed_data[epoch_test_relative_idx], epoch_de_real_data[epoch_test_relative_idx], epoch_mask[epoch_test_relative_idx])
                
                epoch_acc_train = calculate_accuracy(epoch_estimated_labels[epoch_train_relative_idx], epoch_target_labels[epoch_train_relative_idx])
                epoch_acc_val = calculate_accuracy(epoch_estimated_labels[epoch_val_relative_idx], epoch_target_labels[epoch_val_relative_idx])
                epoch_acc_test = calculate_accuracy(epoch_estimated_labels[epoch_test_relative_idx], epoch_target_labels[epoch_test_relative_idx])

                train_losses.append(epoch_model_loss_train.item())
                val_losses.append(epoch_model_loss_val.item())
                test_losses.append(epoch_model_loss_test.item())
                train_maes.append(epoch_mae_train)
                train_rmses.append(epoch_rmse_train)
                train_mapes.append(epoch_mape_train)
                train_accs.append(epoch_acc_train)
                val_maes.append(epoch_mae_val)
                val_rmses.append(epoch_rmse_val)
                val_mapes.append(epoch_mape_val)
                val_accs.append(epoch_acc_val)
                test_maes.append(epoch_mae_test)
                test_rmses.append(epoch_rmse_test)
                test_mapes.append(epoch_mape_test)
                test_accs.append(epoch_acc_test)
                train_de_maes.append(epoch_de_mae_train)
                train_de_rmses.append(epoch_de_rmse_train)
                train_de_mapes.append(epoch_de_mape_train)
                val_de_maes.append(epoch_de_mae_val)
                val_de_rmses.append(epoch_de_rmse_val)
                val_de_mapes.append(epoch_de_mape_val)
                test_de_maes.append(epoch_de_mae_test)
                test_de_rmses.append(epoch_de_rmse_test)
                test_de_mapes.append(epoch_de_mape_test)

                train_mae_acc = epoch_mae_train / epoch_acc_train
                val_mae_acc = epoch_mae_val / epoch_acc_val
                test_mae_acc = epoch_mae_test / epoch_acc_test

                train_mae_accs.append(train_mae_acc)
                val_mae_accs.append(val_mae_acc)
                test_mae_accs.append(test_mae_acc)

                if epoch_model_loss_val < best_val_loss:
                    best_val_loss = epoch_model_loss_val
                    triggered = 0
                else:
                    triggered += 1
                    if triggered >= patience:
                        print("Early stopping triggered at epoch {}".format(epoch))
                        break
            
            min_val_loss_index = val_losses.index(min(val_losses))

            data_record = {
                'Dataset': opt.data_name,
                'Missing Rate':opt.miss_rate,
                'Data ID':opt.id_num,
                'Model':opt.model,
                'Epoch': min_val_loss_index + 1,

                'Train Loss': train_losses[min_val_loss_index],
                'Validation Loss': val_losses[min_val_loss_index],
                'Test Loss': test_losses[min_val_loss_index],

                'Train Accuracy': train_accs[min_val_loss_index]*100,
                'Val Accuracy': val_accs[min_val_loss_index]*100,
                'Test Accuracy': test_accs[min_val_loss_index]*100,

                'Train MAE(reconstruct)': train_maes[min_val_loss_index],
                'Val MAE(reconstruct)': val_maes[min_val_loss_index],
                'Test MAE': test_maes[min_val_loss_index],

                'Train RMSE(reconstruct)': train_rmses[min_val_loss_index],
                'Val RMSE(reconstruct)': val_rmses[min_val_loss_index],
                'Test RMSE': test_rmses[min_val_loss_index],

                'Train MAPE(reconstruct)': train_mapes[min_val_loss_index],
                'Val MAPE(reconstruct)': val_mapes[min_val_loss_index],
                'Test MAPE': test_mapes[min_val_loss_index],

                'Train de_MAE(reconstruct)': train_de_maes[min_val_loss_index],
                'Val de_MAE(reconstruct)': val_de_maes[min_val_loss_index],
                'Test de_MAE': test_de_maes[min_val_loss_index],

                'Train de_RMSE(reconstruct)': train_de_rmses[min_val_loss_index],
                'Val de_RMSE(reconstruct)': val_de_rmses[min_val_loss_index],
                'Test de_RMSE': test_de_rmses[min_val_loss_index],

                'Train de_MAPE(reconstruct)': train_de_mapes[min_val_loss_index],
                'Val de_MAPE(reconstruct)': val_de_mapes[min_val_loss_index],
                'Test de_MAPE': test_de_mapes[min_val_loss_index],
                # 'Train MAE*ACC': train_mae_accs[min_val_loss_index],
                # 'Val MAE*ACC': val_mae_accs[min_val_loss_index],
                # 'Test MAE*ACC': test_mae_accs[min_val_loss_index]
            }

            filename = f'zz_result/results.csv'
            if not os.path.isfile(filename):
                df = pd.DataFrame([data_record])
            else:
                df = pd.read_csv(filename)
                df = df.append(data_record, ignore_index=True)

            df.to_csv(filename, index=False)
            filename = f'zz_result/{opt.data_name}/{opt.data_name}.csv'

            if not os.path.isfile(filename):
                df = pd.DataFrame([data_record])
            else:
                df = pd.read_csv(filename)
                df = df.append(data_record, ignore_index=True)

            df.to_csv(filename, index=False)

            plt.figure(figsize=(15, 9))
            plt.subplot(3, 3, 1)
            plt.plot(train_losses, label='Train Loss')
            plt.plot(val_losses, label='Validation Loss')
            plt.plot(test_losses, label='Test Loss')
            plt.title('Loss over epochs')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()

            plt.subplot(3, 3, 2)
            plt.plot(train_maes, label='Train MAE')
            plt.plot(val_maes, label='Val MAE')
            plt.plot(test_maes, label='Test MAE')
            plt.title('Metrics over epochs')
            plt.xlabel('Epoch')
            plt.legend()

            plt.subplot(3, 3, 3)
            plt.plot(train_de_maes, label='Train de_MAE')
            plt.plot(val_de_maes, label='Val de_MAE')
            plt.plot(test_de_maes, label='Test de_MAE')
            plt.title('Metrics over epochs')
            plt.xlabel('Epoch')
            plt.legend()

            plt.subplot(3, 3, 4)
            plt.plot(train_rmses, label='Train RMSE')
            plt.plot(val_rmses, label='Val RMSE')
            plt.plot(test_rmses, label='Test RMSE')
            plt.title('Metrics over epochs')
            plt.xlabel('Epoch')
            plt.legend()

            plt.subplot(3, 3, 5)
            plt.plot(train_de_rmses, label='Train de_RMSE')
            plt.plot(val_de_rmses, label='Val de_RMSE')
            plt.plot(test_de_rmses, label='Test de_RMSE')
            plt.title('Metrics over epochs')
            plt.xlabel('Epoch')
            plt.legend()

            plt.subplot(3, 3, 6)
            plt.plot(train_mapes, label='Train MAPE')
            plt.plot(val_mapes, label='Val MAPE')
            plt.plot(test_mapes, label='Test MAPE')
            plt.title('Metrics over epochs')
            plt.xlabel('Epoch')
            plt.legend()

            plt.subplot(3, 3, 7)
            plt.plot(train_de_mapes, label='Train de_MAPE')
            plt.plot(val_de_mapes, label='Val de_MAPE')
            plt.plot(test_de_mapes, label='Test de_MAPE')
            plt.title('Metrics over epochs')
            plt.xlabel('Epoch')
            plt.legend()

            plt.subplot(3, 3, 8)
            plt.plot(train_accs, label='Train Accuracy')
            plt.plot(val_accs, label='Val Accuracy')
            plt.plot(test_accs, label='Test Accuracy')
            plt.title('Metrics over epochs')
            plt.xlabel('Epoch')
            plt.legend()

            plt.subplot(3, 3, 9) 
            plt.plot(train_mae_accs, label='Train MAE/ACC')
            plt.plot(val_mae_accs, label='Validation MAE/ACC')
            plt.plot(test_mae_accs, label='Test MAE/ACC')
            plt.title('MAE*ACC over epochs')
            plt.xlabel('Epoch')
            plt.ylabel('MAE*ACC')
            plt.legend()

            plt.tight_layout()
            plt.show()
