In [1]:
import os
import torch
import yaml
import wandb
import argparse
import pandas as pd
import time

from datetime import datetime
from pathlib import Path

# user-defined modules
from trainers import BertMLMTrainer

# user-defined functions
from construct_vocab import construct_pheno_vocab
from utils import get_split_indices
from data_preprocessing import preprocess_TESSy

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
os.environ['WANDB_MODE'] = 'disabled' # 'dryrun' or 'run' or 'offline' or 'disabled' or 'online'
if device.type == "cuda":
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    torch.cuda.empty_cache()
else:
    print("Using CPU") 
    
BASE_DIR = Path(os.path.abspath('')).parent
RESULTS_DIR = Path(os.path.join(BASE_DIR / "results" / "temp"))
os.chdir(BASE_DIR)
print("base directory:", BASE_DIR)

config_path = BASE_DIR / "config_pheno.yaml"
with open(config_path, "r") as config_file:
    config = yaml.safe_load(config_file)

Using GPU: NVIDIA GeForce RTX 3080
base directory: C:\Users\jespe\Documents\GitHub_local\ARFusion


## Model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class JointEmbedding(nn.Module):
    
    def __init__(self, config, vocab_size):
        super(JointEmbedding, self).__init__()
        
        self.emb_dim = config['emb_dim']
        self.vocab_size = vocab_size
        self.dropout_prob = config['dropout_prob']
        
        self.token_emb = nn.Embedding(self.vocab_size, self.emb_dim) 
        # self.token_type_emb = nn.Embedding(self.vocab_size, self.emb_dim) 
        self.position_emb = nn.Embedding(self.vocab_size, self.emb_dim) 
        
        self.dropout = nn.Dropout(self.dropout_prob)
        self.layer_norm = nn.LayerNorm(self.emb_dim)
        
    def forward(self, input_tensor):
        # input_tensor: (batch_size, seq_len)
        # token_type_ids: (batch_size, seq_len)
        # position_ids: (batch_size, seq_len)
        
        seq_len = input_tensor.size(-1)
        
        pos_tensor = self.numeric_position(seq_len, input_tensor)
        # token_type not relevant for unimodal data
        # token_type_tensor = torch.zeros_like(input_tensor).to(device) # (batch_size, seq_len)
        # token_type_tensor[:, (seq_len//2 + 1):] = 1 # here, we assume that the sentence is split in half
        
        token_emb = self.token_emb(input_tensor) # (batch_size, seq_len, emb_dim)
        # token_type_emb = self.token_type_emb(token_type_tensor) # (batch_size, seq_len, emb_dim)
        position_emb = self.position_emb(pos_tensor) # (batch_size, seq_len, emb_dim)
        
        # emb = token_emb + token_type_emb + position_emb
        emb = token_emb + position_emb
        emb = self.dropout(emb)
        emb = self.layer_norm(emb) 
        return emb
                
    def numeric_position(self, dim, input_tensor): # input_tensor: (batch_size, seq_len)
        # dim is the length of the sequence
        position_ids = torch.arange(dim, dtype=torch.long, device=device) # create tensor of [0, 1, 2, ..., dim-1]
        return position_ids.expand_as(input_tensor) # expand to (batch_size, seq_len)
    
################################################################################################################
################################################################################################################
################################################################################################################

class MultiHeadAttention(nn.Module):
    
    def __init__(self, config):
        super(MultiHeadAttention, self).__init__()
        
        self.emb_dim = config['emb_dim']
        self.num_heads = config['num_heads']
        self.dropout_prob = config['dropout_prob']
        
        self.head_dim = self.emb_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.emb_dim, f"Embedding dimension must be divisible by number of heads, got {self.emb_dim} and {self.num_heads}"
        
        self.q = nn.Linear(self.emb_dim, self.emb_dim)
        self.k = nn.Linear(self.emb_dim, self.emb_dim)
        self.v = nn.Linear(self.emb_dim, self.emb_dim)
        
        self.dropout = nn.Dropout(self.dropout_prob)
    
    def forward(self, input_emb: torch.Tensor, attn_mask:torch.Tensor = None):
        B, L, D = input_emb.size() # (L=batch_size, L=seq_len, D=emb_dim)
        
        # project input embeddings to query, key, value, then split into num_heads, reducing the embedding dimension
        query = self.q(input_emb).view(B, L, self.num_heads, self.head_dim).transpose(1,2) # (B, num_heads, L, head_dim)
        key = self.k(input_emb).view(B, L, self.num_heads, self.head_dim).transpose(1,2) # (B, num_heads, L, head_dim)
        value = self.v(input_emb).view(B, L, self.num_heads, self.head_dim).transpose(1,2) # (B, num_heads, L, head_dim)
        
        scale_factor = query.size(-1) ** 0.5
        attn_scores = torch.matmul(query, key.transpose(-1, -2)) / scale_factor # (B, num_heads, L, L)
        
        attn_scores = attn_scores.masked_fill_(~attn_mask, -1e9) if attn_mask is not None else attn_scores 
        attn_weights = F.softmax(attn_scores, dim=-1) # (B, num_heads, L, L)
        attn_weights = self.dropout(attn_weights)
        
        attn = torch.matmul(attn_weights, value) # (B, num_heads, L, head_dim)
        attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, num_heads, head_dim) -> (B, L, D), concatenate the heads
        
        return attn
        

class EncoderLayer(nn.Module):
    def __init__(self, config):
        super(EncoderLayer, self).__init__()
        
        self.emb_dim = config['emb_dim']
        self.num_heads = config['num_heads']
        self.hidden_dim = config['hidden_dim']
        self.dropout_prob = config['dropout_prob']
        
        self.attention = MultiHeadAttention(config)
        self.feed_forward = nn.Sequential(
            nn.Linear(self.emb_dim, self.hidden_dim),
            nn.Dropout(self.dropout_prob),
            nn.GELU(),
            nn.Linear(self.hidden_dim, self.emb_dim),
            nn.Dropout(self.dropout_prob)
        )
        self.layer_norm = nn.LayerNorm(self.emb_dim)
        
    def forward(self, input_emb: torch.Tensor, attn_mask: torch.Tensor = None):
        x = input_emb
        attn = self.attention(x, attn_mask)
        x = x + attn
        x = self.layer_norm(x)
        res = x
        x = self.feed_forward(x)
        x = x + res
        x = self.layer_norm(x)
        
        return x

class PhenoBERT(nn.Module):
    
    def __init__(self, config, vocab_size, antibiotics):
        super(PhenoBERT, self).__init__()
                
        self.vocab_size = vocab_size
        
        self.emb_dim = config['emb_dim']
        self.vocab_size = vocab_size
        self.max_seq_len = None # Can be set later
        self.num_heads = config['num_heads']
        self.num_layers = config['num_layers']
        self.hidden_dim = config['hidden_dim']
        self.dropout_prob = config['dropout_prob']
        
        self.embedding = JointEmbedding(config, vocab_size)
        self.encoder = nn.ModuleList([EncoderLayer(config) for _ in range(self.num_layers)])
        
        self.token_prediction_layer = nn.Linear(self.emb_dim, self.vocab_size) # MLM task
        self.softmax = nn.LogSoftmax(dim=-1) # log softmax improves numerical stability, we use NLLLoss later
        if antibiotics:
            self.classification_layer = [AbPredictor(self.emb_dim).to(device) for _ in range(len(antibiotics))] # classification task
        
    def forward(self, input_tensor, attn_mask): 
        embedded = self.embedding(input_tensor)
        for layer in self.encoder:
            embedded = layer(embedded, attn_mask)
        encoded = embedded # ouput of the BERT Encoder
        
        if self.classification_layer: # ASSUMES MLM AND CLASSIFICATION ARE NOT DONE AT THE SAME TIME
            cls_token = encoded[:, 0, :] # (batch_size, emb_dim)
            predictions = torch.cat([net(cls_token) for net in self.classification_layer], dim=1) # (batch_size, num_ab)
            return predictions
        else:
            token_prediction = self.token_prediction_layer(encoded) # (batch_size, seq_len, vocab_size)
            return self.softmax(token_prediction)

################################################################################################################

class AbPredictor(nn.Module): # predicts resistance or susceptibility for an antibiotic
    def __init__(self, hidden_dim):
        super(AbPredictor, self).__init__()
        
        self.hidden_dim = hidden_dim
        
        self.classifier = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(self.hidden_dim),
            nn.Linear(self.hidden_dim, 1), # binary classification
        )
        # self.classifiers = nn.ModuleList(
        #     [nn.Sequential(
        #         nn.Linear(self.hidden_dim, self.hidden_dim),
        #         nn.ReLU(),
        #         nn.LayerNorm(self.hidden_dim),
        #         nn.Linear(self.hidden_dim, 2), # one value for R and one for S
        #     ) for _ in range(self.num_ab)]
        # )
    
    def forward(self, X):
        # X is the CLS token of the BERT model
        return self.classifier(X)

## Dataset

In [4]:
from copy import deepcopy
from torch.utils.data import Dataset
import numpy as np

class PhenotypeDataset(Dataset):      
    # df column names
    INDICES_MASKED = 'indices_masked' # input to BERT, token indices of the masked sequence
    TARGET_RESISTANCES = 'target_resistances' # resistance of the target antibiotics, what we want to predict
    TOKEN_MASK = 'token_mask' # True if token is masked, False otherwise
    AB_MASK = 'ab_mask' # True if antibiotic is masked, False otherwise
    # # if original text is included
    ORIGINAL_SEQUENCE = 'original_sequence'
    MASKED_SEQUENCE = 'masked_sequence'
    
    def __init__(self,
                 ds: pd.DataFrame,
                 vocab,
                 antibiotics: list,
                 specials: dict,
                 max_seq_len: int,
                 base_dir: Path,
                 include_sequences: bool = False,
                 random_state: int = 42,
                 ):
        
        os.chdir(base_dir)
        self.random_state = random_state
        np.random.seed(self.random_state)
        
        self.ds = ds.reset_index(drop=True) 
        self.num_samples = self.ds.shape[0]
        self.vocab = vocab
        self.antibiotics = antibiotics
        self.num_ab = len(self.antibiotics)
        self.ab_to_idx = {ab: i for i, ab in enumerate(self.antibiotics)}
        self.enc_res = {'S': 0, 'R': 1}
        self.vocab_size = len(self.vocab)
        self.CLS = specials['CLS']
        self.PAD = specials['PAD']
        self.MASK = specials['MASK']
        self.UNK = specials['UNK']
        self.special_tokens = specials.values()
        self.max_seq_len = max_seq_len
           
        self.include_sequences = include_sequences
        if self.include_sequences:
            self.columns = [self.INDICES_MASKED, self.TARGET_RESISTANCES, self.TOKEN_MASK, self.AB_MASK,
                            self.ORIGINAL_SEQUENCE, self.MASKED_SEQUENCE]
        else: 
            self.columns = [self.INDICES_MASKED, self.TARGET_RESISTANCES, self.TOKEN_MASK, self.AB_MASK]        
        
    def __len__(self):
        return self.num_samples
    
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        
        input = torch.tensor(item[self.INDICES_MASKED], dtype=torch.long, device=device)
        target_res = torch.tensor(item[self.TARGET_RESISTANCES], dtype=torch.float32, device=device)
        token_mask = torch.tensor(item[self.TOKEN_MASK], dtype=torch.bool, device=device)
        ab_mask = torch.tensor(item[self.AB_MASK], dtype=torch.bool, device=device)
        attn_mask = (input != self.vocab[self.PAD]).unsqueeze(0).unsqueeze(1) # one dim for batch, one for heads
        
        if self.include_sequences:
            original_sequence = item[self.ORIGINAL_SEQUENCE]
            masked_sequence = item[self.MASKED_SEQUENCE]
            return input, target_res, token_mask, ab_mask, attn_mask, original_sequence, masked_sequence
        else:
            return input, target_res, token_mask, ab_mask, attn_mask

       
    def prepare_dataset(self, mask_prob: float = 0.15): # will be called at the start of each epoch (dynamic masking)
        sequences, masked_sequences, target_resistances, token_masks, ab_masks = self._construct_masked_sequences(mask_prob)
        
        indices_masked = [self.vocab.lookup_indices(masked_seq) for masked_seq in masked_sequences]
        
        if self.include_sequences:
            rows = zip(indices_masked, target_resistances, token_masks, ab_masks, sequences, masked_sequences)
        else:
            rows = zip(indices_masked, target_resistances, token_masks, ab_masks)
        self.df = pd.DataFrame(rows, columns=self.columns)

    
    def _encode_sequence(self, seq: list):
        dict = {ab: res for ab, res in [token.split('_') for token in seq]}
        indices = [self.ab_to_idx[ab] for ab in dict.keys()]
        resistances = [self.enc_res[res] for res in dict.values()]
        
        return indices, resistances
    
    
    def _construct_masked_sequences(self, mask_prob: float):  
        # RoBERTa: 80% -> [MASK], 10% -> original token, 10% -> random token
        self.mask_prob = mask_prob
        sequences = deepcopy(self.ds['phenotypes'].tolist())
        masked_sequences = list()
        # all_target_indices = list()
        all_target_resistances = list()
        ab_masks = list() # will be applied to the output of the model, i.e. (batch_size, num_ab)
        token_masks = list() # will be applied to the the sequence itself, i.e. (batch_size, seq_len)
        for seq in deepcopy(sequences):
            seq_len = len(seq)
            # target_indices, target_resistances = self._encode_sequence(seq) # if we don't want to indicate masking here, we 
            # all_target_indices.append(target_indices)                       # encode the whole sequence, and use a token mask
            # all_target_resistances.append(target_resistances)
            
            token_mask = [False] * seq_len # indicates which tokens in the sequence are masked, includes all tokens
            ab_mask = [False] * self.num_ab # will indicate which antibiotics are masked, indexed in the order of self.antibiotics
            target_resistances = [-1]*self.num_ab # -1 indicates padding, will indicate the target resistance, same indexing as ab_mask
            tokens_masked = 0
            for i in range(seq_len):
                if np.random.rand() < self.mask_prob: 
                    ab, res = seq[i].split('_')
                    ab_idx = self.ab_to_idx[ab]
                    tokens_masked += 1
                    r = np.random.rand()
                    if r < 0.8: 
                        seq[i] = self.MASK
                    elif r < 0.9:
                        j = np.random.randint(self.vocab_size-self.num_ab*2, self.vocab_size) # select random pheno token
                        seq[i] = self.vocab.lookup_token(j)
                    # else: do nothing, since r > 0.9 and we keep the same token
                    token_mask[i] = True
                    ab_mask[ab_idx] = True # indicate which antibiotic is masked at this position
                    target_resistances[ab_idx] = self.enc_res[res] # the target resistance of the antibiotic
            if tokens_masked == 0: # mask at least one token
                i = np.random.randint(seq_len)
                ab, res = seq[i].split('_')
                ab_idx = self.ab_to_idx[ab]
                r = np.random.rand()
                if r < 0.8: 
                    seq[i] = self.MASK
                elif r < 0.9:
                    j = np.random.randint(self.vocab_size-self.num_ab*2, self.vocab_size) # select random token, excluding specials
                    seq[i] = self.vocab.lookup_token(j)
                # else: do nothing, since r > 0.9 and we keep the same token
                token_mask[i] = True
                ab_mask[ab_idx] = True # indicate which antibiotic is masked at this position
                target_resistances[ab_idx] = self.enc_res[res] # the target resistance of the antibiotic
                
            masked_sequences.append(seq)
            token_masks.append(token_mask)
            ab_masks.append(ab_mask)
            all_target_resistances.append(target_resistances)
        
        for i in range(len(sequences)):
            token_masks[i] = 5*[False] + token_masks[i]
            seq_start = [self.CLS, 
                         str(self.ds['year'].iloc[i]), 
                         self.ds['country'].iloc[i], 
                         self.ds['gender'].iloc[i], 
                         str(int(self.ds['age'].iloc[i]))]
            
            sequences[i][:0] = seq_start
            masked_sequences[i][:0] = seq_start
            # all_target_indices[i][:0] = [-1]*5 
            
            seq_len = len(sequences[i])
            if seq_len < self.max_seq_len:
                sequences[i].extend([self.PAD] * (self.max_seq_len - seq_len))
                masked_sequences[i].extend([self.PAD] * (self.max_seq_len - seq_len))
                token_masks[i].extend([False] * (self.max_seq_len - seq_len))
            # the antibiotic-specific lists should always be of length num_ab
            pheno_len = len(all_target_resistances[i])
            all_target_resistances[i].extend([-1] * (self.num_ab - pheno_len))
            # all_target_indices[i].extend([-1] * (self.num_ab - pheno_len)) # -1 indicates padding
            # ab_mask is defined with correct length
                
        return sequences, masked_sequences, all_target_resistances, token_masks, ab_masks  
    
    
    def reconstruct_sequence(self, seq_from_batch):
        tuple_len = len(seq_from_batch[0])
        sequences = list()
        for j in range(tuple_len):
            sequence = list()
            for i in range(self.max_seq_len):
                sequence.append(seq_from_batch[i][j])
            sequences.append(sequence)
        return sequences

## Trainer

In [5]:

import os
import torch
import torch.nn as nn
import time 
import matplotlib.pyplot as plt
import wandb

# from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from pathlib import Path

from datetime import datetime

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

############################################ Trainer for MLM task ############################################

class BertPhenoTrainer(nn.Module):
    
    def __init__(self,
                 config: dict,
                 model: PhenoBERT,
                 antibiotics: list, # list of antibiotics in the dataset
                 train_set,
                 val_set,
                 test_set, # can be None
                 results_dir: Path = None,
                 ):
        super(BertPhenoTrainer, self).__init__()
        
        self.random_state = config["random_state"]
        torch.manual_seed(self.random_state)
        torch.cuda.manual_seed(self.random_state)
        
        self.model = model
        self.antibiotics = antibiotics
        self.num_ab = len(self.antibiotics) 
        self.train_set, self.train_size = train_set, len(train_set)
        self.train_size = len(self.train_set)      
        self.model.max_seq_len = self.train_set.max_seq_len 
        self.val_set, self.val_size = val_set, len(val_set)
        if test_set:
            self.test_set, self.test_size = test_set, len(test_set)
        self.split = config["split"]
        self.project_name = config["project_name"]
        self.wandb_name = config["name"] if config["name"] else datetime.now().strftime("%Y%m%d-%H%M%S")
         
        self.batch_size = config["batch_size"]
        self.lr = config["lr"]
        self.weight_decay = config["weight_decay"]
        self.epochs = config["epochs"]
        self.patience = config["early_stopping_patience"]
        self.save_model = config["save_model"] if config["save_model"] else False
        
        self.do_testing = config["do_testing"] if config["do_testing"] else False
        self.num_batches = self.train_size // self.batch_size
        
        self.mask_prob = config["mask_prob"]
        self.criterions = [nn.BCEWithLogitsLoss() for _ in range(self.num_ab)] # the list is so that we can introduce individual weights
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        self.scheduler = None
        # self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.9)
        # self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98)
                 
        self.current_epoch = 0
        self.report_every = config["report_every"] if config["report_every"] else 100
        self.print_progress_every = config["print_progress_every"] if config["print_progress_every"] else 1000
        self._splitter_size = 70
        self.results_dir = results_dir
        os.makedirs(self.results_dir) if not os.path.exists(self.results_dir) else None
        
        
    def print_model_summary(self):        
        print("Model summary:")
        print("="*self._splitter_size)
        print(f"Embedding dim: {self.model.emb_dim}")
        print(f"Hidden dim: {self.model.hidden_dim}")
        print(f"Number of heads: {self.model.num_heads}")
        print(f"Number of encoder layers: {self.model.num_layers}")
        print(f"Max sequence length: {self.model.max_seq_len}")
        print(f"Vocab size: {len(self.train_set.vocab):,}")
        print(f"Number of parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}")
        print("="*self._splitter_size)
        
    
    def print_trainer_summary(self):
        print("Trainer summary:")
        print("="*self._splitter_size)
        print(f"Device: {device} ({torch.cuda.get_device_name(0)})")
        print(f"Training dataset size: {self.train_size:,}")
        print(f"Number of antibiotics: {self.num_ab}")
        print(f"Antibiotics: {self.antibiotics}")
        print(f"Train-val-test split {self.split[0]:.0%} - {self.split[1]:.0%} - {self.split[2]:.0%}")
        print(f"Will test? {'Yes' if self.do_testing else 'No'}")
        print(f"Mask probability: {self.mask_prob:.0%}")
        print(f"Number of epochs: {self.epochs}")
        print(f"Early stopping patience: {self.patience}")
        print(f"Batch size: {self.batch_size}")
        print(f"Number of batches: {self.num_batches:,}")
        print(f"Dropout probability: {self.model.dropout_prob:.0%}")
        print(f"Learning rate: {self.lr}")
        print(f"Weight decay: {self.weight_decay}")
        print("="*self._splitter_size)
        
        
    def __call__(self):      
        self.wandb_run = self._init_wandb()
        self.val_set.prepare_dataset(mask_prob=self.mask_prob) 
        self.val_loader = DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False)
        if self.do_testing:
            self.test_set.prepare_dataset(mask_prob=self.mask_prob) 
            self.test_loader = DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)
        
        start_time = time.time()
        self.best_val_loss = float('inf')
        
        self.losses = []
        self.val_losses = []
        self.val_accuracies = []
        self.train_accuracies = []
        self.val_seq_accuracies = []
        self.val_ab_stats = []
        self.val_iso_stats = []
        for self.current_epoch in range(self.current_epoch, self.epochs):
            self.model.train()
            # Dynamic masking: New mask for training set each epoch
            self.train_set.prepare_dataset(mask_prob=self.mask_prob)
            self.train_loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
            epoch_start_time = time.time()
            loss = self.train(self.current_epoch) # returns loss, averaged over batches
            self.losses.append(loss) 
            print(f"Epoch completed in {(time.time() - epoch_start_time)/60:.1f} min")
            print(f"Elapsed time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time))}")
            # print("Evaluating on training set...")
            # _, train_acc = self.evaluate(self.train_loader)
            # self.train_accuracies.append(train_acc)
            print("Evaluating on validation set...")
            results = self.evaluate(self.val_loader, self.val_set)
            self._update_val_lists(results)
            self._report_epoch_results()
            early_stop = self.early_stopping()
            if early_stop:
                print(f"Early stopping at epoch {self.current_epoch+1} with validation loss {self.val_losses[-1]:.3f}")
                s = f"Best validation loss {self.best_val_loss:.3f}"
                s += f" | Validation accuracy {self.val_accuracies[self.best_epoch]:.2%}"
                s += f" | Validation sequence accuracy {self.val_seq_accuracies[self.best_epoch]:.2%}"
                s += f" at epoch {self.best_epoch+1}"
                print(s)
                self.wandb_run.log({"Losses/final_val_loss": self.best_val_loss, 
                           "Accuracies/final_val_acc":self.val_accuracies[self.best_epoch],
                           "Accuracies/final_val_seq_acc": self.val_seq_accuracies[self.best_epoch],
                           "final_epoch": self.best_epoch+1})
                print("="*self._splitter_size)
                self.model.load_state_dict(self.best_model_state) 
                self.current_epoch = self.best_epoch
                break
            self.scheduler.step() if self.scheduler else None
        if not early_stop:    
            self.wandb_run.log({"Losses/final_val_loss": self.val_losses[-1], 
                    "Accuracies/final_val_acc":self.val_accuracies[-1],
                    "Accuracies/final_val_seq_acc": self.val_seq_accuracies[-1],
                    "final_epoch": self.current_epoch+1})
        self.save_model(self.results_dir / "model_state.pt") if self.save_model else None
        train_time = (time.time() - start_time)/60
        self.wandb_run.log({"Training time (min)": train_time})
        disp_time = f"{train_time//60:.0f}h {train_time % 60:.1f} min" if train_time > 60 else f"{train_time:.1f} min"
        print(f"Training completed in {disp_time}")
        if not early_stop:
            s = f"Final validation loss {self.val_losses[-1]:.3f}"
            s += f" | Final validation accuracy {self.val_accuracies[-1]:.2%}"
            s += f" | Final validation sequence accuracy {self.val_seq_accuracies[-1]:.2%}"
            print(s)
        
        if self.do_testing:
            print("Evaluating on test set...")
            results = self.evaluate(self.test_loader, self.test_set)
            self.test_loss = results["loss"]
            self.test_acc = results["acc"]
            test_seq_acc = results["seq_acc"]
            self.test_ab_stats = results["ab_stats"]
            self.test_iso_stats = results["iso_stats"]
            self.wandb_run.log({"Losses/test_loss": self.test_loss, 
                                "Accuracies/test_acc": self.test_acc,
                                "Accuracies/test_seq_acc": test_seq_acc})
        self._visualize_losses(savepath=self.results_dir / "losses.png")
        self._visualize_accuracy(savepath=self.results_dir / "accuracy.png")
        
        return self.val_ab_stats, self.val_iso_stats, self.current_epoch
        
    def train(self, epoch: int):
        print(f"Epoch {epoch+1}/{self.epochs}")
        time_ref = time.time()
        
        epoch_loss = 0
        reporting_loss = 0
        printing_loss = 0
        for i, batch in enumerate(self.train_loader):
            batch_index = i + 1
            # input, target_res, token_mask, ab_mask, attn_mask, original_seq, masked_seq = batch
            # original_seq = self.train_set.reconstruct_sequence(original_seq)
            # masked_seq = self.train_set.reconstruct_sequence(masked_seq)
            input, target_res, token_mask, ab_mask, attn_mask = batch
            
            # print("original sequence:", original_seq)
            # print("masked sequence:", masked_seq)
            
            # print("input shape:", input.shape)
            # print("input:", input)
            # print("target_res shape:", target_res.shape)
            # print("target_res:", target_res)
            # print("attn_mask shape:", attn_mask.shape)
            # print("attn_mask:", attn_mask)
            # print("token_mask shape:", token_mask.shape)
            # print("token_mask:", token_mask)
            # print("ab_mask shape:", ab_mask.shape)
            # print("ab_mask:", ab_mask)
            
            self.optimizer.zero_grad() # zero out gradients
            pred_logits = self.model(input, attn_mask) # get predictions for all antibiotics
            
            # select only the antibtiotics present in the samples - to calculate loss
            losses = list()
            for j in range(self.num_ab): # for each antibiotic
                mask = ab_mask[:, j] # (batch_size,), indicates which samples contain the antibiotic masked
                if mask.any(): # if there is at least one masked sample for this antibiotic
                    # isolate the predictions and targets for the antibiotic
                    ab_pred_logits = pred_logits[mask, j] # (num_masked_samples,)
                    ab_targets = target_res[mask, j] # (num_masked_samples,)
                    ab_loss = self.criterions[j](ab_pred_logits, ab_targets)
                    losses.append(ab_loss)
            loss = sum(losses) / len(losses) # average loss over antibiotics
            epoch_loss += loss.item() 
            reporting_loss += loss.item()
            printing_loss += loss.item()
            
            loss.backward() 
            self.optimizer.step() 
            if batch_index % self.report_every == 0:
                self._report_loss_results(batch_index, reporting_loss)
                reporting_loss = 0 
                
            if batch_index % self.print_progress_every == 0:
                time_elapsed = time.gmtime(time.time() - time_ref) 
                self._print_loss_summary(time_elapsed, batch_index, printing_loss) 
                printing_loss = 0           
        avg_epoch_loss = epoch_loss / self.num_batches
        return avg_epoch_loss 
    
    def early_stopping(self):
        if self.val_losses[-1] < self.best_val_loss:
            self.best_val_loss = self.val_losses[-1]
            self.best_epoch = self.current_epoch
            self.best_model_state = self.model.state_dict()
            self.early_stopping_counter = 0
            return False
        else:
            self.early_stopping_counter += 1
            return True if self.early_stopping_counter >= self.patience else False
        
            
    def evaluate(self, loader: DataLoader, ds_obj, print_mode: bool = True):
        self.model.eval()
        # prepare evaluation statistics dataframes
        eval_stats_ab, eval_stats_iso = self._init_eval_stats(ds_obj)
        with torch.no_grad():
            loss = 0
            num_preds = np.zeros((self.num_ab, 2)) # tracks the number of predictions for each antibiotic & resistance
            num_correct = np.zeros_like(num_preds) # tracks the number of correct predictions for each antibiotic & resistance
            for batch_idx, batch in enumerate(loader):
                input, target_res, token_mask, ab_mask, attn_mask = batch
                pred_logits = self.model(input, attn_mask) # get predictions for all antibiotics
                pred_res = torch.where(pred_logits > 0, torch.ones_like(pred_logits), torch.zeros_like(pred_logits))
                
                eval_stats_iso = self._update_iso_stats(batch_idx, pred_res, target_res, ab_mask, eval_stats_iso)
                batch_loss = list()
                for j in range(self.num_ab): # for each antibiotic
                    mask = ab_mask[:, j] # (batch_size,), indicates which samples contain the antibiotic masked
                    if mask.any(): # if there is at least one masked sample for this antibiotic
                        ab_pred_logits = pred_logits[mask, j] # (num_masked_samples,)
                        ab_targets = target_res[mask, j] # (num_masked_samples,)
                        num_R = ab_targets.sum().item()
                        num_S = ab_targets.shape[0] - num_R
                        num_preds[j, :] += [num_S, num_R]
                        
                        ab_loss = self.criterions[j](ab_pred_logits, ab_targets)
                        batch_loss.append(ab_loss.item())
                        
                        ab_pred_res = pred_res[mask, j] # (num_masked_samples,)
                        num_correct[j, :] += self._get_num_correct(ab_pred_res, ab_targets)    
                loss += sum(batch_loss) / len(batch_loss) # average loss over antibiotics
            loss /= len(loader) # average loss over batches
            acc = num_correct.sum() / num_preds.sum() # accuracy over all predictions
            seq_acc = eval_stats_iso['correct_all'].sum() / eval_stats_iso.shape[0] # accuracy over all sequences
            
            eval_stats_ab = self._update_ab_eval_stats(eval_stats_ab, num_preds, num_correct)
            
            eval_stats_iso['accuracy_S'] = eval_stats_iso.apply(
                lambda row: row['correct_S']/row['num_masked_S'] if row['num_masked_S'] > 0 else np.nan, axis=1)
            eval_stats_iso['accuracy_R'] = eval_stats_iso.apply(
                lambda row: row['correct_R']/row['num_masked_R'] if row['num_masked_R'] > 0 else np.nan, axis=1)
        if print_mode:
            print(f"Loss: {loss:.3f} | Accuracy: {acc:.2%} | Sequence accuracy: {seq_acc:.2%}")
            print("="*self._splitter_size)
        
        results = {"loss": loss, 
                   "acc": acc, 
                   "seq_acc": seq_acc,
                   "ab_stats": eval_stats_ab,
                   "iso_stats": eval_stats_iso}
        return results
    
    
    def _update_val_lists(self, results: dict):
        self.val_losses.append(results["loss"])
        self.val_accuracies.append(results["acc"])
        self.val_seq_accuracies.append(results["seq_acc"])
        self.val_ab_stats.append(results["ab_stats"])
        self.val_iso_stats.append(results["iso_stats"])
    
    
    def _init_eval_stats(self, ds_obj):
        eval_stats_ab = pd.DataFrame(columns=['ab', 'res', 'num_pred', 'num_correct'])
        tmp = []
        [tmp.extend([ab, ab]) for ab in self.antibiotics]
        eval_stats_ab['ab'] = tmp
        eval_stats_ab['res'] = ['S', 'R']*self.num_ab
        eval_stats_ab['num_pred'], eval_stats_ab['num_correct'] = 0, 0
        eval_stats_iso = ds_obj.ds.copy()
        eval_stats_iso['num_masked'] = 0
        eval_stats_iso['num_masked_S'] = 0
        eval_stats_iso['num_masked_R'] = 0
        eval_stats_iso['correct_S'] = 0
        eval_stats_iso['correct_R'] = 0
        eval_stats_iso['correct_all'] = False
        # eval_stats_iso['correct_mask'] = [-1]*eval_stats_iso.shape[0] # indicates which antibiotics are -1: not masked, 0: incorrect, 1:correct
        eval_stats_iso.drop(columns=['phenotypes'], inplace=True)
        
        return eval_stats_ab, eval_stats_iso
    
    
    def _update_ab_eval_stats(self, eval_stats_ab: pd.DataFrame, num_preds: np.ndarray, num_correct: np.ndarray):
        for j in range(self.num_ab): 
            eval_stats_ab.loc[2*j, 'num_pred'] = num_preds[j, 0]
            eval_stats_ab.loc[2*j+1, 'num_pred'] = num_preds[j, 1]
            eval_stats_ab.loc[2*j, 'num_correct'] = num_correct[j, 0]
            eval_stats_ab.loc[2*j+1, 'num_correct'] = num_correct[j, 1]
        eval_stats_ab['accuracy'] = eval_stats_ab['num_correct'] / eval_stats_ab['num_pred']
        return eval_stats_ab
    
    def _get_num_correct(self, pred_res: torch.Tensor, target_res: torch.Tensor):
        eq = torch.eq(pred_res, target_res)
        num_correct_S = eq[target_res == 0].sum().item()
        num_correct_R = eq[target_res == 1].sum().item()
        return [num_correct_S, num_correct_R]
    
    
    def _update_iso_stats(self, batch_index: int, pred_res: torch.Tensor, target_res: torch.Tensor, ab_mask: torch.Tensor,
                          eval_stats_iso: pd.DataFrame):
        for i in range(pred_res.shape[0]): # for each isolate
            global_idx = batch_index * self.batch_size + i # index of the isolate in the dataframe
            iso_ab_mask = ab_mask[i]
            
            # counts
            eval_stats_iso.loc[global_idx, 'num_masked'] = int(iso_ab_mask.sum().item())
            iso_target_res = target_res[i][iso_ab_mask]
            num_R = iso_target_res.sum().item()
            num_S = iso_target_res.shape[0] - num_R
            eval_stats_iso.loc[global_idx, 'num_masked_S'] = num_S
            eval_stats_iso.loc[global_idx, 'num_masked_R'] = num_R
            
            # correct predictions
            iso_pred_res = pred_res[i][iso_ab_mask]
            eq = torch.eq(iso_pred_res, iso_target_res)
            num_R_correct = eq[iso_target_res == 1].sum().item()
            num_S_correct = eq[iso_target_res == 0].sum().item()
            eval_stats_iso.loc[global_idx, 'correct_S'] = num_S_correct
            eval_stats_iso.loc[global_idx, 'correct_R'] = num_R_correct
            
            eval_stats_iso.loc[global_idx, 'correct_all'] = bool(eq.all().item()) # 1 if all antibiotics are predicted correctly, 0 otherwise       
        return eval_stats_iso
     
     
    def _init_wandb(self):
        self.wandb_run = wandb.init(
            project=self.project_name, # name of the project
            name=self.wandb_name, # name of the run
            
            config={
                # "dataset": "NCBI",
                "epochs": self.epochs,
                "batch_size": self.batch_size,
                # "model": "BERT",
                "hidden_dim": self.model.hidden_dim,
                "num_layers": self.model.num_layers,
                "num_heads": self.model.num_heads,
                "emb_dim": self.model.emb_dim,
                "lr": self.lr,
                "weight_decay": self.weight_decay,
                "mask_prob": self.mask_prob,
                "max_seq_len": self.model.max_seq_len,
                "vocab_size": len(self.train_set.vocab),
                "num_parameters": sum(p.numel() for p in self.model.parameters() if p.requires_grad),
                "train_size": self.train_size,
                "random_state": self.random_state,
                # "val_size": self.val_size,
                # "test_size": self.test_size,
                # "early_stopping_patience": self.patience,
                # "dropout_prob": self.model.dropout_prob,
            }
        )
        self.wandb_run.watch(self.model) # watch the model for gradients and parameters
        self.wandb_run.define_metric("epoch", hidden=True)
        self.wandb_run.define_metric("batch", hidden=True)
        
        self.wandb_run.define_metric("Losses/live_loss", step_metric="batch")
        self.wandb_run.define_metric("Losses/train_loss", summary="min", step_metric="epoch")
        self.wandb_run.define_metric("Losses/val_loss", summary="min", step_metric="epoch")
        self.wandb_run.define_metric("Accuracies/val_acc", summary="max", step_metric="epoch")
        self.wandb_run.define_metric("Accuracies/val_seq_acc", summary="max", step_metric="epoch")
        
        if self.do_testing:
            self.wandb_run.define_metric("Losses/test_loss")
            self.wandb_run.define_metric("Accuracies/test_acc")
            self.wandb_run.define_metric("Accuracies/test_seq_acc")
        self.wandb_run.define_metric("Losses/final_val_loss")
        self.wandb_run.define_metric("Accuracies/final_val_acc")
        self.wandb_run.define_metric("Accuracies/final_val_seq_acc")
        self.wandb_run.define_metric("final_epoch")

        return self.wandb_run
     
    def _report_epoch_results(self):
        wandb_dict = {
            "epoch": self.current_epoch+1,
            "Losses/train_loss": self.losses[-1],
            "Losses/val_loss": self.val_losses[-1],
            "Accuracies/val_acc": self.val_accuracies[-1],
            "Accuracies/val_seq_acc": self.val_seq_accuracies[-1]
        }
        self.wandb_run.log(wandb_dict)
    
        
    def _report_loss_results(self, batch_index, tot_loss):
        avg_loss = tot_loss / self.report_every
        
        global_step = self.current_epoch * self.num_batches + batch_index # global step, total #batches seen
        self.wandb_run.log({"batch": global_step, "Losses/live_loss": avg_loss})
        # self.writer.add_scalar("Loss", avg_loss, global_step=global_step)
    
        
    def _print_loss_summary(self, time_elapsed, batch_index, tot_loss):
        progress = batch_index / self.num_batches
        mlm_loss = tot_loss / self.print_progress_every
          
        s = f"{time.strftime('%H:%M:%S', time_elapsed)}" 
        s += f" | Epoch: {self.current_epoch+1}/{self.epochs} | {batch_index}/{self.num_batches} ({progress:.2%}) | "\
                f"Loss: {mlm_loss:.3f}"
        print(s)
    
    
    def _visualize_losses(self, savepath: Path = None):
        fig, ax = plt.subplots()
        ax.plot(range(len(self.losses)), self.losses, '-o', label='Training')
        ax.plot(range(len(self.val_losses)), self.val_losses, '-o', label='Validation')
        ax.axhline(y=self.test_loss, color='r', linestyle='--', label='Test') if self.do_testing else None
        ax.set_title('MLM losses')
        ax.set_xlabel('Epoch')
        ax.set_xticks(range(len(self.losses))) if len(self.losses) < 10 else ax.set_xticks(range(0, len(self.losses), 5))
        ax.set_ylabel('Loss')
        ax.legend()
        plt.savefig(savepath, dpi=300) if savepath else None
        # self.wandb_run.log({"Losses/losses": wandb.log(ax)})
        self.wandb_run.log({"Losses/losses": wandb.Image(ax)})
        plt.close()
        
    
    def _visualize_accuracy(self, savepath: Path = None):
        fig, ax = plt.subplots()
        ax.plot(range(len(self.val_accuracies)), self.val_accuracies, '-o', label='Validation')
        ax.axhline(y=self.test_acc, color='r', linestyle='--', label='Test') if self.do_testing else None
        ax.set_title('MLM accuracy')
        ax.set_xlabel('Epoch')
        ax.set_xticks(range(len(self.val_accuracies))) if len(self.val_accuracies) < 10 else ax.set_xticks(range(0, len(self.val_accuracies), 5))
        ax.set_ylabel('Accuracy')
        ax.legend()
        plt.savefig(savepath, dpi=300) if savepath else None
        # self.wandb_run.log({"Accuracies/accuracy": wandb.log(ax)})
        self.wandb_run.log({"Accuracies/accuracy": wandb.Image(ax)})
        plt.close() 
    
    
    def save_model(self, savepath: Path):
        torch.save(self.model.state_dict(), savepath)
        print(f"Model saved to {savepath}")
        print("="*self._splitter_size)
        
        
    def load_model(self, savepath: Path):
        print("="*self._splitter_size)
        print(f"Loading model from {savepath}")
        self.model.load_state_dict(torch.load(savepath))
        print("Model loaded")
        print("="*self._splitter_size)

## Testing

In [6]:
print("Loading dataset...")
ds = pd.read_pickle(config['data']['load_path'])
# ds = ds.iloc[:20000]
num_samples = ds.shape[0]
print(f"Number of samples: {num_samples:,}")

specials = config['specials']
print("Constructing vocabulary...")
savepath_vocab = BASE_DIR / "data" / "pheno_vocab.pt" if config['save_vocab'] else None
vocab, antibiotics = construct_pheno_vocab(ds,
                                           specials,
                                           savepath_vocab=savepath_vocab, 
                                           separate_phenotypes=config['separate_phenotypes'])
print(f"Found {len(antibiotics)} antibiotics: {antibiotics}")
vocab_size = len(vocab)

max_phenotypes_len = ds['num_phenotypes'].max()    
if config['max_seq_len'] == 'auto':
    if config['separate_phenotypes']:
        max_seq_len = 2*max_phenotypes_len + 4 + 1 # +4 for year, country, age & gender, +1 for CLS token
    else:
        max_seq_len = max_phenotypes_len + 4 + 1
else:
    max_seq_len = config['max_seq_len']

train_indices, val_indices, test_indices = get_split_indices(num_samples, config['split'], 
                                                                 random_state=config['random_state'])
train_set = PhenotypeDataset(ds.iloc[train_indices], vocab, antibiotics, specials, max_seq_len, base_dir=BASE_DIR,
                            #  include_sequences=True
                             )
val_set = PhenotypeDataset(ds.iloc[val_indices], vocab, antibiotics, specials, max_seq_len, base_dir=BASE_DIR,
                        #    include_sequences=True
                           )
test_set = PhenotypeDataset(ds.iloc[test_indices], vocab, antibiotics, specials, max_seq_len, base_dir=BASE_DIR,
                            # include_sequences=True
                            )

os.environ['WANDB_MODE'] = config['wandb_mode']
           
print("Loading model...")
config['name'] = "test_new_code"
config['batch_size'] = 64
config['num_heads'] = 4
config['num_layers'] = 4
config['epochs'] = 10
config['print_progress_every'] = 1500
bert = PhenoBERT(config, vocab_size, antibiotics).to(device)
trainer = BertPhenoTrainer(
    config=config,
    model=bert,
    train_set=train_set,
    val_set=val_set,
    test_set=test_set,
    results_dir=RESULTS_DIR,
    antibiotics=antibiotics,
)

trainer.print_model_summary()
trainer.print_trainer_summary()
ab_stats, iso_stats, best_epoch = trainer()
print("Done!")

Loading dataset...
Number of samples: 1,439,018
Constructing vocabulary...
Found 20 antibiotics: ['AMP', 'CTX', 'GEN', 'TOB', 'CIP', 'CAZ', 'CRO', 'OFX', 'AMK', 'AMX', 'LVX', 'TZP', 'AMC', 'FEP', 'COL', 'MFX', 'NOR', 'NET', 'PIP', 'NAL']
Loading model...
Model summary:
Embedding dim: 256
Hidden dim: 256
Number of heads: 4
Number of encoder layers: 4
Max sequence length: 22
Vocab size: 218
Number of parameters: 1,486,042
Trainer summary:
Device: cuda (NVIDIA GeForce RTX 3080)
Training dataset size: 1,151,214
Number of antibiotics: 20
Antibiotics: ['AMP', 'CTX', 'GEN', 'TOB', 'CIP', 'CAZ', 'CRO', 'OFX', 'AMK', 'AMX', 'LVX', 'TZP', 'AMC', 'FEP', 'COL', 'MFX', 'NOR', 'NET', 'PIP', 'NAL']
Train-val-test split 80% - 10% - 10%
Will test? No
Mask probability: 25%
Number of epochs: 10
Early stopping patience: 3
Batch size: 64
Number of batches: 17,987
Dropout probability: 10%
Learning rate: 5e-05
Weight decay: 0.01


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Epoch 1/10
00:00:56 | Epoch: 1/10 | 1500/17987 (8.34%) | Loss: 0.306
00:01:54 | Epoch: 1/10 | 3000/17987 (16.68%) | Loss: 0.247
00:02:48 | Epoch: 1/10 | 4500/17987 (25.02%) | Loss: 0.235
00:03:42 | Epoch: 1/10 | 6000/17987 (33.36%) | Loss: 0.225
00:04:39 | Epoch: 1/10 | 7500/17987 (41.70%) | Loss: 0.227
00:05:35 | Epoch: 1/10 | 9000/17987 (50.04%) | Loss: 0.219
00:06:31 | Epoch: 1/10 | 10500/17987 (58.38%) | Loss: 0.213
00:07:27 | Epoch: 1/10 | 12000/17987 (66.71%) | Loss: 0.217
00:08:22 | Epoch: 1/10 | 13500/17987 (75.05%) | Loss: 0.214
00:09:18 | Epoch: 1/10 | 15000/17987 (83.39%) | Loss: 0.214
00:10:15 | Epoch: 1/10 | 16500/17987 (91.73%) | Loss: 0.212
Epoch completed in 11.2 min
Elapsed time: 00:11:50
Evaluating on validation set...
Loss: 0.208 | Accuracy: 91.52% | Sequence accuracy: 85.51%
Epoch 2/10
00:00:53 | Epoch: 2/10 | 1500/17987 (8.34%) | Loss: 0.207
00:01:48 | Epoch: 2/10 | 3000/17987 (16.68%) | Loss: 0.210
00:02:42 | Epoch: 2/10 | 4500/17987 (25.02%) | Loss: 0.209
00:03:3

### Work with eval_stats

Antibiotic stats

In [7]:
print("Antibiotic statistics at first epoch:")
print(ab_stats[0])
print("="*50)
print(f"Antibiotic statistics at best epoch ({best_epoch}/{config['epochs']}):")
print(ab_stats[best_epoch])

Antibiotic statistics at first epoch:
     ab res  num_pred  num_correct  accuracy
0   AMP   S     11605        10154  0.874968
1   AMP   R     14501         9064  0.625060
2   CTX   S     27512        27335  0.993566
3   CTX   R      2957         2114  0.714914
4   GEN   S     32221        31762  0.985755
5   GEN   R      2790         1007  0.360932
6   TOB   S     16347        16015  0.979690
7   TOB   R      1499         1019  0.679787
8   CIP   S     30147        29396  0.975089
9   CIP   R      7258         3896  0.536787
10  CAZ   S     31013        30663  0.988714
11  CAZ   R      2770         2196  0.792780
12  CRO   S      9950         9845  0.989447
13  CRO   R       989          818  0.827098
14  OFX   S      2942         2919  0.992182
15  OFX   R       667          508  0.761619
16  AMK   S     19013        19013  1.000000
17  AMK   R       215            0  0.000000
18  AMX   S      4229         3962  0.936865
19  AMX   R      5009         2863  0.571571
20  LVX   S      

Isolate stats

In [8]:
print(f"Isolate statistics at best epoch ({best_epoch}/{config['epochs']}):")
iso_stats[best_epoch].head(n=20)

Isolate statistics at best epoch (9/10):


Unnamed: 0,year,country,gender,age,num_phenotypes,num_R,num_S,num_masked,num_masked_S,num_masked_R,correct_S,correct_R,correct_all,accuracy_S,accuracy_R
0,2019,IE,M,74.0,7,5,2,3,1,2,1,2,True,1.0,1.0
1,2018,PT,M,87.0,10,3,7,5,3,2,3,1,False,1.0,0.5
2,2013,AT,F,3.0,6,1,5,1,1,0,1,0,True,1.0,
3,2020,NO,M,72.0,9,0,9,3,3,0,3,0,True,1.0,
4,2004,NL,F,65.0,4,0,4,1,1,0,1,0,True,1.0,
5,2017,FI,F,64.0,5,0,5,1,1,0,1,0,True,1.0,
6,2009,BE,M,16.0,4,3,1,1,0,1,0,1,True,,1.0
7,2011,BE,M,70.0,7,0,7,1,1,0,0,0,False,0.0,
8,2020,IT,F,69.0,8,2,6,1,1,0,1,0,True,1.0,
9,2018,IT,M,94.0,6,2,4,1,1,0,1,0,True,1.0,
