# MultimodalAMR

In [1]:
import copy
import numpy as np
import pandas as pd
import sys
sys.path.insert(0,'../../')
import torch
from torch import nn
import torchmetrics.classification
from torch import cuda
from torch.utils.data import DataLoader
from multimodal_amr.models.modules import ResMLP
from maldi2resistance.data.ms_data import MS_Data
assert cuda.is_available()
assert cuda.device_count() > 0
print(cuda.get_device_name(cuda.current_device()))

DEVICE = torch.device("cuda")
SEED = 42
torch.manual_seed(SEED)

save_folder = "Dualbranch_5cv-DRIAMS-A"

NVIDIA RTX 2000 Ada Generation Laptop GPU


### Load the Dataset

In [2]:
UMG_antibiotics = ['Ampicillin',
 'Cefotaxim',
 'Ceftazidime',
 'Ceftriaxone',
 'Ciprofloxacin',
 'Clindamycin',
 'Cotrimoxazole',
 'Erythromycin',
 'Fosfomycin',
 'Gentamicin',
 'Imipenem',
 'Levofloxacin',
 'Meropenem',
 'Moxifloxacin',
 'Oxacillin',
 'PenicillinG',
 'Piperacillin-Tazobactam',
 'Tetracycline',
 'Vancomycin']

In [2]:
ms_data = MS_Data(
    root_dir="/home/youngjunpark/Data/MS_data",
    sites=["DRIAMS-A"],
    years=[2015,2016,2017,2018],
    bin_size=1,
    #antibiotics=UMG_antibiotics,
)
ms_data.loading_type = "memory"
ms_data

Loading Spectra into Memory:   0%|          | 0/38331 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin-Amoxicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Rifampicin,Teicoplanin,Tetracycline,Tobramycin
Number resistant:,975,9920,21966,4223,6518,2338,2455,7299,4475,7462,3637,2850,4872,288,4529,1326,3413,2303,6872,3973,5145,412,871,4641,9881,6546,542,226,3082,1695
Number susceptible:,16247,15308,4905,5813,21958,4382,14937,19246,6103,23081,7975,15483,21768,14465,6550,4803,7224,8276,22519,16811,24386,1696,5234,6344,3525,21852,10424,7465,6836,16495
Number data points:,17222,25228,26871,10036,28476,6720,17392,26545,10578,30543,11612,18333,26640,14753,11079,6129,10637,10579,29391,20784,29531,2108,6105,10985,13406,28398,10966,7691,9918,18190


In [3]:
len(ms_data.label_stats.columns)

30

In [4]:
from maldi2resistance.model.dualBranch import DualBranchOneHot
import copy
from maldi2resistance.model.singleBranchMlp import SingleBranchMLP

model = DualBranchOneHot( input_dim_spectrum= 18000, input_dim_drug= len(ms_data.selected_antibiotics))
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [5]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                   Param #
DualBranchOneHot                         --
├─Sequential: 1-1                        --
│    └─Linear: 2-1                       9,216,512
│    └─GELU: 2-2                         --
│    └─Dropout: 2-3                      --
│    └─LayerNorm: 2-4                    1,024
│    └─Linear: 2-5                       131,328
│    └─GELU: 2-6                         --
│    └─Dropout: 2-7                      --
│    └─LayerNorm: 2-8                    512
│    └─Linear: 2-9                       32,896
│    └─GELU: 2-10                        --
│    └─Dropout: 2-11                     --
│    └─LayerNorm: 2-12                   256
│    └─Linear: 2-13                      8,256
├─Embedding: 1-2                         1,920
Total params: 9,392,704
Trainable params: 9,392,704
Non-trainable params: 0


In [6]:
from torchmetrics.utilities.compute import auc
from torchmetrics.classification import BinaryAUROC, BinaryPrecisionRecallCurve
from maldi2resistance.data.ms_data import MS_Data_SingleAntibiotic
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 64
fig_path = Path(f"./{save_folder}/figures")
fig_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(f"./{save_folder}/csv")
csv_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

gen = torch.Generator()

all_auc_roc_micro = []
all_auc_pr_micro = []

all_auc_roc_macro = []
all_auc_pr_macro = []

for fold, (train_data, test_data) in enumerate(ms_data.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    train_dataset_single_antibiotic = MS_Data_SingleAntibiotic(ms_data=train_data)
    test_dataset_single_antibiotic = MS_Data_SingleAntibiotic(ms_data=test_data)
    
    train_loader = DataLoader(train_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30), leave= False, position=1):
        overall_loss = 0
        
        for batch_idx, (x, y, pos) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            pos = pos.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x, pos)
        
            loss = F.binary_cross_entropy_with_logits(output, y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
        
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels, test_pos = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    test_pos = test_pos.to(DEVICE)
    model = model.to(DEVICE)
    
    output = model(test_features, test_pos)
    output = torch.squeeze(output)
    
    auRoc = BinaryAUROC()
    auc_roc = auRoc(output, test_labels.int())
    all_auc_roc_micro.append(auc_roc)
    
    metric = BinaryPrecisionRecallCurve()
    metric.update(output, test_labels.int())
    precision, recall, thresholds = metric.compute()
    aucPC = auc(recall, precision)
    
    all_auc_pr_micro.append(aucPC)
    
    macro_aucroc = 0
    macro_aucpr = 0
    
    n_not_empty = 0
    
    test_dataset_single_antibiotic = MS_Data_SingleAntibiotic(ms_data=test_data, use_morganFingerprint4Drug= False)
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True)
    _, _, test_pos = next(iter(test_loader))
    
    outcome = []
    for pos, antibiotic in enumerate(ms_data.selected_antibiotics):
        out_part = output[test_pos == pos]
        label_part = test_labels[test_pos == pos]
        
        occurrences = torch.bincount(label_part.int())
        try:
            if label_part.numel() == 0 or occurrences[1].item() == 0 or occurrences[0].item()== 0:
                continue
        except IndexError:
            # no positives
            continue
        
        au_roc = auRoc(out_part, label_part.int())
        
        metric.update(out_part, label_part.int())
        precision, recall, thresholds = metric.compute()
        aucPC = auc(recall, precision)
        
        n_not_empty +=1
        macro_aucroc += au_roc.item()
        macro_aucpr += aucPC.item()
    
        outcome.append({
            'antibiotics': antibiotic,
            'AUROC': au_roc.detach().item(),
            'AUPRC': aucPC.detach().item(),
            'Susceptible': occurrences[0].detach().item(),
            'Resistance': occurrences[1].detach().item(),
        })
    pd.DataFrame().from_dict(outcome).to_csv(f"{csv_path}/cv{fold}.csv")
        
        
    macro_aucroc = macro_aucroc / n_not_empty
    macro_aucpr = macro_aucpr / n_not_empty
    
    all_auc_roc_macro.append(macro_aucroc)
    all_auc_pr_macro.append(macro_aucpr)
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)

Start training ...


Create single label Dataset:   0%|          | 0/30664 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/7667 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.005311 	Learning rate: 0.001000
	Average Loss: 0.004396 	Learning rate: 0.001000
	Average Loss: 0.004002 	Learning rate: 0.001000
	Average Loss: 0.003769 	Learning rate: 0.001000
	Average Loss: 0.003560 	Learning rate: 0.001000
	Average Loss: 0.003417 	Learning rate: 0.001000
	Average Loss: 0.003263 	Learning rate: 0.001000
	Average Loss: 0.003155 	Learning rate: 0.001000
	Average Loss: 0.003061 	Learning rate: 0.001000
	Average Loss: 0.002977 	Learning rate: 0.000500
	Average Loss: 0.002704 	Learning rate: 0.000500
	Average Loss: 0.002638 	Learning rate: 0.000500
	Average Loss: 0.002591 	Learning rate: 0.000500
	Average Loss: 0.002547 	Learning rate: 0.000500
	Average Loss: 0.002513 	Learning rate: 0.000500
	Average Loss: 0.002461 	Learning rate: 0.000500
	Average Loss: 0.002417 	Learning rate: 0.000500
	Average Loss: 0.002388 	Learning rate: 0.000500
	Average Loss: 0.002353 	Learning rate: 0.000500
	Average Loss: 0.002310 	Learning rate: 0.000250
	Average Loss: 0.002

Create single label Dataset:   0%|          | 0/7667 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/30665 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.005377 	Learning rate: 0.001000
	Average Loss: 0.004446 	Learning rate: 0.001000
	Average Loss: 0.004085 	Learning rate: 0.001000
	Average Loss: 0.003815 	Learning rate: 0.001000
	Average Loss: 0.003588 	Learning rate: 0.001000
	Average Loss: 0.003448 	Learning rate: 0.001000
	Average Loss: 0.003321 	Learning rate: 0.001000
	Average Loss: 0.003192 	Learning rate: 0.001000
	Average Loss: 0.003100 	Learning rate: 0.001000
	Average Loss: 0.003016 	Learning rate: 0.000500
	Average Loss: 0.002761 	Learning rate: 0.000500
	Average Loss: 0.002687 	Learning rate: 0.000500
	Average Loss: 0.002633 	Learning rate: 0.000500
	Average Loss: 0.002591 	Learning rate: 0.000500
	Average Loss: 0.002532 	Learning rate: 0.000500
	Average Loss: 0.002504 	Learning rate: 0.000500
	Average Loss: 0.002461 	Learning rate: 0.000500
	Average Loss: 0.002414 	Learning rate: 0.000500
	Average Loss: 0.002376 	Learning rate: 0.000500
	Average Loss: 0.002355 	Learning rate: 0.000250
	Average Loss: 0.002

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/30665 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.005309 	Learning rate: 0.001000
	Average Loss: 0.004370 	Learning rate: 0.001000
	Average Loss: 0.004015 	Learning rate: 0.001000
	Average Loss: 0.003759 	Learning rate: 0.001000
	Average Loss: 0.003568 	Learning rate: 0.001000
	Average Loss: 0.003403 	Learning rate: 0.001000
	Average Loss: 0.003284 	Learning rate: 0.001000
	Average Loss: 0.003158 	Learning rate: 0.001000
	Average Loss: 0.003067 	Learning rate: 0.001000
	Average Loss: 0.002988 	Learning rate: 0.000500
	Average Loss: 0.002710 	Learning rate: 0.000500
	Average Loss: 0.002645 	Learning rate: 0.000500
	Average Loss: 0.002605 	Learning rate: 0.000500
	Average Loss: 0.002558 	Learning rate: 0.000500
	Average Loss: 0.002511 	Learning rate: 0.000500
	Average Loss: 0.002462 	Learning rate: 0.000500
	Average Loss: 0.002425 	Learning rate: 0.000500
	Average Loss: 0.002382 	Learning rate: 0.000500
	Average Loss: 0.002349 	Learning rate: 0.000500
	Average Loss: 0.002310 	Learning rate: 0.000250
	Average Loss: 0.002

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/30665 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.005310 	Learning rate: 0.001000
	Average Loss: 0.004405 	Learning rate: 0.001000
	Average Loss: 0.004038 	Learning rate: 0.001000
	Average Loss: 0.003782 	Learning rate: 0.001000
	Average Loss: 0.003579 	Learning rate: 0.001000
	Average Loss: 0.003416 	Learning rate: 0.001000
	Average Loss: 0.003298 	Learning rate: 0.001000
	Average Loss: 0.003192 	Learning rate: 0.001000
	Average Loss: 0.003085 	Learning rate: 0.001000
	Average Loss: 0.002993 	Learning rate: 0.000500
	Average Loss: 0.002728 	Learning rate: 0.000500
	Average Loss: 0.002670 	Learning rate: 0.000500
	Average Loss: 0.002620 	Learning rate: 0.000500
	Average Loss: 0.002575 	Learning rate: 0.000500
	Average Loss: 0.002519 	Learning rate: 0.000500
	Average Loss: 0.002486 	Learning rate: 0.000500
	Average Loss: 0.002439 	Learning rate: 0.000500
	Average Loss: 0.002390 	Learning rate: 0.000500
	Average Loss: 0.002359 	Learning rate: 0.000500
	Average Loss: 0.002324 	Learning rate: 0.000250
	Average Loss: 0.002

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/30665 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.005313 	Learning rate: 0.001000
	Average Loss: 0.004392 	Learning rate: 0.001000
	Average Loss: 0.004015 	Learning rate: 0.001000
	Average Loss: 0.003769 	Learning rate: 0.001000
	Average Loss: 0.003577 	Learning rate: 0.001000
	Average Loss: 0.003400 	Learning rate: 0.001000
	Average Loss: 0.003277 	Learning rate: 0.001000
	Average Loss: 0.003163 	Learning rate: 0.001000
	Average Loss: 0.003075 	Learning rate: 0.001000
	Average Loss: 0.002973 	Learning rate: 0.000500
	Average Loss: 0.002715 	Learning rate: 0.000500
	Average Loss: 0.002658 	Learning rate: 0.000500
	Average Loss: 0.002595 	Learning rate: 0.000500
	Average Loss: 0.002557 	Learning rate: 0.000500
	Average Loss: 0.002504 	Learning rate: 0.000500
	Average Loss: 0.002472 	Learning rate: 0.000500
	Average Loss: 0.002419 	Learning rate: 0.000500
	Average Loss: 0.002393 	Learning rate: 0.000500
	Average Loss: 0.002337 	Learning rate: 0.000500
	Average Loss: 0.002317 	Learning rate: 0.000250
	Average Loss: 0.002

Create single label Dataset:   0%|          | 0/7666 [00:00<?, ?it/s]

In [7]:
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_roc_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_roc_micro]) }")
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_pr_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_pr_micro]) }")

Mean	: 0.9226457476615906
 SD 	: 0.0037608046550303698
Mean	: 0.8614680171012878
 SD 	: 0.005926158279180527


In [8]:
print(f"Mean\t: {np.mean([x for x in all_auc_roc_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_roc_macro]) }")
print(f"Mean\t: {np.mean([x for x in all_auc_pr_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_pr_macro]) }")

Mean	: 0.9232978133360545
 SD 	: 0.00341953552856399
Mean	: 0.8610860661665599
 SD 	: 0.006223688003158097
