In [None]:
from pathlib import Path
import matplotlib.style
import matplotlib as mpl
import torch
from tqdm.auto import tqdm

DEVICE = torch.device("cpu")
SEED = 76436278

torch.manual_seed(SEED)
mpl.style.use("default")

checkpoint = torch.load('./model.pt')


In [None]:
from maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
    # antibiotics= ['Ciprofloxacin', 'Ceftriaxone', "Cefepime", "Piperacillin-Tazobactam", "Tobramycin"]
)

driams.loading_type = "memory"

train_size = int(0.8 * len(driams))
test_size = len(driams) - train_size

driams

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
from maldi2resistance.data.driams import DriamsSingleAntibiotic
from torch.utils.data import DataLoader

gen = torch.Generator()


batch_size = 128

train_size = int(0.8 * len(driams))
test_size = len(driams) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(driams, [train_size, test_size], generator=gen.manual_seed(SEED))

test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_dataset, use_morganFingerprint4Drug= True)


Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

In [None]:
driams.label_stats

Unnamed: 0,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,...,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
positive,1068,13366,8578,21966,628,618,4223,7383,2338,3470,...,1205,5537,12431,7616,486,580,244,3534,1707,227
negative,20941,24992,4194,4905,456,457,5813,31567,4382,24566,...,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
n_sum,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,...,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.selected_antibiotics)

38

In [None]:
from torch.utils.data import DataLoader

test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=512, shuffle=True)

In [None]:
from multimodal_amr.models.modules import ResMLP
from torch import nn


class Residual_AMR_Classifier(nn.Module):
    """Overall model definition."""

    def __init__(self, config, n_input_spectrum, n_input_drug):
        super().__init__()
        self.config = config

        # Sample embedding
        if config["species_embedding_dim"]==0 and config["conv_out_size"]==config["sample_embedding_dim"]:
            self.sample_emb = nn.Identity()
        elif config["species_embedding_dim"]==0:
            self.sample_emb = nn.Linear(
                config["conv_out_size"],
                config["sample_embedding_dim"],
            )
        else:
            self.sample_emb = nn.Linear(
                config["conv_out_size"] + config["species_embedding_dim"],
                config["sample_embedding_dim"],
            )

        # Maldi-tof spectrum embedding
        # self.spectrum_emb = Conv1d_Block(output_dim=config["conv_out_size"])
        self.spectrum_emb = nn.Linear(n_input_spectrum ,config["conv_out_size"])

        # Drugs layers
        # if config["drug_emb_type"] == "vae_embedding" or config["drug_emb_type"] == "gnn_embedding":
        #     self.drug_emb = nn.Identity()
        # elif config["drug_emb_type"] == "fingerprint":
        self.drug_emb = nn.Linear(
            n_input_drug, config["drug_embedding_dim"]
        )

        # Output network
        self.net = ResMLP(
            config["n_hidden_layers"],
            config["sample_embedding_dim"] + config["drug_embedding_dim"],
            1,
            p_dropout=0.2,
        )

    def forward(self, spectrum, fingerprint):
        spectrum_embedding = self.spectrum_emb(spectrum)
        dr_emb = self.drug_emb(fingerprint)

        return self.net(torch.cat([dr_emb, spectrum_embedding], dim=1))




conf = {
    "conv_out_size": 512,
    "species_embedding_dim":0,
    "sample_embedding_dim":512,
    "drug_embedding_dim":512,
    "n_hidden_layers": 5,
}
        

In [None]:
from maldi2resistance.model.MLP import AeBasedMLP

model = Residual_AMR_Classifier(config= conf,n_input_spectrum = 18000, n_input_drug= 1024)
model.load_state_dict(checkpoint['model_state_dict'])

model.to(DEVICE)

Residual_AMR_Classifier(
  (sample_emb): Identity()
  (spectrum_emb): Linear(in_features=18000, out_features=512, bias=True)
  (drug_emb): Linear(in_features=1024, out_features=512, bias=True)
  (net): ResMLP(
    (net): Sequential(
      (0): ResBlock(
        (block): Sequential(
          (0): ReLU()
          (1): Linear(in_features=1024, out_features=1024, bias=True)
          (2): Dropout(p=0.2, inplace=False)
          (3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResBlock(
        (block): Sequential(
          (0): ReLU()
          (1): Linear(in_features=1024, out_features=1024, bias=True)
          (2): Dropout(p=0.2, inplace=False)
          (3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): ResBlock(
        (block): Sequential(
          (0): ReLU()
          (1): Linear(in_features=1024, out_features=1024, bias=True)
          (2): Dropou

In [None]:
model = model.to(DEVICE)

In [None]:
output = []
test_labels = []

for test_features, labels, test_pos in tqdm(test_loader, leave=False):
    test_features = test_features.to(DEVICE)
    test_pos = test_pos.to(DEVICE)
    test_labels.append(labels)
    
    result = model(test_features, test_pos).detach().cpu()
    
    output.append(result)


  0%|          | 0/277 [00:00<?, ?it/s]

In [None]:
output = torch.cat(output).squeeze()
test_labels = torch.cat(test_labels).int()

In [None]:
from torchmetrics.classification import BinaryROC

metric = BinaryROC()
metric.update(output, test_labels)
metric.compute()

(tensor([0.0000, 0.0010, 0.0013,  ..., 1.0000, 1.0000, 1.0000]),
 tensor([0.0000, 0.1489, 0.1807,  ..., 1.0000, 1.0000, 1.0000]),
 tensor([1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 6.9361e-24, 3.5459e-24,
         4.4026e-25]))

In [None]:
from torchmetrics.classification import BinaryAUROC

auRoc = BinaryAUROC()
auc_roc = auRoc(torch.sigmoid(output), test_labels)
auc_roc

tensor(0.9175)

In [None]:
test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_dataset, use_morganFingerprint4Drug= False)
test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True)

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

In [None]:
output.shape

torch.Size([141408])

In [None]:
test_pos.shape

torch.Size([141408])

In [None]:
_, _, test_pos = next(iter(test_loader))

In [None]:
roc_auc = {}
macro_aucroc = 0

for pos, antibiotic in enumerate(driams.selected_antibiotics):
    out_part = output[test_pos == pos]
    label_part = test_labels[test_pos == pos]
    au_roc = auRoc(out_part, label_part.int())
    roc_auc[antibiotic] = au_roc
    macro_aucroc += au_roc.item()
    
    
macro_aucroc = macro_aucroc / len(driams.selected_antibiotics)
print(macro_aucroc)

0.917578105863772


In [None]:
roc_auc

{'Amikacin': tensor(0.9166),
 'Amoxicillin-Clavulanic acid': tensor(0.9206),
 'Ampicillin': tensor(0.9249),
 'Ampicillin-Amoxicillin': tensor(0.9189),
 'Aztreonam': tensor(0.9378),
 'Benzylpenicillin': tensor(0.9326),
 'Cefazolin': tensor(0.9040),
 'Cefepime': tensor(0.9245),
 'Cefpodoxime': tensor(0.9241),
 'Ceftazidime': tensor(0.9169),
 'Ceftriaxone': tensor(0.9193),
 'Cefuroxime': tensor(0.9175),
 'Ciprofloxacin': tensor(0.9141),
 'Clarithromycin': tensor(0.9335),
 'Clindamycin': tensor(0.9195),
 'Colistin': tensor(0.9123),
 'Cotrimoxazole': tensor(0.9153),
 'Ertapenem': tensor(0.9172),
 'Erythromycin': tensor(0.9132),
 'Fosfomycin': tensor(0.9134),
 'Fosfomycin-Trometamol': tensor(0.9146),
 'Fusidic acid': tensor(0.9010),
 'Gentamicin': tensor(0.9143),
 'Imipenem': tensor(0.9193),
 'Levofloxacin': tensor(0.9229),
 'Meropenem': tensor(0.9221),
 'Mupirocin': tensor(0.9040),
 'Nitrofurantoin': tensor(0.9083),
 'Norfloxacin': tensor(0.9155),
 'Oxacillin': tensor(0.9127),
 'Penicillin'