# MultimodalAMR

In [1]:
import copy
import numpy as np
import pandas as pd
import sys
sys.path.insert(0,'../../')
import torch
from torch import nn
import torchmetrics.classification
from torch import cuda
from torch.utils.data import DataLoader
from multimodal_amr.models.modules import ResMLP
from maldi2resistance.data.ms_data import MS_Data
assert cuda.is_available()
assert cuda.device_count() > 0
print(cuda.get_device_name(cuda.current_device()))

DEVICE = torch.device("cuda")
SEED = 42
torch.manual_seed(SEED)

save_folder = "assymetric_loss_multimodalAMR_5cv-UMG"

NVIDIA RTX 2000 Ada Generation Laptop GPU


### Load the Dataset

In [2]:
UMG_antibiotics = ['Ampicillin',
 'Cefotaxim',
 'Ceftazidime',
 'Ceftriaxone',
 'Ciprofloxacin',
 'Clindamycin',
 'Cotrimoxazole',
 'Erythromycin',
 'Fosfomycin',
 'Gentamicin',
 'Imipenem',
 'Levofloxacin',
 'Meropenem',
 'Moxifloxacin',
 'Oxacillin',
 'PenicillinG',
 'Piperacillin-Tazobactam',
 'Tetracycline',
 'Vancomycin']

In [3]:
ms_data = MS_Data(
    root_dir="/home/youngjunpark/Data/MS_data",
    sites=["UMG"],
    years=[2020,2021],
    bin_size=1,
    antibiotics=UMG_antibiotics,
)
ms_data.loading_type = "memory"
ms_data



Loading Spectra into Memory:   0%|          | 0/23669 [00:00<?, ?it/s]

Antibiotic:,Ampicillin,Cefotaxim,Ceftazidime,Ceftriaxone,Ciprofloxacin,Clindamycin,Cotrimoxazole,Erythromycin,Fosfomycin,Gentamicin,Imipenem,Levofloxacin,Meropenem,Moxifloxacin,Oxacillin,PenicillinG,Piperacillin-Tazobactam,Tetracycline,Vancomycin
Number resistant:,8045,1203,2937,597,3267,5459,5257,5475,1849,12669,3938,1599,515,2808,1868,5883,3343,1356,330
Number susceptible:,7465,9931,8980,4852,9245,5381,14354,4098,7920,4800,11342,4224,11917,5045,3804,1885,8624,4625,10551
Number data points:,15510,11134,11917,5449,12512,10840,19611,9573,9769,17469,15280,5823,12432,7853,5672,7768,11967,5981,10881


In [4]:
len(ms_data.label_stats.columns)

19

In [5]:
class Residual_AMR_Classifier(nn.Module):
    """Overall model definition."""

    def __init__(self, config, n_input_spectrum, n_input_drug):
        super().__init__()
        self.config = config

        # Sample embedding
        if config["species_embedding_dim"]==0 and config["conv_out_size"]==config["sample_embedding_dim"]:
            self.sample_emb = nn.Identity()
        elif config["species_embedding_dim"]==0:
            self.sample_emb = nn.Linear(
                config["conv_out_size"],
                config["sample_embedding_dim"],
            )
        else:
            self.sample_emb = nn.Linear(
                config["conv_out_size"] + config["species_embedding_dim"],
                config["sample_embedding_dim"],
            )

        self.spectrum_emb = nn.Linear(n_input_spectrum ,config["conv_out_size"])

        # Drugs layers
        # if config["drug_emb_type"] == "vae_embedding" or config["drug_emb_type"] == "gnn_embedding":
        #     self.drug_emb = nn.Identity()
        # elif config["drug_emb_type"] == "fingerprint":
        self.drug_emb = nn.Linear(
            n_input_drug, config["drug_embedding_dim"]
        )

        # Output network
        self.net = ResMLP(
            config["n_hidden_layers"],
            config["sample_embedding_dim"] + config["drug_embedding_dim"],
            1,
            p_dropout=0.2,
        )

    def forward(self, spectrum, fingerprint):
        spectrum_embedding = self.spectrum_emb(spectrum)
        dr_emb = self.drug_emb(fingerprint)

        return self.net(torch.cat([dr_emb, spectrum_embedding], dim=1))




conf = {
    "conv_out_size": 512,
    "species_embedding_dim":0,
    "sample_embedding_dim":512,
    "drug_embedding_dim":512,
    "n_hidden_layers": 5,
}

In [6]:
model = Residual_AMR_Classifier(config= conf,n_input_spectrum = ms_data.n_bins, n_input_drug= 1024)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [7]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                        Param #
Residual_AMR_Classifier                       --
├─Identity: 1-1                               --
├─Linear: 1-2                                 9,216,512
├─Linear: 1-3                                 524,800
├─ResMLP: 1-4                                 --
│    └─Sequential: 2-1                        --
│    │    └─ResBlock: 3-1                     1,051,648
│    │    └─ResBlock: 3-2                     1,051,648
│    │    └─ResBlock: 3-3                     1,051,648
│    │    └─ResBlock: 3-4                     1,051,648
│    │    └─ResBlock: 3-5                     1,051,648
│    │    └─Linear: 3-6                       1,025
Total params: 15,000,577
Trainable params: 15,000,577
Non-trainable params: 0


In [8]:
from torchmetrics.utilities.compute import auc
from maldi2resistance.loss.asymmetricLoss import AsymmetricLoss
from torchmetrics.classification import BinaryAUROC, BinaryPrecisionRecallCurve
from maldi2resistance.data.ms_data import MS_Data_SingleAntibiotic
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 64
fig_path = Path(f"./{save_folder}/figures")
fig_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(f"./{save_folder}/csv")
csv_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []
criterion = AsymmetricLoss()

gen = torch.Generator()

all_auc_roc_micro = []
all_auc_pr_micro = []

all_auc_roc_macro = []
all_auc_pr_macro = []

for fold, (train_data, test_data) in enumerate(ms_data.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    train_dataset_single_antibiotic = MS_Data_SingleAntibiotic(ms_data=train_data, use_morganFingerprint4Drug=True)
    test_dataset_single_antibiotic = MS_Data_SingleAntibiotic(ms_data=test_data, use_morganFingerprint4Drug=True)
    
    train_loader = DataLoader(train_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30), leave= False, position=1):
        overall_loss = 0
        
        for batch_idx, (x, y, pos) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            pos = pos.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x, pos)

            loss = criterion(output, y)
            #loss = F.binary_cross_entropy_with_logits(torch.squeeze(output), y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
        
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels, test_pos = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    test_pos = test_pos.to(DEVICE)
    model = model.to(DEVICE)
    
    output = model(test_features, test_pos)
    output = torch.squeeze(output)
    
    auRoc = BinaryAUROC()
    auc_roc = auRoc(output, test_labels.int())
    all_auc_roc_micro.append(auc_roc)
    
    metric = BinaryPrecisionRecallCurve()
    metric.update(output, test_labels.int())
    precision, recall, thresholds = metric.compute()
    aucPC = auc(recall, precision)
    
    all_auc_pr_micro.append(aucPC)
    
    macro_aucroc = 0
    macro_aucpr = 0
    
    n_not_empty = 0
    
    test_dataset_single_antibiotic = MS_Data_SingleAntibiotic(ms_data=test_data, use_morganFingerprint4Drug= False)
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True)
    _, _, test_pos = next(iter(test_loader))
    
    outcome = []
    for pos, antibiotic in enumerate(ms_data.selected_antibiotics):
        out_part = output[test_pos == pos]
        label_part = test_labels[test_pos == pos]
        
        occurrences = torch.bincount(label_part.int())
        try:
            if label_part.numel() == 0 or occurrences[1].item() == 0 or occurrences[0].item()== 0:
                continue
        except IndexError:
            # no positives
            continue
        
        au_roc = auRoc(out_part, label_part.int())
        
        metric.update(out_part, label_part.int())
        precision, recall, thresholds = metric.compute()
        aucPC = auc(recall, precision)
        
        n_not_empty +=1
        macro_aucroc += au_roc.item()
        macro_aucpr += aucPC.item()
    
        outcome.append({
            'antibiotics': antibiotic,
            'AUROC': au_roc.detach().item(),
            'AUPRC': aucPC.detach().item(),
            'Susceptible': occurrences[0].detach().item(),
            'Resistance': occurrences[1].detach().item(),
        })
    pd.DataFrame().from_dict(outcome).to_csv(f"{csv_path}/cv{fold}.csv")
        
        
    macro_aucroc = macro_aucroc / n_not_empty
    macro_aucpr = macro_aucpr / n_not_empty
    
    all_auc_roc_macro.append(macro_aucroc)
    all_auc_pr_macro.append(macro_aucpr)
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)

Start training ...


Create single label Dataset:   0%|          | 0/18935 [00:00<?, ?it/s]



Create single label Dataset:   0%|          | 0/4734 [00:00<?, ?it/s]



  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 95737627.516128 	Learning rate: 0.001000
	Average Loss: 10035.201840 	Learning rate: 0.001000
	Average Loss: 5712.775135 	Learning rate: 0.001000
	Average Loss: 5712.700149 	Learning rate: 0.001000
	Average Loss: 5714.970954 	Learning rate: 0.001000
	Average Loss: 5710.764998 	Learning rate: 0.001000
	Average Loss: 5712.701087 	Learning rate: 0.001000
	Average Loss: 5713.818888 	Learning rate: 0.001000
	Average Loss: 5711.633784 	Learning rate: 0.001000
	Average Loss: 5711.734707 	Learning rate: 0.000500
	Average Loss: 5711.930299 	Learning rate: 0.000500
	Average Loss: 5711.784824 	Learning rate: 0.000500
	Average Loss: 5711.754985 	Learning rate: 0.000500
	Average Loss: 5713.081252 	Learning rate: 0.000500
	Average Loss: 5712.859024 	Learning rate: 0.000500
	Average Loss: 5711.521234 	Learning rate: 0.000500
	Average Loss: 5715.577778 	Learning rate: 0.000500
	Average Loss: 5712.208826 	Learning rate: 0.000500
	Average Loss: 5712.502111 	Learning rate: 0.000500
	Averag

Create single label Dataset:   0%|          | 0/4734 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/18935 [00:00<?, ?it/s]



Create single label Dataset:   0%|          | 0/4734 [00:00<?, ?it/s]



  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 720669.415973 	Learning rate: 0.001000
	Average Loss: 3452.639057 	Learning rate: 0.001000
	Average Loss: 3484.712386 	Learning rate: 0.001000
	Average Loss: 3484.651631 	Learning rate: 0.001000
	Average Loss: 3484.013930 	Learning rate: 0.001000
	Average Loss: 3484.674325 	Learning rate: 0.001000
	Average Loss: 3484.377790 	Learning rate: 0.001000
	Average Loss: 3484.979202 	Learning rate: 0.001000
	Average Loss: 3485.878761 	Learning rate: 0.001000
	Average Loss: 3484.244678 	Learning rate: 0.000500
	Average Loss: 3484.522518 	Learning rate: 0.000500
	Average Loss: 3484.957158 	Learning rate: 0.000500
	Average Loss: 3484.870363 	Learning rate: 0.000500
	Average Loss: 3484.353580 	Learning rate: 0.000500
	Average Loss: 3485.196716 	Learning rate: 0.000500
	Average Loss: 3485.362126 	Learning rate: 0.000500
	Average Loss: 3483.650211 	Learning rate: 0.000500
	Average Loss: 3483.599630 	Learning rate: 0.000500
	Average Loss: 3484.293889 	Learning rate: 0.000500
	Average L

Create single label Dataset:   0%|          | 0/4734 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/18935 [00:00<?, ?it/s]



Create single label Dataset:   0%|          | 0/4734 [00:00<?, ?it/s]



  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 1462972.304897 	Learning rate: 0.001000
	Average Loss: 8741089.412104 	Learning rate: 0.001000
	Average Loss: 8032.461654 	Learning rate: 0.001000
	Average Loss: 3716.511769 	Learning rate: 0.001000
	Average Loss: 3716.597864 	Learning rate: 0.001000
	Average Loss: 3716.890528 	Learning rate: 0.001000
	Average Loss: 3717.439830 	Learning rate: 0.001000
	Average Loss: 3716.191445 	Learning rate: 0.001000
	Average Loss: 3718.362800 	Learning rate: 0.001000
	Average Loss: 3717.877165 	Learning rate: 0.000500
	Average Loss: 3716.773910 	Learning rate: 0.000500
	Average Loss: 3717.617223 	Learning rate: 0.000500
	Average Loss: 3718.677333 	Learning rate: 0.000500
	Average Loss: 15783.890537 	Learning rate: 0.000500
	Average Loss: 4850.719524 	Learning rate: 0.000500
	Average Loss: 4900.826959 	Learning rate: 0.000500
	Average Loss: 4901.201754 	Learning rate: 0.000500
	Average Loss: 4902.160042 	Learning rate: 0.000500
	Average Loss: 4901.609258 	Learning rate: 0.000500
	Aver

Create single label Dataset:   0%|          | 0/4734 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/18935 [00:00<?, ?it/s]



Create single label Dataset:   0%|          | 0/4734 [00:00<?, ?it/s]



  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 422098.568548 	Learning rate: 0.001000
	Average Loss: 3523.255658 	Learning rate: 0.001000
	Average Loss: 3523.739886 	Learning rate: 0.001000
	Average Loss: 3523.318287 	Learning rate: 0.001000
	Average Loss: 3524.172401 	Learning rate: 0.001000
	Average Loss: 3523.955031 	Learning rate: 0.001000
	Average Loss: 154664.452837 	Learning rate: 0.001000
	Average Loss: 3670.957149 	Learning rate: 0.001000
	Average Loss: 3670.996710 	Learning rate: 0.001000
	Average Loss: 3671.504874 	Learning rate: 0.000500
	Average Loss: 3670.323322 	Learning rate: 0.000500
	Average Loss: 3670.574160 	Learning rate: 0.000500
	Average Loss: 3670.429950 	Learning rate: 0.000500
	Average Loss: 3671.502355 	Learning rate: 0.000500
	Average Loss: 3669.410777 	Learning rate: 0.000500
	Average Loss: 3670.728084 	Learning rate: 0.000500
	Average Loss: 3671.495995 	Learning rate: 0.000500
	Average Loss: 3671.188421 	Learning rate: 0.000500
	Average Loss: 3672.010565 	Learning rate: 0.000500
	Average

Create single label Dataset:   0%|          | 0/4734 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/18936 [00:00<?, ?it/s]



Create single label Dataset:   0%|          | 0/4733 [00:00<?, ?it/s]



  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 22480956.378198 	Learning rate: 0.001000
	Average Loss: 4189.189225 	Learning rate: 0.001000
	Average Loss: 4186.871941 	Learning rate: 0.001000
	Average Loss: 12830.335259 	Learning rate: 0.001000
	Average Loss: 6218.521969 	Learning rate: 0.001000
	Average Loss: 6208.011375 	Learning rate: 0.001000
	Average Loss: 6478.335999 	Learning rate: 0.001000
	Average Loss: 6280.108699 	Learning rate: 0.001000
	Average Loss: 6279.863121 	Learning rate: 0.001000
	Average Loss: 6279.202062 	Learning rate: 0.000500
	Average Loss: 6280.877349 	Learning rate: 0.000500
	Average Loss: 6280.672682 	Learning rate: 0.000500
	Average Loss: 6279.118839 	Learning rate: 0.000500
	Average Loss: 6278.209795 	Learning rate: 0.000500
	Average Loss: 6278.357263 	Learning rate: 0.000500
	Average Loss: 6278.095699 	Learning rate: 0.000500
	Average Loss: 6278.795455 	Learning rate: 0.000500
	Average Loss: 6278.934365 	Learning rate: 0.000500
	Average Loss: 6279.929854 	Learning rate: 0.000500
	Averag

Create single label Dataset:   0%|          | 0/4733 [00:00<?, ?it/s]

In [11]:
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_roc_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_roc_micro]) }")
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_pr_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_pr_micro]) }")

Mean	: 0.9333194494247437
 SD 	: 0.0014133455697447062
Mean	: 0.907570481300354
 SD 	: 0.0023374252486974


In [12]:
print(f"Mean\t: {np.mean([x for x in all_auc_roc_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_roc_macro]) }")
print(f"Mean\t: {np.mean([x for x in all_auc_pr_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_pr_macro]) }")

Mean	: 0.9331997494948538
 SD 	: 0.001337751375321144
Mean	: 0.9074487259513454
 SD 	: 0.0026486777480110115
