In [None]:
!git clone https://kekayan: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@github.com/kekayan/progNet-SAINT.git

In [None]:
%pip install -q einops

In [None]:
%cd progNet-SAINT/

In [1]:
import numpy as np
from sklearn.preprocessing import LabelEncoder


In [2]:
import pandas as pd

df = pd.read_csv("../data/clinical_and_other_features.csv")

In [3]:
def data_split(X,y,nan_mask,indices):
    x_d = {
        'data': X.values[indices],
        'mask': nan_mask.values[indices]
    }

    if x_d['data'].shape != x_d['mask'].shape:
        raise'Shape of data not same as that of nan mask!'

    y_d = {
        'data': y[indices].reshape(-1, 1)
    }
    return x_d, y_d

In [4]:
from torch.utils.data import Dataset, DataLoader

class DataSetCatCon(Dataset):
    def __init__(self, X, Y, cat_cols,task='clf',continuous_mean_std=None):

        cat_cols = list(cat_cols)
        X_mask =  X['mask'].copy()
        X = X['data'].copy()
        con_cols = list(set(np.arange(X.shape[1])) - set(cat_cols))
        self.X1 = X[:,cat_cols].copy().astype(np.int64) #categorical columns
        self.X2 = X[:,con_cols].copy().astype(np.float32) #numerical columns
        self.X1_mask = X_mask[:,cat_cols].copy().astype(np.int64) #categorical columns
        self.X2_mask = X_mask[:,con_cols].copy().astype(np.int64) #numerical columns
        self.y = Y['data']#.astype(np.float32) if regression
        self.cls = np.zeros_like(self.y,dtype=int)
        self.cls_mask = np.ones_like(self.y,dtype=int)
        if continuous_mean_std is not None:
            mean, std = continuous_mean_std
            self.X2 = (self.X2 - mean) / std

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        # X1 has categorical data, X2 has continuous
        return np.concatenate((self.cls[idx], self.X1[idx])), self.X2[idx],self.y[idx], np.concatenate((self.cls_mask[idx], self.X1_mask[idx])), self.X2_mask[idx]

In [5]:
def prepare_dataset(df):
  df1 = df.drop(['Overall Near-complete Response:  Looser Definition','Near-complete Response (Graded Measure)'],axis=1)
  df1.columns = df1.columns.str.strip()
  X = df1.drop('Overall Near-complete Response:  Stricter Definition',axis=1)
  y = df1['Overall Near-complete Response:  Stricter Definition']
  cont_columns = ['Date of Birth (Days)', 'Days to Surgery (from the date of diagnosis)', 'Age at last contact in EMR f/u(days)(from the date of diagnosis) ,last time patient known to be alive, unless age of death is reported(in such case the age of death',
    'Age at mammo (days)', 'Days to distant recurrence(from the date of diagnosis)', 'Days to local recurrence (from the date of diagnosis)',
    'Days to death (from the date of diagnosis)', 'Days to last local recurrence free assessment (from the date of diagnosis)', 
    ]
  categorical_columns = list(set(X.columns) - set(cont_columns))

  # convert categorical columns to str type
  X[categorical_columns] = X[categorical_columns].astype(str)

  cat_idxs = [X.columns.get_loc(c) for c in categorical_columns]
  con_idxs = [X.columns.get_loc(c) for c in cont_columns]
  X["Set"] = np.random.choice(["train", "valid", "test"], p = [.65, .15, .2], size=(X.shape[0],))

  train_indices = X[X.Set=="train"].index
  valid_indices = X[X.Set=="valid"].index
  test_indices = X[X.Set=="test"].index

  X = X.drop(columns=['Set'])
  temp = X.fillna("MissingValue")
#   creates a bert style mask for the missing values
  nan_mask = temp.ne("MissingValue").astype(int)

  cat_dims = []
  for col in categorical_columns:
      X[col] = X[col].fillna("MissingValue")
      l_enc = LabelEncoder()
      X[col] = l_enc.fit_transform(X[col].values)
      cat_dims.append(len(l_enc.classes_))

  for col in cont_columns:
      X[col] = pd.to_numeric(X[col], errors='coerce')
      X.fillna(X.loc[train_indices, col].mean(), inplace=True)
  y = y.values
  l_enc = LabelEncoder()
  y = l_enc.fit_transform(y)
  X_train, y_train = data_split(X,y,nan_mask,train_indices)
  X_valid, y_valid = data_split(X,y,nan_mask,valid_indices)
  X_test, y_test = data_split(X,y,nan_mask,test_indices)
  train_mean, train_std = np.array(X_train['data'][:,con_idxs],dtype=np.float32).mean(0), np.array(X_train['data'][:,con_idxs],dtype=np.float32).std(0)
  train_std = np.where(train_std < 1e-6, 1e-6, train_std)
  continuous_mean_std = np.array([train_mean,train_std]).astype(np.float32)
  train_ds = DataSetCatCon(X_train, y_train, cat_idxs,'clf',continuous_mean_std)
  trainloader = DataLoader(train_ds, batch_size=64, shuffle=True,num_workers=1)

  valid_ds = DataSetCatCon(X_valid, y_valid, cat_idxs,'clf', continuous_mean_std)
  validloader = DataLoader(valid_ds, batch_size=64, shuffle=False,num_workers=1)

  test_ds = DataSetCatCon(X_test, y_test, cat_idxs,'clf', continuous_mean_std)
  testloader = DataLoader(test_ds, batch_size=64, shuffle=False,num_workers=1)
  y_dim = len(np.unique(y_train['data'][:,0]))

  cat_dims = np.append(np.array([1]),np.array(cat_dims)).astype(int) #Appending 1 for CLS token, this is later used to generate embeddings.

  return trainloader, validloader, testloader, cat_dims, con_idxs , cat_idxs, y_dim , continuous_mean_std , X_train, y_train, X_valid, y_valid, X_test, y_test

In [6]:
trainloader, validloader, testloader, cat_dims, con_idxs , cat_idxs, y_dim , continuous_mean_std, X_train, y_train, X_valid, y_valid, X_test, y_test = prepare_dataset(df)

In [7]:
from models import SAINT


In [8]:
import torch
from torch import nn
from models import SAINT

from torch.utils.data import DataLoader
import torch.optim as optim
from utils import count_parameters, classification_scores, mean_sq_error
from augmentations import embed_data_mask
from augmentations import add_noise

In [9]:
from pretraining import SAINT_pretrain

In [10]:
opt_dict = {
    'd_task': 'clf',
    'dtask': 'clf',
    'task': 'multiclass',
    'batchsize': 32,
    'pt_aug': ['mixup', 'cutmix'],
    'pt_aug_lam': 0.1,
    'pretrain_epochs': 80, #50
    'nce_temp': 0.7,
    'lam0': 0.5,
    'lam1': 10,
    'lam2': 1,
    'lam3': 10,
    'pt_projhead_style': 'diff',
    'pt_tasks': ['contrastive','denoising'],
    'mixup_lam': 0.3,
    'ssl_samples': 312,
    'lr':0.0001,
    'train_noise_type':None,
    'train_noise_level':0,
}

class AttributeDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

opt = AttributeDict(opt_dict)

In [11]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [12]:
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
model = SAINT(
categories = tuple(cat_dims),
num_continuous = len(con_idxs),
dim = 32,              # embedding dimension
dim_out = 1,
depth = 1,             # depth of the network (nr. of transformer blocks)
heads = 8,             # number of attention heads
attn_dropout = 0.1,
ff_dropout = 0.1,
mlp_hidden_mults = (4, 2),
cont_embeddings = 'MLP', # options: 'MLP', 'linear', 'hybrid' (MLP with continuous embeddings concatenated to the transformer block outputs)
attentiontype = 'colrow', # options: 'col', 'row', 'colrow', 'colrowv2'
final_mlp_style = 'sep',
y_dim = y_dim
)
model.to(device)
model = SAINT_pretrain(model, cat_idxs,X_train,y_train, continuous_mean_std, opt, device=device)

### Finetune

In [None]:
df2 = pd.read_csv('../data/clinical_and_other_features_filtered.csv')

In [None]:
df.shape == df2.shape

In [None]:
trainloader, validloader, testloader, cat_dims, con_idxs , cat_idxs, y_dim , continuous_mean_std, X_train, y_train, X_valid, y_valid, X_test, y_test = prepare_dataset(df2)

In [None]:
print('We are in semi-supervised learning case')

train_bsize = min(opt.ssl_samples//4,opt.batchsize)

train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask,continuous_mean_std)
trainloader = DataLoader(train_ds, batch_size=train_bsize, shuffle=True,num_workers=2)

In [None]:
import torch.optim as optim

In [None]:
optimizer = optim.AdamW(model.parameters(),lr=0.001, betas=(0.9,0.999))

In [None]:
device = 'cuda'

In [None]:
modelsave_path='outputs'

In [None]:
best_valid_auroc = 0
best_valid_accuracy = 0
best_test_auroc = 0
best_test_accuracy = 0
best_valid_rmse = 100000
print('Training begins now.')
for epoch in range(600):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        optimizer.zero_grad()
        # x_categ is the the categorical data, with y appended as last feature. x_cont has continuous data. cat_mask is an array of ones same shape as x_categ except for last column(corresponding to y's) set to 0s. con_mask is an array of ones same shape as x_cont.
        x_categ, x_cont, y_gts, cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device)
        if opt.train_noise_type is not None and opt.train_noise_level>0:
            noise_dict = {
                'noise_type' : opt.train_noise_type,
                'lambda' : opt.train_noise_level
            }
            if opt.train_noise_type == 'cutmix':
                x_categ, x_cont = add_noise(x_categ,x_cont, noise_params = noise_dict)
            elif opt.train_noise_type == 'missing':
                cat_mask, con_mask = add_noise(cat_mask, con_mask, noise_params = noise_dict)
        # We are converting the data to embeddings in the next step
        _ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model)
        reps = model.transformer(x_categ_enc, x_cont_enc)
        # select only the representations corresponding to y and apply mlp on it in the next step to get the predictions.
        y_reps = reps[:,0,:]

        y_outs = model.mlpfory(y_reps)
        if opt.task == 'regression':
            loss = criterion(y_outs,y_gts)
        else:
            loss = criterion(y_outs,y_gts.squeeze())
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(running_loss)
    if epoch%5==0:
            model.eval()
            with torch.no_grad():
                if opt.task in ['binary','multiclass']:
                    accuracy, auroc = classification_scores(model, validloader, device, opt.task)
                    test_accuracy, test_auroc = classification_scores(model, testloader, device, opt.task)

                    print('[EPOCH %d] VALID ACCURACY: %.3f' %
                        (epoch + 1, accuracy ))
                    print('[EPOCH %d] TEST ACCURACY: %.3f' %
                        (epoch + 1, test_accuracy ))

                    if opt.task =='multiclass':
                        if accuracy > best_valid_accuracy:
                            best_valid_accuracy = accuracy
                            best_test_auroc = test_auroc
                            best_test_accuracy = test_accuracy
                            torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
                    else:
                        if auroc > best_valid_auroc:
                            best_valid_auroc = auroc
                            best_test_auroc = test_auroc
                            best_test_accuracy = test_accuracy
                            torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))

                else:
                    valid_rmse = mean_sq_error(model, validloader, device)
                    test_rmse = mean_sq_error(model, testloader, device)
                    print('[EPOCH %d] VALID RMSE: %.3f' %
                        (epoch + 1, valid_rmse ))
                    print('[EPOCH %d] TEST RMSE: %.3f' %
                        (epoch + 1, test_rmse ))
                    if valid_rmse < best_valid_rmse:
                        best_valid_rmse = valid_rmse
                        best_test_rmse = test_rmse
                        torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
            model.train()



total_parameters = count_parameters(model)
print('TOTAL NUMBER OF PARAMS: %d' %(total_parameters))
print('Accuracy on best model:  %.3f' %(best_test_accuracy))

### without pre train

In [None]:
model = SAINT(
categories = tuple(cat_dims),
num_continuous = len(con_idxs),
dim = 32,              # embedding dimension
dim_out = 1,
depth = 1,             # depth of the network (nr. of transformer blocks)
heads = 8,             # number of attention heads
attn_dropout = 0.1,
ff_dropout = 0.1,
mlp_hidden_mults = (4, 2),
cont_embeddings = 'MLP', # options: 'MLP', 'linear', 'hybrid' (MLP with continuous embeddings concatenated to the transformer block outputs)
attentiontype = 'colrow', # options: 'col', 'row', 'colrow', 'colrowv2'
final_mlp_style = 'sep',
y_dim = y_dim
)
model.to('cuda')

In [None]:
best_valid_auroc = 0
best_valid_accuracy = 0
best_test_auroc = 0
best_test_accuracy = 0
best_valid_rmse = 100000
print('Training begins now.')
for epoch in range(600):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        optimizer.zero_grad()
        # x_categ is the the categorical data, with y appended as last feature. x_cont has continuous data. cat_mask is an array of ones same shape as x_categ except for last column(corresponding to y's) set to 0s. con_mask is an array of ones same shape as x_cont.
        x_categ, x_cont, y_gts, cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device)
        if opt.train_noise_type is not None and opt.train_noise_level>0:
            noise_dict = {
                'noise_type' : opt.train_noise_type,
                'lambda' : opt.train_noise_level
            }
            if opt.train_noise_type == 'cutmix':
                x_categ, x_cont = add_noise(x_categ,x_cont, noise_params = noise_dict)
            elif opt.train_noise_type == 'missing':
                cat_mask, con_mask = add_noise(cat_mask, con_mask, noise_params = noise_dict)
        # We are converting the data to embeddings in the next step
        _ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model)
        reps = model.transformer(x_categ_enc, x_cont_enc)
        # select only the representations corresponding to y and apply mlp on it in the next step to get the predictions.
        y_reps = reps[:,0,:]

        y_outs = model.mlpfory(y_reps)
        if opt.task == 'regression':
            loss = criterion(y_outs,y_gts)
        else:
            loss = criterion(y_outs,y_gts.squeeze())
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(running_loss)
    if epoch%5==0:
            model.eval()
            with torch.no_grad():
                if opt.task in ['binary','multiclass']:
                    accuracy, auroc = classification_scores(model, validloader, device, opt.task)
                    test_accuracy, test_auroc = classification_scores(model, testloader, device, opt.task)

                    print('[EPOCH %d] VALID ACCURACY: %.3f' %
                        (epoch + 1, accuracy ))
                    print('[EPOCH %d] TEST ACCURACY: %.3f' %
                        (epoch + 1, test_accuracy ))

                    if opt.task =='multiclass':
                        if accuracy > best_valid_accuracy:
                            best_valid_accuracy = accuracy
                            best_test_auroc = test_auroc
                            best_test_accuracy = test_accuracy
                            torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
                    else:
                        if auroc > best_valid_auroc:
                            best_valid_auroc = auroc
                            best_test_auroc = test_auroc
                            best_test_accuracy = test_accuracy
                            torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))

                else:
                    valid_rmse = mean_sq_error(model, validloader, device)
                    test_rmse = mean_sq_error(model, testloader, device)
                    print('[EPOCH %d] VALID RMSE: %.3f' %
                        (epoch + 1, valid_rmse ))
                    print('[EPOCH %d] TEST RMSE: %.3f' %
                        (epoch + 1, test_rmse ))
                    if valid_rmse < best_valid_rmse:
                        best_valid_rmse = valid_rmse
                        best_test_rmse = test_rmse
                        torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
            model.train()



total_parameters = count_parameters(model)
print('TOTAL NUMBER OF PARAMS: %d' %(total_parameters))
print('Accuracy on best model:  %.3f' %(best_test_accuracy))