In [None]:
!pip install import_ipynb
!pip install -U -q PyDrive
!pip install pytorch_pretrained_bert
!pip install sparse
!pip install transformers
!pip install torchmetrics
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [None]:
!pip install einops

In [None]:
# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.

import torch
from torch_geometric.data import Data

import numpy as np
import sparse

import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as tgmnn
from torch_geometric.nn import global_mean_pool
from torch_geometric.loader import DataListLoader as GraphLoader
from torch_geometric.data import Batch

from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
import time
from sklearn import preprocessing
import math
from torch.utils.data import Dataset
import copy
import sklearn.metrics as skm
import pandas as pd
import random
from torch.utils.data.dataset import Dataset
import pytorch_pretrained_bert as Bert
import itertools
from einops import rearrange, repeat

In [None]:
import ast
from typing import Optional, Tuple, Union
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptTensor, PairTensor, SparseTensor
from torch_geometric.utils import softmax
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import LayerNorm
import torch.nn.functional as F
from torch import Tensor

class TransformerConv(MessagePassing):
    _alpha: OptTensor
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        heads: int = 1,
        concat: bool = True,
        beta: bool = False,
        dropout: float = 0.,
        edge_dim: Optional[int] = None,
        bias: bool = True,
        root_weight: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.beta = beta and root_weight
        self.root_weight = root_weight
        self.concat = concat
        self.dropout = dropout
        self.edge_dim = edge_dim
        self._alpha = None

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_key = Linear(in_channels[0], heads * out_channels)
        self.lin_query = Linear(in_channels[1], heads * out_channels)
        self.lin_value = Linear(in_channels[0], heads * out_channels)
        self.layernorm1 = LayerNorm(out_channels)
        self.layernorm2 = LayerNorm(out_channels)
        self.gelu = nn.GELU()
        self.proj = Linear(heads * out_channels, out_channels)
        self.ffn = Linear(out_channels, out_channels)
        self.ffn2 = Linear(out_channels, out_channels)
        if edge_dim is not None:
            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
        else:
            self.lin_edge = self.register_parameter('lin_edge', None)


        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.lin_key.reset_parameters()
        self.lin_query.reset_parameters()
        self.lin_value.reset_parameters()
        if self.edge_dim:
            self.lin_edge.reset_parameters()


    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, batch=None, return_attention_weights=None):
        # type: (Union[Tensor, PairTensor], Tensor, OptTensor, NoneType) -> Tensor  # noqa
        # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, NoneType) -> Tensor  # noqa
        # type: (Union[Tensor, PairTensor], Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]]  # noqa
        # type: (Union[Tensor, PairTensor], Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]]  # noqa
        # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, bool) -> Tuple[Tensor, SparseTensor]  # noqa
        r"""Runs the forward pass of the module.

        Args:
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(edge_index, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """

        H, C = self.heads, self.out_channels
        residual = x
        x = self.layernorm1(x, batch)
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)
        query = self.lin_query(x[1]).view(-1, H, C)
        key = self.lin_key(x[0]).view(-1, H, C)
        value = self.lin_value(x[0]).view(-1, H, C)
        # propagate_type: (query: Tensor, key:Tensor, value: Tensor, edge_attr: OptTensor) # noqa
        out = self.propagate(edge_index, query=query, key=key, value=value,
                             edge_attr=edge_attr, size=None)
        alpha = self._alpha
        self._alpha = None
        if self.concat:
            out = self.proj(out.view(-1, self.heads * self.out_channels))
        else:
            out = out.mean(dim=1)
        out = F.dropout(out, p=self.dropout, training=self.training)
        out = out+residual
        residual = out

        out = self.layernorm2(out)
        out = self.gelu(self.ffn(out))
        out = F.dropout(out, p=self.dropout, training=self.training)
        out = self.ffn2(out)
        out = F.dropout(out, p=self.dropout, training=self.training)
        out = out + residual
        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
                edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:


        if self.lin_edge is not None:
            assert edge_attr is not None
            edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
                                                      self.out_channels)
            key_j = key_j + edge_attr

        alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        out = value_j
        if edge_attr is not None:
            out = out + edge_attr

        out = out * alpha.view(-1, self.heads, 1)
        return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')


class GraphTransformer(torch.nn.Module):
    def __init__(self, config):
        super().__init__()

        self.conv = tgmnn.Sequential('x, edge_index, edge_attr, batch', [
            (TransformerConv(config.hidden_size // 5, config.hidden_size // 5, heads=2, edge_dim=config.hidden_size // 5, dropout=config.hidden_dropout_prob, concat=True), 'x, edge_index, edge_attr -> x'),
            nn.GELU(),
            (TransformerConv(config.hidden_size // 5, config.hidden_size // 5, heads=2, edge_dim=config.hidden_size // 5, dropout=config.hidden_dropout_prob, concat=True), 'x, edge_index, edge_attr -> x'),
            nn.GELU(),
            (TransformerConv(config.hidden_size // 5, config.hidden_size // 5, heads=2, edge_dim=config.hidden_size // 5, dropout=config.hidden_dropout_prob, concat=False), 'x, edge_index, edge_attr -> x'),
        ])

        self.embed = nn.Embedding(config.vocab_size, config.hidden_size // 5)
        self.embed_ee = nn.Embedding(7, config.hidden_size // 5)

    def forward(self, x, edge_index, edge_index_readout, edge_attr, batch):
        indices = (x==0).nonzero().squeeze()
        h_nodes = self.conv(self.embed(x), edge_index, self.embed_ee(edge_attr), batch)
        x = h_nodes[indices]
        return x



class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, segment, age
    """

    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        #self.word_embeddings = nn.Linear(config.vocab_size, config.hidden_size)
        self.word_embeddings = GraphTransformer(config)
        self.type_embeddings = nn.Embedding(11, config.hidden_size//5, padding_idx=0)

        self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size//5). \
            from_pretrained(embeddings=self._init_posi_embedding(config.age_vocab_size, config.hidden_size//5))

        self.time_embeddings = nn.Embedding(367, config.hidden_size//5). \
            from_pretrained(embeddings=self._init_posi_embedding(367, config.hidden_size//5))

        self.delta_embeddings = nn.Embedding(config.delta_size, config.hidden_size//5). \
            from_pretrained(embeddings=self._init_posi_embedding(config.delta_size, config.hidden_size//5))

        self.los_embeddings = nn.Embedding(1192, config.hidden_size//5). \
            from_pretrained(embeddings=self._init_posi_embedding(1192, config.hidden_size//5))

        self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size//5). \
            from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size//5))



        self.seq_layers = nn.Sequential(
            nn.LayerNorm(config.hidden_size),
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU()
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size)
        self.acti = nn.GELU()
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))

    def forward(self, nodes, edge_index, edge_index_readout, edge_attr, batch, age_ids, time_ids, delta_ids, type_ids, posi_ids, los):
        word_embed = self.word_embeddings(nodes, edge_index, edge_index_readout, edge_attr, batch)
        type_embeddings = self.type_embeddings(type_ids)
        age_embed = self.age_embeddings(age_ids)
        los_embed = self.los_embeddings(los)

        time_embeddings = self.time_embeddings(time_ids)
        delta_embeddings = self.delta_embeddings(delta_ids)
        posi_embeddings = self.posi_embeddings(posi_ids)


        word_embed = torch.reshape(word_embed, type_embeddings.shape)
        embeddings = torch.cat((word_embed, type_embeddings, posi_embeddings, age_embed, time_embeddings), dim=2)
        b, n, _ = embeddings.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)
        embeddings = self.seq_layers(embeddings)
        embeddings = self.LayerNorm(embeddings)

        return embeddings

    def _init_posi_embedding(self, max_position_embedding, hidden_size):
        def even_code(pos, idx):
            return np.sin(pos / (10000 ** (2 * idx / hidden_size)))

        def odd_code(pos, idx):
            return np.cos(pos / (10000 ** (2 * idx / hidden_size)))

        # initialize position embedding table
        lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)

        # reset table parameters with hard encoding
        # set even dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(0, hidden_size, step=2):
                lookup_table[pos, idx] = even_code(pos, idx)
        # set odd dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(1, hidden_size, step=2):
                lookup_table[pos, idx] = odd_code(pos, idx)

        return torch.tensor(lookup_table)

#%%

class BertModel(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config=config)
        self.encoder = Bert.modeling.BertEncoder(config=config)
        self.pooler = Bert.modeling.BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, nodes, edge_index, edge_index_readout, edge_attr, batch, age_ids, time_ids, delta_ids, type_ids, posi_ids, attention_mask=None, los=None,
                output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(age_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(nodes, edge_index, edge_index_readout, edge_attr, batch, age_ids, time_ids, delta_ids, type_ids, posi_ids, los)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output


#%%

class BertForMTR(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config):
        super(BertForMTR, self).__init__(config)
        self.num_labels = 1
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        #self.gru = nn.GRU(config.hidden_size, config.hidden_size // 2, 1, batch_first = True, bidirectional=True)
        #self.gru = nn.Linear(config.hidden_size * 50, config.hidden_size)
        self.classifier = nn.Linear(config.hidden_size, 1)
        self.relu = nn.ReLU()
        self.apply(self.init_bert_weights)
    def forward(self, nodes, edge_index, edge_index_readout, edge_attr, batch, age_ids, time_ids, delta_ids, type_ids, posi_ids, attention_mask=None, labels=None, masks=None, los=None):
        _, pooled_output = self.bert(nodes, edge_index, edge_index_readout, edge_attr, batch, age_ids, time_ids, delta_ids, type_ids, posi_ids, attention_mask,los,
                                     output_all_encoded_layers=False)
        #pooled_output = self.dropout(pooled_output)
        #pooled_output = pooled_output * attention_mask.unsqueeze(-1)
        #pooled_output = torch.sum(pooled_output, axis=1) / torch.sum(attention_mask, axis=1).unsqueeze(-1)
        #pooled_output = torch.mean(_, axis=1)
        #pooled_output, x = self.gru(pooled_output)
        #pooled_output = self.gru(torch.flatten(pooled_output, start_dim=1))
        #pooled_output = self.relu(self.dropout(pooled_output))
        logits = self.classifier(pooled_output).squeeze(dim=1)
        bce_logits_loss = nn.BCEWithLogitsLoss(reduction='mean')
        discr_supervised_loss = bce_logits_loss(logits, labels)

        return discr_supervised_loss, logits

#%%

class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings = config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        self.age_vocab_size = config.get('age_vocab_size')
        self.delta_size = config.get('delta_size')
        self.graph_dropout_prob = config.get('graph_dropout_prob')

class TrainConfig(object):
    def __init__(self, config):
        self.batch_size = config.get('batch_size')
        self.use_cuda = config.get('use_cuda')
        self.max_len_seq = config.get('max_len_seq')
        self.train_loader_workers = config.get('train_loader_workers')
        self.test_loader_workers = config.get('test_loader_workers')
        self.device = config.get('device')
        self.output_dir = config.get('output_dir')
        self.output_name = config.get('output_name')
        self.best_name = config.get('best_name')

#%%

class GDSet(Dataset):
    def __init__(self, g):
        self.g = g

    def __getitem__(self, index):

        g = self.g[index]
        for i in range(len(g)):
          g[i]['posi_ids'] = i
        return g

    def __len__(self):
        return len(self.g)

In [None]:
import pickle
with open('/content/drive/My Drive/GANBEHRT/final_data/new/data', 'rb') as handle:
    dataset = pickle.load(handle)

In [None]:
train_l = int(len(dataset)*0.70)
val_l = int(len(dataset)*0.10)
test_l = len(dataset) - val_l - train_l
number_output = 1

In [None]:
file_config = {
    'model_path': 'model/', # where to save model
    'model_name': 'CVDTransformer', # model name
    'file_name': 'log.txt',  # log path
}
#create_folder(file_config['model_path'])

global_params = {
    'max_seq_len': 50,
    'month': 1,
    'gradient_accumulation_steps': 1
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

train_params = {
    'batch_size': 64,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'device': "cuda" if torch.cuda.is_available() else "cpu",
    'data_len' : len(dataset),
    'train_data_len' : train_l,
    'val_data_len' : val_l,
    'test_data_len' : test_l,
    'epochs' : 30,
    'action' : 'train'
}

model_config = {
    'vocab_size': 7204, # number of disease + symbols for word embedding
    'hidden_size': 108*5, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': 103, # number of vocab for age embedding
    'delta_size': 144, # number of vocab for age embedding
    'gender_vocab_size': 2,
    'ethnicity_vocab_size': 2,
    'race_vocab_size': 6,
    'num_labels':1,
    'feature_dict':7204,
    'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.2, # dropout rate
    'graph_dropout_prob': 0.2, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 12, # number of attention heads
    'attention_probs_dropout_prob': 0.2, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
    'number_output' : number_output,
    'n_layers' : 3 - 1,
    'alpha' : 0.1
}

In [None]:
from sklearn.model_selection import ShuffleSplit
rr=1
rs = ShuffleSplit(n_splits=1, test_size=.20, random_state=rr)

k = 5
few_shots = 0.05

for i, (train_index_tmp, test_index) in enumerate(rs.split(dataset)):
  rs2 = ShuffleSplit(n_splits=1, test_size=.125, random_state=rr)
  for j, (train_index, val_index) in enumerate(rs2.split(train_index_tmp)):
    train_index = train_index_tmp[train_index]
    if few_shots < 1:
      train_index = random.sample(list(train_index), int(len(train_index) * few_shots))
    val_index = train_index_tmp[val_index]

    trainDSet = [dataset[x] for x in train_index]
    valDSet = [dataset[x] for x in val_index]
    testDSet = [dataset[x] for x in test_index]




In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
conf = BertConfig(model_config)
behrt = BertForMTR(conf)

behrt = behrt.to(train_params['device'])

#models parameters
transformer_vars = [i for i in behrt.parameters()]

#optimizer
import transformers
optim_behrt = torch.optim.AdamW(transformer_vars, lr=3e-5)
#sched = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optim_behrt, 1000, 500*train_params['epochs'], 4, -1)

In [None]:
def run_epoch(e, trainload, device):
    tr_loss = 0
    start = time.time()
    behrt.train()
    for step, data in enumerate(trainload):
        optim_behrt.zero_grad()

        batched_data = Batch()
        graph_batch = batched_data.from_data_list(list(itertools.chain.from_iterable(data)))
        graph_batch = graph_batch.to(device)
        nodes = graph_batch.x
        edge_index = graph_batch.edge_index
        edge_index_readout = graph_batch.edge_index
        edge_attr = graph_batch.edge_attr
        batch = graph_batch.batch
        age_ids = torch.reshape(graph_batch.age, [graph_batch.age.shape[0] // 50, 50])
        time_ids = torch.reshape(graph_batch.time, [graph_batch.time.shape[0] // 50, 50])
        delta_ids = torch.reshape(graph_batch.delta, [graph_batch.delta.shape[0] // 50, 50])
        type_ids = torch.reshape(graph_batch.adm_type, [graph_batch.adm_type.shape[0] // 50, 50])
        posi_ids = torch.reshape(graph_batch.posi_ids, [graph_batch.posi_ids.shape[0] // 50, 50])
        attMask = torch.reshape(graph_batch.mask_v, [graph_batch.mask_v.shape[0] // 50, 50])
        attMask = torch.cat((torch.ones((attMask.shape[0], 1)).to(device), attMask), dim=1)
        los = torch.reshape(graph_batch.los, [graph_batch.los.shape[0] // 50, 50])

        labels = torch.reshape(graph_batch.label, [graph_batch.label.shape[0] // 50, 50])[:, 0].float()
        masks = torch.reshape(graph_batch.mask, [graph_batch.mask.shape[0] // 50, 50])[:, 0]
        loss, logits = behrt(nodes, edge_index, edge_index_readout, edge_attr, batch, age_ids, time_ids,delta_ids,type_ids,posi_ids,attMask, labels, masks, los)

        if global_params['gradient_accumulation_steps'] >1:
            loss = loss/global_params['gradient_accumulation_steps']
        loss.backward()
        tr_loss += loss.item()
        if step%500 == 0:
            print(loss.item())
        optim_behrt.step()
        #sched.step()
        del loss
        #result = result + torch.sum(torch.sum(torch.mul(torch.abs(torch.subtract(pred, label)), target_mask), dim = 0)).cpu()
        #sum_labels = sum_labels + torch.sum(target_mask, dim=0).cpu()
    #print(result / sum_labels)
    cost = time.time() - start
    return tr_loss, cost
#%%

def train(trainload, valload, device):
    with open("v_behrt_log_train.txt", 'w') as f:
            f.write('')
    best_val = math.inf
    for e in range(train_params["epochs"]):
        print("Epoch n" + str(e))
        train_loss, train_time_cost = run_epoch(e, trainload, device)
        val_loss, val_time_cost,pred, label, mask = eval(valload, False, device)
        train_loss = (train_loss * train_params['batch_size']) / len(trainload)
        val_loss = (val_loss * train_params['batch_size']) / len(valload)
        print('TRAIN {}\t{} secs\n'.format(train_loss, train_time_cost))
        with open("v_behrt_log_train.txt", 'a') as f:
            f.write("Epoch n" + str(e) + '\n TRAIN {}\t{} secs\n'.format(train_loss, train_time_cost))
            f.write('EVAL {}\t{} secs\n'.format(val_loss, val_time_cost) + '\n\n\n')
        print('EVAL {}\t{} secs\n'.format(val_loss, val_time_cost))
        if val_loss < best_val:
            print("** ** * Saving fine - tuned model ** ** * ")
            model_to_save = behrt.module if hasattr(behrt, 'module') else behrt
            save_model(model_to_save.state_dict(), '/content/drive/My Drive/GANBEHRT/models/v_behrt')
            best_val = val_loss
    return train_loss, val_loss


#%%

def eval(_valload, saving, device):
    tr_loss = 0
    tr_g_loss = 0
    tr_d_un = 0
    tr_d_sup = 0
    temp_loss = 0
    start = time.time()
    behrt.eval()
    if saving:
        with open("/content/drive/My Drive/GANBEHRT/preds/v_behrt_preds.csv", 'w') as f:
            f.write('')
        with open("/content/drive/My Drive/GANBEHRT/preds/v_behrt_labels.csv", 'w') as f:
            f.write('')
        with open("/content/drive/My Drive/GANBEHRT/preds/v_behrt_masks.csv", 'w') as f:
            f.write('')
    for step, data in enumerate(_valload):
        optim_behrt.zero_grad()

        batched_data = Batch()
        graph_batch = batched_data.from_data_list(list(itertools.chain.from_iterable(data)))
        graph_batch = graph_batch.to(device)
        nodes = graph_batch.x
        edge_index = graph_batch.edge_index
        edge_index_readout = graph_batch.edge_index
        edge_attr = graph_batch.edge_attr
        batch = graph_batch.batch
        age_ids = torch.reshape(graph_batch.age, [graph_batch.age.shape[0] // 50, 50])
        time_ids = torch.reshape(graph_batch.time, [graph_batch.time.shape[0] // 50, 50])
        delta_ids = torch.reshape(graph_batch.delta, [graph_batch.delta.shape[0] // 50, 50])
        type_ids = torch.reshape(graph_batch.adm_type, [graph_batch.adm_type.shape[0] // 50, 50])
        posi_ids = torch.reshape(graph_batch.posi_ids, [graph_batch.posi_ids.shape[0] // 50, 50])
        attMask = torch.reshape(graph_batch.mask_v, [graph_batch.mask_v.shape[0] // 50, 50])
        attMask = torch.cat((torch.ones((attMask.shape[0], 1)).to(device), attMask), dim=1)
        los = torch.reshape(graph_batch.los, [graph_batch.los.shape[0] // 50, 50])

        labels = torch.reshape(graph_batch.label, [graph_batch.label.shape[0] // 50, 50])[:, 0].float()
        masks = torch.reshape(graph_batch.mask, [graph_batch.mask.shape[0] // 50, 50])[:, 0]
        loss, logits = behrt(nodes, edge_index, edge_index_readout, edge_attr, batch, age_ids, time_ids,delta_ids,type_ids,posi_ids,attMask, labels, masks, los)

        if saving:
            with open("/content/drive/My Drive/GANBEHRT/preds/v_behrt_preds.csv", 'a') as f:
                pd.DataFrame(logits.detach().cpu().numpy()).to_csv(f, header=False)
            with open("/content/drive/My Drive/GANBEHRT/preds/v_behrt_labels.csv", 'a') as f:
                pd.DataFrame(labels.detach().cpu().numpy()).to_csv(f, header=False)
            with open("/content/drive/My Drive/GANBEHRT/preds/v_behrt_masks.csv", 'a') as f:
                pd.DataFrame(masks.detach().cpu().numpy()).to_csv(f, header=False)

        tr_loss += loss.item()
        del loss

    print("TOTAL LOSS", (tr_loss * train_params['batch_size']) / len(_valload))

    cost = time.time() - start
    return tr_loss, cost, logits, labels, masks

#%%

def save_model(_model_dict, file_name):
    torch.save(_model_dict, file_name)

In [None]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters())
count_parameters(behrt)

In [None]:
pretrained_dict = torch.load("/content/drive/My Drive/GANBEHRT/models/v_behrt_pre_graph_nam_vtpmnp_7k", map_location=train_params['device'])
model_dict = behrt.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
behrt.load_state_dict(model_dict)

In [None]:
print(train_params['max_len_seq'])
if train_params['action'] == 'train' or train_params['action'] == 'resume':
    trainload = GraphLoader(GDSet(trainDSet), batch_size=train_params['batch_size'], shuffle=False)
    valload = GraphLoader(GDSet(valDSet), batch_size=train_params['batch_size'], shuffle=False)

    train_loss, val_loss = train(trainload, valload, train_params['device'])