### File operation

In [None]:
# breast 0 categorical feature; 10 numerical features; total 699 samples.

In [None]:
cp -r /kaggle/input/uci-breast/data_breast /kaggle/working/

### Set parameters and import package

In [None]:
import re
import os
import math
import json
import yaml
import torch
import pickle
import argparse
import datetime
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [None]:
parser= argparse.ArgumentParser(description= "TabCSDI")
parser.add_argument("--device", default= 'cpu', help= "Device")
parser.add_argument("--seed", type= int, default= 1)
parser.add_argument("--testmissingratio", type= float, default= 0.2)
parser.add_argument("--nfold", type= int, default= 5, help= "for 5-fold test")
parser.add_argument("--unconditional", action= "store_true", default= 0)
parser.add_argument("--modelfolder", type= str, default= "")
parser.add_argument("--nsample", type= int, default= 100)
args= parser.parse_args([])

### Load data

In [None]:
def process_func(path: str, aug_rate= 1, missing_ratio= 0.1):
    data= pd.read_csv(path, header= None).iloc[1:, 1:]
    data.replace("?", np.nan, inplace= True)
    data_aug= pd.concat([data]* aug_rate)
    observed_values= data_aug.values.astype("float32")
    # False, omit; True, exist observed value.
    observed_masks= ~np.isnan(observed_values)
    masks= observed_masks.copy()
    # for each column, mask {missing_ratio} % of observed values.
    for col in range(observed_values.shape[1]):
        obs_indices= np.where(masks[:, col])[0]
        miss_indices= np.random.choice(obs_indices, (int)(len(obs_indices)* missing_ratio), replace= False)
        masks[miss_indices, col]= False
    # gt_mask: 0 for missing elements and manully maksed elements
    gt_masks= masks.reshape(observed_masks.shape)
    # replace missing values with 0
    observed_values= np.nan_to_num(observed_values)
    # convert observabd_masks and gt_masks to integer type
    observed_masks= observed_masks.astype(int)
    gt_masks = gt_masks.astype(int)
    # observed_values, (line_num, col_num); observed_masks, (line_num, col_num); gt_masks, (line_num, col_num).
    # Two dimensional form data; Two dimensional mask formed by original missing values; Two dimensional mask according missing_ratio, covering observed_masks.
    return observed_values, observed_masks, gt_masks

In [None]:
class tabular_dataset(Dataset):
    # eval_length should be equal to attributes number.
    def __init__(self, eval_length= 10, use_index_list= None, aug_rate= 1, missing_ratio= 0.1, seed= 0):
        self.eval_length= eval_length
        np.random.seed(seed)
        dataset_path= "./data_breast/breast-cancer-wisconsin.data"
        processed_data_path= (f"./data_breast/missing_ratio-{missing_ratio}_seed-{seed}.pk")
        processed_data_path_norm= (f"./data_breast/missing_ratio-{missing_ratio}_seed-{seed}_max-min_norm.pk")
        if not os.path.isfile(processed_data_path):
            # self.gt_masks, False means miss value, True means condition observation.
            # Traverse the features one by one and mask 10% of the selected features
            self.observed_values, self.observed_masks, self.gt_masks= process_func(
                dataset_path, aug_rate= aug_rate, missing_ratio= missing_ratio
            )
            with open(processed_data_path, "wb") as f:
                pickle.dump([self.observed_values, self.observed_masks, self.gt_masks], f)
            print("--------Dataset created--------")
        elif os.path.isfile(processed_data_path_norm):
            with open(processed_data_path_norm, "rb") as f:
                self.observed_values, self.observed_masks, self.gt_masks= pickle.load(f)
            print("--------Normalized dataset loaded--------")
        if use_index_list is None:
            self.use_index_list= np.arange(len(self.observed_values))
        else:
            # row index
            self.use_index_list= use_index_list

    def __getitem__(self, org_index):
        index= self.use_index_list[org_index]
        s= {
            "observed_data": self.observed_values[index],
            "observed_mask": self.observed_masks[index],
            "gt_mask": self.gt_masks[index],
            "timepoints": np.arange(self.eval_length), # np.arange(K), K is feature number
        }
        return s

    def __len__(self):
        return len(self.use_index_list)

In [None]:
def get_dataloader(seed= 1, nfold= 5, batch_size= 128, missing_ratio= 0.1):
    dataset= tabular_dataset(missing_ratio= missing_ratio, seed= seed)
    print(f"Dataset size:{len(dataset)} entries")
    indlist= np.arange(len(dataset))
    np.random.seed(seed+ 1)
    np.random.shuffle(indlist)
    tmp_ratio= 1/ nfold
    start= (int)((nfold- 1)* len(dataset)* tmp_ratio)
    end= (int)(nfold* len(dataset)* tmp_ratio)
    # end 20% for test
    test_index= indlist[start:end]
    remain_index= np.delete(indlist, np.arange(start, end))
    
    np.random.shuffle(remain_index)
    # train, 64%; valid, 16%; test, 20%.
    num_train= (int)(len(remain_index)* 0.8)
    train_index= remain_index[:num_train]
    valid_index= remain_index[num_train:]

    # Here we perform max-min normalization.
    processed_data_path_norm= (f"./data_breast/missing_ratio-{missing_ratio}_seed-{seed}_max-min_norm.pk")    
    # processed_data_path_norm= (f"./drivaernet/missing_ratio-{missing_ratio}_seed-{seed}_max-min_norm.pk")
    if not os.path.isfile(processed_data_path_norm):
        print("--------------Dataset has not been normalized yet. Perform data normalization and store the mean value of each column.--------------")
        # data transformation after train-test split.
        col_num= dataset.observed_values.shape[1]
        max_arr= np.zeros(col_num)
        min_arr= np.zeros(col_num)
        mean_arr= np.zeros(col_num)
        for k in range(col_num):
            # Using observed_mask to avoid counting missing values.
            obs_ind= dataset.observed_masks[train_index, k].astype(bool)
            temp= dataset.observed_values[train_index, k]
            max_arr[k]= max(temp[obs_ind])
            min_arr[k]= min(temp[obs_ind])
        print(f"--------------Max-value for each column {max_arr}--------------")
        print(f"--------------Min-value for each column {min_arr}--------------")
        # Avoid dividing by 0
        dataset.observed_values= ((dataset.observed_values- 0+ 1)/ (max_arr- 0+ 1))* dataset.observed_masks
        with open(processed_data_path_norm, "wb") as f:
            pickle.dump([dataset.observed_values, dataset.observed_masks, dataset.gt_masks], f)
    # Create datasets and corresponding data loaders objects.
    train_dataset= tabular_dataset(
        use_index_list= train_index, missing_ratio= missing_ratio, seed= seed
    )
    train_loader= DataLoader(train_dataset, batch_size=batch_size, shuffle=1)
    valid_dataset= tabular_dataset(
        use_index_list= valid_index, missing_ratio= missing_ratio, seed= seed
    )
    valid_loader= DataLoader(valid_dataset, batch_size= batch_size, shuffle= 0)
    test_dataset= tabular_dataset(
        use_index_list= test_index, missing_ratio= missing_ratio, seed= seed
    )
    test_loader= DataLoader(test_dataset, batch_size= batch_size, shuffle= 0)
    print(f"Training dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(valid_dataset)}")
    print(f"Testing dataset size: {len(test_dataset)}")
    return train_loader, valid_loader, test_loader

In [None]:
train_loader, valid_loader, test_loader= get_dataloader(seed= 1, nfold= 5, batch_size= 128, missing_ratio= 0.1)

### Set parameters

In [None]:
config= {
    "train": {
        "epochs": 100,
        "batch_size": 128,
        "lr": 0.0005
    },
    "diffusion": {
        "layers": 4,
        "channels": 256,
        "nheads": 2,
        "diffusion_embedding_dim": 128,
        "beta_start": 0.0001,
        "beta_end": 0.5,
        "num_steps": 150,
        "schedule": "quad",
        "mixed": False
    },
    "model": {
        "is_unconditional": 0,
        "timeemb": 32,
        "featureemb": 32,
        "target_strategy": "random",
        "mixed": False
    }
}

In [None]:
config["model"]["is_unconditional"]= args.unconditional
config["model"]["test_missing_ratio"]= args.testmissingratio

### Model module

In [None]:
def get_torch_trans(heads= 8, layers= 1, channels= 64):
    # Get transformer encoder_layer
    encoder_layer= nn.TransformerEncoderLayer(
        d_model= channels, nhead= heads, dim_feedforward= 64, activation= "gelu"
    )
    # Get TransformerEncoder object
    return nn.TransformerEncoder(encoder_layer, num_layers= layers)

In [None]:
def Conv1d_with_init(in_channels, out_channels, kernel_size):
    # Get Conv1d layer
    layer= nn.Conv1d(in_channels, out_channels, kernel_size)
    # Weight initialization
    nn.init.kaiming_normal_(layer.weight)
    return layer

In [None]:
class DiffusionEmbedding(nn.Module):
    # Get timestep embedding
    def __init__(self, num_steps, embedding_dim= 128, projection_dim= None):
        super().__init__()
        if projection_dim is None:
            projection_dim= embedding_dim
        # The benefits of buffering:
        # 1. Avoid including parameters that do not require optimization in the trainable parameters of the model
        # 2. Efficient storage and access
        # 3. Clearly distinguish between model parameters and auxiliary data
        self.register_buffer("embedding", self._build_embedding(num_steps, embedding_dim/ 2), persistent= False)
        self.projection1= nn.Linear(embedding_dim, projection_dim)
        self.projection2= nn.Linear(projection_dim, projection_dim)

    def forward(self, diffusion_step):
        x= self.embedding[diffusion_step]
        x= self.projection1(x)
        x= F.silu(x)
        x= self.projection2(x)
        x= F.silu(x)
        return x

    # t_embedding(t). The embedding dimension is 128 in total for every time step t.
    def _build_embedding(self, num_steps, dim= 64):
        steps= torch.arange(num_steps).unsqueeze(1)  # (T,1)
        frequencies= 10.0** (torch.arange(dim)/ (dim- 1)* 4.0).unsqueeze(0)  # (1,dim)
        table= steps* frequencies  # (T,dim)
        table= torch.cat([torch.sin(table), torch.cos(table)], dim= 1)  # (T,dim*2)
        return table

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, side_dim, channels, diffusion_embedding_dim, nheads):
        super().__init__()
        self.diffusion_projection= nn.Linear(diffusion_embedding_dim, channels)
        self.cond_projection= Conv1d_with_init(side_dim, 2* channels, 1)
        self.mid_projection= Conv1d_with_init(channels, 2* channels, 1)
        self.output_projection= Conv1d_with_init(channels, 2* channels, 1)
        # Temporal Transformer layer
        self.time_layer= get_torch_trans(heads= nheads, layers= 1, channels= channels)
        # Feature Transformer layer
        self.feature_layer= get_torch_trans(heads= nheads, layers= 1, channels= channels)

    def forward_time(self, y, base_shape):
        B, channel, K, L= base_shape
        if L== 1:
            return y
        y= y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B* K, channel, L)
        # input shape for transformerencoder: [seq, batch, emb]
        y= self.time_layer(y.permute(2, 0, 1)).permute(1, 2, 0)
        y= y.reshape(B, K, channel, L).permute(0, 2, 1, 3).reshape(B, channel, K* L)
        return y

    def forward_feature(self, y, base_shape):
        B, channel, K, L= base_shape
        if K== 1:
            return y
        y= y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B* L, channel, K)
        y= self.feature_layer(y.permute(2, 0, 1)).permute(1, 2, 0)
        y= y.reshape(B, L, channel, K).permute(0, 2, 3, 1).reshape(B, channel, K* L)
        return y
    
    # x, (B, channels, K, L); cond_info, (B, 49, 1, L); diffusion_emb, (B, emb_dim).
    def forward(self, x, cond_info, diffusion_emb):
        B, channel, K, L= x.shape
        base_shape= x.shape
        x= x.reshape(B, channel, K* L)
        # Diffusion embedding processing
        diffusion_emb= self.diffusion_projection(diffusion_emb).unsqueeze(-1)  # (B,channel,1)
        # x, (B, channel, KL); diffusion_emb, (B, channel, 1); y, (B, channel, KL).
        y= x+ diffusion_emb
        # Temporal transformer
        y= self.forward_time(y, base_shape)
        # Feature transformer
        y= self.forward_feature(y, base_shape)  # (B, channel, K* L)
        # Combining conditional information
        y= self.mid_projection(y)  # (B, 2*channel, K* L)
        _, cond_dim, _, _= cond_info.shape
        cond_info= cond_info.reshape(B, cond_dim, K* L)
        cond_info= self.cond_projection(cond_info)  # (B, 2* channel, K* L)
        y= y+ cond_info
        # Gated
        gate, filter= torch.chunk(y, 2, dim= 1)
        y= torch.sigmoid(gate)* torch.tanh(filter) # y, (B, channel, KL)
        y= self.output_projection(y)
        # Residual
        residual, skip= torch.chunk(y, 2, dim= 1)
        x= x.reshape(base_shape)
        residual= residual.reshape(base_shape)
        skip= skip.reshape(base_shape)
        return (x+ residual)/ math.sqrt(2.0), skip

In [None]:
class diff_CSDI(nn.Module):
    # when train.
    # inputdim= 2, cond_observed data, unmasked parts in the non-test datset; noisy data, masked parts in the non-test dataset.
    def __init__(self, config, inputdim= 2):
        super().__init__()
        self.config= config
        self.channels= config["channels"]
        self.diffusion_embedding= DiffusionEmbedding(
            num_steps= config["num_steps"],
            embedding_dim= config["diffusion_embedding_dim"] # 128
        )
        # token_emb_dim, 1
        self.token_emb_dim= config["token_emb_dim"] if config["mixed"] else 1
        # inputdim, 2
        inputdim= 2* self.token_emb_dim
        self.input_projection= Conv1d_with_init(inputdim, self.channels, 1)
        self.output_projection1= Conv1d_with_init(self.channels, self.channels, 1)
        self.output_projection2= Conv1d_with_init(self.channels, self.token_emb_dim, 1)
        nn.init.zeros_(self.output_projection2.weight)
        self.residual_layers= nn.ModuleList(
            [
                ResidualBlock(
                    side_dim= config["side_dim"],
                    channels= self.channels,
                    diffusion_embedding_dim= config["diffusion_embedding_dim"],
                    nheads= config["nheads"],
                )
                for _ in range(config["layers"])
            ]
        )
    
    # total_input, (B, 2, 1, L); side_info, (B, 49, 1, L); t, (B, ).    
    def forward(self, x, cond_info, diffusion_step):
        B, inputdim, K, L= x.shape
        x= x.reshape(B, inputdim, K* L)
        x= self.input_projection(x)
        x= F.relu(x)
        x= x.reshape(B, self.channels, K, L)
        diffusion_emb= self.diffusion_embedding(diffusion_step)
        skip= []
        for layer in self.residual_layers:
            x, skip_connection= layer(x, cond_info, diffusion_emb)
            skip.append(skip_connection)
        x= torch.sum(torch.stack(skip), dim= 0)/ math.sqrt(len(self.residual_layers))
        x= x.reshape(B, self.channels, K* L)
        x= self.output_projection1(x)
        x= F.relu(x)
        x= self.output_projection2(x)
        if self.config["mixed"]:
            x= x.permute(0, 2, 1)
            x= x.reshape(B, K, L* self.token_emb_dim)
        else:
            x= x.reshape(B, K, L)
        return x

In [None]:
class CSDI_base(nn.Module):
    # Diffusion model for missing value imputation and prediction task
    def __init__(self, target_dim, config, device):
        super().__init__()
        self.device= device
        self.target_dim= target_dim
        self.emb_time_dim= config["model"]["timeemb"]
        self.emb_feature_dim= config["model"]["featureemb"]
        self.is_unconditional= config["model"]["is_unconditional"]
        self.target_strategy= config["model"]["target_strategy"]
        self.emb_total_dim= self.emb_time_dim+ self.emb_feature_dim
        if self.is_unconditional== False:
            self.emb_total_dim+= 1  # for conditional mask
        self.embed_layer= nn.Embedding(num_embeddings= self.target_dim, embedding_dim= self.emb_feature_dim)
        config_diff= config["diffusion"]
        config_diff["side_dim"]= self.emb_total_dim
        # input_dim, 2
        input_dim= 1 if self.is_unconditional== True else 2
        self.diffmodel= diff_CSDI(config_diff, input_dim)
        # parameters for diffusion models
        self.num_steps = config_diff["num_steps"]
        if config_diff["schedule"] == "quad":
            self.beta = (
                np.linspace(
                    config_diff["beta_start"] ** 0.5,
                    config_diff["beta_end"] ** 0.5,
                    self.num_steps,
                )
                ** 2
            )
        elif config_diff["schedule"] == "linear":
            self.beta = np.linspace(
                config_diff["beta_start"], config_diff["beta_end"], self.num_steps
            )

        self.alpha_hat = 1 - self.beta
        self.alpha = np.cumprod(self.alpha_hat)
        self.alpha_torch = (
            torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1)
        )

    def time_embedding(self, pos, d_model= 128):
        pe= torch.zeros(pos.shape[0], pos.shape[1], d_model).to(self.device)
        position= pos.unsqueeze(2)
        div_term= 1/ torch.pow(10000.0, torch.arange(0, d_model, 2).to(self.device)/ d_model)
        pe[:, :, 0:: 2]= torch.sin(position* div_term)
        pe[:, :, 1:: 2]= torch.cos(position* div_term)
        return pe
    
    def get_randmask(self, observed_mask):
        # Randomly mask 20% of each record in observabd_mask.
        rand_for_mask= torch.rand_like(observed_mask)* observed_mask
        rand_for_mask= rand_for_mask.reshape(len(rand_for_mask), -1)
        # for each record
        for i in range(len(observed_mask)):
            sample_ratio= np.random.rand()  # missing ratio
            num_observed= observed_mask[i].sum().item() # observed number
            num_masked= round(num_observed* sample_ratio) # masked number
            rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices]= -1
        # cond_mask 1, pred item; 0, cond item.
        cond_mask= (rand_for_mask> 0).reshape(observed_mask.shape).float() # random mask 10% for each record
        return cond_mask

    def get_side_info(self, observed_tp, cond_mask):
        B, K, L= cond_mask.shape
        time_embed= self.time_embedding(observed_tp, self.emb_time_dim)
        time_embed= time_embed.unsqueeze(2).expand(-1, -1, K, -1)
        feature_embed= self.embed_layer(torch.arange(self.target_dim).to(self.device))
        feature_embed= feature_embed.unsqueeze(0).unsqueeze(2).expand(B, -1, K, -1)
        side_info= time_embed
        side_info= torch.cat([0* time_embed, feature_embed], dim=-1)  # (B,L,K,*)
        side_info= side_info.permute(0, 3, 2, 1)  # (B,*,K,L)
        if self.is_unconditional== False:
            side_mask= cond_mask.unsqueeze(1)  # (B,1,K,L)
            side_info= torch.cat([side_info, side_mask], dim= 1)
        return side_info

    def calc_loss_valid(
        self, observed_data, cond_mask, observed_mask, side_info, is_train
    ):
        loss_sum = 0
        # In validation, perform T steps forward and backward.
        for t in range(self.num_steps):
            loss = self.calc_loss(
                observed_data, cond_mask, observed_mask, side_info, is_train, set_t=t
            )
            loss_sum += loss.detach()
        return loss_sum / self.num_steps

    def calc_loss(self, observed_data, cond_mask, observed_mask, side_info, is_train, set_t= -1):
        B, K, L= observed_data.shape
        if is_train!= 1:  # for validation
            t= (torch.ones(B)* set_t).long().to(self.device)
        else:
            t= torch.randint(0, self.num_steps, [B]).to(self.device)
        current_alpha= self.alpha_torch[t]
        noise= torch.randn_like(observed_data)
        noisy_data= (current_alpha** 0.5)* observed_data+ (1.0- current_alpha)** 0.5* noise
        total_input= self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask)
        predicted= self.diffmodel(total_input, side_info, t)  # (B, K, L)
        target_mask = observed_mask - cond_mask
        residual = (noise - predicted) * target_mask
        num_eval = target_mask.sum()
        loss = (residual**2).sum() / (num_eval if num_eval > 0 else 1)
        return loss
    
    # input of diffusion model
    def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask):
        if self.is_unconditional== True:
            total_input= noisy_data.unsqueeze(1)  # (B,1,K,L)
        else:
            cond_obs= (cond_mask* observed_data).unsqueeze(1)
            noisy_target= ((1- cond_mask)* noisy_data).unsqueeze(1)
            total_input= torch.cat([cond_obs, noisy_target], dim= 1)  # (B,2,K,L)
        return total_input

    def impute(self, observed_data, cond_mask, side_info, n_samples):
        B, K, L= observed_data.shape
        imputed_samples= torch.zeros(B, n_samples, K, L).to(self.device)
        for i in range(n_samples):
            # generate noisy observation for unconditional model
            if self.is_unconditional == True:
                noisy_obs = observed_data
                noisy_cond_history = []
                # perform T steps forward
                for t in range(self.num_steps):
                    noise = torch.randn_like(noisy_obs)
                    noisy_obs = (self.alpha_hat[t] ** 0.5) * noisy_obs + self.beta[t] ** 0.5 * noise
                    noisy_cond_history.append(noisy_obs * cond_mask)
            current_sample = torch.randn_like(observed_data)
            # perform T steps backward
            for t in range(self.num_steps- 1, -1, -1):
                if self.is_unconditional == True:
                    diff_input = (cond_mask * noisy_cond_history[t]+ (1.0 - cond_mask) * current_sample)
                    diff_input = diff_input.unsqueeze(1)  # (B,1,K,L)
                else:
                    cond_obs = (cond_mask * observed_data).unsqueeze(1)
                    noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1)
                    diff_input = torch.cat([cond_obs, noisy_target], dim=1)  # (B,2,K,L)
                predicted = self.diffmodel(diff_input, side_info, torch.tensor([t]).to(self.device))
                coeff1 = 1 / self.alpha_hat[t] ** 0.5
                coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5
                current_sample = coeff1 * (current_sample - coeff2 * predicted)
                if t > 0:
                    noise = torch.randn_like(current_sample)
                    sigma = ((1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]) ** 0.5
                    current_sample += sigma * noise
            imputed_samples[:, i] = current_sample.detach()
        return imputed_samples

    def forward(self, batch, is_train= 1):
        (observed_data, observed_mask, observed_tp, gt_mask, for_pattern_mask, _,)= self.process_data(batch)
        # In testing, using `gt_mask` (generated with fixed missing rate).
        if is_train== 0:
            cond_mask= gt_mask
        # In training, generate random mask
        else:
            cond_mask= self.get_randmask(observed_mask)
        # side information including time embedding+ feature embedding+ cond_mask
        side_info= self.get_side_info(observed_tp, cond_mask)
        # train, mse
        loss_func= self.calc_loss if is_train == 1 else self.calc_loss_valid
        return loss_func(observed_data, cond_mask, observed_mask, side_info, is_train)

    def evaluate(self, batch, n_samples):
        (observed_data, observed_mask, observed_tp, gt_mask, _, cut_length, )= self.process_data(batch)
        with torch.no_grad():
            cond_mask= gt_mask
            target_mask= observed_mask- cond_mask
            side_info= self.get_side_info(observed_tp, cond_mask)
            # imputate.
            # n_samples, 100, Control the sampling frequency of the model during evaluation; stability assessment.
            samples= self.impute(observed_data, cond_mask, side_info, n_samples)
        return samples, observed_data, target_mask, observed_mask, observed_tp

In [None]:
class TabCSDI(CSDI_base):
    def __init__(self, config, device, target_dim= 10):
        super(TabCSDI, self).__init__(target_dim, config, device)

    def process_data(self, batch):
        # Insert K=1 axis. All mask now with shape (B, 1, L).
        observed_data= batch["observed_data"][:, np.newaxis, :] # (B, 1, L)
        observed_data= observed_data.to(self.device).float()
        observed_mask= batch["observed_mask"][:, np.newaxis, :] # (B, 1, L)
        observed_mask= observed_mask.to(self.device).float()
        observed_tp= batch["timepoints"].to(self.device).float() # (B, 10)
        gt_mask= batch["gt_mask"][:, np.newaxis, :]
        gt_mask= gt_mask.to(self.device).float() # (B, 1, L)
        cut_length= torch.zeros(len(observed_data)).long().to(self.device) # (B, )
        for_pattern_mask= observed_mask # (B, 1, L)
        return (observed_data, observed_mask, observed_tp, gt_mask, for_pattern_mask, cut_length,)

In [None]:
def train(model, config, train_loader, valid_loader= None, valid_epoch_interval= 20, foldername= "",):
    # Control random seed in the current script.
    torch.manual_seed(0)
    np.random.seed(0)
    optimizer= Adam(model.parameters(), lr= config["lr"], weight_decay= 1e-6)
    if foldername!= "":
        output_path= foldername+ "/model.pth"
    p0= int(0.25* config["epochs"])
    p1= int(0.5* config["epochs"])
    p2= int(0.75* config["epochs"])
    p3= int(0.9* config["epochs"])
    # When the number of training rounds reaches p0, p1, p2, and p3, the learning rate will decay.
    lr_scheduler= torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones= [p0, p1, p2, p3], gamma= 0.1)
    history = {'train_loss':[], 'val_loss':[], 'val_rmse':[], 'val_mae':[]}
    best_valid_rmse= 1e10
    for epoch_no in range(config["epochs"]):
        train_avg_loss= 0
        model.train()
        # The minimum and maximum time intervals for progress bar updates have been specified.
        with tqdm(train_loader, mininterval= 5.0, maxinterval= 50.0) as it:
            for batch_no, train_batch in enumerate(it, start= 1):
                optimizer.zero_grad()
                # The forward method returns loss.
                loss= model(train_batch)
                loss.backward()
                train_avg_loss+= loss.item()
                optimizer.step()
                it.set_postfix(
                    ordered_dict= {"avg_epoch_loss": train_avg_loss/ batch_no, "epoch": epoch_no,},
                    refresh= False,
                )
            # step each epoch.
            lr_scheduler.step()

        if valid_loader is not None and (epoch_no+ 1)% valid_epoch_interval== 0:
            history['train_loss'].append(train_avg_loss/ batch_no)
            print("Start validation")
            model.eval()
            avg_loss_valid= 0
            # some initial settings
            val_nsample= 10
            val_scaler= 1
            mse_total= 0
            mae_total= 0
            evalpoints_total= 0
            with torch.no_grad():
                with tqdm(valid_loader, mininterval= 5.0, maxinterval= 50.0) as it:
                    for batch_no, valid_batch in enumerate(it, start= 1):
                        loss= model(train_batch)
                        avg_loss_valid+= loss.item()
                        output= model.evaluate(valid_batch, val_nsample)
                        # `eval_points` is `target_mask`. `observed_time` is `observed_tp`(10)
                        # `c_target` is `observed_data`
                        (samples, c_target, eval_points, observed_points, observed_time,)= output
                        samples= samples.permute(0, 1, 3, 2)
                        c_target= c_target.permute(0, 2, 1)
                        eval_points= eval_points.permute(0, 2, 1)
                        observed_points= observed_points.permute(0, 2, 1)
                        samples_median= samples.median(dim= 1)
                        mse_current= (((samples_median.values- c_target)* eval_points)** 2)* (val_scaler** 2)
                        mae_current= (torch.abs((samples_median.values- c_target)* eval_points))* val_scaler
                        mae_total+= torch.sum(mae_current, dim= 0)
                        mse_total+= torch.sum(mse_current, dim= 0)
                        evalpoints_total+= torch.sum(eval_points, dim= 0)
                        it.set_postfix(
                            ordered_dict= {
                                "rmse_total": torch.mean(torch.sqrt(torch.div(mse_total, evalpoints_total))).item(),
                                "batch_no": batch_no,
                            },
                            refresh= True,
                        )
                    history['val_rmse'].append(torch.mean(torch.sqrt(torch.div(mse_total, evalpoints_total))).item())
                    history['val_mae'].append(torch.mean(torch.div(mae_total, evalpoints_total)).item())
                    history['val_loss'].append(avg_loss_valid/ batch_no)
                    # save model
                    if best_valid_rmse> torch.mean(torch.sqrt(torch.div(mse_total, evalpoints_total))).item() and foldername!= "":
                        torch.save(model.state_dict(), output_path)
    # Use folloing code for saving training history.
    with open(foldername+'/saved_history.pkl', 'wb') as f:
        pickle.dump(history, f)

In [None]:
def evaluate(model, test_loader, nsample= 100, scaler= 1, mean_scaler=0, foldername=""):
    # Control random seed in the current script.
    torch.manual_seed(0)
    np.random.seed(0)

    with torch.no_grad():
        model.eval()
        mse_total = 0
        mae_total = 0
        evalpoints_total = 0
        all_target = []
        all_observed_point = []
        all_observed_time = []
        all_evalpoint = []
        all_generated_samples = []
        with tqdm(test_loader, mininterval=5.0, maxinterval=50.0) as it:
            for batch_no, test_batch in enumerate(it, start= 1):
                output = model.evaluate(test_batch, nsample)
                # samples, (B, 100, 1, L); observed_data, (B, 1, L); target_mask, (B, 1, L).
                # observed_mask, (B, 1, L); observed_tp, (B, 10).
                samples, c_target, eval_points, observed_points, observed_time = output
                samples = samples.permute(0, 1, 3, 2)  # (B,nsample,L,K)
                c_target = c_target.permute(0, 2, 1)  # (B,L,K)
                eval_points = eval_points.permute(0, 2, 1)
                observed_points = observed_points.permute(0, 2, 1)
                # take the median from samples.
                # samples_median, (B, L, 1).
                samples_median = samples.median(dim=1)
                all_target.append(c_target)
                all_evalpoint.append(eval_points)
                all_observed_point.append(observed_points)
                all_observed_time.append(observed_time)
                all_generated_samples.append(samples)
                mse_current = (((samples_median.values - c_target) * eval_points) ** 2) * (scaler**2)
                mae_current = (torch.abs((samples_median.values - c_target) * eval_points)) * scaler
                mse_total += torch.sum(mse_current, dim=0)
                mae_total += torch.sum(mae_current, dim=0)
                evalpoints_total += torch.sum(eval_points, dim=0)
                it.set_postfix(
                    ordered_dict={
                        "rmse_total": torch.mean(
                            torch.sqrt(torch.div(mse_total, evalpoints_total))
                        ).item(),
                        "batch_no": batch_no,
                    },
                    refresh=True,
                )

            with open(foldername + "/result_nsample" + str(nsample) + ".pk", "wb") as f:
                pickle.dump([torch.mean(torch.sqrt(torch.div(mse_total, evalpoints_total))).item(),], f, )
            print("RMSE:", torch.mean(torch.sqrt(torch.div(mse_total, evalpoints_total))).item(), )
            
        # Use following code for saving generated results.
        with open(foldername + "/generated_outputs_nsample" + str(nsample) + ".pk", "wb") as f:
            all_target = torch.cat(all_target, dim=0)
            all_evalpoint = torch.cat(all_evalpoint, dim=0)
            all_observed_point = torch.cat(all_observed_point, dim=0)
            all_observed_time = torch.cat(all_observed_time, dim=0)
            all_generated_samples = torch.cat(all_generated_samples, dim=0)
            pickle.dump(
                [
                    all_generated_samples,
                    all_target,
                    all_evalpoint,
                    all_observed_point,
                    all_observed_time,
                    scaler,
                    mean_scaler,
                ],
                f,
            )

In [None]:
args.device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
model= TabCSDI(config, args.device, target_dim= 10).to(args.device)

In [None]:
# Create folder
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
foldername = "./save/breast_fold" + str(args.nfold) + "_" + current_time + "/"
print("model folder:", foldername)
os.makedirs(foldername, exist_ok=True)
with open(foldername + "config.json", "w") as f:
    json.dump(config, f, indent=4)
# Train model
if args.modelfolder== "":
    train(
        model,
        config["train"],
        train_loader,
        valid_loader= valid_loader,
        foldername= foldername,
    )
else:
    model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth"))

In [None]:
print("---------------Start testing---------------")
model.load_state_dict(torch.load("/kaggle/working/save/breast_fold5_20241108_123916/model.pth"))
evaluate(model, test_loader, nsample= 10, scaler= 1, foldername= foldername)

### Mean imputation

In [None]:
class MeanNet(nn.Module):
    def __init__(self, feat_mean_val):
        super().__init__()
        self.fmv= feat_mean_val
    def forward(self, x):
        return self.fmv[torch.where(x['gt_mask']== 0)[1]]

In [None]:
dataset= tabular_dataset(missing_ratio= 0.1, seed= args.seed)
meannet= MeanNet(torch.from_numpy((dataset.observed_values* dataset.gt_masks).mean(axis= 0)))

In [None]:
mse_total= 0
mae_total= 0
evalpoints_total= 0
for batch_id, batch_test in enumerate(test_loader):
    label= batch_test['observed_data'][~batch_test['gt_mask'].to(torch.bool)]
    pred= meannet(batch_test)
    mse_current= ((pred- label) ** 2)
    mae_current= torch.abs(pred- label)
    mse_total+= torch.sum(mse_current, dim= 0)
    mae_total+= torch.sum(mae_current, dim= 0)
    evalpoints_total+= (batch_test['gt_mask']== 0).sum()
print(f'rmse:{torch.sqrt(mse_total/ evalpoints_total)}\nmae:{mae_total/ evalpoints_total}')

### Visualization of Model Metrics

In [None]:
import pickle
import matplotlib.pyplot as plt
with open(f'./{foldername}/saved_history.pkl', 'rb') as f:
    history= pickle.load(f)
plt.figure(figsize=(8, 6))
plt.plot(history['train_loss'], label= 'Train Loss')
plt.plot(history['val_loss'], label= 'Valid Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss value')
plt.title('Training and Validation Loss over each 20 Epochs')
plt.legend()
plt.savefig('Training and Validation Loss.png', dpi=300)
plt.show()

In [None]:
import pickle
import matplotlib.pyplot as plt
with open(f'./{foldername}/saved_history.pkl', 'rb') as f:
    history= pickle.load(f)
plt.figure(figsize=(8, 6))
plt.plot(history['val_rmse'], label= 'Valid RMSE')
plt.plot(history['val_mae'], label= 'Valid MAE')
plt.xlabel('Epochs')
plt.ylabel('Value')
plt.title('Validation RMSE and MAE over each 20 Epochs')
plt.legend()
plt.savefig('Validation RMSE and MAE.png', dpi=300)
plt.show()