### File operation

In [None]:
# census, 9 categorical feature; 6 numerical features; total 20000 samples.

In [None]:
cp -r /kaggle/input/census-dataset/ /kaggle/working/census_dataset

In [None]:
rm -rf /kaggle/working/census_dataset

### Set parameters and import package

In [None]:
import re
import os
import math
import json
import yaml
import torch
import pickle
import argparse
import datetime
import typing as ty
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
from torch import Tensor
from torch.optim import Adam
import torch.nn.init as nn_init
import category_encoders as ce
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [None]:
parser= argparse.ArgumentParser(description= "TabCSDI")
parser.add_argument("--device", default= 'cpu', help= "Device")
parser.add_argument("--seed", type= int, default= 1)
parser.add_argument("--testmissingratio", type= float, default= 0.2)
parser.add_argument("--nfold", type= int, default= 5, help= "for 5-fold test")
parser.add_argument("--unconditional", action= "store_true", default= 0)
parser.add_argument("--modelfolder", type= str, default= "")
parser.add_argument("--nsample", type= int, default= 100)
args= parser.parse_args([])

### Load data

In [None]:
def process_func(path= "/kaggle/working/census_dataset/adult_trim.data", cat_list= [1, 3, 5, 6, 7, 8, 9, 13, 14], missing_ratio= 0.1, encode= True):
    data= pd.read_csv(path, header= None)
    # Swap columns
    temp_list= [i for i in range(data.shape[1]) if i not in cat_list]
    temp_list.extend(cat_list)
    new_cols_order= temp_list
    data= data.reindex(columns= data.columns[new_cols_order])
    data.columns= [i for i in range(data.shape[1])]
    # create two lists to store position
    cont_list= [i for i in range(0, data.shape[1]- len(cat_list))] # continuous features index
    cat_list= [i for i in range(len(cont_list), data.shape[1])] # catlogue features index
    observed_values= data.values
    observed_masks= ~pd.isnull(data)
    observed_masks= observed_masks.values
    masks= observed_masks.copy()
    # In this section, obtain gt_masks
    # for each column, mask `missing_ratio` % of observed values.
    for col in range(masks.shape[1]):
        obs_indices= np.where(masks[:, col])[0]
        miss_indices= np.random.choice(obs_indices, (int)(len(obs_indices) * missing_ratio), replace= False)
        masks[miss_indices, col]= False
    # gt_mask: 0 for missing elements and manully maksed elements
    gt_masks= masks.reshape(observed_masks.shape)
    num_cate_list= []
    if encode== True:
        # set encoder here
        encoder= ce.ordinal.OrdinalEncoder(cols= data.columns[cat_list])
        encoder.fit(data)
        new_df= encoder.transform(data)
        # we now need to transform these masks to the new one, suitable for mixed data types.
        cum_num_bits= 0
        new_observed_masks= observed_masks.copy()
        new_gt_masks= gt_masks.copy()
        for index, col in enumerate(cat_list):
            num_cate_list.append(new_df.iloc[:, col].nunique())
            corresponding_cols= len(
                [s for s in new_df.columns if isinstance(s, str) and s.startswith(str(col) + "_")]
            )
            add_col_num= corresponding_cols
            insert_col_obs= observed_masks[:, col]
            insert_col_gt= gt_masks[:, col]
            for i in range(add_col_num- 1):
                new_observed_masks= np.insert(
                    new_observed_masks, cum_num_bits+ col, insert_col_obs, axis= 1
                )
                new_gt_masks= np.insert(
                    new_gt_masks, cum_num_bits+ col, insert_col_gt, axis= 1
                )
            cum_num_bits+= add_col_num- 1
        new_observed_values= new_df.values
        new_observed_values= np.nan_to_num(new_observed_values)
        new_observed_values= new_observed_values.astype(float)
        with open("./census_dataset/transformed_columns.pk", "wb") as f:
            pickle.dump([cont_list, num_cate_list], f)
        with open("./census_dataset/encoder.pk", "wb") as f:
            pickle.dump(encoder, f)
    if encode== True:
        return new_observed_values, new_observed_masks, new_gt_masks, cont_list
    else:
        cont_cols= [i for i in data.columns if i not in cat_list]
        return observed_values, observed_masks, gt_masks, cont_list

In [None]:
class tabular_Dataset(Dataset):
    # eval_length should be equal to attributes number.
    def __init__(self, eval_length= 15, use_index_list= None, missing_ratio= 0.1, seed= 0):
        self.eval_length= eval_length
        np.random.seed(seed)
        dataset_path= "./census_dataset/adult_trim.data"
        processed_data_path= (f"./census_dataset/missing_ratio-{missing_ratio}_seed-{seed}.pk")
        processed_data_path_norm= f"./census_dataset/missing_ratio-{missing_ratio}_seed-{seed}_max-min_norm.pk"
        # self.cont_cols is only saved in .pk file before normalization.
        cat_list= [1, 3, 5, 6, 7, 8, 9, 13, 14]
        if not os.path.isfile(processed_data_path):
            (self.observed_values, self.observed_masks, self.gt_masks, self.cont_cols,)= process_func(dataset_path, cat_list= cat_list, missing_ratio= missing_ratio, encode= True,)
            with open(processed_data_path, "wb") as f:
                pickle.dump([self.observed_values, self.observed_masks, self.gt_masks, self.cont_cols,], f,)
            print("--------Dataset created--------")
        elif os.path.isfile(processed_data_path_norm):  # load datasetfile
            with open(processed_data_path_norm, "rb") as f:
                self.observed_values, self.observed_masks, self.gt_masks= pickle.load(f)
            print("--------Normalized dataset loaded--------")
        if use_index_list is None:
            self.use_index_list= np.arange(len(self.observed_values))
        else:
            self.use_index_list= use_index_list

    def __getitem__(self, org_index):
        index= self.use_index_list[org_index]
        s= {
            "observed_data": self.observed_values[index],
            "observed_mask": self.observed_masks[index],
            "gt_mask": self.gt_masks[index],
            "timepoints": np.arange(self.eval_length),
        }
        return s

    def __len__(self):
        return len(self.use_index_list)

In [None]:
def get_dataloader(seed= 1, nfold= 5, batch_size= 128, missing_ratio= 0.1):
    dataset= tabular_Dataset(missing_ratio= missing_ratio, seed= seed)
    print(f"Dataset size:{len(dataset)} entries")
    indlist= np.arange(len(dataset))
    np.random.seed(seed + 1)
    np.random.shuffle(indlist)
    tmp_ratio= 1/ nfold
    start= (int)((nfold- 1)* len(dataset)* tmp_ratio)
    end= (int)(nfold* len(dataset)* tmp_ratio)
    test_index= indlist[start: end]
    remain_index= np.delete(indlist, np.arange(start, end))
    np.random.shuffle(remain_index)
    num_train= (int)(len(remain_index)* 0.8)
    train_index= remain_index[:num_train]
    valid_index= remain_index[num_train:]
    # Here we perform max-min normalization.
    processed_data_path_norm= (f"./census_dataset/missing_ratio-{missing_ratio}_seed-{seed}_max-min_norm.pk")
    if not os.path.isfile(processed_data_path_norm):
        print("--------------Dataset has not been normalized yet. Perform data normalization and store the mean value of each column.--------------")
        # data transformation after train-test split.
        col_num= len(dataset.cont_cols)
        max_arr= np.zeros(col_num)
        min_arr= np.zeros(col_num)
        mean_arr= np.zeros(col_num)
        for index, k in enumerate(dataset.cont_cols):
            # Using observed_mask to avoid counting missing values (now represented as 0)
            obs_ind= dataset.observed_masks[train_index, k].astype(bool)
            temp= dataset.observed_values[train_index, k]
            max_arr[index]= max(temp[obs_ind])
            min_arr[index]= min(temp[obs_ind])
        print(f"--------------Max-value for cont-variable column {max_arr}--------------")
        print(f"--------------Min-value for cont-variable column {min_arr}--------------")
        for index, k in enumerate(dataset.cont_cols):
            dataset.observed_values[:, k]= (
                (dataset.observed_values[:, k]- (min_arr[index]- 1))/ (max_arr[index]- min_arr[index]+ 1)
            )* dataset.observed_masks[:, k]
        with open(processed_data_path_norm, "wb") as f:
            pickle.dump([dataset.observed_values, dataset.observed_masks, dataset.gt_masks], f)
    # Now the path exists, so the dataset object initialization performs data loading.
    train_dataset= tabular_Dataset(use_index_list= train_index, missing_ratio= missing_ratio, seed= seed)
    train_loader= DataLoader(train_dataset, batch_size= batch_size, shuffle= 1)
    valid_dataset= tabular_Dataset(use_index_list= valid_index, missing_ratio= missing_ratio, seed= seed)
    valid_loader= DataLoader(valid_dataset, batch_size= batch_size, shuffle= 0)
    test_dataset= tabular_Dataset(use_index_list= test_index, missing_ratio= missing_ratio, seed= seed)
    test_loader= DataLoader(test_dataset, batch_size= batch_size, shuffle= 0)
    print(f"Training dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(valid_dataset)}")
    print(f"Testing dataset size: {len(test_dataset)}")
    return train_loader, valid_loader, test_loader

In [None]:
train_loader, valid_loader, test_loader= get_dataloader(seed= 1, nfold= 5, batch_size= 128, missing_ratio= 0.1)

### Set parameters

In [None]:
config= {
    "train": {
        "epochs": 250,
        "batch_size": 128,
        "lr": 0.0005
    },
    "diffusion": {
        "layers": 4,
        "channels": 128,
        "nheads": 4,
        "diffusion_embedding_dim": 128,
        "beta_start": 0.0001,
        "beta_end": 0.5,
        "num_steps": 100,
        "schedule": "quad",
        "mixed": True,
        "token_emb_dim": 8
    },
    "model": {
        "is_unconditional": 0,
        "timeemb": 128,
        "featureemb": 8,
        "target_strategy": "random",
        "mixed": True,
        "token_emb_dim": 8
    }
}

In [None]:
config["model"]["is_unconditional"]= args.unconditional
config["model"]["test_missing_ratio"]= args.testmissingratio

### Model module

In [None]:
# partially stole from https://github.com/Yura52/tabular-dl-revisiting-models/blob/main/bin/ft_transformer.py
class Tokenizer(nn.Module):
    def __init__(self, d_numerical: int, categories: ty.Optional[ty.List[int]], d_token: int, bias: bool,)-> None:
        super().__init__()
        d_bias= d_numerical+ len(categories)
        category_offsets= torch.tensor([0]+ categories[:-1]).cumsum(0)
        self.d_token= d_token
        self.register_buffer("category_offsets", category_offsets)
        self.category_embeddings= nn.Embedding(sum(categories)+ 1, self.d_token)
        self.category_embeddings.weight.requires_grad= False
        nn_init.kaiming_uniform_(self.category_embeddings.weight, a= math.sqrt(5))
        self.weight= nn.Parameter(Tensor(d_numerical, self.d_token))
        self.weight.requires_grad= False
        self.bias= nn.Parameter(Tensor(d_bias, self.d_token)) if bias else None
        nn_init.kaiming_uniform_(self.weight, a= math.sqrt(5))
        if self.bias is not None:
            nn_init.kaiming_uniform_(self.bias, a= math.sqrt(5))
            self.bias.requires_grad= False
    @property
    def n_tokens(self)-> int:
        return len(self.weight)+ (0 if self.category_offsets is None else len(self.category_offsets))

    def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor])-> Tensor:
        x_some= x_num if x_cat is None else x_cat
        x_cat= x_cat.type(torch.int32)
        assert x_some is not None
        x= self.weight.T* x_num
        if x_cat is not None:
            x= x[:, np.newaxis, :, :]
            x= x.permute(0, 1, 3, 2)
            x= torch.cat([x, self.category_embeddings(x_cat + self.category_offsets[None])], dim= 2,)
        if self.bias is not None:
            x= x+ self.bias[None]
        return x

    def recover(self, Batch, d_numerical):
        B, L, K= Batch.shape
        L_new= int(L/ self.d_token)
        Batch= Batch.reshape(B, L_new, self.d_token)
        Batch= Batch- self.bias
        Batch_numerical= Batch[:, :d_numerical, :]
        Batch_numerical= Batch_numerical/ self.weight
        Batch_numerical= torch.mean(Batch_numerical, 2, keepdim= False)
        Batch_cat= Batch[:, d_numerical:, :]
        new_Batch_cat= torch.zeros([Batch_cat.shape[0], Batch_cat.shape[1]])
        for i in range(Batch_cat.shape[1]):
            token_start= self.category_offsets[i]+ 1
            if i== Batch_cat.shape[1]- 1:
                token_end= self.category_embeddings.weight.shape[0]- 1
            else:
                token_end= self.category_offsets[i+ 1]
            emb_vec= self.category_embeddings.weight[token_start: token_end+ 1, :]
            for j in range(Batch_cat.shape[0]):
                distance= torch.norm(emb_vec- Batch_cat[j, i, :], dim= 1)
                nearest= torch.argmin(distance)
                new_Batch_cat[j, i]= nearest+ 1
            new_Batch_cat= new_Batch_cat.to(Batch_numerical.device)
        return torch.cat([Batch_numerical, new_Batch_cat], dim= 1)

In [None]:
def Conv1d_with_init(in_channels, out_channels, kernel_size):
    # Get Conv1d layer
    layer= nn.Conv1d(in_channels, out_channels, kernel_size)
    # Weight initialization
    nn.init.kaiming_normal_(layer.weight)
    return layer

In [None]:
def get_torch_trans(heads= 8, layers= 1, channels= 64):
    encoder_layer= nn.TransformerEncoderLayer(
        d_model= channels, nhead= heads, dim_feedforward= 64, activation= "gelu"
    )
    return nn.TransformerEncoder(encoder_layer, num_layers= layers)

In [None]:
class DiffusionEmbedding(nn.Module):
    # Get timestep embedding
    def __init__(self, num_steps, embedding_dim= 128, projection_dim= None):
        super().__init__()
        if projection_dim is None:
            projection_dim= embedding_dim
        # The benefits of buffering:
        # 1. Avoid including parameters that do not require optimization in the trainable parameters of the model
        # 2. Efficient storage and access
        # 3. Clearly distinguish between model parameters and auxiliary data
        self.register_buffer("embedding", self._build_embedding(num_steps, embedding_dim/ 2), persistent= False)
        self.projection1= nn.Linear(embedding_dim, projection_dim)
        self.projection2= nn.Linear(projection_dim, projection_dim)

    def forward(self, diffusion_step):
        x= self.embedding[diffusion_step]
        x= self.projection1(x)
        x= F.silu(x)
        x= self.projection2(x)
        x= F.silu(x)
        return x

    # t_embedding(t). The embedding dimension is 128 in total for every time step t.
    def _build_embedding(self, num_steps, dim= 64):
        steps= torch.arange(num_steps).unsqueeze(1)  # (T,1)
        frequencies= 10.0** (torch.arange(dim)/ (dim- 1)* 4.0).unsqueeze(0)  # (1,dim)
        table= steps* frequencies  # (T,dim)
        table= torch.cat([torch.sin(table), torch.cos(table)], dim= 1)  # (T,dim*2)
        return table

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, side_dim, channels, diffusion_embedding_dim, nheads):
        super().__init__()
        self.diffusion_projection= nn.Linear(diffusion_embedding_dim, channels)
        self.cond_projection= Conv1d_with_init(side_dim, 2* channels, 1)
        self.mid_projection= Conv1d_with_init(channels, 2* channels, 1)
        self.output_projection= Conv1d_with_init(channels, 2* channels, 1)
        # Temporal Transformer layer
        self.time_layer= get_torch_trans(heads= nheads, layers= 1, channels= channels)
        # Feature Transformer layer
        self.feature_layer= get_torch_trans(heads= nheads, layers= 1, channels= channels)

    def forward_time(self, y, base_shape):
        B, channel, K, L= base_shape
        if L== 1:
            return y
        y= y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B* K, channel, L)
        # input shape for transformerencoder: [seq, batch, emb]
        y= self.time_layer(y.permute(2, 0, 1)).permute(1, 2, 0)
        y= y.reshape(B, K, channel, L).permute(0, 2, 1, 3).reshape(B, channel, K* L)
        return y

    def forward_feature(self, y, base_shape):
        B, channel, K, L= base_shape
        if K== 1:
            return y
        y= y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B* L, channel, K)
        y= self.feature_layer(y.permute(2, 0, 1)).permute(1, 2, 0)
        y= y.reshape(B, L, channel, K).permute(0, 2, 3, 1).reshape(B, channel, K* L)
        return y
    
    # x, (B, channels, K, L); cond_info, (B, 49, 1, L); diffusion_emb, (B, emb_dim).
    def forward(self, x, cond_info, diffusion_emb):
        B, channel, K, L= x.shape
        base_shape= x.shape
        x= x.reshape(B, channel, K* L)
        # Diffusion embedding processing
        diffusion_emb= self.diffusion_projection(diffusion_emb).unsqueeze(-1)  # (B,channel,1)
        # x, (B, channel, KL); diffusion_emb, (B, channel, 1); y, (B, channel, KL).
        y= x+ diffusion_emb
        # Temporal transformer
        y= self.forward_time(y, base_shape)
        # Feature transformer
        y= self.forward_feature(y, base_shape)  # (B, channel, K* L)
        # Combining conditional information
        y= self.mid_projection(y)  # (B, 2*channel, K* L)
        _, cond_dim, _, _= cond_info.shape
        cond_info= cond_info.reshape(B, cond_dim, K* L)
        cond_info= self.cond_projection(cond_info)  # (B, 2* channel, K* L)
        y= y+ cond_info
        # Gated
        gate, filter= torch.chunk(y, 2, dim= 1)
        y= torch.sigmoid(gate)* torch.tanh(filter) # y, (B, channel, KL)
        y= self.output_projection(y)
        # Residual
        residual, skip= torch.chunk(y, 2, dim= 1)
        x= x.reshape(base_shape)
        residual= residual.reshape(base_shape)
        skip= skip.reshape(base_shape)
        return (x+ residual)/ math.sqrt(2.0), skip

In [None]:
class diff_CSDI(nn.Module):
    def __init__(self, config, inputdim=2):
        super().__init__()
        self.config= config
        self.channels= config["channels"]
        self.diffusion_embedding= DiffusionEmbedding(num_steps= config["num_steps"], embedding_dim= config["diffusion_embedding_dim"],)
        self.token_emb_dim= config["token_emb_dim"] if config["mixed"] else 1
        inputdim= 2* self.token_emb_dim
        self.input_projection= Conv1d_with_init(inputdim, self.channels, 1)
        self.output_projection1= Conv1d_with_init(self.channels, self.channels, 1)
        self.output_projection2= Conv1d_with_init(self.channels, self.token_emb_dim, 1)
        nn.init.zeros_(self.output_projection2.weight)
        self.residual_layers= nn.ModuleList(
            [
                ResidualBlock(
                    side_dim= config["side_dim"],
                    channels= self.channels,
                    diffusion_embedding_dim= config["diffusion_embedding_dim"],
                    nheads= config["nheads"],
                )
                for _ in range(config["layers"])
            ]
        )

    def forward(self, x, cond_info, diffusion_step):
        B, inputdim, K, L= x.shape
        x= x.reshape(B, inputdim, K * L)
        x= self.input_projection(x)
        x= F.relu(x)
        x= x.reshape(B, self.channels, K, L)
        diffusion_emb= self.diffusion_embedding(diffusion_step)
        skip= []
        for layer in self.residual_layers:
            x, skip_connection= layer(x, cond_info, diffusion_emb)
            skip.append(skip_connection)
        x= torch.sum(torch.stack(skip), dim= 0)/ math.sqrt(len(self.residual_layers))
        x= x.reshape(B, self.channels, K* L)
        x= self.output_projection1(x)
        x= F.relu(x)
        x= self.output_projection2(x)
        if self.config["mixed"]:
            x= x.permute(0, 2, 1)
            x= x.reshape(B, K, L* self.token_emb_dim)
        else:
            x= x.reshape(B, K, L)
        return x

In [None]:
class CSDI_base(nn.Module):
    def __init__(self, exe_name, target_dim, config, device):
        super().__init__()
        self.device= device
        self.target_dim= target_dim
        # load embedding vector dimension.
        self.emb_time_dim= config["model"]["timeemb"]
        self.emb_feature_dim= config["model"]["featureemb"]
        self.emb_total_dim= self.emb_time_dim+ self.emb_feature_dim
        self.is_unconditional= config["model"]["is_unconditional"]
        self.target_strategy= config["model"]["target_strategy"]
        # For categorical variables
        self.mixed= config["model"]["mixed"]
        if exe_name== "census":
            with open("./census_dataset/transformed_columns.pk", "rb") as f:
                cont_list, num_cate_list= pickle.load(f)
        self.cont_list= cont_list
        if self.mixed:
            self.token_dim= config["model"]["token_emb_dim"]
            # set tokenizer
            d_numerical= len(cont_list)
            categories= num_cate_list
            d_token= self.token_dim
            token_bias= True
            self.tokenizer= Tokenizer(d_numerical, categories, d_token, token_bias)
        if self.is_unconditional== False:
            self.emb_total_dim+= 1  # for conditional mask
        self.embed_layer= nn.Embedding(num_embeddings= self.target_dim, embedding_dim= self.emb_feature_dim)
        config_diff= config["diffusion"]
        config_diff["side_dim"]= self.emb_total_dim
        input_dim= 1 if self.is_unconditional== True else 2
        tot_feature_num= len(cont_list)+ len(num_cate_list)
        self.diffmodel= diff_CSDI(config_diff, input_dim)
        # parameters for diffusion models
        self.num_steps= config_diff["num_steps"]
        if config_diff["schedule"]== "quad":
            self.beta= (np.linspace(config_diff["beta_start"]** 0.5, config_diff["beta_end"]** 0.5, self.num_steps,)** 2)
        elif config_diff["schedule"]== "linear":
            self.beta= np.linspace(config_diff["beta_start"], config_diff["beta_end"], self.num_steps)
        self.alpha_hat = 1 - self.beta
        self.alpha = np.cumprod(self.alpha_hat)
        self.alpha_torch = (
            torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1)
        )

    def time_embedding(self, pos, d_model=128):
        pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(self.device)
        position = pos.unsqueeze(2)
        div_term = 1 / torch.pow(10000.0, torch.arange(0, d_model, 2).to(self.device) / d_model)
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)
        return pe

    def get_randmask(self, observed_mask):
        rand_for_mask = torch.rand_like(observed_mask) * observed_mask
        rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1)
        for i in range(len(observed_mask)):
            sample_ratio = np.random.rand()
            num_observed = observed_mask[i].sum().item()
            num_masked = round(num_observed * sample_ratio)
            rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1
        cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float()
        return cond_mask

    def get_side_info(self, observed_tp, cond_mask):
        B, K, L = cond_mask.shape
        time_embed = self.time_embedding(observed_tp, self.emb_time_dim)  # (B,L,emb)
        time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1)
        feature_embed = self.embed_layer(torch.arange(self.target_dim).to(self.device))  # (K,emb)
        feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1)
        side_info = torch.cat([time_embed, feature_embed], dim=-1)  # (B,L,K,*)
        side_info = side_info.permute(0, 3, 2, 1)  # (B,*,K,L)
        if self.is_unconditional == False:
            side_mask = cond_mask.unsqueeze(1)  # (B,1,K,L)
            side_info = torch.cat([side_info, side_mask], dim=1)
        return side_info

    def calc_loss_valid(self, observed_data, cond_mask, observed_mask, side_info, is_train):
        loss_sum = 0
        for t in range(self.num_steps):
            loss = self.calc_loss(observed_data, cond_mask, observed_mask, side_info, is_train, set_t= t)
            loss_sum += loss.detach()
        return loss_sum / self.num_steps

    def calc_loss(self, observed_data, cond_mask, observed_mask, side_info, is_train, set_t=-1):
        B, K, L = observed_data.shape
        if is_train != 1:
            t = (torch.ones(B) * set_t).long().to(self.device)
        else:
            t = torch.randint(0, self.num_steps, [B]).to(self.device)
        current_alpha = self.alpha_torch[t]  # (B,1,1)
        noise = torch.randn_like(observed_data)
        # Perform forward step. Adding noise to all data.
        noisy_data = (current_alpha**0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise
        total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask)
        predicted = self.diffmodel(total_input, side_info, t)  # (B,K,L*token_dim)
        target_mask = observed_mask - cond_mask
        target_mask = torch.repeat_interleave(target_mask, self.token_dim, dim=2)
        residual = (noise - predicted) * target_mask
        num_eval = target_mask.sum()
        loss = (residual**2).sum() / (num_eval if num_eval > 0 else 1)
        return loss

    def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask):
        cond_mask = torch.repeat_interleave(cond_mask, self.token_dim, dim=2)
        cond_obs = (cond_mask * observed_data).unsqueeze(1)
        noisy_target = ((1 - cond_mask) * noisy_data).unsqueeze(1)
        total_input = torch.cat([cond_obs, noisy_target], dim=1)  # (B,2,K,L)
        B, old_input_dim, K, L = total_input.shape
        total_input = total_input.reshape(B, old_input_dim, K, int(L / self.token_dim), self.token_dim)
        total_input = total_input.permute(0, 1, 4, 2, 3)
        total_input = total_input.reshape(B, old_input_dim * self.token_dim, K, int(L / self.token_dim))
        return total_input

    def impute(self, observed_data, cond_mask, side_info, n_samples):
        B, K, L = observed_data.shape
        cond_mask = torch.repeat_interleave(cond_mask, self.token_dim, dim=2)
        imputed_samples = torch.zeros(B, n_samples, K, L).to(self.device)
        # Perform n_samples times of forward and backward pass for same input data.
        for i in range(n_samples):
            if self.is_unconditional == True:
                noisy_obs = observed_data
                noisy_cond_history = []
                # perform T steps forward
                for t in range(self.num_steps):
                    noise = torch.randn_like(noisy_obs)
                    noisy_obs = (self.alpha_hat[t] ** 0.5) * noisy_obs + self.beta[t] ** 0.5 * noise
                    noisy_cond_history.append(noisy_obs * cond_mask)
            current_sample = torch.randn_like(observed_data)
            # perform T steps backward
            for t in range(self.num_steps - 1, -1, -1):
                if self.is_unconditional == True:
                    diff_input = (cond_mask * noisy_cond_history[t]+ (1.0 - cond_mask) * current_sample)
                    diff_input = diff_input.unsqueeze(1)  # (B,1,K,L)
                else:
                    # fix original x^{co} as condition
                    cond_obs = (cond_mask * observed_data).unsqueeze(1)
                    noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1)
                    diff_input = torch.cat([cond_obs, noisy_target], dim=1)  # (B,2,K,L)
                    B, old_input_dim, K, L = diff_input.shape
                    diff_input = diff_input.reshape(B, old_input_dim, K, int(L / self.token_dim), self.token_dim)
                    diff_input = diff_input.permute(0, 1, 4, 2, 3)
                    diff_input = diff_input.reshape(B, old_input_dim * self.token_dim, K, int(L / self.token_dim))
                predicted = self.diffmodel(diff_input, side_info, torch.tensor([t]).to(self.device))  # (B,K,L)
                coeff1 = 1 / self.alpha_hat[t] ** 0.5
                coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5
                current_sample = coeff1 * (current_sample - coeff2 * predicted)
                if t > 0:
                    noise = torch.randn_like(current_sample)
                    sigma = ((1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]) ** 0.5
                    current_sample += sigma * noise
            imputed_samples[:, i] = current_sample.detach()
        return imputed_samples

    def forward(self, batch, is_train=1):
        (observed_data, observed_mask, observed_tp, gt_mask, for_pattern_mask, _,)= self.process_data(batch)
        # In testing, using `gt_mask` (generated with fixed missing rate) as cond_mask.
        # In training, generate random mask as cond_mask
        if is_train == 0:
            cond_mask = gt_mask
        else:
            cond_mask = self.get_randmask(observed_mask)
        side_info = self.get_side_info(observed_tp, cond_mask)
        # The main calculation procedures are in `self.calc_loss()`
        loss_func = self.calc_loss if is_train == 1 else self.calc_loss_valid
        return loss_func(observed_data, cond_mask, observed_mask, side_info, is_train)

    def evaluate(self, batch, n_samples):
        (observed_data, observed_mask, observed_tp, gt_mask, _, cut_length,)= self.process_data(batch)
        with torch.no_grad():
            # gt_mask: 0 for missing elements and manully maksed elements
            cond_mask= gt_mask
            # target_mask: 1 for manually masked elements
            target_mask= observed_mask- cond_mask
            side_info= self.get_side_info(observed_tp, cond_mask)
            samples= self.impute(observed_data, cond_mask, side_info, n_samples)
        return samples, observed_data, target_mask, observed_mask, observed_tp

In [None]:
class TabCSDI(CSDI_base):
    def __init__(self, exe_name, config, device, target_dim= 1):
        super().__init__(exe_name, target_dim, config, device)

    def process_data(self, batch):
        # Insert K=1 axis. All mask now with shape (B, 1, L). L=# of attributes.
        observed_data= batch["observed_data"][:, np.newaxis, :]
        observed_data= observed_data.to(self.device).float()
        observed_data= self.tokenizer(observed_data[:, :, self.cont_list], observed_data[:, :, len(self.cont_list) :],)
        B, K, L, C= observed_data.shape
        observed_data= observed_data.reshape(B, K, L * C)
        observed_mask= batch["observed_mask"][:, np.newaxis, :]
        observed_mask= observed_mask.to(self.device).float()
        observed_tp= batch["timepoints"].to(self.device).float()
        gt_mask= batch["gt_mask"][:, np.newaxis, :]
        gt_mask= gt_mask.to(self.device).float()
        cut_length= torch.zeros(len(observed_data)).long().to(self.device)
        for_pattern_mask= observed_mask
        return (observed_data, observed_mask, observed_tp, gt_mask, for_pattern_mask, cut_length,)

In [None]:
def train(exe_name, model, config, train_loader, valid_loader= None, valid_epoch_interval= 20, foldername= "",):
    if exe_name== "census":
        with open("./census_dataset/transformed_columns.pk", "rb") as f:
            cont_list, num_cate_list= pickle.load(f)
        with open("./census_dataset/encoder.pk", "rb") as f:
            encoder= pickle.load(f)
    # Control random seed in the current script.
    torch.manual_seed(0)
    np.random.seed(0)
    optimizer= Adam(model.parameters(), lr= config["lr"], weight_decay= 1e-6)
    if foldername!= "":
        output_path= foldername+ "/model.pth"
    p0= int(0.25* config["epochs"])
    p1= int(0.5* config["epochs"])
    p2= int(0.75* config["epochs"])
    p3= int(0.9* config["epochs"])
    # When the number of training rounds reaches p0, p1, p2, and p3, the learning rate will decay.
    lr_scheduler= torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones= [p0, p1, p2, p3], gamma= 0.1)
    history = {'train_loss':[], 'val_loss':[], 'val_rmse':[], 'val_mae':[], 'erro_rate': []}
    err_total = np.zeros([len(num_cate_list)])
    err_total_eval_nums = np.zeros([len(num_cate_list)])    
    best_valid_rmse_erro_rate= 1e10
    for epoch_no in range(config["epochs"]):
        train_avg_loss= 0
        model.train()
        # The minimum and maximum time intervals for progress bar updates have been specified.
        with tqdm(train_loader, mininterval= 5.0, maxinterval= 50.0) as it:
            for batch_no, train_batch in enumerate(it, start= 1):
                optimizer.zero_grad()
                # The forward method returns loss.
                loss= model(train_batch)
                loss.backward()
                train_avg_loss+= loss.item()
                optimizer.step()
                it.set_postfix(
                    ordered_dict= {"avg_epoch_loss": train_avg_loss/ batch_no, "epoch": epoch_no,},
                    refresh= False,
                )
            # step each epoch.
            lr_scheduler.step()

        if valid_loader is not None and (epoch_no+ 1)% valid_epoch_interval== 0:
            history['train_loss'].append(train_avg_loss/ batch_no)
            print("Start validation")
            model.eval()
            avg_loss_valid= 0
            # some initial settings
            val_nsample= 10
            val_scaler= 1
            mse_total= 0
            mae_total= 0
            evalpoints_total= 0
            with torch.no_grad():
                with tqdm(valid_loader, mininterval= 5.0, maxinterval= 50.0) as it:
                    for batch_no, valid_batch in enumerate(it, start= 1):
                        loss= model(train_batch)
                        avg_loss_valid+= loss.item()
                        output= model.evaluate(valid_batch, val_nsample)
                        # `eval_points` is `target_mask`. `observed_time` is `observed_tp`(10)
                        # `c_target` is `observed_data`
                        (samples, c_target, eval_points, observed_points, observed_time,)= output
                        samples= samples.permute(0, 1, 3, 2)
                        c_target= c_target.permute(0, 2, 1)
                        eval_points= eval_points.permute(0, 2, 1)
                        observed_points= observed_points.permute(0, 2, 1)
                        samples_median= samples.median(dim= 1)
                        samples_median= model.tokenizer.recover(samples_median.values, len(cont_list))
                        c_target= model.tokenizer.recover(c_target, len(cont_list))
                        # for continous variables
                        mse_current= (((samples_median[:, cont_list] - c_target[:, cont_list])* eval_points[:, cont_list, 0])** 2) * (val_scaler**2)
                        mae_current= (torch.abs((samples_median[:, cont_list] - c_target[:, cont_list])* eval_points[:, cont_list, 0])) * val_scaler
                        # for categorical variables
                        for i in range(len(num_cate_list)):
                            matched_nums = (samples_median[:, len(cont_list) + i]== c_target[:, len(cont_list) + i]* eval_points[:, len(cont_list) + i, 0]).sum()
                            eval_nums = eval_points[:, len(cont_list) + i, 0].sum()
                            err_total[i] += eval_nums - matched_nums
                            err_total_eval_nums[i] += eval_nums
                        mse_total += torch.sum(mse_current, dim=0)
                        mae_total += torch.sum(mae_current, dim=0)
                        evalpoints_total+= torch.sum(eval_points[:, cont_list, 0], dim=0)
                        it.set_postfix(
                            ordered_dict= {
                                "rmse_total": torch.mean(torch.sqrt(torch.div(mse_total, evalpoints_total))).item(),
                                "erro_rate": (err_total/ err_total_eval_nums).mean(),
                                "batch_no": batch_no
                            },
                            refresh= True,
                        )
                    history['val_rmse'].append(torch.mean(torch.sqrt(torch.div(mse_total, evalpoints_total))).item())
                    history['val_mae'].append(torch.mean(torch.div(mae_total, evalpoints_total)).item())
                    history['val_loss'].append(avg_loss_valid/ batch_no)
                    history['erro_rate'].append((err_total/ err_total_eval_nums).mean())
                    # save model
                    if best_valid_rmse_erro_rate> torch.mean(torch.sqrt(torch.div(mse_total, evalpoints_total))).item()+ (err_total/ err_total_eval_nums).mean() and foldername!= "":
                        torch.save(model.state_dict(), output_path)
    # Use folloing code for saving training history.
    with open(foldername+'/saved_history.pkl', 'wb') as f:
        pickle.dump(history, f)

In [None]:
def evaluate_ft(exe_name, model, test_loader, nsample=100, scaler=1, mean_scaler=0, foldername=""):
    if exe_name == "census":
        with open("./census_dataset/transformed_columns.pk", "rb") as f:
            cont_list, num_cate_list = pickle.load(f)
        with open("./census_dataset/encoder.pk", "rb") as f:
            encoder = pickle.load(f)
    print(cont_list, num_cate_list)
    torch.manual_seed(0)
    np.random.seed(0)
    with torch.no_grad():
        model.eval()
        mse_total = 0
        mae_total = 0
        err_total = np.zeros([len(num_cate_list)])
        err_total_eval_nums = np.zeros([len(num_cate_list)])
        evalpoints_total = 0
        all_target = []
        all_observed_point = []
        all_observed_time = []
        all_evalpoint = []
        all_generated_samples = []
        with tqdm(test_loader, mininterval=5.0, maxinterval=50.0) as it:
            for batch_no, test_batch in enumerate(it, start=1):
                output = model.evaluate(test_batch, nsample)
                samples, c_target, eval_points, observed_points, observed_time = output
                samples = samples.permute(0, 1, 3, 2)  # (B,nsample,L,K)
                c_target = c_target.permute(0, 2, 1)  # (B,L,K)
                eval_points = eval_points.permute(0, 2, 1)
                observed_points = observed_points.permute(0, 2, 1)
                # take the median from samples.
                samples_median = samples.median(dim=1)  # (B, L, K)
                samples_median = model.tokenizer.recover(samples_median.values, len(cont_list))
                c_target = model.tokenizer.recover(c_target, len(cont_list))
                all_target.append(c_target)
                all_evalpoint.append(eval_points)
                all_observed_point.append(observed_points)
                all_observed_time.append(observed_time)
                all_generated_samples.append(samples)
                # for continous variables
                mse_current= (((samples_median[:, cont_list] - c_target[:, cont_list])* eval_points[:, cont_list, 0])** 2) * (scaler**2)
                mae_current= (torch.abs((samples_median[:, cont_list] - c_target[:, cont_list])* eval_points[:, cont_list, 0])) * scaler
                # for categorical variables
                for i in range(len(num_cate_list)):
                    matched_nums = (samples_median[:, len(cont_list) + i]== c_target[:, len(cont_list) + i]* eval_points[:, len(cont_list) + i, 0]).sum()
                    eval_nums = eval_points[:, len(cont_list) + i, 0].sum()
                    err_total[i] += eval_nums - matched_nums
                    err_total_eval_nums[i] += eval_nums
                mse_total += torch.sum(mse_current, dim=0)
                mae_total += torch.sum(mae_current, dim=0)
                evalpoints_total += torch.sum(eval_points[:, cont_list, 0], dim=0)
                it.set_postfix(
                    ordered_dict= {
                        "rmse_total": torch.mean(torch.sqrt(torch.div(mse_total, evalpoints_total))).item(),
                        "batch_no": batch_no,
                    },
                    refresh= True,
                )

            with open(foldername + "/result_nsample" + str(nsample) + ".pk", "wb") as f:
                pickle.dump([torch.mean(torch.sqrt(torch.div(mse_total, evalpoints_total))).item(), err_total/ err_total_eval_nums,],f,)
                print("RMSE:", torch.mean(torch.sqrt(torch.div(mse_total, evalpoints_total))).item(),)
                print("ERR_CATE:", err_total / err_total_eval_nums)

In [None]:
args.device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
exe_name= "census"
model= TabCSDI(exe_name, config, args.device).to(args.device)

In [None]:
# Create folder
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
foldername = "./save/census_fold" + str(args.nfold) + "_" + current_time + "/"
print("model folder:", foldername)
os.makedirs(foldername, exist_ok=True)
with open(foldername + "config.json", "w") as f:
    json.dump(config, f, indent=4)
# Train model
if args.modelfolder== "":
    train(exe_name, model, config["train"], train_loader, valid_loader= valid_loader, foldername= foldername,)
else:
    model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth"))

In [None]:
print("---------------Start testing---------------")
evaluate_ft(exe_name, model, test_loader, nsample= args.nsample, scaler= 1, foldername= foldername)

### Visualization of Model Metrics

In [None]:
import pickle
import matplotlib.pyplot as plt
with open(f'{foldername}saved_history.pkl', 'rb') as f:
    history= pickle.load(f)
plt.figure(figsize=(8, 6))
plt.plot(history['train_loss'], label= 'Train Loss')
plt.plot(history['val_loss'], label= 'Valid Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss value')
plt.title('Training and Validation Loss over each 20 Epochs')
plt.legend()
plt.savefig('Training and Validation Loss.png', dpi=300)
plt.show()

In [None]:
import pickle
import matplotlib.pyplot as plt
with open(f'{foldername}saved_history.pkl', 'rb') as f:
    history= pickle.load(f)
plt.figure(figsize=(8, 6))
plt.plot(history['val_rmse'], label= 'Valid RMSE')
plt.plot(history['val_mae'], label= 'Valid MAE')
plt.xlabel('Epochs')
plt.ylabel('Value')
plt.title('Validation RMSE and MAE over each 20 Epochs')
plt.legend()
plt.savefig('Validation RMSE and MAE.png', dpi=300)
plt.show()

### Tokenizer

In [None]:
# partially stole from https://github.com/Yura52/tabular-dl-revisiting-models/blob/main/bin/ft_transformer.py
class Tokenizer(nn.Module):
    def __init__(self, d_numerical: int, categories: ty.Optional[ty.List[int]], d_token: int, bias: bool,)-> None:
        super().__init__()
        d_bias= d_numerical+ len(categories)
        # 9是指该特征有9个取值
        # [0]+ [9, 16, 8, 15, 6, 5, 2, 42]
        # tensor([  0,   9,  25,  33,  48,  54,  59,  61, 103])
        category_offsets= torch.tensor([0]+ categories[:-1]).cumsum(0)
        # 8
        self.d_token= d_token
        self.register_buffer("category_offsets", category_offsets)
        # (7, 8)
        self.category_embeddings= nn.Embedding(sum(categories)+ 1, self.d_token)
        self.category_embeddings.weight.requires_grad= False
        nn_init.kaiming_uniform_(self.category_embeddings.weight, a= math.sqrt(5))
        self.weight= nn.Parameter(Tensor(d_numerical, self.d_token))
        self.weight.requires_grad= False
        self.bias= nn.Parameter(Tensor(d_bias, self.d_token)) if bias else None
        nn_init.kaiming_uniform_(self.weight, a= math.sqrt(5))
        if self.bias is not None:
            nn_init.kaiming_uniform_(self.bias, a= math.sqrt(5))
            self.bias.requires_grad= False
    @property
    def n_tokens(self)-> int:
        return len(self.weight)+ (0 if self.category_offsets is None else len(self.category_offsets))
    # x_num, (B, 8, 6); x_cat, (B, 8, 9). 
    def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor])-> Tensor:
        x_some= x_num if x_cat is None else x_cat
        x_cat= x_cat.type(torch.int32)
        assert x_some is not None
        # self.weight.T.shape, (8, 6); x_num, (B, 1, 6);
        # x, (B, 8, 6).
        # 数值变量直接embedding
        x= self.weight.T* x_num
        # x_cat, (B, 1, 9).
        if x_cat is not None:
            # x, (B, 1, 8, 6)
            x= x[:, np.newaxis, :, :]
            # x, (B, 1, 6, 8)
            x= x.permute(0, 1, 3, 2)
            # x_cat, (B, 1, 9); self.category_offsets[None], (1, 9)
            # x, torch.cat([(B, 1, 6, 8), (B, 1, 9, 8)])-> (B, 1, 15, 8)
            # 先跳过之前分类变量的取值数，再embedding
            x= torch.cat([x, self.category_embeddings(x_cat + self.category_offsets[None])], dim= 2,)
        if self.bias is not None:
            x= x+ self.bias[None]
        return x
    # d_numerical, 6
    def recover(self, Batch, d_numerical):
        # 128, 120, 1
        B, L, K= Batch.shape
        # L_new, 15= 120/ 8
        L_new= int(L/ self.d_token)
        # Batch, (B, 15, 8)
        Batch= Batch.reshape(B, L_new, self.d_token)
        Batch= Batch- self.bias
        Batch_numerical= Batch[:, :d_numerical, :]
        # 对数值变量嵌入进行还原, (B, 9, 8), 除以嵌入取平均
        Batch_numerical= Batch_numerical/ self.weight
        # Batch_numerical, (B, 9, 8)-> (B, 9)
        Batch_numerical= torch.mean(Batch_numerical, 2, keepdim= False)
        # (B, 9, 8)
        Batch_cat = Batch[:, d_numerical:, :]
        # (B, 9)
        new_Batch_cat = torch.zeros([Batch_cat.shape[0], Batch_cat.shape[1]])
        # for each cata feature
        for i in range(Batch_cat.shape[1]):
            # token_start, 1, 10, 26, 34, 49, 55, 60, 62, 104.
            token_start = self.category_offsets[i] + 1
            if i == Batch_cat.shape[1] - 1:
                token_end = self.category_embeddings.weight.shape[0] - 1
            else:
                token_end = self.category_offsets[i + 1]
            # 1, 9
            print(f'token_start: {token_start}, token_end: {token_end}')
            # emb_vec, (9, 8).
            emb_vec = self.category_embeddings.weight[token_start : token_end + 1, :]
            # 量化操作(根据嵌入的距离，为学到的嵌入分配固定嵌入)
            for j in range(Batch_cat.shape[0]):
                distance = torch.norm(emb_vec - Batch_cat[j, i, :], dim=1)
                nearest = torch.argmin(distance)
                new_Batch_cat[j, i] = nearest + 1
            new_Batch_cat = new_Batch_cat.to(Batch_numerical.device)
        return torch.cat([Batch_numerical, new_Batch_cat], dim=1)

In [None]:
with open("./census_dataset/transformed_columns.pk", "rb") as f:
    cont_list, num_cate_list = pickle.load(f)
tok= Tokenizer(6, num_cate_list, config["model"]["token_emb_dim"], True)
print(num_cate_list)
print(torch.tensor([0]+ num_cate_list[:-1]).cumsum(0)[None])
observed_data= next(iter(test_loader))['observed_data'][:, np.newaxis, :]
# (B, K, L)
print(f'{observed_data.shape}')
print(f'original data: {observed_data[0]}')
observed_data= tok(
    observed_data[:, :, cont_list],
    observed_data[:, :, len(cont_list):],
)
print(f'encoding data: {observed_data[0, 0].shape}')
observed_data= observed_data.view(observed_data.shape[0], -1).unsqueeze(-1)
covered_data= tok.recover(observed_data, 6)
print(f'recovered data: {covered_data[0]}')