# Multimodal Model Implementation

In [1]:
import os
import torch
import yaml
import wandb
import argparse
import pandas as pd
import time

from datetime import datetime
from pathlib import Path

# user-defined modules
from trainers import BertMLMTrainer

# user-defined functions
# from construct_vocab import construct_MM_vocab
from data_preprocessing import preprocess_NCBI, preprocess_TESSy
from utils import get_split_indices

# set the base directory
base_dir = Path(os.path.abspath('')).parent
os.chdir(base_dir)
print("base directory:", base_dir)

base directory: C:\Users\jespe\Documents\GitHub_local\ARFusion


In [2]:
with open(base_dir / 'config_MM.yaml') as f:
    config = yaml.safe_load(f)

## Dataset

### Prepare unimodal datasets

In [3]:
data_dict = config['data']
if data_dict['TESSy']['prepare_data']:
    ds_TESSy = preprocess_TESSy(
        path=data_dict['TESSy']['raw_path'],
        pathogens=data_dict['pathogens'],
        save_path=data_dict['TESSy']['save_path'],
        exclude_antibiotics=data_dict['exclude_antibiotics'],
        impute_age=data_dict['TESSy']['impute_age'],
        impute_gender=data_dict['TESSy']['impute_gender']
    )
else:
    ds_TESSy = pd.read_pickle(os.path.join(base_dir, data_dict['TESSy']['load_path']))
ds_TESSy.head()
num_TESSy = len(ds_TESSy)

In [4]:
ds_NCBI = preprocess_NCBI(
    path=data_dict['NCBI']['raw_path'],
    include_phenotype=data_dict['NCBI']['include_phenotype'],
    ab_names_to_abbr=data_dict['antibiotics']['ab_names_to_abbr'],
    exclude_antibiotics=data_dict['exclude_antibiotics'], 
    threshold_year=data_dict['NCBI']['threshold_year'],
    exclude_genotypes=data_dict['NCBI']['exclude_genotypes'],
    exclude_assembly_variants=data_dict['NCBI']['exclude_assembly_variants'],
    exclusion_chars=data_dict['NCBI']['exclusion_chars'],
    gene_count_threshold=data_dict['NCBI']['gene_count_threshold']
)
ds_NCBI[ds_NCBI['num_ab'] > 0].head()
num_NCBI = len(ds_NCBI)

Parsing phenotypes...
Parsing genotypes...
Number of isolates before parsing: 341,565
Removing 253 isolates with year < 1970
Removing genotypes with assembly variants: ['=PARTIAL', '=MISTRANSLATION', '=HMM']
Dropping 52 isolates with more than 35 genotypes
Number of isolates after parsing: 339,349
Number of isolates with phenotype info after parsing: 339,349


### Vocabulary construction

In [6]:
from itertools import chain
from collections import Counter
from torchtext.vocab import vocab as Vocab

def construct_MM_vocab(df_geno: pd.DataFrame,
                       df_pheno: pd.DataFrame,
                       antibiotics: list,
                       specials: dict,
                       savepath_vocab: Path = None):
    token_counter = Counter()
    ds_geno = df_geno.copy()
    ds_pheno = df_pheno.copy()
    
    CLS, PAD, MASK, UNK = specials['CLS'], specials['PAD'], specials['MASK'], specials['UNK']
    # SEP = specials['SEP']
    special_tokens = specials.values()
    
    year_geno = ds_geno[ds_geno['year'] != PAD]['year'].astype('Int16')
    min_year = min(year_geno.min(), ds_pheno['year'].min())
    max_year = max(year_geno.max(), ds_pheno['year'].max())
    year_range = range(min_year, max_year+1)
    token_counter.update([str(y) for y in year_range])
    
    min_age, max_age = ds_pheno['age'].min(), ds_pheno['age'].max()
    age_range = range(int(min_age), int(max_age+1))
    token_counter.update([str(a) for a in age_range])
    
    genders = ds_pheno['gender'].unique().astype(str).tolist()
    token_counter.update(genders)
    
    country_code_to_name = data_dict['TESSy']['country_code_to_name']
    pheno_countries = ds_pheno['country'].sort_values().map(country_code_to_name).unique()
    geno_countries = ds_geno['country'].sort_values().dropna().unique()
    countries = set(pheno_countries).union(set(geno_countries))
    token_counter.update(countries)
    print(list(countries))
    
    token_counter.update(list(chain(*ds_geno['genotypes'])))
    token_counter.update([ab + '_' + res for ab in antibiotics for res in ['R', 'S']])  
    print(token_counter)  
    
    vocab = Vocab(token_counter, specials=special_tokens)
    vocab.set_default_index(vocab[UNK])
    if savepath_vocab:
        torch.save(vocab, savepath_vocab)

antibiotics = set(data_dict['antibiotics']['abbr_to_names'].keys()) - set(data_dict['exclude_antibiotics'])
vocab = construct_MM_vocab(
    df_geno=ds_NCBI,
    df_pheno=ds_TESSy,
    antibiotics=antibiotics,
    specials=config['specials'],
)

['Finland', 'India', 'Niger', 'Portugal', 'Costa Rica', 'Kazakhstan', 'South Africa', 'Greece', 'Sweden', 'Puerto Rico', 'Iceland', 'Gambia', 'Switzerland', 'Taiwan', 'Senegal', 'Hong Kong', 'Japan', 'Thailand', 'Lithuania', 'Gaza Strip', 'Zambia', 'Russia', 'Croatia', 'Sri Lanka', 'Argentina', 'Benin', 'Antarctica', 'Slovakia', 'French Guiana', 'Singapore', 'Estonia', 'Jamaica', 'Papua New Guinea', 'DRC', 'Dominican Republic', 'Ukraine', 'Philippines', 'Georgia', 'Lebanon', 'Mozambique', 'Madagascar', 'Nepal', 'Chile', 'Burkina Faso', 'Bangladesh', 'Mexico', 'Guam', 'Algeria', 'Turkey', 'Greenland', 'Ethiopia', 'Norway', 'Armenia', 'Ecuador', 'New Zealand', 'Malawi', 'Uzbekistan', 'Paraguay', 'Peru', 'Somalia', 'Spain', 'Latvia', 'Canada', 'Cambodia', "Cote d'Ivoire", 'Burundi', 'Tonga', 'Tanzania', 'Oman', 'Kenya', 'Myanmar', 'Saudi Arabia', 'Kosovo', 'Romania', 'Tunisia', 'Gabon', 'Pacific Ocean', 'Guyana', 'Denmark', 'Czechia', 'Togo', 'China', 'Uganda', 'South Korea', 'Djibouti', 

### Create multimodal dataset

In [None]:
import numpy as np
import torch
import pandas as pd

from copy import deepcopy
from itertools import chain
from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class PhenotypeDataset(Dataset):      
    # df column names
    INDICES_MASKED = 'indices_masked' # input to BERT, token indices of the masked sequence
    TARGET_RESISTANCES = 'target_resistances' # resistance of the target antibiotics, what we want to predict
    TOKEN_MASK = 'token_mask' # True if token is masked, False otherwise
    AB_MASK = 'ab_mask' # True if antibiotic is masked, False otherwise
    # # if original text is included
    ORIGINAL_SEQUENCE = 'original_sequence'
    MASKED_SEQUENCE = 'masked_sequence'
    
    def __init__(self,
                 ds: pd.DataFrame,
                 vocab,
                 antibiotics: list,
                 specials: dict,
                 max_seq_len: int,
                 include_sequences: bool = False,
                 random_state: int = 42,
                 ):
        
        self.random_state = random_state
        np.random.seed(self.random_state)
        
        self.ds = ds.reset_index(drop=True) 
        self.original_ds = deepcopy(self.ds) 
        tot_pheno = self.ds['num_ab'].sum()
        tot_S = self.ds['num_S'].sum()
        tot_R = self.ds['num_R'].sum()
        print(f"Proportion of S/R {tot_S / tot_pheno:.1%}/{tot_R / tot_pheno:.1%}")
        self.num_samples = self.ds.shape[0]
        self.vocab = vocab
        self.antibiotics = antibiotics
        self.num_ab = len(self.antibiotics)
        self.ab_to_idx = {ab: i for i, ab in enumerate(self.antibiotics)}
        self.enc_res = {'S': 0, 'R': 1}
        self.vocab_size = len(self.vocab)
        self.CLS = specials['CLS']
        self.PAD = specials['PAD']
        self.MASK = specials['MASK']
        self.UNK = specials['UNK']
        self.special_tokens = specials.values()
        self.max_seq_len = max_seq_len
           
        self.include_sequences = include_sequences
        if self.include_sequences:
            self.columns = [self.INDICES_MASKED, self.TARGET_RESISTANCES, self.TOKEN_MASK, self.AB_MASK,
                            self.ORIGINAL_SEQUENCE, self.MASKED_SEQUENCE]
        else: 
            self.columns = [self.INDICES_MASKED, self.TARGET_RESISTANCES, self.TOKEN_MASK, self.AB_MASK]        
        
        
    def __len__(self):
        return self.num_samples
    
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        
        input = torch.tensor(item[self.INDICES_MASKED], dtype=torch.long, device=device)
        target_res = torch.tensor(item[self.TARGET_RESISTANCES], dtype=torch.float32, device=device)
        token_mask = torch.tensor(item[self.TOKEN_MASK], dtype=torch.bool, device=device)
        ab_mask = torch.tensor(item[self.AB_MASK], dtype=torch.bool, device=device)
        attn_mask = (input != self.vocab[self.PAD]).unsqueeze(0).unsqueeze(1) # one dim for batch, one for heads
        
        if self.include_sequences:
            original_sequence = item[self.ORIGINAL_SEQUENCE]
            masked_sequence = item[self.MASKED_SEQUENCE]
            return input, target_res, token_mask, ab_mask, attn_mask, original_sequence, masked_sequence
        else:
            return input, target_res, token_mask, ab_mask, attn_mask

       
    def prepare_dataset(self,  mask_prob: float = None, num_known_ab: int = None): # will be called at the start of each epoch (dynamic masking)
        ## IT IS PROBABLY MORE EFFICIENT TO DO THIS IN THE PREPROCESSING STEP, GIVEN MASKING METHOD IS CONSTANT ACROSS EPOCHS
        if num_known_ab:
            print(f"Preparing dataset for masking with {num_known_ab} known antibiotics")
            self.num_known_ab = num_known_ab
            self.mask_prob = None
            self.ds = self.original_ds[self.original_ds['num_ab'] > self.num_known_ab].reset_index(drop=True)
            self.num_samples = self.ds.shape[0]
            print(f"Dropping {self.original_ds.shape[0] - self.num_samples} samples with less than {self.num_known_ab+1} antibiotics")
            tot_pheno = self.ds['num_ab'].sum()
            tot_S = self.ds['num_S'].sum()
            tot_R = self.ds['num_R'].sum()
            print(f"Now {self.num_samples} samples left, S/R proportion: {tot_S/tot_pheno:.1%}/{tot_R / tot_pheno:.1%}")
        else:
            print(f"Preparing dataset for masking with mask_prob = {mask_prob}")
            self.mask_prob = mask_prob
            self.num_known_ab = None
            
        sequences, masked_sequences, target_resistances, token_masks, ab_masks = self._construct_masked_sequences()
        indices_masked = [self.vocab.lookup_indices(masked_seq) for masked_seq in masked_sequences]
        
        if self.include_sequences:
            rows = zip(indices_masked, target_resistances, token_masks, ab_masks, sequences, masked_sequences)
        else:
            rows = zip(indices_masked, target_resistances, token_masks, ab_masks)
        self.df = pd.DataFrame(rows, columns=self.columns)

    
    def _encode_sequence(self, seq: list):
        dict = {ab: res for ab, res in [token.split('_') for token in seq]}
        indices = [self.ab_to_idx[ab] for ab in dict.keys()]
        resistances = [self.enc_res[res] for res in dict.values()]
        
        return indices, resistances
    
    
    def _construct_masked_sequences(self):  
        # RoBERTa: 80% -> [MASK], 10% -> original token, 10% -> random token
        
        sequences = deepcopy(self.ds['phenotypes'].tolist())
        masked_sequences = list()
        all_target_resistances = list()
        ab_masks = list() # will be applied to the output of the model, i.e. (batch_size, num_ab)
        token_masks = list() # will be applied to the the sequence itself, i.e. (batch_size, seq_len)
        for seq in deepcopy(sequences):
            seq_len = len(seq)
            
            token_mask = [False] * seq_len # indicates which tokens in the sequence are masked, includes all tokens
            ab_mask = [False] * self.num_ab # will indicate which antibiotics are masked, indexed in the order of self.antibiotics
            target_resistances = [-1]*self.num_ab # -1 indicates padding, will indicate the target resistance, same indexing as ab_mask
            if self.mask_prob:
                tokens_masked = 0
                for i in range(seq_len):
                    if np.random.rand() < self.mask_prob: 
                        ab, res = seq[i].split('_')
                        ab_idx = self.ab_to_idx[ab]
                        tokens_masked += 1
                        r = np.random.rand()
                        if r < 0.8: 
                            seq[i] = self.MASK
                        elif r < 0.9:
                            j = np.random.randint(self.vocab_size-self.num_ab*2, self.vocab_size) # select random pheno token
                            seq[i] = self.vocab.lookup_token(j)
                        # else: do nothing, since r > 0.9 and we keep the same token
                        token_mask[i] = True
                        ab_mask[ab_idx] = True # indicate which antibiotic is masked at this position
                        target_resistances[ab_idx] = self.enc_res[res] # the target resistance of the antibiotic
                if tokens_masked == 0: # mask at least one token
                    i = np.random.randint(seq_len)
                    ab, res = seq[i].split('_')
                    ab_idx = self.ab_to_idx[ab]
                    r = np.random.rand()
                    if r < 0.8: 
                        seq[i] = self.MASK
                    elif r < 0.9:
                        j = np.random.randint(self.vocab_size-self.num_ab*2, self.vocab_size) # select random token, excluding specials
                        seq[i] = self.vocab.lookup_token(j)
                    # else: do nothing, since r > 0.9 and we keep the same token
                    token_mask[i] = True
                    ab_mask[ab_idx] = True # indicate which antibiotic is masked at this position
                    target_resistances[ab_idx] = self.enc_res[res] # the target resistance of the antibiotic
            else:
                # randomly select seq_len - num_known_ab antibiotics to mask
                mask_indices = np.random.choice(seq_len, seq_len - self.num_known_ab, replace=False)
                for i in mask_indices: # implement ROBERTa masking
                    ab, res = seq[i].split('_')
                    ab_idx = self.ab_to_idx[ab]
                    seq[i] = self.MASK
                    token_mask[i] = True
                    ab_mask[ab_idx] = True
                    target_resistances[ab_idx] = self.enc_res[res]
            masked_sequences.append(seq)
            token_masks.append(token_mask)
            ab_masks.append(ab_mask)
            all_target_resistances.append(target_resistances)
        
        for i in range(len(sequences)):
            token_masks[i] = 5*[False] + token_masks[i]
            seq_start = [self.CLS, 
                         str(self.ds['year'].iloc[i]), 
                         self.ds['country'].iloc[i], 
                         self.ds['gender'].iloc[i], 
                         str(int(self.ds['age'].iloc[i]))]
            
            sequences[i][:0] = seq_start
            masked_sequences[i][:0] = seq_start
            
            seq_len = len(sequences[i])
            if seq_len < self.max_seq_len:
                sequences[i].extend([self.PAD] * (self.max_seq_len - seq_len))
                masked_sequences[i].extend([self.PAD] * (self.max_seq_len - seq_len))
                token_masks[i].extend([False] * (self.max_seq_len - seq_len))
            # the antibiotic-specific lists should always be of length num_ab
            pheno_len = len(all_target_resistances[i])
            all_target_resistances[i].extend([-1] * (self.num_ab - pheno_len))
            # ab_mask is defined with correct length
        return sequences, masked_sequences, all_target_resistances, token_masks, ab_masks  
    
    
    def reconstruct_sequence(self, seq_from_batch):
        tuple_len = len(seq_from_batch[0])
        sequences = list()
        for j in range(tuple_len):
            sequence = list()
            for i in range(self.max_seq_len):
                sequence.append(seq_from_batch[i][j])
            sequences.append(sequence)
        return sequence