# MultimodalAMR

In [1]:
import copy
import numpy as np
import pandas as pd
import sys
sys.path.insert(0,'../../')
import torch
from torch import nn
import torchmetrics.classification
from torch import cuda
from torch.utils.data import DataLoader
from multimodal_amr.models.modules import ResMLP
from maldi2resistance.data.ms_data import MS_Data
assert cuda.is_available()
assert cuda.device_count() > 0
print(cuda.get_device_name(cuda.current_device()))

DEVICE = torch.device("cuda")
SEED = 42
torch.manual_seed(SEED)

save_folder = "multimodalAMR_5cv-DRIAMS-A"

NVIDIA RTX 2000 Ada Generation Laptop GPU


### Load the Dataset

UMG_antibiotics = ['Ampicillin',
 'Cefotaxim',
 'Ceftazidime',
 'Ceftriaxone',
 'Ciprofloxacin',
 'Clindamycin',
 'Cotrimoxazole',
 'Erythromycin',
 'Fosfomycin',
 'Gentamicin',
 'Imipenem',
 'Levofloxacin',
 'Meropenem',
 'Moxifloxacin',
 'Oxacillin',
 'PenicillinG',
 'Piperacillin-Tazobactam',
 'Tetracycline',
 'Vancomycin']

In [2]:
ms_data = MS_Data(
    root_dir="/home/youngjunpark/Data/MS_data",
    sites=["DRIAMS-A"],
    years=[2015,2016,2017,2018],
    bin_size=1,
    #antibiotics=UMG_antibiotics,
)
ms_data.loading_type = "memory"
ms_data

Loading Spectra into Memory:   0%|          | 0/38331 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin-Amoxicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Rifampicin,Teicoplanin,Tetracycline,Tobramycin
Number resistant:,975,9920,21966,4223,6518,2338,2455,7299,4475,7462,3637,2850,4872,288,4529,1326,3413,2303,6872,3973,5145,412,871,4641,9881,6546,542,226,3082,1695
Number susceptible:,16247,15308,4905,5813,21958,4382,14937,19246,6103,23081,7975,15483,21768,14465,6550,4803,7224,8276,22519,16811,24386,1696,5234,6344,3525,21852,10424,7465,6836,16495
Number data points:,17222,25228,26871,10036,28476,6720,17392,26545,10578,30543,11612,18333,26640,14753,11079,6129,10637,10579,29391,20784,29531,2108,6105,10985,13406,28398,10966,7691,9918,18190


In [3]:
len(ms_data.label_stats.columns)

30

In [4]:
class Residual_AMR_Classifier(nn.Module):
    """Overall model definition."""

    def __init__(self, config, n_input_spectrum, n_input_drug):
        super().__init__()
        self.config = config

        # Sample embedding
        if config["species_embedding_dim"]==0 and config["conv_out_size"]==config["sample_embedding_dim"]:
            self.sample_emb = nn.Identity()
        elif config["species_embedding_dim"]==0:
            self.sample_emb = nn.Linear(
                config["conv_out_size"],
                config["sample_embedding_dim"],
            )
        else:
            self.sample_emb = nn.Linear(
                config["conv_out_size"] + config["species_embedding_dim"],
                config["sample_embedding_dim"],
            )

        self.spectrum_emb = nn.Linear(n_input_spectrum ,config["conv_out_size"])

        # Drugs layers
        # if config["drug_emb_type"] == "vae_embedding" or config["drug_emb_type"] == "gnn_embedding":
        #     self.drug_emb = nn.Identity()
        # elif config["drug_emb_type"] == "fingerprint":
        self.drug_emb = nn.Linear(
            n_input_drug, config["drug_embedding_dim"]
        )

        # Output network
        self.net = ResMLP(
            config["n_hidden_layers"],
            config["sample_embedding_dim"] + config["drug_embedding_dim"],
            1,
            p_dropout=0.2,
        )

    def forward(self, spectrum, fingerprint):
        spectrum_embedding = self.spectrum_emb(spectrum)
        dr_emb = self.drug_emb(fingerprint)

        return self.net(torch.cat([dr_emb, spectrum_embedding], dim=1))




conf = {
    "conv_out_size": 512,
    "species_embedding_dim":0,
    "sample_embedding_dim":512,
    "drug_embedding_dim":512,
    "n_hidden_layers": 5,
}

In [5]:
model = Residual_AMR_Classifier(config= conf,n_input_spectrum = ms_data.n_bins, n_input_drug= 1024)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [6]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                        Param #
Residual_AMR_Classifier                       --
├─Identity: 1-1                               --
├─Linear: 1-2                                 9,216,512
├─Linear: 1-3                                 524,800
├─ResMLP: 1-4                                 --
│    └─Sequential: 2-1                        --
│    │    └─ResBlock: 3-1                     1,051,648
│    │    └─ResBlock: 3-2                     1,051,648
│    │    └─ResBlock: 3-3                     1,051,648
│    │    └─ResBlock: 3-4                     1,051,648
│    │    └─ResBlock: 3-5                     1,051,648
│    │    └─Linear: 3-6                       1,025
Total params: 15,000,577
Trainable params: 15,000,577
Non-trainable params: 0


In [7]:
from torchmetrics.utilities.compute import auc
from torchmetrics.classification import BinaryAUROC, BinaryPrecisionRecallCurve
from maldi2resistance.data.ms_data import MS_Data_SingleAntibiotic
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 64
fig_path = Path(f"./{save_folder}/figures")
fig_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(f"./{save_folder}/csv")
csv_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

gen = torch.Generator()

all_auc_roc_micro = []
all_auc_pr_micro = []

all_auc_roc_macro = []
all_auc_pr_macro = []

for fold, (train_data, test_data) in enumerate(ms_data.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    train_dataset_single_antibiotic = MS_Data_SingleAntibiotic(ms_data=train_data, use_morganFingerprint4Drug=True)
    test_dataset_single_antibiotic = MS_Data_SingleAntibiotic(ms_data=test_data, use_morganFingerprint4Drug=True)
    
    train_loader = DataLoader(train_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30), leave= False, position=1):
        overall_loss = 0
        
        for batch_idx, (x, y, pos) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            pos = pos.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x, pos)
        
            loss = F.binary_cross_entropy_with_logits(torch.squeeze(output), y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
        
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels, test_pos = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    test_pos = test_pos.to(DEVICE)
    model = model.to(DEVICE)
    
    output = model(test_features, test_pos)
    output = torch.squeeze(output)
    
    auRoc = BinaryAUROC()
    auc_roc = auRoc(output, test_labels.int())
    all_auc_roc_micro.append(auc_roc)
    
    metric = BinaryPrecisionRecallCurve()
    metric.update(output, test_labels.int())
    precision, recall, thresholds = metric.compute()
    aucPC = auc(recall, precision)
    
    all_auc_pr_micro.append(aucPC)
    
    macro_aucroc = 0
    macro_aucpr = 0
    
    n_not_empty = 0
    
    test_dataset_single_antibiotic = MS_Data_SingleAntibiotic(ms_data=test_data, use_morganFingerprint4Drug= False)
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True)
    _, _, test_pos = next(iter(test_loader))
    
    outcome = []
    for pos, antibiotic in enumerate(ms_data.selected_antibiotics):
        out_part = output[test_pos == pos]
        label_part = test_labels[test_pos == pos]
        
        occurrences = torch.bincount(label_part.int())
        try:
            if label_part.numel() == 0 or occurrences[1].item() == 0 or occurrences[0].item()== 0:
                continue
        except IndexError:
            # no positives
            continue
        
        au_roc = auRoc(out_part, label_part.int())
        
        metric.update(out_part, label_part.int())
        precision, recall, thresholds = metric.compute()
        aucPC = auc(recall, precision)
        
        n_not_empty +=1
        macro_aucroc += au_roc.item()
        macro_aucpr += aucPC.item()
    
        outcome.append({
            'antibiotics': antibiotic,
            'AUROC': au_roc.detach().item(),
            'AUPRC': aucPC.detach().item(),
            'Susceptible': occurrences[0].detach().item(),
            'Resistance': occurrences[1].detach().item(),
        })
    pd.DataFrame().from_dict(outcome).to_csv(f"{csv_path}/cv{fold}.csv")
        
        
    macro_aucroc = macro_aucroc / n_not_empty
    macro_aucpr = macro_aucpr / n_not_empty
    
    all_auc_roc_macro.append(macro_aucroc)
    all_auc_pr_macro.append(macro_aucpr)
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)

Start training ...


Create single label Dataset:   0%|          | 0/30664 [00:00<?, ?it/s]



Create single label Dataset:   0%|          | 0/7667 [00:00<?, ?it/s]



  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.005473 	Learning rate: 0.001000
	Average Loss: 0.004351 	Learning rate: 0.001000
	Average Loss: 0.003842 	Learning rate: 0.001000
	Average Loss: 0.003482 	Learning rate: 0.001000
	Average Loss: 0.003221 	Learning rate: 0.001000
	Average Loss: 0.003017 	Learning rate: 0.001000
	Average Loss: 0.002838 	Learning rate: 0.001000
	Average Loss: 0.002678 	Learning rate: 0.001000
	Average Loss: 0.002534 	Learning rate: 0.001000
	Average Loss: 0.002401 	Learning rate: 0.000500
	Average Loss: 0.002043 	Learning rate: 0.000500
	Average Loss: 0.001939 	Learning rate: 0.000500
	Average Loss: 0.001843 	Learning rate: 0.000500
	Average Loss: 0.001762 	Learning rate: 0.000500
	Average Loss: 0.001683 	Learning rate: 0.000500
	Average Loss: 0.001611 	Learning rate: 0.000500
	Average Loss: 0.001542 	Learning rate: 0.000500
	Average Loss: 0.001472 	Learning rate: 0.000500
	Average Loss: 0.001400 	Learning rate: 0.000500
	Average Loss: 0.001344 	Learning rate: 0.000250
	Average Loss: 0.001

Create single label Dataset:   0%|          | 0/7667 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/30665 [00:00<?, ?it/s]



Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]



  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.005504 	Learning rate: 0.001000
	Average Loss: 0.004379 	Learning rate: 0.001000
	Average Loss: 0.003869 	Learning rate: 0.001000
	Average Loss: 0.003503 	Learning rate: 0.001000
	Average Loss: 0.003242 	Learning rate: 0.001000
	Average Loss: 0.003033 	Learning rate: 0.001000
	Average Loss: 0.002868 	Learning rate: 0.001000
	Average Loss: 0.002713 	Learning rate: 0.001000
	Average Loss: 0.002573 	Learning rate: 0.001000
	Average Loss: 0.002442 	Learning rate: 0.000500
	Average Loss: 0.002087 	Learning rate: 0.000500
	Average Loss: 0.001970 	Learning rate: 0.000500
	Average Loss: 0.001890 	Learning rate: 0.000500
	Average Loss: 0.001814 	Learning rate: 0.000500
	Average Loss: 0.001739 	Learning rate: 0.000500
	Average Loss: 0.001665 	Learning rate: 0.000500
	Average Loss: 0.001607 	Learning rate: 0.000500
	Average Loss: 0.001539 	Learning rate: 0.000500
	Average Loss: 0.001485 	Learning rate: 0.000500
	Average Loss: 0.001421 	Learning rate: 0.000250
	Average Loss: 0.001

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/30665 [00:00<?, ?it/s]



Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]



  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.005466 	Learning rate: 0.001000
	Average Loss: 0.004351 	Learning rate: 0.001000
	Average Loss: 0.003847 	Learning rate: 0.001000
	Average Loss: 0.003494 	Learning rate: 0.001000
	Average Loss: 0.003227 	Learning rate: 0.001000
	Average Loss: 0.003007 	Learning rate: 0.001000
	Average Loss: 0.002818 	Learning rate: 0.001000
	Average Loss: 0.002666 	Learning rate: 0.001000
	Average Loss: 0.002528 	Learning rate: 0.001000
	Average Loss: 0.002407 	Learning rate: 0.000500
	Average Loss: 0.002036 	Learning rate: 0.000500
	Average Loss: 0.001923 	Learning rate: 0.000500
	Average Loss: 0.001844 	Learning rate: 0.000500
	Average Loss: 0.001771 	Learning rate: 0.000500
	Average Loss: 0.001704 	Learning rate: 0.000500
	Average Loss: 0.001621 	Learning rate: 0.000500
	Average Loss: 0.001557 	Learning rate: 0.000500
	Average Loss: 0.001506 	Learning rate: 0.000500
	Average Loss: 0.001437 	Learning rate: 0.000500
	Average Loss: 0.001370 	Learning rate: 0.000250
	Average Loss: 0.001

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/30665 [00:00<?, ?it/s]



Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]



  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.005568 	Learning rate: 0.001000
	Average Loss: 0.004430 	Learning rate: 0.001000
	Average Loss: 0.003879 	Learning rate: 0.001000
	Average Loss: 0.003503 	Learning rate: 0.001000
	Average Loss: 0.003229 	Learning rate: 0.001000
	Average Loss: 0.003029 	Learning rate: 0.001000
	Average Loss: 0.002848 	Learning rate: 0.001000
	Average Loss: 0.002709 	Learning rate: 0.001000
	Average Loss: 0.002576 	Learning rate: 0.001000
	Average Loss: 0.002449 	Learning rate: 0.000500
	Average Loss: 0.002092 	Learning rate: 0.000500
	Average Loss: 0.001973 	Learning rate: 0.000500
	Average Loss: 0.001899 	Learning rate: 0.000500
	Average Loss: 0.001810 	Learning rate: 0.000500
	Average Loss: 0.001743 	Learning rate: 0.000500
	Average Loss: 0.001669 	Learning rate: 0.000500
	Average Loss: 0.001604 	Learning rate: 0.000500
	Average Loss: 0.001529 	Learning rate: 0.000500
	Average Loss: 0.001475 	Learning rate: 0.000500
	Average Loss: 0.001417 	Learning rate: 0.000250
	Average Loss: 0.001

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/30665 [00:00<?, ?it/s]



Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]



  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.005517 	Learning rate: 0.001000
	Average Loss: 0.004342 	Learning rate: 0.001000
	Average Loss: 0.003822 	Learning rate: 0.001000
	Average Loss: 0.003483 	Learning rate: 0.001000
	Average Loss: 0.003224 	Learning rate: 0.001000
	Average Loss: 0.003010 	Learning rate: 0.001000
	Average Loss: 0.002823 	Learning rate: 0.001000
	Average Loss: 0.002676 	Learning rate: 0.001000
	Average Loss: 0.002536 	Learning rate: 0.001000
	Average Loss: 0.002407 	Learning rate: 0.000500
	Average Loss: 0.002026 	Learning rate: 0.000500
	Average Loss: 0.001914 	Learning rate: 0.000500
	Average Loss: 0.001826 	Learning rate: 0.000500
	Average Loss: 0.001750 	Learning rate: 0.000500
	Average Loss: 0.001670 	Learning rate: 0.000500
	Average Loss: 0.001601 	Learning rate: 0.000500
	Average Loss: 0.001529 	Learning rate: 0.000500
	Average Loss: 0.001460 	Learning rate: 0.000500
	Average Loss: 0.001400 	Learning rate: 0.000500
	Average Loss: 0.001336 	Learning rate: 0.000250
	Average Loss: 0.001

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

In [8]:
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_roc_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_roc_micro]) }")
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_pr_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_pr_micro]) }")

Mean	: 0.9223073124885559
 SD 	: 0.0031904003117233515
Mean	: 0.8642351031303406
 SD 	: 0.0035700583830475807


In [9]:
print(f"Mean\t: {np.mean([x for x in all_auc_roc_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_roc_macro]) }")
print(f"Mean\t: {np.mean([x for x in all_auc_pr_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_pr_macro]) }")

Mean	: 0.9221437613169353
 SD 	: 0.0034870115657113765
Mean	: 0.8641235069433849
 SD 	: 0.003565214931805852
