# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7efeceb9afb0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)
umg = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
    bin_size=1,
    sites=["UMG"],
    years=[2020,2021],
    antibiotics=driams.selected_antibiotics,
)

driams.loading_type = "memory"
umg.loading_type = "memory"

driams



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/73745 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
from multimodal_amr.models.modules import ResMLP
from torch import nn


class Residual_AMR_Classifier(nn.Module):
    """Overall model definition."""

    def __init__(self, config, n_input_spectrum, n_input_drug):
        super().__init__()
        self.config = config

        # Sample embedding
        if config["species_embedding_dim"]==0 and config["conv_out_size"]==config["sample_embedding_dim"]:
            self.sample_emb = nn.Identity()
        elif config["species_embedding_dim"]==0:
            self.sample_emb = nn.Linear(
                config["conv_out_size"],
                config["sample_embedding_dim"],
            )
        else:
            self.sample_emb = nn.Linear(
                config["conv_out_size"] + config["species_embedding_dim"],
                config["sample_embedding_dim"],
            )

        self.spectrum_emb = nn.Linear(n_input_spectrum ,config["conv_out_size"])

        # Drugs layers
        # if config["drug_emb_type"] == "vae_embedding" or config["drug_emb_type"] == "gnn_embedding":
        #     self.drug_emb = nn.Identity()
        # elif config["drug_emb_type"] == "fingerprint":
        self.drug_emb = nn.Linear(
            n_input_drug, config["drug_embedding_dim"]
        )

        # Output network
        self.net = ResMLP(
            config["n_hidden_layers"],
            config["sample_embedding_dim"] + config["drug_embedding_dim"],
            1,
            p_dropout=0.2,
        )

    def forward(self, spectrum, fingerprint):
        spectrum_embedding = self.spectrum_emb(spectrum)
        dr_emb = self.drug_emb(fingerprint)

        return self.net(torch.cat([dr_emb, spectrum_embedding], dim=1))




conf = {
    "conv_out_size": 512,
    "species_embedding_dim":0,
    "sample_embedding_dim":512,
    "drug_embedding_dim":512,
    "n_hidden_layers": 5,
}
        

In [None]:
import copy

model = Residual_AMR_Classifier(config= conf,n_input_spectrum = driams.n_bins, n_input_drug= 1024)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                        Param #
Residual_AMR_Classifier                       --
├─Identity: 1-1                               --
├─Linear: 1-2                                 9,216,512
├─Linear: 1-3                                 524,800
├─ResMLP: 1-4                                 --
│    └─Sequential: 2-1                        --
│    │    └─ResBlock: 3-1                     1,051,648
│    │    └─ResBlock: 3-2                     1,051,648
│    │    └─ResBlock: 3-3                     1,051,648
│    │    └─ResBlock: 3-4                     1,051,648
│    │    └─ResBlock: 3-5                     1,051,648
│    │    └─Linear: 3-6                       1,025
Total params: 15,000,577
Trainable params: 15,000,577
Non-trainable params: 0


In [None]:
from torch.utils.data import DataLoader

In [None]:
from torchmetrics.utilities.compute import auc
from torchmetrics.classification import BinaryAUROC, BinaryPrecisionRecallCurve
from maldi2resistance.data.driams import DriamsSingleAntibiotic
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

gen = torch.Generator()

all_auc_roc_micro = []
all_auc_pr_micro = []

all_auc_roc_macro = []
all_auc_pr_macro = []

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    train_data_umg = torch.utils.data.ConcatDataset([train_data, umg])
    
    train_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= train_data_umg, use_morganFingerprint4Drug=True)
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data, use_morganFingerprint4Drug=True)
    
    train_loader = DataLoader(train_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30), leave= False, position=1):
        overall_loss = 0
        
        for batch_idx, (x, y, pos) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            pos = pos.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x, pos)
        
            loss = F.binary_cross_entropy_with_logits(torch.squeeze(output), y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
        
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels, test_pos = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    test_pos = test_pos.to(DEVICE)
    model = model.to(DEVICE)
    
    output = model(test_features, test_pos)
    output = torch.squeeze(output)
    
    auRoc = BinaryAUROC()
    auc_roc = auRoc(output, test_labels.int())
    all_auc_roc_micro.append(auc_roc)
    
    metric = BinaryPrecisionRecallCurve()
    metric.update(output, test_labels.int())
    precision, recall, thresholds = metric.compute()
    aucPC = auc(recall, precision)
    
    all_auc_pr_micro.append(aucPC)
    
    macro_aucroc = 0
    macro_aucpr = 0
    
    n_not_empty = 0
    
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data, use_morganFingerprint4Drug= False)
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True)
    _, _, test_pos = next(iter(test_loader))
    
    for pos, antibiotic in enumerate(driams.selected_antibiotics):
        out_part = output[test_pos == pos]
        label_part = test_labels[test_pos == pos]
        
        occurrences = torch.bincount(label_part.int())

        try:
            if label_part.numel() == 0 or occurrences[1].item() == 0 or occurrences[0].item()== 0:
                continue
        except IndexError:
            # no positives
            continue
        
        au_roc = auRoc(out_part, label_part.int())
        
        metric.update(out_part, label_part.int())
        precision, recall, thresholds = metric.compute()
        aucPC = auc(recall, precision)
        
        n_not_empty +=1
        macro_aucroc += au_roc.item()
        macro_aucpr += aucPC.item()
        
        
    macro_aucroc = macro_aucroc / n_not_empty
    macro_aucpr = macro_aucpr / n_not_empty
    
    all_auc_roc_macro.append(macro_aucroc)
    all_auc_pr_macro.append(macro_aucpr)
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)
    

Start training ...


Create single label Dataset:   0%|          | 0/118369 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.002513 	Learning rate: 0.001000
	Average Loss: 0.002010 	Learning rate: 0.001000
	Average Loss: 0.001838 	Learning rate: 0.001000
	Average Loss: 0.001713 	Learning rate: 0.001000
	Average Loss: 0.001623 	Learning rate: 0.001000
	Average Loss: 0.001547 	Learning rate: 0.001000
	Average Loss: 0.001488 	Learning rate: 0.001000
	Average Loss: 0.001434 	Learning rate: 0.001000
	Average Loss: 0.001376 	Learning rate: 0.001000
	Average Loss: 0.001326 	Learning rate: 0.000500
	Average Loss: 0.001179 	Learning rate: 0.000500
	Average Loss: 0.001129 	Learning rate: 0.000500
	Average Loss: 0.001094 	Learning rate: 0.000500
	Average Loss: 0.001057 	Learning rate: 0.000500
	Average Loss: 0.001026 	Learning rate: 0.000500
	Average Loss: 0.000996 	Learning rate: 0.000500
	Average Loss: 0.000967 	Learning rate: 0.000500
	Average Loss: 0.000938 	Learning rate: 0.000500
	Average Loss: 0.000911 	Learning rate: 0.000500
	Average Loss: 0.000885 	Learning rate: 0.000250
	Average Loss: 0.000

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/118369 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.002521 	Learning rate: 0.001000
	Average Loss: 0.002017 	Learning rate: 0.001000
	Average Loss: 0.001845 	Learning rate: 0.001000
	Average Loss: 0.001726 	Learning rate: 0.001000
	Average Loss: 0.001637 	Learning rate: 0.001000
	Average Loss: 0.001563 	Learning rate: 0.001000
	Average Loss: 0.001497 	Learning rate: 0.001000
	Average Loss: 0.001441 	Learning rate: 0.001000
	Average Loss: 0.001386 	Learning rate: 0.001000
	Average Loss: 0.001335 	Learning rate: 0.000500
	Average Loss: 0.001184 	Learning rate: 0.000500
	Average Loss: 0.001138 	Learning rate: 0.000500
	Average Loss: 0.001100 	Learning rate: 0.000500
	Average Loss: 0.001064 	Learning rate: 0.000500
	Average Loss: 0.001032 	Learning rate: 0.000500
	Average Loss: 0.001003 	Learning rate: 0.000500
	Average Loss: 0.000972 	Learning rate: 0.000500
	Average Loss: 0.000944 	Learning rate: 0.000500
	Average Loss: 0.000919 	Learning rate: 0.000500
	Average Loss: 0.000891 	Learning rate: 0.000250
	Average Loss: 0.000

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/118369 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.002533 	Learning rate: 0.001000
	Average Loss: 0.002025 	Learning rate: 0.001000
	Average Loss: 0.001854 	Learning rate: 0.001000
	Average Loss: 0.001730 	Learning rate: 0.001000
	Average Loss: 0.001639 	Learning rate: 0.001000
	Average Loss: 0.001569 	Learning rate: 0.001000
	Average Loss: 0.001502 	Learning rate: 0.001000
	Average Loss: 0.001446 	Learning rate: 0.001000
	Average Loss: 0.001391 	Learning rate: 0.001000
	Average Loss: 0.001340 	Learning rate: 0.000500
	Average Loss: 0.001193 	Learning rate: 0.000500
	Average Loss: 0.001147 	Learning rate: 0.000500
	Average Loss: 0.001112 	Learning rate: 0.000500
	Average Loss: 0.001077 	Learning rate: 0.000500
	Average Loss: 0.001048 	Learning rate: 0.000500
	Average Loss: 0.001014 	Learning rate: 0.000500
	Average Loss: 0.000989 	Learning rate: 0.000500
	Average Loss: 0.000961 	Learning rate: 0.000500
	Average Loss: 0.000935 	Learning rate: 0.000500
	Average Loss: 0.000912 	Learning rate: 0.000250
	Average Loss: 0.000

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/118369 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.002516 	Learning rate: 0.001000
	Average Loss: 0.002016 	Learning rate: 0.001000
	Average Loss: 0.001849 	Learning rate: 0.001000
	Average Loss: 0.001727 	Learning rate: 0.001000
	Average Loss: 0.001640 	Learning rate: 0.001000
	Average Loss: 0.001561 	Learning rate: 0.001000
	Average Loss: 0.001499 	Learning rate: 0.001000
	Average Loss: 0.001440 	Learning rate: 0.001000
	Average Loss: 0.001387 	Learning rate: 0.001000
	Average Loss: 0.001338 	Learning rate: 0.000500
	Average Loss: 0.001190 	Learning rate: 0.000500
	Average Loss: 0.001145 	Learning rate: 0.000500
	Average Loss: 0.001106 	Learning rate: 0.000500
	Average Loss: 0.001074 	Learning rate: 0.000500
	Average Loss: 0.001040 	Learning rate: 0.000500
	Average Loss: 0.001009 	Learning rate: 0.000500
	Average Loss: 0.000982 	Learning rate: 0.000500
	Average Loss: 0.000952 	Learning rate: 0.000500
	Average Loss: 0.000925 	Learning rate: 0.000500
	Average Loss: 0.000900 	Learning rate: 0.000250
	Average Loss: 0.000

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/118369 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.002516 	Learning rate: 0.001000
	Average Loss: 0.002018 	Learning rate: 0.001000
	Average Loss: 0.001848 	Learning rate: 0.001000
	Average Loss: 0.001727 	Learning rate: 0.001000
	Average Loss: 0.001635 	Learning rate: 0.001000
	Average Loss: 0.001561 	Learning rate: 0.001000
	Average Loss: 0.001504 	Learning rate: 0.001000
	Average Loss: 0.001444 	Learning rate: 0.001000
	Average Loss: 0.001388 	Learning rate: 0.001000
	Average Loss: 0.001337 	Learning rate: 0.000500
	Average Loss: 0.001188 	Learning rate: 0.000500
	Average Loss: 0.001143 	Learning rate: 0.000500
	Average Loss: 0.001104 	Learning rate: 0.000500
	Average Loss: 0.001069 	Learning rate: 0.000500
	Average Loss: 0.001038 	Learning rate: 0.000500
	Average Loss: 0.001008 	Learning rate: 0.000500
	Average Loss: 0.000978 	Learning rate: 0.000500
	Average Loss: 0.000950 	Learning rate: 0.000500
	Average Loss: 0.000922 	Learning rate: 0.000500
	Average Loss: 0.000898 	Learning rate: 0.000250
	Average Loss: 0.000

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_roc_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_roc_micro]) }")

Mean	: 0.925898551940918
 SD 	: 0.002435918664559722


In [None]:
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_pr_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_pr_micro]) }")

Mean	: 0.85932856798172
 SD 	: 0.006041183602064848


In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x for x in all_auc_roc_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_roc_macro]) }")

Mean	: 0.925540589345129
 SD 	: 0.0035438631055176298


In [None]:
print(f"Mean\t: {np.mean([x for x in all_auc_pr_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_pr_macro]) }")

Mean	: 0.8594055869077382
 SD 	: 0.005953843206352955
