# Results request - CAZ_R performance on TESSy

In [1]:
import torch
import yaml
import wandb
import argparse
import numpy as np
import pandas as pd
import os
import sys
from pathlib import Path
from datetime import datetime
from sklearn.model_selection import KFold
from torch.utils.data import Dataset, DataLoader

BASE_DIR = Path(os.path.abspath(''))
sys.path.append(str(BASE_DIR))
os.chdir(BASE_DIR)

# user-defined modules
from multimodal.models import BERT
from multimodal.datasets import MMFinetuneDataset
from multimodal.trainers import MMBertFineTuner

# user-defined functions
from utils import get_split_indices, export_results, get_average_and_std_df

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

In [2]:
config = yaml.safe_load(open("config_MM.yaml"))
data_config = config['data']
defined_antibiotics = sorted(list(set(data_config['antibiotics']['abbr_to_names'].keys()) - set(data_config['exclude_antibiotics'])))
ab_to_idx = {ab: idx for idx, ab in enumerate(defined_antibiotics)}
specials = config['specials']
cls_token, pad_token, mask_token = specials['CLS'], specials['PAD'], specials['MASK']
max_seq_len = 56

## Load and prepare dataset for inference

In [3]:
# ds_path = data_config['TESSy']['load_path']
ds_path = 'data/TESSy_15_all_pathogens.pkl'
ds_TESSy = pd.read_pickle(ds_path)
print(f"Total number of samples in TESSy: {len(ds_TESSy):,}")

Total number of samples in TESSy: 3,303,501


Isolate samples with CAZ_R

In [4]:
antibiotics = ['CAZ', 'CIP', 'AMP', 'GEN']
CAZ_idx = ab_to_idx['CAZ']
ds_CAZ = ds_TESSy[ds_TESSy['phenotypes'].apply(lambda x: 'CAZ_R' in x)].reset_index(drop=True)
ds_CAZ['phenotypes'] = ds_CAZ['phenotypes'].apply(lambda x: [p for p in x if p.split('_')[0] in antibiotics])
ds_CAZ['num_ab'] = ds_CAZ['phenotypes'].apply(len)
ds_CAZ['country'] = ds_CAZ['country'].map(config['data']['TESSy']['country_code_to_name'])
ds_CAZ.drop(columns=['num_R', 'num_S'], inplace=True)
ds_CAZ = ds_CAZ.sample(frac=1, random_state=config['random_state']).reset_index(drop=True)
print(f"Number of samples with CAZ_R phenotype: {len(ds_CAZ):,}")


Number of samples with CAZ_R phenotype: 191,477


Prepare dataset

In [5]:
vocab = torch.load(BASE_DIR / config['fine_tuning']['loadpath_vocab'])

class MMInferenceDataset(DataLoader):
    
    def __init__(self, ds, vocab, defined_antibiotics, max_seq_len, specials):
        self.ds = ds
        self.vocab = vocab
        self.max_seq_len = max_seq_len
        self.specials = specials
        self.CLS, self.PAD, self.MASK = specials['CLS'], specials['PAD'], specials['MASK']
        self.device = device
        
        self.phenotypes = self.ds['phenotypes'].tolist()
        self.year_col = self.ds['year'].astype(str).tolist()
        self.country_col = self.ds['country'].tolist()
        self.gender_col = self.ds['gender'].tolist()
        self.age_col = self.ds['age'].astype(int).astype(str).tolist()
        
        self.columns = ['indices_masked', 'token_types', 'attn_mask', 'masked_sequences']
        
    def prepare_dataset(self):
        masked_phenotypes = []
        for phen_list in self.phenotypes:
            masked_phen_list = []
            for p in phen_list:
                if p.split('_')[0] != 'CAZ':
                    # masked_phen_list.append(p)
                    pass
                else:
                    masked_phen_list.append(specials['MASK'])
            masked_phenotypes.append(masked_phen_list)

        masked_sequences = [[specials['CLS'], self.year_col[i], self.country_col[i], self.gender_col[i], self.age_col[i]] + masked_phenotypes[i] for i in range(len(self.ds))]
        token_types = [[0]*5 + [2]*(len(masked_sequences[i])-5) for i in range(len(self.ds))]
        masked_sequences = [seq + [pad_token]*(max_seq_len-len(seq)) for seq in masked_sequences]
        indices_masked = [vocab.lookup_indices(masked_seq) for masked_seq in masked_sequences]
        token_types = [tt + [2]*(max_seq_len-len(tt)) for tt in token_types]
        attn_mask = [[False if token == pad_token else True for token in seq] for seq in masked_sequences]
        
        rows = zip(indices_masked, token_types, attn_mask, masked_sequences)
        self.df = pd.DataFrame(rows, columns=self.columns)
    
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        
        input = torch.tensor(item['indices_masked'], dtype=torch.long, device=self.device)
        token_types = torch.tensor(item['token_types'], dtype=torch.long, device=self.device)
        masked_sequences = item['masked_sequences']
        attn_mask = (input != self.vocab[self.PAD]).unsqueeze(0).unsqueeze(1)
        
        return input, token_types, attn_mask
    
num_samples = 50000
ds_inference = MMInferenceDataset(ds_CAZ.iloc[:num_samples], vocab, defined_antibiotics, max_seq_len, specials)
inference_loader = DataLoader(ds_inference, batch_size=512, shuffle=False)

## Load vocab & fine-tuned model

In [14]:
vocab_size = len(vocab)
num_ab = 15 # from fine-tuning
model_type = 'HardCPT'
# model_path = 'results/MM/Anna_Erik_request/FT_'+model_type+'_0.75_0.75_AE-request/best_model_state.pt'
model_path = 'results/MM/FT_test/best_model_state.pt'
# model_path = 'model_state.pt'

bert = BERT(
    config,
    vocab_size=vocab_size,
    max_seq_len=max_seq_len,
    num_ab=num_ab,
    pad_idx=vocab[pad_token],
    pheno_only=True
).to(device)
print("Randomly initialized model:")
print(bert.classification_layer[CAZ_idx].state_dict()['classifier.0.weight'])
print("Saved model:")
print(torch.load(model_path)['ab_predictors'][CAZ_idx]['classifier.0.weight'])
bert.set_state_dict(torch.load(model_path))
print("After loading the model:")
print(bert.classification_layer[CAZ_idx].state_dict()['classifier.0.weight'])

Randomly initialized model:
tensor([[ 0.0306, -0.0027,  0.0307,  ...,  0.0351, -0.0173, -0.0269],
        [ 0.0190, -0.0158,  0.0091,  ..., -0.0002,  0.0177,  0.0011],
        [-0.0013, -0.0269,  0.0433,  ..., -0.0033, -0.0327, -0.0112],
        ...,
        [ 0.0253, -0.0400,  0.0114,  ...,  0.0205,  0.0162,  0.0374],
        [-0.0391, -0.0191,  0.0089,  ...,  0.0239,  0.0140,  0.0355],
        [ 0.0343, -0.0301,  0.0399,  ..., -0.0383,  0.0250, -0.0319]],
       device='cuda:0')
Saved model:
tensor([[ 0.0296, -0.0004,  0.0302,  ...,  0.0383, -0.0157, -0.0272],
        [ 0.0175, -0.0161,  0.0082,  ..., -0.0029,  0.0134, -0.0006],
        [-0.0056, -0.0305,  0.0472,  ...,  0.0040, -0.0290, -0.0133],
        ...,
        [ 0.0236, -0.0335,  0.0165,  ...,  0.0200,  0.0181,  0.0316],
        [-0.0369, -0.0191,  0.0094,  ...,  0.0244,  0.0161,  0.0328],
        [ 0.0334, -0.0300,  0.0410,  ..., -0.0350,  0.0225, -0.0344]],
       device='cuda:0')
After loading the model:
tensor([[ 0.0296, 

In [7]:
for name, param in bert.named_parameters():
    if param.requires_grad:
        print (name, param.data)

embedding.token_emb.weight tensor([[ 1.9216e+00,  1.4828e+00,  8.9941e-01,  ...,  4.2293e-01,
         -3.3384e-01,  5.2018e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.2197e+00,  9.5632e-01, -1.5797e+00,  ...,  2.2140e-01,
          3.2251e-01,  1.3186e+00],
        ...,
        [-3.1087e-01, -9.1666e-01,  7.0923e-01,  ...,  2.4328e-01,
         -3.0090e-01,  1.2158e-03],
        [-1.7630e-01,  1.5282e+00,  1.4459e-01,  ...,  1.0812e-01,
          3.6922e-01,  1.0254e+00],
        [-8.9638e-01,  5.5180e-01, -6.0085e-01,  ..., -3.6476e-01,
         -1.2499e+00, -9.0787e-01]], device='cuda:0')
embedding.token_type_emb.weight tensor([[-1.5926,  1.8451, -0.7605,  ..., -2.2119,  0.8873, -0.2940],
        [-0.1488,  0.6961,  1.3563,  ..., -2.3248, -0.2301, -0.6946],
        [-0.9833,  0.9448,  1.9257,  ..., -0.9350,  1.5616, -0.2012]],
       device='cuda:0')
embedding.layer_norm.weight tensor([0.9974, 0.9943, 0.991

## Create an inference *evaluator*

## Main process

In [8]:
bert.eval()
ds_inference.prepare_dataset()
print("CAZ idx:", CAZ_idx)
print("Number of samples in CAZ dataset:", len(ds_CAZ))
print("Number of batches in inference loader:", len(inference_loader))
tot_num_R_pred = 0
for input, token_types, attn_mask in inference_loader:
    pred_logits = bert(input, token_types, attn_mask)
    pred_res = torch.where(pred_logits > 0, torch.ones_like(pred_logits), torch.zeros_like(pred_logits))    
    num_R_pred = pred_res[:, CAZ_idx].sum().item()
    tot_num_R_pred += num_R_pred
    print(f"Number of predicted CAZ_R samples: {num_R_pred}")
    print(f"Accuracy: {num_R_pred/pred_res.shape[0]:.4f}")
print(f"Total accuracy: {tot_num_R_pred/num_samples:.4f}")

CAZ idx: 2
Number of samples in CAZ dataset: 191477
Number of batches in inference loader: 98
Number of predicted CAZ_R samples: 359.0
Accuracy: 0.7012
Number of predicted CAZ_R samples: 353.0
Accuracy: 0.6895
Number of predicted CAZ_R samples: 346.0
Accuracy: 0.6758
Number of predicted CAZ_R samples: 360.0
Accuracy: 0.7031
Number of predicted CAZ_R samples: 362.0
Accuracy: 0.7070
Number of predicted CAZ_R samples: 348.0
Accuracy: 0.6797
Number of predicted CAZ_R samples: 353.0
Accuracy: 0.6895
Number of predicted CAZ_R samples: 341.0
Accuracy: 0.6660
Number of predicted CAZ_R samples: 357.0
Accuracy: 0.6973
Number of predicted CAZ_R samples: 349.0
Accuracy: 0.6816
Number of predicted CAZ_R samples: 345.0
Accuracy: 0.6738
Number of predicted CAZ_R samples: 348.0
Accuracy: 0.6797
Number of predicted CAZ_R samples: 347.0
Accuracy: 0.6777
Number of predicted CAZ_R samples: 346.0
Accuracy: 0.6758
Number of predicted CAZ_R samples: 356.0
Accuracy: 0.6953
Number of predicted CAZ_R samples: 3