# Multimodal Model Implementation

In [1]:
import os
import torch
import yaml
import wandb
import argparse
import pandas as pd
import numpy as np
import time
import sys
from pathlib import Path

base_dir = Path(os.path.abspath('')).parent
sys.path.append(str(base_dir))

from datetime import datetime

# user-defined functions
# from construct_vocab import construct_MM_vocab
from data_preprocessing import preprocess_NCBI, preprocess_TESSy
from utils import get_split_indices

# set the base directory
os.chdir(base_dir)
print("base directory:", base_dir)

base directory: c:\Users\jespe\Documents\GitHub_local\ARFusion


In [2]:
with open(base_dir / 'config_MM.yaml') as f:
    config = yaml.safe_load(f)

## Dataset

### Prepare unimodal datasets

In [3]:
data_dict = config['data']
if data_dict['TESSy']['prepare_data']:
    ds_TESSy = preprocess_TESSy(
        path=data_dict['TESSy']['raw_path'],
        pathogens=data_dict['pathogens'],
        save_path=data_dict['TESSy']['save_path'],
        exclude_antibiotics=data_dict['exclude_antibiotics'],
        impute_age=data_dict['TESSy']['impute_age'],
        impute_gender=data_dict['TESSy']['impute_gender']
    )
else:
    ds_TESSy = pd.read_pickle(os.path.join(base_dir, data_dict['TESSy']['load_path']))
num_TESSy = len(ds_TESSy)
print(f"Number of isolates in the TESSy dataset: {num_TESSy:,}")
ds_TESSy.head()

Number of isolates in the TESSy dataset: 1,437,004


Unnamed: 0,year,country,gender,age,phenotypes,num_ab,num_R,num_S
0,2001,AT,F,61.0,"[AMP_S, CTX_S, GEN_S, TOB_S]",4,0,4
1,2001,AT,M,37.0,"[AMP_S, CTX_S, GEN_S, TOB_S]",4,0,4
2,2001,AT,F,79.0,"[AMP_S, CTX_S, GEN_S, TOB_S]",4,0,4
3,2001,AT,F,54.0,"[AMP_S, CIP_S, CTX_S, GEN_S, TOB_S]",5,0,5
4,2001,AT,M,63.0,"[AMP_R, CAZ_S, CIP_R, CRO_S, CTX_S, GEN_S, OFX...",8,3,5


In [4]:
if data_dict['NCBI']['prepare_data']:
    ds_NCBI = preprocess_NCBI(
        path=data_dict['NCBI']['raw_path'],
        save_path=data_dict['NCBI']['save_path'],
        include_phenotype=data_dict['NCBI']['include_phenotype'],
        ab_names_to_abbr=data_dict['antibiotics']['ab_names_to_abbr'],
        exclude_antibiotics=data_dict['exclude_antibiotics'], 
        threshold_year=data_dict['NCBI']['threshold_year'],
        exclude_genotypes=data_dict['NCBI']['exclude_genotypes'],
        exclude_assembly_variants=data_dict['NCBI']['exclude_assembly_variants'],
        exclusion_chars=data_dict['NCBI']['exclusion_chars'],
        gene_count_threshold=data_dict['NCBI']['gene_count_threshold']
    )
else:
    ds_NCBI = pd.read_pickle(os.path.join(base_dir, data_dict['NCBI']['load_path']))
num_NCBI = len(ds_NCBI)
ds_NCBI[ds_NCBI['num_ab'] > 0].head()

Parsing phenotypes...
Parsing genotypes...
Number of isolates before parsing: 341,565
Removing 253 isolates with year < 1970
Removing genotypes with assembly variants: ['=PARTIAL', '=MISTRANSLATION', '=HMM']
Dropping 52 isolates with more than 35 genotypes
Number of isolates after parsing: 339,349
Number of isolates with phenotype info after parsing: 6,439


Unnamed: 0,year,country,genotypes,phenotypes,num_ab,num_genotypes,num_point_mutations
2819,2013,USA,"[glpT_E448K=POINT, parC_E84V=POINT, parE_I529L...","[AMP_R, FEP_R, CAZ_R, CRO_R, CIP_R, GEN_S, LVX...",9.0,18,9
2820,2014,USA,"[dfrA14, glpT_E448K=POINT, parC_E84V=POINT, pa...","[AMP_R, CTX_R, CAZ_R, CRO_R, CIP_R, GEN_S, LVX...",9.0,21,9
3210,2012,USA,"[glpT_E448K=POINT, aph(3'')-Ib, blaTEM-1, pmrB...","[AMP_R, CRO_S, CIP_S, GEN_S, NAL_S]",5.0,8,2
3211,2012,USA,"[glpT_E448K=POINT, aph(3'')-Ib, blaTEM-1, pmrB...","[AMP_R, CRO_S, CIP_S, GEN_S, NAL_S]",5.0,8,2
3212,2012,USA,"[floR, aph(3')-Ia=PARTIAL_END_OF_CONTIG, sul1,...","[AMP_S, CRO_S, CIP_S, GEN_S, NAL_S]",5.0,6,0


### Vocabulary construction

In [5]:
from itertools import chain
from collections import Counter
from torchtext.vocab import vocab as Vocab

def construct_MM_vocab(df_geno: pd.DataFrame,
                       df_pheno: pd.DataFrame,
                       antibiotics: list,
                       specials: dict,
                       savepath_vocab: Path = None):
    token_counter = Counter()
    ds_geno = df_geno.copy()
    ds_pheno = df_pheno.copy()
    
    PAD, UNK = specials['PAD'], specials['UNK']
    special_tokens = specials.values()
    
    year_geno = ds_geno[ds_geno['year'] != PAD]['year'].astype('Int16')
    min_year = min(year_geno.min(), ds_pheno['year'].min())
    max_year = max(year_geno.max(), ds_pheno['year'].max())
    year_range = range(min_year, max_year+1)
    token_counter.update([str(y) for y in year_range])
    
    min_age, max_age = ds_pheno['age'].min(), ds_pheno['age'].max()
    age_range = range(int(min_age), int(max_age+1))
    token_counter.update([str(a) for a in age_range])
    
    genders = ds_pheno['gender'].unique().astype(str).tolist()
    token_counter.update(genders)
    
    pheno_countries = ds_pheno['country'].sort_values().unique()
    geno_countries = ds_geno['country'].sort_values().dropna().unique()
    countries = set(pheno_countries).union(set(geno_countries))
    token_counter.update(countries)
    
    token_counter.update(list(chain(*ds_geno['genotypes'])))
    token_counter.update([ab + '_' + res for ab in antibiotics for res in ['R', 'S']])  
    
    vocab = Vocab(token_counter, specials=special_tokens)
    vocab.set_default_index(vocab[UNK])
    if savepath_vocab:
        torch.save(vocab, savepath_vocab)
    
    return vocab

### Multimodal dataset

In [6]:
import numpy as np
import torch
import pandas as pd

from copy import deepcopy
from itertools import chain
from torch.utils.data import Dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    

class MMPretrainDataset(Dataset):
    # df column names
    INDICES_MASKED = 'indices_masked' # input to BERT, token indices of the masked sequence
    TARGET_RESISTANCES = 'target_resistances' # resistance of the masked antibiotic, what we want to predict
    TARGET_INDICES = 'target_indices' # indices of the target tokens for the genotype masking
    TOKEN_TYPES = 'token_types' # 0 for patient info, 1 for genotype, 2 for phenotype
    # if sequences are included
    MASKED_SEQUENCE = 'masked_sequence'
    # ORIGINAL_SEQUENCE = 'original_sequence'
    
    def __init__(
        self, 
        ds_geno: pd.DataFrame,
        ds_pheno: pd.DataFrame,
        vocab,
        antibiotics: list,
        specials: dict,
        max_seq_len: int,
        mask_prob_geno: float,
        mask_prob_pheno: float = None,
        num_known_ab: int = None,
        include_sequences: bool = False,
        random_state: int = 42
    ):
        self.random_state = random_state
        np.random.seed(random_state)
        
        self.ds_geno = ds_geno.reset_index(drop=True)
        self.num_geno = ds_geno.shape[0]
        self.ds_pheno = ds_pheno.reset_index(drop=True)
        self.num_pheno = ds_pheno.shape[0]
        self.num_samples = self.num_geno + self.num_pheno
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.antibiotics = antibiotics
        self.num_ab = len(antibiotics)
        self.ab_to_idx = {ab: idx for idx, ab in enumerate(antibiotics)}
        self.enc_res = {'S': 0, 'R': 1}
        self.max_seq_len = max_seq_len
        self.CLS, self.PAD, self.MASK, self.UNK = specials.values()
        
        self.mask_prob_geno = mask_prob_geno
        self.mask_prob_pheno = mask_prob_pheno
        self.num_known_ab = num_known_ab
        assert not (self.mask_prob_pheno and self.num_known_ab), "Either mask_prob_pheno or num_known_ab should be given, not both"
        
        self.ds_geno['source'] = 'geno'
        self.ds_pheno['source'] = 'pheno'
        geno_cols = ['year', 'country', 'num_genotypes', 'source']
        pheno_cols = ['year', 'country', 'gender', 'age', 'num_ab', 'source']
        self.combined_ds = pd.concat([self.ds_geno[geno_cols], self.ds_pheno[pheno_cols]], ignore_index=True)
        
        self.include_sequences = include_sequences
        if self.include_sequences:
            self.columns = [self.INDICES_MASKED, self.TARGET_INDICES, self.TARGET_RESISTANCES, self.TOKEN_TYPES,
                            self.MASKED_SEQUENCE]
        else:
            self.columns = [self.INDICES_MASKED, self.TARGET_INDICES, self.TARGET_RESISTANCES, self.TOKEN_TYPES]
        
        
    def __len__(self):
        return self.num_samples
    
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        
        input = torch.tensor(item[self.INDICES_MASKED], dtype=torch.long, device=device)
        target_res = torch.tensor(item[self.TARGET_RESISTANCES], dtype=torch.float32, device=device)
        token_types = torch.tensor(item[self.TOKEN_TYPES], dtype=torch.long, device=device)
        target_indices = torch.tensor(item[self.TARGET_INDICES], dtype=torch.long, device=device) 
        attn_mask = (input != self.vocab[self.PAD]).unsqueeze(0).unsqueeze(1) # one dim for batch, one for heads   
        
        if self.include_sequences:
            # original_sequence = item[self.ORIGINAL_SEQUENCE]
            masked_sequence = item[self.MASKED_SEQUENCE]
            return input, target_indices, target_res, token_types, attn_mask, masked_sequence
        else:
            return input, target_indices, target_res, token_types, attn_mask
    
    
    def prepare_dataset(self):
        geno_sequences = deepcopy(self.ds_geno['genotypes'].tolist())
        pheno_sequences = deepcopy(self.ds_pheno['phenotypes'].tolist())
                
        masked_geno_sequences, geno_target_indices = self._mask_geno_sequences(geno_sequences)
        geno_target_resistances = [[-1]*self.num_ab for _ in range(self.num_geno)] # no ab masking for genotypes
        geno_token_types = [[0]*3 + [1]*(self.max_seq_len - 3) for _ in range(self.num_geno)]
        
        # pheno sequences use ab masking, so token_mask is all False, CE-Loss won't be calculated for these sequences
        masked_pheno_sequences, pheno_target_resistances = self._mask_pheno_sequences(pheno_sequences)
        pheno_token_types = [[0]*5 + [2]*(self.max_seq_len - 5) for _ in range(self.num_pheno)]
        pheno_target_indices = [[-1]*self.max_seq_len for _ in range(self.num_pheno)]

        masked_sequences = masked_geno_sequences + masked_pheno_sequences
        indices_masked = [self.vocab.lookup_indices(masked_seq) for masked_seq in masked_sequences]
        target_indices = geno_target_indices + pheno_target_indices
        token_types = geno_token_types + pheno_token_types
        target_resistances = geno_target_resistances + pheno_target_resistances
        
        if self.include_sequences:
            rows = zip(indices_masked, target_indices, target_resistances, token_types,
                       masked_sequences)
        else:
            rows = zip(indices_masked, target_indices, target_resistances, token_types)
        self.df = pd.DataFrame(rows, columns=self.columns)
        
        
    def _mask_geno_sequences(self, geno_sequences):
        masked_geno_sequences = list()
        target_indices_list = list()
        
        years = self.ds_geno['year'].astype(str).tolist()
        countries = self.ds_geno['country'].tolist()
        seq_starts = [[self.CLS, years[i], countries[i]] for i in range(self.ds_geno.shape[0])]
        for i, geno_seq in enumerate(geno_sequences):
            seq_len = len(geno_seq)
            token_mask = np.random.rand(seq_len) < self.mask_prob_geno   
            target_indices = np.array([-1]*seq_len)
            if not token_mask.any():
                # if no tokens are masked, mask one random token
                idx = np.random.randint(seq_len)
                target_indices[idx] = self.vocab[geno_seq[idx]]
                r = np.random.rand()
                if r < 0.8:
                    geno_seq[idx] = self.MASK
                elif r < 0.9:
                    geno_seq[idx] = self.vocab.lookup_token(np.random.randint(self.vocab_size))
            else:
                indices = token_mask.nonzero()[0]
                target_indices[indices] = self.vocab.lookup_indices([geno_seq[i] for i in indices])
                for i in indices:
                    r = np.random.rand()
                    if r < 0.8:
                        geno_seq[i] = self.MASK
                    elif r < 0.9:
                        geno_seq[i] = self.vocab.lookup_token(np.random.randint(self.vocab_size))
            geno_seq = seq_starts[i] + geno_seq
            target_indices = [-1]*3 + target_indices.tolist() 
            masked_geno_sequences.append(geno_seq)
            target_indices_list.append(target_indices)
            
        masked_geno_sequences = [seq + [self.PAD]*(self.max_seq_len - len(seq)) for seq in masked_geno_sequences]
        target_indices_list = [indices + [-1]*(self.max_seq_len - len(indices)) for indices in target_indices_list]
        return masked_geno_sequences, target_indices_list
    
    
    def _mask_pheno_sequences(self, pheno_sequences):
        masked_pheno_sequences = list()
        target_resistances = list()
        
        years = self.ds_pheno['year'].astype('Int16').astype(str).tolist()
        countries = self.ds_pheno['country'].tolist()
        genders = self.ds_pheno['gender'].tolist()
        ages = self.ds_pheno['age'].astype(int).astype(str).tolist()
        seq_starts = [[self.CLS, years[i], countries[i], genders[i], ages[i]] for i in range(self.num_pheno)]

        if self.mask_prob_pheno:
            for i, pheno_seq in enumerate(pheno_sequences):
                seq_len = len(pheno_seq)
                token_mask = np.random.rand(seq_len) < self.mask_prob_pheno
                target_res = [-1]*self.num_ab
                if not token_mask.any():
                    idx = np.random.randint(seq_len)
                    ab, res = pheno_seq[idx].split('_')
                    target_res[self.ab_to_idx[ab]] = self.enc_res[res]  
                    r = np.random.rand()
                    if r < 0.8:
                        pheno_seq[idx] = self.MASK
                    elif r < 0.9:
                        pheno_seq[idx] = self.vocab.lookup_token(np.random.randint(self.vocab_size)) 
                else:
                    for idx in token_mask.nonzero()[0]:
                        ab, res = pheno_seq[idx].split('_')
                        target_res[self.ab_to_idx[ab]] = self.enc_res[res]
                        r = np.random.rand()
                        if r < 0.8:
                            pheno_seq[idx] = self.MASK
                        elif r < 0.9:
                            pheno_seq[idx] = self.vocab.lookup_token(np.random.randint(self.vocab_size))
                pheno_seq = seq_starts[i] + pheno_seq
                masked_pheno_sequences.append(pheno_seq)
                target_resistances.append(target_res)
        else:
            for i, pheno_seq in enumerate(pheno_sequences):
                seq_len = len(pheno_seq)
                target_res = [-1]*self.num_ab
                indices = np.random.choice(seq_len, self.num_known_ab, replace=False)
                for idx in indices:
                    ab, res = pheno_seq[idx].split('_')
                    target_res[self.ab_to_idx[ab]] = self.enc_res[res]
                    r = np.random.rand()
                    if r < 0.8:
                        pheno_seq[idx] = self.MASK
                    elif r < 0.9:
                        pheno_seq[idx] = self.vocab.lookup_token(np.random.randint(self.vocab_size))
                pheno_seq = seq_starts[i] + pheno_seq
                masked_pheno_sequences.append(pheno_seq)
                target_resistances.append(target_res)
            
        masked_pheno_sequences = [seq + [self.PAD]*(self.max_seq_len - len(seq)) for seq in masked_pheno_sequences]
        return masked_pheno_sequences, target_resistances
            
    def shuffle(self):
        self.df = self.df.sample(frac=1, random_state=self.random_state)
        self.combined_ds = self.combined_ds.loc[self.df.index].reset_index(drop=True) # combined dataset is aligned with df
        self.df.reset_index(drop=True, inplace=True)


## Pre-trainer

In [7]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from multimodal.models import BERT

class MMBertPreTrainer(nn.Module):
    
    def __init__(self,
                 config: dict,
                 model: BERT,
                 antibiotics: list, # list of antibiotics in the dataset
                 train_set,
                 val_set,
                 results_dir: Path = None,
    ):
        super(MMBertPreTrainer, self).__init__()
        
        self.random_state = config["random_state"]
        np.random.seed(self.random_state)
        torch.manual_seed(self.random_state)
        torch.cuda.manual_seed(self.random_state)
        
        self.model = model
        self.project_name = config["project_name"]
        self.wandb_name = config["name"] if config["name"] else datetime.now().strftime("%Y%m%d-%H%M%S")
        self.antibiotics = antibiotics
        self.num_ab = len(self.antibiotics) 
        
        self.train_set, self.train_size = train_set, len(train_set)
        self.val_set, self.val_size = val_set, len(val_set) 
        assert round(self.val_size / (self.train_size + self.val_size), 2) == config["val_share"], "Validation set size does not match intended val_share"
        self.val_share, self.train_share = config["val_share"], 1 - config["val_share"]
        self.batch_size = config["batch_size"]
        self.num_batches = round(self.train_size / self.batch_size)
        self.vocab = self.train_set.vocab
         
        self.lr = config["lr"]
        self.weight_decay = config["weight_decay"]
        self.epochs = config["epochs"]
        self.patience = config["early_stopping_patience"]
        self.save_model_ = config["save_model"] if config["save_model"] else False
        
        self.mask_prob_geno = self.train_set.mask_prob_geno
        self.mask_prob_pheno = self.train_set.mask_prob_pheno
        self.mask_probs = {'geno': self.mask_prob_geno, 'pheno': self.mask_prob_pheno}
        self.num_known_ab = self.train_set.num_known_ab
        
        self.ab_criterions = [nn.BCEWithLogitsLoss().to(device) for _ in range(self.num_ab)] # the list is so that we can introduce individual weights
        self.geno_criterion = nn.CrossEntropyLoss(ignore_index = -1).to(device) # ignores loss where target_indices == -1
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        self.scheduler = None
        # self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.9)
        # self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98)
                 
        self.current_epoch = 0
        self.report_every = config["report_every"] if config["report_every"] else 1000
        self.print_progress_every = config["print_progress_every"] if config["print_progress_every"] else 1000
        self._splitter_size = 70
        self.results_dir = results_dir
        if self.results_dir:
            self.results_dir.mkdir(parents=True, exist_ok=True) 
        
        
    def print_model_summary(self):        
        print("Model summary:")
        print("="*self._splitter_size)
        print(f"Embedding dim: {self.model.emb_dim}")
        print(f"Feed-forward dim: {self.model.ff_dim}")
        print(f"Hidden dim: {self.model.hidden_dim}")
        print(f"Number of heads: {self.model.num_heads}")
        print(f"Number of encoder layers: {self.model.num_layers}")
        print(f"Dropout probability: {self.model.dropout_prob:.0%}")
        print(f"Max sequence length: {self.model.max_seq_len}")
        print(f"Vocab size: {len(self.vocab):,}")
        print(f"Number of parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}")
        print("="*self._splitter_size)
        
    
    def print_trainer_summary(self):
        print("Trainer summary:")
        print("="*self._splitter_size)
        if device.type == "cuda":
            print(f"Device: {device} ({torch.cuda.get_device_name(0)})")
        else:
            print(f"Device: {device}")        
        print(f"Training dataset size: {self.train_size:,}")
        print(f"Batch size: {self.batch_size}")
        print(f"Number of batches: {self.num_batches:,}")
        print(f"Number of antibiotics: {self.num_ab}")
        print(f"Antibiotics: {self.antibiotics}")
        print(f"CV split: {self.train_share:.0%} train | {self.val_share:.0%} val")
        print(f"Mask probability (genotypes): {self.mask_prob_geno:.0%}")
        if self.mask_prob_pheno:
            print(f"Mask probability (phenotypes): {self.mask_prob_pheno:.0%}")
        if self.num_known_ab:
            print(f"Number of known antibiotics: {self.num_known_ab}")
        print(f"Number of epochs: {self.epochs}")
        print(f"Early stopping patience: {self.patience}")
        print(f"Learning rate: {self.lr}")
        print(f"Weight decay: {self.weight_decay}")
        print("="*self._splitter_size)
        
        
    def __call__(self):      
        print("Initializing training...")
        self.wandb_run = self._init_wandb()
        self.val_set.prepare_dataset()
        self.val_set.shuffle() # to avoid batches of only genotypes or only phenotypes
        self.val_loader = DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False)
        
        start_time = time.time()
        self.best_val_loss = float('inf') 
        self._init_result_lists()
        for self.current_epoch in range(self.current_epoch, self.epochs):
            self.model.train()
            # Dynamic masking: New mask for training set each epoch
            self.train_set.prepare_dataset()
            self.train_loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
            epoch_start_time = time.time()
            train_losses = self.train(self.current_epoch) # returns loss, averaged over batches
            self.losses.append(train_losses['loss']) 
            print(f"Epoch completed in {(time.time() - epoch_start_time)/60:.1f} min")
            print("Loss: {:.4f} | Genotype loss: {:.4f} | Phenotype loss: {:.4f}".format(
                train_losses['loss'], train_losses['geno_loss'], train_losses['pheno_loss']))
            print("Evaluating on validation set...")
            val_results = self.evaluate(self.val_loader, self.val_set)
            print("Val loss: {:.4f} | Genotype loss: {:.4f} | Phenotype loss: {:.4f}".format(
                val_results['loss'], val_results['geno_loss'], val_results['pheno_loss']))
            print("Phenotype accuracy: {:.2%} | Phenotype isolate accuracy: {:.2%}".format(
                val_results['pheno_acc'], val_results['pheno_iso_acc']))
            print("Genotype accuracy: {:.2%} | Genotype isolate accuracy: {:.2%}".format(
                val_results['geno_acc'], val_results['geno_iso_acc']))
            print(f"Elapsed time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time))}")
            self._update_val_lists(val_results)
            self._report_epoch_results()
            early_stop = self.early_stopping()
            if early_stop:
                print(f"Early stopping at epoch {self.current_epoch+1} with validation loss {self.val_losses[-1]:.4f}")
                print(f"Validation stats at best epoch ({self.best_epoch+1}):")
                s1 = f"Loss: {self.val_losses[self.best_epoch]:.4f}" 
                s1 += f"| Phenotype Loss: {self.val_pheno_losses[self.best_epoch]:.4f}"
                s1 += f" | Genotype Loss: {self.val_geno_losses[self.best_epoch]:.4f}"
                print(s1)
                s2 = f" | Phenotype accuracy: {self.val_pheno_accs[self.best_epoch]:.2%}"
                s2 += f" | Phenotype isolate accuracy: {self.val_pheno_iso_accs[self.best_epoch]:.2%}"
                print(s2)
                s3 = f" | Genotype accuracy: {self.val_geno_accs[self.best_epoch]:.2%}"
                s3 += f" | Genotype isolate accuracy: {self.val_geno_iso_accs[self.best_epoch]:.2%}"
                print(s3)
                self.wandb_run.log({
                    "Losses/final_val_loss": self.best_val_loss, 
                    "Losses/final_val_geno_loss": self.val_geno_losses[self.best_epoch],
                    "Losses/final_val_pheno_loss": self.val_pheno_losses[self.best_epoch],
                    "Accuracies/final_val_pheno_acc": self.val_pheno_accs[self.best_epoch],
                    "Accuracies/final_val_pheno_iso_acc": self.val_pheno_iso_accs[self.best_epoch],
                    "Accuracies/final_val_geno_acc": self.val_geno_accs[self.best_epoch],
                    "Accuracies/final_val_geno_iso_acc": self.val_geno_iso_accs[self.best_epoch],
                    "final_epoch": self.best_epoch+1
                })
                print("="*self._splitter_size)
                self.model.load_state_dict(self.best_model_state) 
                self.current_epoch = self.best_epoch
                break
            if self.scheduler:
                self.scheduler.step()
        if not early_stop:    
            self.wandb_run.log({
                    "Losses/final_val_loss": self.best_val_loss, 
                    "Losses/final_val_geno_loss": self.val_geno_losses[-1],
                    "Losses/final_val_pheno_loss": self.val_pheno_losses[-1],
                    "Accuracies/final_val_pheno_acc": self.val_pheno_accs[-1],
                    "Accuracies/final_val_pheno_iso_acc": self.val_pheno_iso_accs[-1],
                    "Accuracies/final_val_geno_acc": self.val_geno_accs[-1],
                    "Accuracies/final_val_geno_iso_acc": self.val_geno_iso_accs[-1],
                    "final_epoch": self.current_epoch+1
                })
        self.model.is_pretrained = True
        if self.save_model_:
            self.save_model() 
        train_time = (time.time() - start_time)/60
        self.wandb_run.log({"Training time (min)": train_time})
        disp_time = f"{train_time//60:.0f}h {train_time % 60:.1f} min" if train_time > 60 else f"{train_time:.1f} min"
        print(f"Training completed in {disp_time}")
        print("="*self._splitter_size)
        if not early_stop:
            print("Final validation stats:")
            s1 = f"Loss: {self.val_losses[-1]:.4f} | Phenotype Loss: {self.val_pheno_losses[-1]:.4f}"
            s1 += f" | Genotype Loss: {self.val_geno_losses[-1]:.4f}"
            print(s1)
            s2 = f"Phenotype accuracy: {self.val_pheno_accs[-1]:.2%}"
            s2 += f" | Phenotype isolate accuracy: {self.val_pheno_iso_accs[-1]:.2%}"
            print(s2)
            s3 = f" Genotype accuracy: {self.val_geno_accs[-1]:.2%}"
            s3 += f" | Genotype isolate accuracy: {self.val_geno_iso_accs[-1]:.2%}"
            print(s3)
        
        results = {
            "best epoch": self.best_epoch,
            "train_losses": self.losses,
            "val_losses": self.val_losses,
            "val_pheno_losses": self.val_pheno_losses,
            "val_geno_losses": self.val_geno_losses,
            "val_pheno_accs": self.val_pheno_accs,
            "val_geno_accs": self.val_geno_accs,
            "val_pheno_iso_accs": self.val_pheno_iso_accs,
            "val_geno_iso_accs": self.val_geno_iso_accs,
            "train_time": train_time,
            "val_iso_stats_geno": self.val_iso_stats_geno,
            "val_iso_stats_pheno": self.val_iso_stats_pheno,
            "val_ab_stats": self.val_ab_stats
        }
        return results
    
    
    def train(self, epoch: int):
        print(f"Epoch {epoch+1}/{self.epochs}")
        time_ref = time.time()
        
        epoch_geno_loss, epoch_pheno_loss = 0, 0
        geno_batches, pheno_batches = 0, 0
        reporting_loss, printing_loss = 0, 0
        for i, batch in enumerate(self.train_loader):
            batch_index = i + 1
            self.optimizer.zero_grad() # zero out gradients
            
            input, target_indices, target_res, token_types, attn_mask = batch   
            # input, target_indices, target_res, token_types, attn_mask, masked_sequences = batch   
            pred_logits, token_pred = self.model(input, token_types, attn_mask) # get predictions for all antibiotics
            ab_mask = target_res != -1 # (batch_size, num_ab), True if antibiotic is masked, False otherwise
            
            loss = 0
            if ab_mask.any(): # if there are phenotypes in the batch
                ## Phenotype loss ##
                ab_indices = ab_mask.any(dim=0).nonzero().squeeze(-1).tolist() # list of indices of antibiotics present in the batch
                losses = list()
                for j in ab_indices: 
                    mask = ab_mask[:, j] # (batch_size,), indicates which samples contain the antibiotic masked
                    # isolate the predictions and targets for the antibiotic
                    ab_pred_logits = pred_logits[mask, j] # (num_masked_samples,)
                    ab_targets = target_res[mask, j] # (num_masked_samples,)
                    ab_loss = self.ab_criterions[j](ab_pred_logits, ab_targets)
                    losses.append(ab_loss)
                pheno_loss = sum(losses) / len(losses) # average loss over antibiotics
                epoch_pheno_loss += pheno_loss.item()
                pheno_batches += 1
                loss += pheno_loss
                
            if (target_indices != -1).any(): # if there are genotypes in the batch
                ## Genotype loss ##
                geno_loss = self.geno_criterion(token_pred.transpose(-1, -2), target_indices) # DOUBLE-CHECK DIMENSIONS
                epoch_geno_loss += geno_loss.item()
                geno_batches += 1
                loss += geno_loss
            reporting_loss += loss.item()
            printing_loss += loss.item()
            
            loss.backward() 
            self.optimizer.step() 
            if batch_index % self.report_every == 0:
                self._report_loss_results(batch_index, reporting_loss)
                reporting_loss = 0 
                
            if batch_index % self.print_progress_every == 0:
                time_elapsed = time.gmtime(time.time() - time_ref) 
                self._print_loss_summary(time_elapsed, batch_index, printing_loss) 
                printing_loss = 0  
        avg_pheno_loss = epoch_pheno_loss / pheno_batches
        avg_geno_loss = epoch_geno_loss / geno_batches
        avg_epoch_loss = avg_geno_loss + avg_pheno_loss
        losses = {"loss": avg_epoch_loss, "geno_loss": avg_geno_loss, "pheno_loss": avg_pheno_loss}
        return losses 
    
    
    def early_stopping(self):
        if self.val_losses[-1] < self.best_val_loss:
            self.best_val_loss = self.val_losses[-1]
            self.best_epoch = self.current_epoch
            self.best_model_state = self.model.state_dict()
            self.early_stopping_counter = 0
            return False
        else:
            self.early_stopping_counter += 1
            return True if self.early_stopping_counter >= self.patience else False
        
            
    def evaluate(self, loader: DataLoader, ds_obj):
        self.model.eval()
        # prepare evaluation statistics dataframes
        ab_stats, iso_stats_pheno, iso_stats_geno = self._init_eval_stats(ds_obj)
        with torch.no_grad(): 
            ## Antibiotic tracking ##
            ab_num = np.zeros((self.num_ab, 2)) # tracks the occurence for each antibiotic & resistance
            ab_num_preds = np.zeros_like(ab_num) # tracks the number of predictions for each antibiotic & resistance
            ab_num_correct = np.zeros_like(ab_num) # tracks the number of correct predictions for each antibiotic & resistance
            ## General tracking ##
            tot_pheno_loss, tot_geno_loss = 0, 0
            geno_batches, pheno_batches = 0, 0
            for i, batch in enumerate(loader):                
                input, target_indices, target_res, token_types, attn_mask = batch   
                # input, target_indices, target_res, token_types, attn_mask, sequences, masked_sequences = batch  
                 
                pred_logits, token_pred = self.model(input, token_types, attn_mask) # get predictions for all antibiotics
                pred_res = torch.where(pred_logits > 0, torch.ones_like(pred_logits), torch.zeros_like(pred_logits)) # logits -> 0/1 (S/R)
                        
                ###### Phenotype loss ######
                ab_mask = target_res >= 0 # (batch_size, num_ab), True if antibiotic is masked, False otherwise
                if ab_mask.any(): # if there are phenotypes in the batch
                    iso_stats_pheno = self._update_pheno_stats(i, pred_res, target_res, ab_mask, iso_stats_pheno)
                    
                    ab_indices = ab_mask.any(dim=0).nonzero().squeeze(-1).tolist() # list of indices of antibiotics present in the batch
                    losses = list()
                    for j in ab_indices: 
                        mask = ab_mask[:, j] # (batch_size,)
                        
                        # isolate the predictions and targets for the antibiotic
                        ab_pred_logits = pred_logits[mask, j] # (num_masked_samples,)
                        ab_targets = target_res[mask, j] # (num_masked_samples,)
                        num_R = ab_targets.sum().item()
                        num_S = ab_targets.shape[0] - num_R
                        ab_num[j, :] += [num_S, num_R]
                        
                        ab_loss = self.ab_criterions[j](ab_pred_logits, ab_targets)
                        losses.append(ab_loss)
                        
                        ab_pred_res = pred_res[mask, j]
                        ab_num_correct[j, :] += self._get_num_correct(ab_pred_res, ab_targets)    
                        ab_num_preds[j, :] += self._get_num_preds(ab_pred_res)
                    pheno_loss = sum(losses) / len(losses) # average loss over antibiotics
                    tot_pheno_loss += pheno_loss.item()
                    pheno_batches += 1
                    
                ###### Genotype loss ######
                token_mask = target_indices != -1 # (batch_size, max_seq_len), True if token is masked, False otherwise
                if token_mask.any(): # if there are genotypes in the batch
                    iso_stats_geno = self._update_geno_stats(i, token_pred, target_indices, token_mask, iso_stats_geno)
                    
                    geno_loss = self.geno_criterion(token_pred.transpose(-1, -2), target_indices) 
                    tot_geno_loss += geno_loss.item()
                    geno_batches += 1
                    
        avg_geno_loss = tot_geno_loss / geno_batches
        avg_pheno_loss = tot_pheno_loss / pheno_batches
        avg_loss = avg_geno_loss + avg_pheno_loss  
        
        ab_stats = self._update_ab_eval_stats(ab_stats, ab_num, ab_num_preds, ab_num_correct)
        iso_stats_geno, iso_stats_pheno = self._calculate_iso_stats(iso_stats_geno, iso_stats_pheno)
        
        pheno_acc = iso_stats_pheno['num_correct'].sum() / iso_stats_pheno['num_masked'].sum()
        pheno_iso_acc = iso_stats_pheno['all_correct'].sum() / iso_stats_pheno.shape[0]
        geno_acc = iso_stats_geno['num_correct'].sum() / iso_stats_geno['num_masked'].sum()
        geno_iso_acc = iso_stats_geno['all_correct'].sum() / iso_stats_geno.shape[0]

        results = {
            "loss": avg_loss, 
            "geno_loss": avg_geno_loss, 
            "pheno_loss": avg_pheno_loss,
            "pheno_acc": pheno_acc,
            "pheno_iso_acc": pheno_iso_acc,
            "geno_acc": geno_acc,
            "geno_iso_acc": geno_iso_acc,
            "ab_stats": ab_stats,
            "iso_stats_pheno": iso_stats_pheno,
            "iso_stats_geno": iso_stats_geno
        }
        return results
            
    
    def _init_result_lists(self):
        self.losses = []
        self.val_losses = []
        self.val_geno_losses = []
        self.val_pheno_losses = []
        self.val_pheno_accs = []
        self.val_pheno_iso_accs = []
        self.val_geno_accs = []
        self.val_geno_iso_accs = []
        self.val_ab_stats = []
        self.val_iso_stats_pheno = []
        self.val_iso_stats_geno = []
        
        
    def _update_val_lists(self, results: dict):
        self.val_losses.append(results["loss"])
        self.val_geno_losses.append(results["geno_loss"])
        self.val_pheno_losses.append(results["pheno_loss"])
        self.val_pheno_accs.append(results["pheno_acc"])
        self.val_pheno_iso_accs.append(results["pheno_iso_acc"])
        self.val_geno_accs.append(results["geno_acc"])
        self.val_geno_iso_accs.append(results["geno_iso_acc"])
        self.val_ab_stats.append(results["ab_stats"])
        self.val_iso_stats_pheno.append(results["iso_stats_pheno"])
        self.val_iso_stats_geno.append(results["iso_stats_geno"])
    
    
    def _init_eval_stats(self, ds_obj):
        ab_stats = pd.DataFrame(columns=[
            'antibiotic', 'num_tot', 'num_S', 'num_R', 'num_pred_S', 'num_pred_R', 
            'num_correct', 'num_correct_S', 'num_correct_R',
            'accuracy', 'sensitivity', 'specificity', 'precision', 'F1'
        ])
        ab_stats['antibiotic'] = self.antibiotics
        ab_stats['num_tot'], ab_stats['num_S'], ab_stats['num_R'] = 0, 0, 0
        ab_stats['num_pred_S'], ab_stats['num_pred_R'] = 0, 0
        ab_stats['num_correct'], ab_stats['num_correct_S'], ab_stats['num_correct_R'] = 0, 0, 0

        combined_ds = ds_obj.combined_ds
        ## Extract phenotype samples 
        iso_stats_pheno = combined_ds[combined_ds['source'] == 'pheno'].drop(columns=['source', 'num_genotypes'])
        iso_stats_pheno['num_masked'], iso_stats_pheno['num_masked_S'], iso_stats_pheno['num_masked_R'] = 0, 0, 0
        iso_stats_pheno['num_correct'], iso_stats_pheno['correct_S'], iso_stats_pheno['correct_R'] = 0, 0, 0
        iso_stats_pheno['sensitivity'], iso_stats_pheno['specificity'], iso_stats_pheno['accuracy'] = 0, 0, 0
        iso_stats_pheno['all_correct'] = False 
        
        ## Extract genotype samples
        iso_stats_geno = combined_ds[combined_ds['source'] == 'geno'].drop(columns=['source', 'age', 'gender', 'num_ab'])
        iso_stats_geno.replace(self.val_set.PAD, np.nan, inplace=True)
        iso_stats_geno['num_masked'], iso_stats_geno['num_correct'], iso_stats_geno['accuracy'] = 0, 0, 0
        iso_stats_geno['all_correct'] = False
      
        return ab_stats, iso_stats_pheno, iso_stats_geno
    
    
    def _update_ab_eval_stats(self, ab_stats: pd.DataFrame, num, num_preds, num_correct):
        for j in range(self.num_ab): 
            ab_stats.loc[j, 'num_tot'] = num[j, :].sum()
            ab_stats.loc[j, 'num_S'], ab_stats.loc[j, 'num_R'] = num[j, 0], num[j, 1]
            ab_stats.loc[j, 'num_pred_S'], ab_stats.loc[j, 'num_pred_R'] = num_preds[j, 0], num_preds[j, 1]
            ab_stats.loc[j, 'num_correct'] = num_correct[j, :].sum()
            ab_stats.loc[j, 'num_correct_S'], ab_stats.loc[j, 'num_correct_R'] = num_correct[j, 0], num_correct[j, 1]
        ab_stats['accuracy'] = ab_stats.apply(
            lambda row: row['num_correct']/row['num_tot'] if row['num_tot'] > 0 else np.nan, axis=1)
        ab_stats['sensitivity'] = ab_stats.apply(
            lambda row: row['num_correct_R']/row['num_R'] if row['num_R'] > 0 else np.nan, axis=1)
        ab_stats['specificity'] = ab_stats.apply(
            lambda row: row['num_correct_S']/row['num_S'] if row['num_S'] > 0 else np.nan, axis=1)
        ab_stats['precision'] = ab_stats.apply(
            lambda row: row['num_correct_R']/row['num_pred_R'] if row['num_pred_R'] > 0 else np.nan, axis=1)
        ab_stats['F1'] = ab_stats.apply(
            lambda row: 2*row['precision']*row['sensitivity']/(row['precision']+row['sensitivity']) 
            if row['precision'] > 0 and row['sensitivity'] > 0 else np.nan, axis=1)
        return ab_stats
    
    
    def _get_num_correct(self, pred_res: torch.Tensor, target_res: torch.Tensor):
        eq = torch.eq(pred_res, target_res)
        num_correct_S = eq[target_res == 0].sum().item()
        num_correct_R = eq[target_res == 1].sum().item()
        return [num_correct_S, num_correct_R]
    
    
    def _get_num_preds(self, pred_res: torch.Tensor):
        num_pred_S = (pred_res == 0).sum().item()
        num_pred_R = (pred_res == 1).sum().item()
        return [num_pred_S, num_pred_R]
    
    
    def _update_pheno_stats(self, batch_idx, pred_res: torch.Tensor, target_res: torch.Tensor, 
                          ab_mask: torch.Tensor, iso_stats_pheno: pd.DataFrame):
        indices = ab_mask.any(dim=1).nonzero().squeeze(-1).tolist() # list of isolates where phenotypes are present
        for idx in indices: 
            iso_ab_mask = ab_mask[idx]
            df_idx = batch_idx * self.batch_size + idx # index of the isolate in the combined dataset
            
            # counts
            num_masked_tot = iso_ab_mask.sum().item()
            num_masked_R = target_res[idx][iso_ab_mask].sum().item()
            num_masked_S = num_masked_tot - num_masked_R
            
            # statistics            
            iso_target_res = target_res[idx][iso_ab_mask]
            eq = torch.eq(pred_res[idx][iso_ab_mask], iso_target_res)
            num_correct_R = eq[iso_target_res == 1].sum().item()
            num_correct_S = eq[iso_target_res == 0].sum().item()
            num_correct = num_correct_S + num_correct_R
            all_correct = eq.all().item()
            
            data = {
                'num_masked': num_masked_tot, 'num_masked_S': num_masked_S, 'num_masked_R': num_masked_R, 
                'num_correct': num_correct, 'correct_S': num_correct_S, 'correct_R': num_correct_R,
                'all_correct': all_correct
            }
            iso_stats_pheno.loc[df_idx, data.keys()] = data.values()
                          
        return iso_stats_pheno
    
    
    def _update_geno_stats(self, batch_idx, token_pred: torch.Tensor, target_indices: torch.Tensor, 
                           token_mask: torch.Tensor, iso_stats_geno: pd.DataFrame):
        indices = token_mask.any(dim=1).nonzero().squeeze(-1).tolist() # list of isolates where genotypes are present
        for idx in indices:
            iso_token_mask = token_mask[idx]    
            df_idx = batch_idx * self.batch_size + idx # index of the isolate in the combined dataset
            
            num_masked = iso_token_mask.sum().item()
            pred_tokens = token_pred[idx, iso_token_mask].argmax(dim=-1)
            targets = target_indices[idx, iso_token_mask]

            eq = torch.eq(pred_tokens, targets)
            data = {
                'num_masked': num_masked, 'num_correct': eq.sum().item(), 'all_correct': eq.all().item()
            }
            iso_stats_geno.loc[df_idx, data.keys()] = data.values()
                
        return iso_stats_geno
    
    
    def _calculate_iso_stats(self, iso_stats_geno: pd.DataFrame, iso_stats_pheno: pd.DataFrame):
        
        iso_stats_geno['accuracy'] = iso_stats_geno['num_correct'] / iso_stats_geno['num_masked']
        
        iso_stats_pheno['accuracy'] = iso_stats_pheno['num_correct'] / iso_stats_pheno['num_masked']
        iso_stats_pheno['sensitivity'] = iso_stats_pheno.apply(
            lambda row: row['correct_R']/row['num_masked_R'] if row['num_masked_R'] > 0 else np.nan, axis=1
        )
        iso_stats_pheno['specificity'] = iso_stats_pheno.apply(
            lambda row: row['correct_S']/row['num_masked_S'] if row['num_masked_S'] > 0 else np.nan, axis=1
        )
        
        return iso_stats_geno, iso_stats_pheno
        
    
     
    def _init_wandb(self):
        self.wandb_run = wandb.init(
            project=self.project_name, # name of the project
            name=self.wandb_name, # name of the run
            
            config={
                "epochs": self.epochs,
                "batch_size": self.batch_size,
                "hidden_dim": self.model.hidden_dim,
                "num_layers": self.model.num_layers,
                "num_heads": self.model.num_heads,
                "emb_dim": self.model.emb_dim,
                'ff_dim': self.model.ff_dim,
                "lr": self.lr,
                "weight_decay": self.weight_decay,
                "mask_probs": self.mask_probs,
                "max_seq_len": self.model.max_seq_len,
                "vocab_size": len(self.vocab),
                "num_parameters": sum(p.numel() for p in self.model.parameters() if p.requires_grad),
                "num_antibiotics": self.num_ab,
                "antibiotics": self.antibiotics,
                "train_size": self.train_size,
                "random_state": self.random_state,
                'val_share': self.val_share,
                "val_size": self.val_size,
                # "early_stopping_patience": self.patience,
                # "dropout_prob": self.model.dropout_prob,
            }
        )
        self.wandb_run.watch(self.model) # watch the model for gradients and parameters
        self.wandb_run.define_metric("epoch", hidden=True)
        self.wandb_run.define_metric("batch", hidden=True)
        
        self.wandb_run.define_metric("Losses/live_loss", step_metric="batch")
        self.wandb_run.define_metric("Losses/train_loss", summary="min", step_metric="epoch")
        self.wandb_run.define_metric("Losses/val_loss", summary="min", step_metric="epoch")
        self.wandb_run.define_metric("Losses/val_geno_loss", summary="min", step_metric="epoch")
        self.wandb_run.define_metric("Losses/val_pheno_loss", summary="min", step_metric="epoch")
        self.wandb_run.define_metric("Accuracies/val_pheno_acc", summary="max", step_metric="epoch")
        self.wandb_run.define_metric("Accuracies/val_pheno_iso_acc", summary="max", step_metric="epoch")
        self.wandb_run.define_metric("Accuracies/val_geno_acc", summary="max", step_metric="epoch")
        self.wandb_run.define_metric("Accuracies/val_geno_iso_acc", summary="max", step_metric="epoch")
        
        self.wandb_run.define_metric("Losses/final_val_loss")
        self.wandb_run.define_metric("Losses/final_val_geno_loss")
        self.wandb_run.define_metric("Losses/final_val_pheno_loss")
        self.wandb_run.define_metric("Accuracies/final_val_pheno_acc")
        self.wandb_run.define_metric("Accuracies/final_val_pheno_iso_acc")
        self.wandb_run.define_metric("Accuracies/final_val_geno_acc")
        self.wandb_run.define_metric("Accuracies/final_val_geno_iso_acc")
        
        self.wandb_run.define_metric("final_epoch")

        return self.wandb_run
     
    def _report_epoch_results(self):
        wandb_dict = {
            "epoch": self.current_epoch+1,
            "Losses/train_loss": self.losses[-1],
            "Losses/val_loss": self.val_losses[-1],
            "Losses/val_geno_loss": self.val_geno_losses[-1],
            "Losses/val_pheno_loss": self.val_pheno_losses[-1],
            "Accuracies/val_pheno_acc": self.val_pheno_accs[-1],
            "Accuracies/val_pheno_iso_acc": self.val_pheno_iso_accs[-1],
            "Accuracies/val_geno_acc": self.val_geno_accs[-1],
            "Accuracies/val_geno_iso_acc": self.val_geno_iso_accs[-1],
        }
        self.wandb_run.log(wandb_dict)
    
        
    def _report_loss_results(self, batch_index, tot_loss):
        avg_loss = tot_loss / self.report_every
        
        global_step = self.current_epoch * self.num_batches + batch_index # global step, total #batches seen
        self.wandb_run.log({"batch": global_step, "Losses/live_loss": avg_loss})
    
        
    def _print_loss_summary(self, time_elapsed, batch_index, tot_loss):
        progress = batch_index / self.num_batches
        mlm_loss = tot_loss / self.print_progress_every
          
        s = f"{time.strftime('%H:%M:%S', time_elapsed)}" 
        s += f" | Epoch: {self.current_epoch+1}/{self.epochs} | {batch_index}/{self.num_batches} ({progress:.2%}) | "\
                f"Loss: {mlm_loss:.4f}"
        print(s)
    
    
    def save_model(self, savepath: Path = None):
        if not savepath:
            savepath = self.results_dir / "model_state.pt"
        torch.save(self.model.state_dict(), savepath)
        print(f"Model saved to {savepath}")
        print("="*self._splitter_size)
        
        
    def _load_model(self, savepath: Path):
        print("="*self._splitter_size)
        print(f"Loading model from {savepath}")
        self.model.load_state_dict(torch.load(savepath))
        print("Model loaded")
        print("="*self._splitter_size)

## Main

In [9]:
from utils import get_multimodal_split_indices, export_results
from multimodal.models import BERT

specials = config['specials']
pad_token = specials['PAD']
pad_idx = list(specials.values()).index(pad_token) # pass to model for embedding

ds_geno = ds_NCBI[ds_NCBI['num_ab'] == 0].reset_index(drop=True)
ds_geno.fillna(pad_token, inplace=True)
ds_pheno = ds_TESSy.copy()
ds_pheno['country'] = ds_pheno['country'].map(config['data']['TESSy']['country_code_to_name'])

# ds_geno = ds_geno.iloc[:20000]
# ds_pheno = ds_pheno.iloc[:100000]

antibiotics = list(set(data_dict['antibiotics']['abbr_to_names'].keys()) - set(data_dict['exclude_antibiotics']))
vocab = construct_MM_vocab(
    df_geno=ds_NCBI,
    df_pheno=ds_pheno,
    antibiotics=antibiotics,
    specials=specials,
)
vocab_size = len(vocab)
config['name'] = 'pt_test'
if config['name']:
        results_dir = Path(os.path.join(base_dir / "results" / "MM", config['name']))
else:
    time_str = datetime.now().strftime("%Y%m%d-%H%M%S")
    results_dir = Path(os.path.join(base_dir / "results" / "MM", "experiment_" + str(time_str)))

os.environ['WANDB_MODE'] = config['wandb_mode']
if config['max_seq_len'] == 'auto':
    max_seq_len = int(max((ds_NCBI['num_genotypes'] + ds_NCBI['num_ab']).max() + 3, ds_pheno['num_ab'].max() + 5))

train_indices, val_indices = get_multimodal_split_indices(
    [ds_geno.shape[0], ds_pheno.shape[0]], 
    val_share=config['val_share'], 
    random_state=config['random_state']
)

ds_pt_train = MMPretrainDataset(
    ds_geno=ds_geno.iloc[train_indices[0]],
    ds_pheno=ds_pheno.iloc[train_indices[1]],
    vocab=vocab,
    antibiotics=antibiotics,
    specials=specials,
    max_seq_len=max_seq_len,
    mask_prob_geno=config['mask_prob_geno'],
    mask_prob_pheno=config['mask_prob_pheno'],
    random_state=config['random_state']
)
ds_pt_val = MMPretrainDataset(
    ds_geno=ds_geno.iloc[val_indices[0]],
    ds_pheno=ds_pheno.iloc[val_indices[1]],
    vocab=vocab,
    antibiotics=antibiotics,
    specials=specials,
    max_seq_len=max_seq_len,
    mask_prob_geno=config['mask_prob_geno'],
    mask_prob_pheno=config['mask_prob_pheno'],
    random_state=config['random_state']
)

bert = BERT(config, vocab_size, max_seq_len, len(antibiotics), pad_idx=pad_idx).to(device)
trainer = MMBertPreTrainer(
    config=config,
    model=bert,
    antibiotics=antibiotics,
    train_set=ds_pt_train,
    val_set=ds_pt_val,
    results_dir=results_dir,
)
trainer.print_model_summary()
trainer.print_trainer_summary()
pt_results = trainer()
export_results(pt_results, results_dir / 'pt_results.pkl')
wandb.finish()

Model summary:
Embedding dim: 256
Feed-forward dim: 256
Hidden dim: 256
Number of heads: 4
Number of encoder layers: 6
Dropout probability: 10%
Max sequence length: 40
Vocab size: 1,550
Number of parameters: 3,178,254
Trainer summary:
Device: cuda (NVIDIA GeForce RTX 3080)
Training dataset size: 1,504,426
Batch size: 64
Number of batches: 23,507
Number of antibiotics: 18
Antibiotics: ['OFX', 'NAL', 'CRO', 'CIP', 'LVX', 'AMX', 'MFX', 'FEP', 'AMC', 'AMP', 'CTX', 'TZP', 'NET', 'CAZ', 'NOR', 'TOB', 'GEN', 'PIP']
CV split: 85% train | 15% val
Mask probability (genotypes): 25%
Mask probability (phenotypes): 25%
Number of epochs: 100
Early stopping patience: 3
Learning rate: 0.0001
Weight decay: 0.01
Initializing training...


Epoch 1/100
00:00:39 | Epoch: 1/100 | 1000/23507 (4.25%) | Loss: 3.4236
00:01:18 | Epoch: 1/100 | 2000/23507 (8.51%) | Loss: 2.4712
00:01:56 | Epoch: 1/100 | 3000/23507 (12.76%) | Loss: 2.1257
00:02:35 | Epoch: 1/100 | 4000/23507 (17.02%) | Loss: 1.9206
00:03:14 | Epoch: 1/100 | 5000/23507 (21.27%) | Loss: 1.8152
00:03:52 | Epoch: 1/100 | 6000/23507 (25.52%) | Loss: 1.6995
00:04:31 | Epoch: 1/100 | 7000/23507 (29.78%) | Loss: 1.6559
00:05:09 | Epoch: 1/100 | 8000/23507 (34.03%) | Loss: 1.5799
00:05:48 | Epoch: 1/100 | 9000/23507 (38.29%) | Loss: 1.5414
00:06:26 | Epoch: 1/100 | 10000/23507 (42.54%) | Loss: 1.4960
00:07:04 | Epoch: 1/100 | 11000/23507 (46.79%) | Loss: 1.4640
00:07:42 | Epoch: 1/100 | 12000/23507 (51.05%) | Loss: 1.4326
00:08:20 | Epoch: 1/100 | 13000/23507 (55.30%) | Loss: 1.4398
00:08:59 | Epoch: 1/100 | 14000/23507 (59.56%) | Loss: 1.3851
00:09:37 | Epoch: 1/100 | 15000/23507 (63.81%) | Loss: 1.3754
00:10:15 | Epoch: 1/100 | 16000/23507 (68.06%) | Loss: 1.3606
00:10:5

## Fine-tuning 

### Fine-tuning dataset

In [None]:
class MMFinetuneDataset(Dataset):
    # df column names
    # ORIGINAL_SEQUENCE = 'original_sequence'
    INDICES_MASKED = 'indices_masked' # input to BERT, token indices of the masked sequence
    # TARGET_INDICES = 'target_indices' # target indices of the masked sequence ## USE IF GENOTYPES ARE ALSO MASKED ##
    TARGET_RESISTANCES = 'target_resistances' # resistance of the target antibiotics, what we want to predict
    TOKEN_TYPES = 'token_types' # # 0 for patient info, 1 for genotype, 2 for phenotype
    # if sequences are included
    MASKED_SEQUENCE = 'masked_sequence'
    
    def __init__(
        self,
        df_MM: pd.DataFrame, 
        vocab,
        antibiotics: list,
        specials: dict,
        max_seq_len: int,
        mask_prob: float,
        num_known_ab: int,
        random_state: int = 42,
        include_sequences: bool = False
    ):
        self.random_state = random_state
        np.random.seed(self.random_state)
        
        self.ds_MM = df_MM.reset_index(drop=True)
        assert all(self.ds_MM['num_ab'] > 0), "Dataset contains isolates without phenotypes"
        self.num_samples = self.ds_MM.shape[0]
        self.vocab = vocab
        self.vocab_size = len(self.vocab)
        self.antibiotics = antibiotics
        self.num_ab = len(self.antibiotics)
        self.ab_to_idx = {ab: idx for idx, ab in enumerate(self.antibiotics)}
        self.enc_res = {'S': 0, 'R': 1}
        self.max_seq_len = max_seq_len
        self.CLS, self.PAD, self.MASK, self.UNK = specials.values()
        
        self.mask_prob = mask_prob
        self.num_known_ab = num_known_ab
        assert not (self.num_known_ab and self.mask_prob), "Cannot specify both num_known_ab and mask_prob"
        
        self.include_sequences = include_sequences
        if self.include_sequences:
            self.columns = [self.INDICES_MASKED, self.TARGET_RESISTANCES, self.TOKEN_TYPES,
                            self.MASKED_SEQUENCE]
        else:
            self.columns = [self.INDICES_MASKED, self.TARGET_RESISTANCES, self.TOKEN_TYPES]
    
    
    def __len__(self):
        return self.num_samples
    
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        
        input = torch.tensor(item[self.INDICES_MASKED], dtype=torch.long, device=device)
        target_res = torch.tensor(item[self.TARGET_RESISTANCES], dtype=torch.float32, device=device)
        token_types = torch.tensor(item[self.TOKEN_TYPES], dtype=torch.long, device=device)
        attn_mask = (input != self.vocab[self.PAD]).unsqueeze(0).unsqueeze(1) # one dim for batch, one for heads
        
        if self.include_sequences:
            masked_sequence = item[self.MASKED_SEQUENCE]
            return input, target_res, token_types, attn_mask, masked_sequence
        else:
            return input, target_res, token_types, attn_mask   
    
    
    def prepare_dataset(self):
        geno_sequences = deepcopy(self.ds_MM['genotypes'].tolist())
        pheno_sequences = deepcopy(self.ds_MM['phenotypes'].tolist())
        years = self.ds_MM['year'].astype(str).tolist()
        countries = self.ds_MM['country'].tolist()
        
        masked_pheno_sequences, target_resistances = self._mask_pheno_sequences(pheno_sequences)
        pheno_token_types = [[2]*len(seq) for seq in masked_pheno_sequences]
        
        geno_token_types = [[1]*len(seq) for seq in geno_sequences]
        seq_starts = [[self.CLS, years[i], countries[i]] for i in range(self.num_samples)]
        
        # combine sequences and pad
        masked_sequences = [seq_starts[i] + geno_sequences[i] + masked_pheno_sequences[i] for i in range(self.num_samples)]
        masked_sequences = [seq + [self.PAD]*(self.max_seq_len - len(seq)) for seq in masked_sequences]
        indices_masked = [self.vocab.lookup_indices(seq) for seq in masked_sequences]
        
        token_types = [[0]*3 + geno_token_types[i] + pheno_token_types[i] for i in range(self.num_samples)]
        token_types = [seq + [2]*(self.max_seq_len - len(seq)) for seq in token_types]
        
        if self.include_sequences:
            rows = zip(indices_masked, target_resistances, token_types, masked_sequences)
        else:
            rows = zip(indices_masked, target_resistances, token_types)
        self.df = pd.DataFrame(rows, columns=self.columns)
         
    
    def _mask_pheno_sequences(self, pheno_sequences):
        masked_pheno_sequences = list()
        target_resistances = list()

        if self.mask_prob:
            for pheno_seq in pheno_sequences:
                seq_len = len(pheno_seq)
                token_mask = np.random.rand(seq_len) < self.mask_prob
                target_res = [-1]*self.num_ab
                if not token_mask.any():
                    idx = np.random.randint(seq_len)
                    ab, res = pheno_seq[idx].split('_')
                    target_res[self.ab_to_idx[ab]] = self.enc_res[res]  
                    r = np.random.rand()
                    if r < 0.8:
                        pheno_seq[idx] = self.MASK
                    elif r < 0.9:
                        pheno_seq[idx] = self.vocab.lookup_token(np.random.randint(self.vocab_size)) 
                else:
                    for idx in token_mask.nonzero()[0]:
                        ab, res = pheno_seq[idx].split('_')
                        target_res[self.ab_to_idx[ab]] = self.enc_res[res]
                        r = np.random.rand()
                        if r < 0.8:
                            pheno_seq[idx] = self.MASK
                        elif r < 0.9:
                            pheno_seq[idx] = self.vocab.lookup_token(np.random.randint(self.vocab_size))
                masked_pheno_sequences.append(pheno_seq)
                target_resistances.append(target_res)
        else:
            for pheno_seq in pheno_sequences:
                seq_len = len(pheno_seq)
                target_res = [-1]*self.num_ab
                indices = np.random.choice(seq_len, self.num_known_ab, replace=False)
                for idx in indices:
                    ab, res = pheno_seq[idx].split('_')
                    target_res[self.ab_to_idx[ab]] = self.enc_res[res]
                    r = np.random.rand()
                    if r < 0.8:
                        pheno_seq[idx] = self.MASK
                    elif r < 0.9:
                        pheno_seq[idx] = self.vocab.lookup_token(np.random.randint(self.vocab_size))
                masked_pheno_sequences.append(pheno_seq)
                target_resistances.append(target_res)    
        return masked_pheno_sequences, target_resistances
        

### Fine-tuner

In [None]:
class MMBertFineTuner():
    
    def __init__(
        self,
        config: dict,
        model,
        antibiotics: list,
        train_set,
        val_set,
        results_dir: Path
    ):
        super(MMBertFineTuner, self).__init__()
        
        config_ft = config["fine_tune"]
        self.random_state = config_ft['random_state']
        np.random.seed(self.random_state)
        torch.manual_seed(self.random_state)
        torch.cuda.manual_seed(self.random_state)
        
        self.model = model
        self.project_name = config_ft["project_name"]
        self.wandb_name = config_ft["name"] if config_ft["name"] else datetime.now().strftime("%Y%m%d-%H%M%S")
        self.antibiotics = antibiotics
        self.num_ab = len(self.antibiotics) 
        
        self.train_set, self.train_size = train_set, len(train_set)
        self.val_set, self.val_size = val_set, len(val_set) 
        assert round(self.val_size / (self.train_size + self.val_size), 2) == config_ft["val_share"], "Validation set size does not match intended val_share"
        self.val_share, self.train_share = config_ft["val_share"], 1 - config_ft["val_share"]
        self.batch_size = config_ft["batch_size"]
        self.num_batches = round(self.train_size / self.batch_size)
        self.vocab = self.train_set.vocab
         
        self.lr = config_ft["lr"]
        self.weight_decay = config_ft["weight_decay"]
        self.epochs = config_ft["epochs"]
        self.patience = config_ft["early_stopping_patience"]
        self.save_model_ = config_ft["save_model"]
        
        self.mask_prob = self.train_set.mask_prob
        self.num_known_ab = self.train_set.num_known_ab
        
        self.ab_criterions = [nn.BCEWithLogitsLoss().to(device) for _ in range(self.num_ab)] # the list is so that we can introduce individual weights
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        self.scheduler = None
        # self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.9)
        # self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98)
                 
        self.current_epoch = 0
        self.report_every = config_ft["report_every"] 
        self.print_progress_every = config_ft["print_progress_every"]
        self._splitter_size = 70
        self.results_dir = results_dir
        if self.results_dir:
            self.results_dir.mkdir(parents=True, exist_ok=True)
            
    def print_model_summary(self):        
        print("Model summary:")
        print("="*self._splitter_size)
        print(f"Is pre-trained: {'Yes' if self.model.is_pretrained else 'No'}")
        print(f"Embedding dim: {self.model.emb_dim}")
        print(f"Feed-forward dim: {self.model.ff_dim}")
        print(f"Hidden dim: {self.model.hidden_dim}")
        print(f"Number of heads: {self.model.num_heads}")
        print(f"Number of encoder layers: {self.model.num_layers}")
        print(f"Dropout probability: {self.model.dropout_prob:.0%}")
        print(f"Max sequence length: {self.model.max_seq_len}")
        print(f"Vocab size: {len(self.vocab):,}")
        print(f"Number of parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}")
        print("="*self._splitter_size)
        
    
    def print_trainer_summary(self):
        print("Trainer summary:")
        print("="*self._splitter_size)
        if device.type == "cuda":
            print(f"Device: {device} ({torch.cuda.get_device_name(0)})")
        else:
            print(f"Device: {device}")        
        print(f"Training dataset size: {self.train_size:,}")
        print(f"Batch size: {self.batch_size}")
        print(f"Number of batches: {self.num_batches:,}")
        print(f"Number of antibiotics: {self.num_ab}")
        print(f"Antibiotics: {self.antibiotics}")
        print(f"CV split: {self.train_share:.0%} train | {self.val_share:.0%} val")
        if self.mask_prob:
            print(f"Mask probability: {self.mask_prob:.0%}")
        if self.num_known_ab:
            print(f"Number of known antibiotics: {self.num_known_ab}")
        print(f"Number of epochs: {self.epochs}")
        print(f"Early stopping patience: {self.patience}")
        print(f"Learning rate: {self.lr}")
        print(f"Weight decay: {self.weight_decay}")
        print("="*self._splitter_size)
    
    def __call__(self):      
        if not self.model.pheno_only:
            self.model.pheno_only = True
        self.wandb_run = self._init_wandb()
        self.val_set.prepare_dataset()
        self.val_loader = DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False)
        
        start_time = time.time()
        self.best_val_loss = float('inf') 
        self._init_result_lists()
        for self.current_epoch in range(self.current_epoch, self.epochs):
            self.model.train()
            self.train_set.prepare_dataset()
            self.train_loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
            epoch_start_time = time.time()
            train_loss = self.train(self.current_epoch) # returns loss, averaged over batches
            self.losses.append(train_loss)
            print(f"Epoch completed in {(time.time() - epoch_start_time)/60:.1f} min | Loss: {train_loss:.4f}")
            print("Evaluating on validation set...")
            val_results = self.evaluate(self.val_loader, self.val_set)
            s = f"Val loss: {val_results['loss']:.4f}"
            s += f" | Accuracy {val_results['acc']:.2%} | Isolate accuracy {val_results['iso_acc']:.2%}"
            print(s)
            print(f"Elapsed time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time))}")
            print("="*self._splitter_size)
            self._update_val_lists(val_results)
            self._report_epoch_results()
            early_stop = self.early_stopping()
            if early_stop:
                print(f"Early stopping at epoch {self.current_epoch+1} with validation loss {self.val_losses[-1]:.4f}")
                print(f"Validation stats at best epoch ({self.best_epoch+1}):")
                s = f"Loss: {self.val_losses[self.best_epoch]:.4f}" 
                s += f" | Accuracy: {self.val_accs[self.best_epoch]:.2%}"
                s += f" | Isolate accuracy: {self.val_iso_accs[self.best_epoch]:.2%}"
                print(s)
                self.wandb_run.log({
                    "Losses/final_val_loss": self.best_val_loss, 
                    "Accuracies/final_val_acc": self.val_accs[self.best_epoch],
                    "Accuracies/final_val_iso_acc": self.val_iso_accs[self.best_epoch],
                    "final_epoch": self.best_epoch+1
                })
                self.model.load_state_dict(self.best_model_state) 
                self.current_epoch = self.best_epoch
                break
            if self.scheduler:
                self.scheduler.step()
        if not early_stop:    
            self.wandb_run.log({
                    "Losses/final_val_loss": self.best_val_loss, 
                    "Accuracies/final_val_acc": self.val_accs[-1],
                    "Accuracies/final_val_iso_acc": self.val_iso_accs[-1],
                    "final_epoch": self.current_epoch+1
                })
        self.model.is_pretrained = True
        if self.save_model_:
            self.save_model(self.results_dir / "model_state.pt") 
        train_time = (time.time() - start_time)/60
        self.wandb_run.log({"Training time (min)": train_time})
        disp_time = f"{train_time//60:.0f}h {train_time % 60:.1f} min" if train_time > 60 else f"{train_time:.1f} min"
        print(f"Training completed in {disp_time}")
        print("="*self._splitter_size)
        if not early_stop:
            print("Final validation stats:")
            s = f"Loss: {self.val_losses[-1]:.4f}"
            s = f" | Accuracy: {self.val_accs[-1]:.2%}"
            s += f" | Isolate accuracy: {self.val_iso_accs[-1]:.2%}"
            print(s)
        
        results = {
            "best epoch": self.best_epoch,
            "train_losses": self.losses,
            "val_losses": self.val_losses,
            "val_accs": self.val_accs,
            "val_iso_accs": self.val_iso_accs,
            "train_time": train_time,
            "val_iso_stats": self.val_iso_stats,
            "val_ab_stats": self.val_ab_stats
        }
        return results
    
    
    def train(self, epoch: int):
        print(f"Epoch {epoch+1}/{self.epochs}")
        time_ref = time.time()
        
        epoch_loss, reporting_loss, printing_loss = 0, 0, 0
        for i, batch in enumerate(self.train_loader):
            batch_index = i + 1
            self.optimizer.zero_grad() # zero out gradients
            
            input, target_res, token_types, attn_mask = batch 
            # input, target_indices, target_res, token_types, attn_mask, masked_sequences = batch   
            pred_logits = self.model(input, token_types, attn_mask) # get predictions for all antibiotics
            ab_mask = target_res != -1 # (batch_size, num_ab), True if antibiotic is masked, False otherwise
            
            ab_indices = ab_mask.any(dim=0).nonzero().squeeze(-1).tolist() # list of indices of antibiotics present in the batch
            losses = list()
            for j in ab_indices: 
                mask = ab_mask[:, j] # (batch_size,), indicates which samples contain the antibiotic masked
                # isolate the predictions and targets for the antibiotic
                ab_pred_logits = pred_logits[mask, j] # (num_masked_samples,)
                ab_targets = target_res[mask, j] # (num_masked_samples,)
                ab_loss = self.ab_criterions[j](ab_pred_logits, ab_targets)
                losses.append(ab_loss)
            loss = sum(losses) / len(losses) # average loss over antibiotics
            epoch_loss += loss
            reporting_loss += loss
            printing_loss += loss
            
            loss.backward() 
            self.optimizer.step() 
            if batch_index % self.report_every == 0:
                self._report_loss_results(batch_index, reporting_loss)
                reporting_loss = 0 
                
            if batch_index % self.print_progress_every == 0:
                time_elapsed = time.gmtime(time.time() - time_ref) 
                self._print_loss_summary(time_elapsed, batch_index, printing_loss) 
                printing_loss = 0  
        avg_epoch_loss = epoch_loss / self.num_batches
        return avg_epoch_loss 
    
    
    def early_stopping(self):
        if self.val_losses[-1] < self.best_val_loss:
            self.best_val_loss = self.val_losses[-1]
            self.best_epoch = self.current_epoch
            self.best_model_state = self.model.state_dict()
            self.early_stopping_counter = 0
            return False
        else:
            self.early_stopping_counter += 1
            return True if self.early_stopping_counter >= self.patience else False
        
            
    def evaluate(self, loader: DataLoader, ds_obj):
        self.model.eval()
        # prepare evaluation statistics dataframes
        ab_stats, iso_stats = self._init_eval_stats(ds_obj)
        with torch.no_grad(): 
            ## Antibiotic tracking ##
            ab_num = np.zeros((self.num_ab, 2)) # tracks the occurence for each antibiotic & resistance
            ab_num_preds = np.zeros_like(ab_num) # tracks the number of predictions for each antibiotic & resistance
            ab_num_correct = np.zeros_like(ab_num) # tracks the number of correct predictions for each antibiotic & resistance
            ## General tracking ##
            loss = 0
            for i, batch in enumerate(loader):                
                input, target_res, token_types, attn_mask = batch   
                 
                pred_logits = self.model(input, token_types, attn_mask) # get predictions for all antibiotics
                pred_res = torch.where(pred_logits > 0, torch.ones_like(pred_logits), torch.zeros_like(pred_logits)) # logits -> 0/1 (S/R)
                        
                ab_mask = target_res >= 0 # (batch_size, num_ab), True if antibiotic is masked, False otherwise
                iso_stats = self._update_pheno_stats(i, pred_res, target_res, ab_mask, iso_stats)
                
                ab_indices = ab_mask.any(dim=0).nonzero().squeeze(-1).tolist() # list of indices of antibiotics present in the batch
                losses = list()
                for j in ab_indices: 
                    mask = ab_mask[:, j] # (batch_size,)
                    
                    # isolate the predictions and targets for the antibiotic
                    ab_pred_logits = pred_logits[mask, j] # (num_masked_samples,)
                    ab_targets = target_res[mask, j] # (num_masked_samples,)
                    num_R = ab_targets.sum().item()
                    num_S = ab_targets.shape[0] - num_R
                    ab_num[j, :] += [num_S, num_R]
                    
                    ab_loss = self.ab_criterions[j](ab_pred_logits, ab_targets)
                    losses.append(ab_loss)
                    
                    ab_pred_res = pred_res[mask, j]
                    ab_num_correct[j, :] += self._get_num_correct(ab_pred_res, ab_targets)    
                    ab_num_preds[j, :] += self._get_num_preds(ab_pred_res)
                loss += sum(losses) / len(losses) # average loss over antibiotics
                    
        avg_loss = loss / len(loader)
        
        ab_stats = self._update_ab_eval_stats(ab_stats, ab_num, ab_num_preds, ab_num_correct)
        iso_stats = self._calculate_iso_stats(iso_stats)
        
        acc = iso_stats['num_correct'].sum() / iso_stats['num_masked'].sum()
        iso_acc = iso_stats['all_correct'].sum() / iso_stats.shape[0]

        results = {
            "loss": avg_loss, 
            "acc": acc,
            "iso_acc": iso_acc,
            "ab_stats": ab_stats,
            "iso_stats": iso_stats,
        }
        return results
            
    
    def _init_result_lists(self):
        self.losses = []
        self.val_losses = []
        self.val_accs = []
        self.val_iso_accs = []
        self.val_ab_stats = []
        self.val_iso_stats = []
        
        
    def _update_val_lists(self, results: dict):
        self.val_losses.append(results["loss"])
        self.val_accs.append(results["acc"])
        self.val_iso_accs.append(results["iso_acc"])
        self.val_ab_stats.append(results["ab_stats"])
        self.val_iso_stats.append(results["iso_stats"])
    
    
    def _init_eval_stats(self, ds_obj):
        ab_stats = pd.DataFrame(columns=[
            'antibiotic', 'num_tot', 'num_S', 'num_R', 'num_pred_S', 'num_pred_R', 
            'num_correct', 'num_correct_S', 'num_correct_R',
            'accuracy', 'sensitivity', 'specificity', 'precision', 'F1'
        ])
        ab_stats['antibiotic'] = self.antibiotics
        ab_stats['num_tot'], ab_stats['num_S'], ab_stats['num_R'] = 0, 0, 0
        ab_stats['num_pred_S'], ab_stats['num_pred_R'] = 0, 0
        ab_stats['num_correct'], ab_stats['num_correct_S'], ab_stats['num_correct_R'] = 0, 0, 0
        
        iso_stats = ds_obj.ds_MM.drop(columns=['genotypes', 'phenotypes'])
        iso_stats['num_masked'], iso_stats['num_masked_S'], iso_stats['num_masked_R'] = 0, 0, 0
        iso_stats['num_correct'], iso_stats['correct_S'], iso_stats['correct_R'] = 0, 0, 0
        iso_stats['sensitivity'], iso_stats['specificity'], iso_stats['accuracy'] = 0, 0, 0
        iso_stats['all_correct'] = False  
        return ab_stats, iso_stats
    
    
    def _update_ab_eval_stats(self, ab_stats: pd.DataFrame, num, num_preds, num_correct):
        for j in range(self.num_ab): 
            ab_stats.loc[j, 'num_tot'] = num[j, :].sum()
            ab_stats.loc[j, 'num_S'], ab_stats.loc[j, 'num_R'] = num[j, 0], num[j, 1]
            ab_stats.loc[j, 'num_pred_S'], ab_stats.loc[j, 'num_pred_R'] = num_preds[j, 0], num_preds[j, 1]
            ab_stats.loc[j, 'num_correct'] = num_correct[j, :].sum()
            ab_stats.loc[j, 'num_correct_S'], ab_stats.loc[j, 'num_correct_R'] = num_correct[j, 0], num_correct[j, 1]
        ab_stats['accuracy'] = ab_stats.apply(
            lambda row: row['num_correct']/row['num_tot'] if row['num_tot'] > 0 else np.nan, axis=1)
        ab_stats['sensitivity'] = ab_stats.apply(
            lambda row: row['num_correct_R']/row['num_R'] if row['num_R'] > 0 else np.nan, axis=1)
        ab_stats['specificity'] = ab_stats.apply(
            lambda row: row['num_correct_S']/row['num_S'] if row['num_S'] > 0 else np.nan, axis=1)
        ab_stats['precision'] = ab_stats.apply(
            lambda row: row['num_correct_R']/row['num_pred_R'] if row['num_pred_R'] > 0 else np.nan, axis=1)
        ab_stats['F1'] = ab_stats.apply(
            lambda row: 2*row['precision']*row['sensitivity']/(row['precision']+row['sensitivity']) 
            if row['precision'] > 0 and row['sensitivity'] > 0 else np.nan, axis=1)
        return ab_stats
    
    
    def _get_num_correct(self, pred_res: torch.Tensor, target_res: torch.Tensor):
        eq = torch.eq(pred_res, target_res)
        num_correct_S = eq[target_res == 0].sum().item()
        num_correct_R = eq[target_res == 1].sum().item()
        return [num_correct_S, num_correct_R]
    
    
    def _get_num_preds(self, pred_res: torch.Tensor):
        num_pred_S = (pred_res == 0).sum().item()
        num_pred_R = (pred_res == 1).sum().item()
        return [num_pred_S, num_pred_R]
    
    
    def _update_pheno_stats(self, batch_idx, pred_res: torch.Tensor, target_res: torch.Tensor, 
                          ab_mask: torch.Tensor, iso_stats: pd.DataFrame):
        for i in range(pred_res.shape[0]): 
            iso_ab_mask = ab_mask[i]
            df_idx = batch_idx * self.batch_size + i # index of the isolate in the combined dataset
            
            # counts
            num_masked_tot = iso_ab_mask.sum().item()
            num_masked_R = target_res[i][iso_ab_mask].sum().item()
            num_masked_S = num_masked_tot - num_masked_R
            
            # statistics            
            iso_target_res = target_res[i][iso_ab_mask]
            eq = torch.eq(pred_res[i][iso_ab_mask], iso_target_res)
            num_correct_R = eq[iso_target_res == 1].sum().item()
            num_correct_S = eq[iso_target_res == 0].sum().item()
            num_correct = num_correct_S + num_correct_R
            all_correct = eq.all().item()
            
            data = {
                'num_masked': num_masked_tot, 'num_masked_S': num_masked_S, 'num_masked_R': num_masked_R, 
                'num_correct': num_correct, 'correct_S': num_correct_S, 'correct_R': num_correct_R,
                'all_correct': all_correct
            }
            iso_stats.loc[df_idx, data.keys()] = data.values()
                          
        return iso_stats
    
    def _calculate_iso_stats(self, iso_stats: pd.DataFrame): 
        iso_stats['accuracy'] = iso_stats['num_correct'] / iso_stats['num_masked']
        iso_stats['sensitivity'] = iso_stats.apply(
            lambda row: row['correct_R']/row['num_masked_R'] if row['num_masked_R'] > 0 else np.nan, axis=1
        )
        iso_stats['specificity'] = iso_stats.apply(
            lambda row: row['correct_S']/row['num_masked_S'] if row['num_masked_S'] > 0 else np.nan, axis=1
        )
        
        return iso_stats
        
     
    def _init_wandb(self):
        print("Initializing wandb...")
        self.wandb_run = wandb.init(
            project=self.project_name, # name of the project
            name=self.wandb_name, # name of the run
            
            config={
                "epochs": self.epochs,
                "batch_size": self.batch_size,
                "hidden_dim": self.model.hidden_dim,
                "num_layers": self.model.num_layers,
                "num_heads": self.model.num_heads,
                "emb_dim": self.model.emb_dim,
                'ff_dim': self.model.ff_dim,
                "lr": self.lr,
                "weight_decay": self.weight_decay,
                "mask_prob": self.mask_prob,
                "num_known_ab": self.num_known_ab,
                "max_seq_len": self.model.max_seq_len,
                "vocab_size": len(self.vocab),
                "num_parameters": sum(p.numel() for p in self.model.parameters() if p.requires_grad),
                "num_antibiotics": self.num_ab,
                "antibiotics": self.antibiotics,
                "train_size": self.train_size,
                "random_state": self.random_state,
                'val_share': self.val_share,
                "val_size": self.val_size,
                "is_pretrained": self.model.is_pretrained,
                # "early_stopping_patience": self.patience,
                # "dropout_prob": self.model.dropout_prob,
            }
        )
        self.wandb_run.watch(self.model) # watch the model for gradients and parameters
        self.wandb_run.define_metric("epoch", hidden=True)
        self.wandb_run.define_metric("batch", hidden=True)
        
        self.wandb_run.define_metric("Losses/live_loss", step_metric="batch")
        self.wandb_run.define_metric("Losses/train_loss", summary="min", step_metric="epoch")
        self.wandb_run.define_metric("Losses/val_loss", summary="min", step_metric="epoch")
        self.wandb_run.define_metric("Accuracies/val_acc", summary="max", step_metric="epoch")
        self.wandb_run.define_metric("Accuracies/val_iso_acc", summary="max", step_metric="epoch")
        
        self.wandb_run.define_metric("Losses/final_val_loss")
        self.wandb_run.define_metric("Accuracies/final_val_acc")
        self.wandb_run.define_metric("Accuracies/final_val_iso_acc")
        
        self.wandb_run.define_metric("final_epoch")

        return self.wandb_run
     
    def _report_epoch_results(self):
        wandb_dict = {
            "epoch": self.current_epoch+1,
            "Losses/train_loss": self.losses[-1],
            "Losses/val_loss": self.val_losses[-1],
            # "Losses/val_geno_loss": self.val_geno_losses[-1],
            "Losses/val_loss": self.val_losses[-1],
            "Accuracies/val_acc": self.val_accs[-1],
            "Accuracies/val_iso_acc": self.val_iso_accs[-1],
            # "Accuracies/val_geno_acc": self.val_geno_accs[-1],
            # "Accuracies/val_geno_iso_acc": self.val_geno_iso_accs[-1],
        }
        self.wandb_run.log(wandb_dict)
    
        
    def _report_loss_results(self, batch_index, tot_loss):
        avg_loss = tot_loss / self.report_every
        
        global_step = self.current_epoch * self.num_batches + batch_index # global step, total #batches seen
        self.wandb_run.log({"batch": global_step, "Losses/live_loss": avg_loss})
    
        
    def _print_loss_summary(self, time_elapsed, batch_index, tot_loss):
        progress = batch_index / self.num_batches
        mlm_loss = tot_loss / self.print_progress_every
          
        s = f"{time.strftime('%H:%M:%S', time_elapsed)}" 
        s += f" | Epoch: {self.current_epoch+1}/{self.epochs} | {batch_index}/{self.num_batches} ({progress:.2%}) | "\
                f"Loss: {mlm_loss:.4f}"
        print(s)
    
    
    def save_model(self, savepath: Path):
        print(type(self.best_model_state))
        torch.save(self.best_model_state, savepath)
        print(f"Model saved to {savepath}")
        print("="*self._splitter_size)
        
        
    def load_model(self, savepath: Path):
        print("="*self._splitter_size)
        print(f"Loading model from {savepath}")
        self.model.load_state_dict(torch.load(savepath))
        self.model.to(device)
        print("Model loaded")
        print("="*self._splitter_size)
    

### Main

In [None]:
from utils import get_split_indices, export_results

config_ft = config['fine_tune']
print("\n Fine-tuning BERT on the multimodal dataset \n")
assert config_ft['ds_path'], "Please specify the path to the pre-processed NCBI dataset"
ds_NCBI = pd.read_pickle(config_ft['ds_path'])
ds_MM = ds_NCBI[ds_NCBI['num_ab'] > 0].reset_index(drop=True)
pad_token = specials['PAD']
ds_MM.fillna(pad_token, inplace=True)
pad_idx = vocab[pad_token]

os.environ['WANDB_MODE'] = config['wandb_mode']

train_indices, val_indices = get_split_indices(
    ds_MM.shape[0], 
    val_share=config_ft['val_share'], 
    random_state=config['random_state']
)
ds_ft_train = MMFinetuneDataset(
    df_MM=ds_MM.iloc[train_indices],
    vocab=vocab,
    antibiotics=antibiotics,
    specials=specials,
    max_seq_len=max_seq_len,
    mask_prob=config_ft['mask_prob'],
    num_known_ab=config_ft['num_known_ab'],
    random_state=config_ft['random_state']
)
ds_ft_val = MMFinetuneDataset(
    df_MM=ds_MM.iloc[val_indices],
    vocab=vocab,
    antibiotics=antibiotics,
    specials=specials,
    max_seq_len=max_seq_len,
    mask_prob=config_ft['mask_prob'],
    num_known_ab=config_ft['num_known_ab'],
    random_state=config_ft['random_state']
)
# set bert in pheno_only mode
bert = BERT(config, vocab_size, max_seq_len, len(antibiotics), pad_idx, pheno_only=True)
    
tuner = MMBertFineTuner(
    config=config,
    model=bert,
    antibiotics=antibiotics,
    train_set=ds_ft_train,
    val_set=ds_ft_val,
    results_dir=results_dir,
)
tuner.load_model(config_ft['model_path'])
tuner.print_model_summary()
tuner.print_trainer_summary()
ft_results = tuner()
export_results(ft_results, results_dir / 'pt_results.pkl')


 Fine-tuning BERT on the multimodal dataset 

Loading model from c:\Users\jespe\Documents\GitHub_local\ARFusion\results\MM\test_run_with_ft\model_state.pt
Model loaded
Model summary:
Is pre-trained: No
Embedding dim: 256
Feed-forward dim: 256
Hidden dim: 256
Number of heads: 4
Number of encoder layers: 6
Dropout probability: 10%
Max sequence length: 40
Vocab size: 1,546
Number of parameters: 3,176,202
Trainer summary:
Device: cuda (NVIDIA GeForce RTX 3080)
Training dataset size: 5,473
Batch size: 16
Number of batches: 342
Number of antibiotics: 18
Antibiotics: ['CIP', 'NAL', 'MFX', 'AMC', 'NOR', 'AMP', 'CTX', 'TOB', 'NET', 'CRO', 'FEP', 'CAZ', 'OFX', 'TZP', 'AMX', 'GEN', 'LVX', 'PIP']
CV split: 85% train | 15% val
Mask probability (phenotypes): 25%
Number of epochs: 100
Early stopping patience: 3
Learning rate: 1e-05
Weight decay: 0.01
Initializing wandb...
Epoch 1/100
Epoch completed in 0.1 min | Loss: 0.4290
Evaluating on validation set...
Val loss: 0.3226 | Accuracy 86.69% | Isolat

KeyboardInterrupt: 