# Results request - CAZ_R performance on TESSy

In [1]:
import torch
import yaml
import wandb
import argparse
import numpy as np
import pandas as pd
import os
import sys
from pathlib import Path
from datetime import datetime
from sklearn.model_selection import KFold
from torch.utils.data import Dataset, DataLoader

BASE_DIR = Path(os.path.abspath(''))
sys.path.append(str(BASE_DIR))
os.chdir(BASE_DIR)

# user-defined modules
from multimodal.models import BERT
from multimodal.datasets import MMFinetuneDataset
from multimodal.trainers import MMBertFineTuner

# user-defined functions
from utils import get_split_indices, export_results, get_average_and_std_df

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
config = yaml.safe_load(open("config_MM.yaml"))
data_config = config['data']
defined_antibiotics = sorted(list(set(data_config['antibiotics']['abbr_to_names'].keys()) - set(data_config['exclude_antibiotics'])))
ab_to_idx = {ab: idx for idx, ab in enumerate(defined_antibiotics)}
specials = config['specials']
cls_token, pad_token, mask_token = specials['CLS'], specials['PAD'], specials['MASK']
max_seq_len = 56

## Load and prepare dataset for inference

In [3]:
# ds_path = data_config['TESSy']['load_path']
ds_path = 'data/TESSy_15_all_pathogens.pkl'
ds_TESSy = pd.read_pickle(ds_path)
print(f"Total number of samples in TESSy: {len(ds_TESSy):,}")

Total number of samples in TESSy: 3,303,501


Isolate samples with CAZ_R

In [4]:
antibiotics = ['CAZ', 'CIP', 'AMP', 'GEN']
CAZ_idx = ab_to_idx['CAZ']
ds_CAZ = ds_TESSy[ds_TESSy['phenotypes'].apply(lambda x: 'CAZ_R' in x)].reset_index(drop=True)
ds_CAZ['phenotypes'] = ds_CAZ['phenotypes'].apply(lambda x: [p for p in x if p.split('_')[0] in antibiotics])
ds_CAZ['num_ab'] = ds_CAZ['phenotypes'].apply(len)
ds_CAZ['country'] = ds_CAZ['country'].map(config['data']['TESSy']['country_code_to_name'])
ds_CAZ.drop(columns=['num_R', 'num_S'], inplace=True)
ds_CAZ = ds_CAZ.sample(frac=1, random_state=config['random_state']).reset_index(drop=True)
print(f"Number of samples with CAZ_R phenotype: {len(ds_CAZ):,}")


Number of samples with CAZ_R phenotype: 191,477


Prepare dataset

In [5]:
vocab = torch.load(BASE_DIR / config['fine_tuning']['loadpath_vocab'])

class MMInferenceDataset(DataLoader):
    
    def __init__(self, ds, vocab, defined_antibiotics, max_seq_len, specials):
        self.ds = ds
        self.vocab = vocab
        self.max_seq_len = max_seq_len
        self.specials = specials
        self.CLS, self.PAD, self.MASK = specials['CLS'], specials['PAD'], specials['MASK']
        self.device = device
        
        self.phenotypes = self.ds['phenotypes'].tolist()
        self.year_col = self.ds['year'].astype(str).tolist()
        self.country_col = self.ds['country'].tolist()
        self.gender_col = self.ds['gender'].tolist()
        self.age_col = self.ds['age'].astype(int).astype(str).tolist()
        
        self.columns = ['indices_masked', 'token_types', 'attn_mask', 'masked_sequences']
        
    def prepare_dataset(self):
        masked_phenotypes = []
        for phen_list in self.phenotypes:
            masked_phen_list = []
            for p in phen_list:
                if p.split('_')[0] != 'CAZ':
                    # masked_phen_list.append(p)
                    pass
                else:
                    masked_phen_list.append(specials['MASK'])
            masked_phenotypes.append(masked_phen_list)

        masked_sequences = [[specials['CLS'], self.year_col[i], self.country_col[i], self.gender_col[i], self.age_col[i]] + masked_phenotypes[i] for i in range(len(self.ds))]
        token_types = [[0]*5 + [2]*(len(masked_sequences[i])-5) for i in range(len(self.ds))]
        masked_sequences = [seq + [pad_token]*(max_seq_len-len(seq)) for seq in masked_sequences]
        indices_masked = [vocab.lookup_indices(masked_seq) for masked_seq in masked_sequences]
        token_types = [tt + [2]*(max_seq_len-len(tt)) for tt in token_types]
        attn_mask = [[False if token == pad_token else True for token in seq] for seq in masked_sequences]
        
        rows = zip(indices_masked, token_types, attn_mask, masked_sequences)
        self.df = pd.DataFrame(rows, columns=self.columns)
    
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        
        input = torch.tensor(item['indices_masked'], dtype=torch.long, device=self.device)
        token_types = torch.tensor(item['token_types'], dtype=torch.long, device=self.device)
        masked_sequences = item['masked_sequences']
        attn_mask = (input != self.vocab[self.PAD]).unsqueeze(0).unsqueeze(1)
        
        return input, token_types, attn_mask
    
num_samples = 50000
ds_inference = MMInferenceDataset(ds_CAZ.iloc[:num_samples], vocab, defined_antibiotics, max_seq_len, specials)
inference_loader = DataLoader(ds_inference, batch_size=512, shuffle=False)

## Load vocab & fine-tuned model

In [6]:
vocab_size = len(vocab)
num_ab = 15 # from fine-tuning
model_type = 'HardCPT'
model_path = 'results/MM/Anna_Erik_request/FT_'+model_type+'_0.75_0.75_AE-request/best_model_state.pt'

bert = BERT(
    config,
    vocab_size=vocab_size,
    max_seq_len=max_seq_len,
    num_ab=num_ab,
    pad_idx=vocab[pad_token],
    pheno_only=True
).to(device)
bert.set_state_dict(torch.load(model_path))
bert.eval()

BERT(
  (embedding): JointEmbedding(
    (token_emb): Embedding(1570, 512, padding_idx=1)
    (token_type_emb): Embedding(3, 512)
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (encoder): ModuleList(
    (0-5): 6 x EncoderLayer(
      (attention): MultiHeadAttention(
        (q): Linear(in_features=512, out_features=512, bias=True)
        (k): Linear(in_features=512, out_features=512, bias=True)
        (v): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dense): Linear(in_features=512, out_features=512, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (feed_forward): Sequential(
        (0): Linear(in_features=512, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=512, bias=True)
        

## Create an inference *evaluator*

## Main process

In [8]:
torch.cuda.empty_cache()
ds_inference.prepare_dataset()
print("CAZ idx:", CAZ_idx)
print("Number of samples in CAZ dataset:", len(ds_CAZ))
print("Number of batches in inference loader:", len(inference_loader))
tot_num_R_pred = 0
for input, token_types, attn_mask in inference_loader:
    pred_logits = bert(input, token_types, attn_mask)
    pred_res = torch.where(pred_logits > 0, torch.ones_like(pred_logits), torch.zeros_like(pred_logits))    
    num_R_pred = pred_res[:, CAZ_idx].sum().item()
    tot_num_R_pred += num_R_pred
    print(f"Number of predicted CAZ_R samples: {num_R_pred}")
    print(f"Accuracy: {num_R_pred/pred_res.shape[0]:.4f}")
print(f"Total accuracy: {tot_num_R_pred/num_samples:.4f}")

CAZ idx: 2
Number of samples in CAZ dataset: 191477
Number of batches in inference loader: 98
Number of predicted CAZ_R samples: 205.0
Accuracy: 0.4004
Number of predicted CAZ_R samples: 184.0
Accuracy: 0.3594
Number of predicted CAZ_R samples: 205.0
Accuracy: 0.4004
Number of predicted CAZ_R samples: 201.0
Accuracy: 0.3926
Number of predicted CAZ_R samples: 188.0
Accuracy: 0.3672
Number of predicted CAZ_R samples: 197.0
Accuracy: 0.3848
Number of predicted CAZ_R samples: 215.0
Accuracy: 0.4199
Number of predicted CAZ_R samples: 210.0
Accuracy: 0.4102
Number of predicted CAZ_R samples: 192.0
Accuracy: 0.3750
Number of predicted CAZ_R samples: 199.0
Accuracy: 0.3887
Number of predicted CAZ_R samples: 208.0
Accuracy: 0.4062
Number of predicted CAZ_R samples: 211.0
Accuracy: 0.4121
Number of predicted CAZ_R samples: 213.0
Accuracy: 0.4160
Number of predicted CAZ_R samples: 203.0
Accuracy: 0.3965
Number of predicted CAZ_R samples: 212.0
Accuracy: 0.4141
Number of predicted CAZ_R samples: 1

KeyboardInterrupt: 