# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7f6700356fb0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)

driams.loading_type = "memory"

driams

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
from maldi2resistance.model.dualBranch import DualBranchOneHot
import copy
from maldi2resistance.model.singleBranchMlp import SingleBranchMLP

model = DualBranchOneHot( input_dim_spectrum= 18000, input_dim_drug= len(driams.selected_antibiotics))
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torch.utils.data import DataLoader

In [None]:
from torchmetrics.utilities.compute import auc
from torchmetrics.classification import BinaryAUROC, BinaryPrecisionRecallCurve
from maldi2resistance.data.driams import DriamsSingleAntibiotic
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

gen = torch.Generator()

all_auc_roc_micro = []
all_auc_pr_micro = []

all_auc_roc_macro = []
all_auc_pr_macro = []

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= train_data)
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data)
    
    train_loader = DataLoader(train_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30), leave= False, position=1):
        overall_loss = 0
        
        for batch_idx, (x, y, pos) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            pos = pos.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x, pos)
        
            loss = F.binary_cross_entropy_with_logits(output, y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels, test_pos = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    test_pos = test_pos.to(DEVICE)
    model = model.to(DEVICE)
    
    output = model(test_features, test_pos)
    
    auRoc = BinaryAUROC()
    auc_roc = auRoc(output, test_labels.int())
    all_auc_roc_micro.append(auc_roc)
    
    metric = BinaryPrecisionRecallCurve()
    metric.update(output, test_labels.int())
    precision, recall, thresholds = metric.compute()
    aucPC = auc(recall, precision)
    
    all_auc_pr_micro.append(aucPC)
    
    macro_aucroc = 0
    macro_aucpr = 0
    
    n_not_empty = 0
    
    for pos, antibiotic in enumerate(driams.selected_antibiotics):
        out_part = output[test_pos == pos]
        label_part = test_labels[test_pos == pos]
        
        occurrences = torch.bincount(label_part.int())

        try:
            if label_part.numel() == 0 or occurrences[1].item() == 0 or occurrences[0].item()== 0:
                continue
        except IndexError:
            # no positives
            continue
        
        au_roc = auRoc(out_part, label_part.int())
        
        metric.update(out_part, label_part.int())
        precision, recall, thresholds = metric.compute()
        aucPC = auc(recall, precision)
        
        n_not_empty +=1
        macro_aucroc += au_roc.item()
        macro_aucpr += aucPC.item()
        
        
    macro_aucroc = macro_aucroc / n_not_empty
    macro_aucpr = macro_aucpr / n_not_empty
    
    all_auc_roc_macro.append(macro_aucroc)
    all_auc_pr_macro.append(macro_aucpr)
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)
    

Start training ...


Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.002510 	Learning rate: 0.001000
	Average Loss: 0.002118 	Learning rate: 0.001000
	Average Loss: 0.001957 	Learning rate: 0.001000
	Average Loss: 0.001845 	Learning rate: 0.001000
	Average Loss: 0.001762 	Learning rate: 0.001000
	Average Loss: 0.001691 	Learning rate: 0.001000
	Average Loss: 0.001632 	Learning rate: 0.001000
	Average Loss: 0.001579 	Learning rate: 0.001000
	Average Loss: 0.001535 	Learning rate: 0.001000
	Average Loss: 0.001504 	Learning rate: 0.000500
	Average Loss: 0.001387 	Learning rate: 0.000500
	Average Loss: 0.001352 	Learning rate: 0.000500
	Average Loss: 0.001332 	Learning rate: 0.000500
	Average Loss: 0.001312 	Learning rate: 0.000500
	Average Loss: 0.001290 	Learning rate: 0.000500
	Average Loss: 0.001268 	Learning rate: 0.000500
	Average Loss: 0.001251 	Learning rate: 0.000500
	Average Loss: 0.001231 	Learning rate: 0.000500
	Average Loss: 0.001220 	Learning rate: 0.000500
	Average Loss: 0.001200 	Learning rate: 0.000250
	Average Loss: 0.001

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.002516 	Learning rate: 0.001000
	Average Loss: 0.002126 	Learning rate: 0.001000
	Average Loss: 0.001970 	Learning rate: 0.001000
	Average Loss: 0.001864 	Learning rate: 0.001000
	Average Loss: 0.001776 	Learning rate: 0.001000
	Average Loss: 0.001716 	Learning rate: 0.001000
	Average Loss: 0.001645 	Learning rate: 0.001000
	Average Loss: 0.001601 	Learning rate: 0.001000
	Average Loss: 0.001557 	Learning rate: 0.001000
	Average Loss: 0.001514 	Learning rate: 0.000500
	Average Loss: 0.001396 	Learning rate: 0.000500
	Average Loss: 0.001367 	Learning rate: 0.000500
	Average Loss: 0.001338 	Learning rate: 0.000500
	Average Loss: 0.001318 	Learning rate: 0.000500
	Average Loss: 0.001296 	Learning rate: 0.000500
	Average Loss: 0.001279 	Learning rate: 0.000500
	Average Loss: 0.001259 	Learning rate: 0.000500
	Average Loss: 0.001238 	Learning rate: 0.000500
	Average Loss: 0.001224 	Learning rate: 0.000500
	Average Loss: 0.001205 	Learning rate: 0.000250
	Average Loss: 0.001

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.002516 	Learning rate: 0.001000
	Average Loss: 0.002131 	Learning rate: 0.001000
	Average Loss: 0.001972 	Learning rate: 0.001000
	Average Loss: 0.001856 	Learning rate: 0.001000
	Average Loss: 0.001776 	Learning rate: 0.001000
	Average Loss: 0.001701 	Learning rate: 0.001000
	Average Loss: 0.001643 	Learning rate: 0.001000
	Average Loss: 0.001596 	Learning rate: 0.001000
	Average Loss: 0.001553 	Learning rate: 0.001000
	Average Loss: 0.001510 	Learning rate: 0.000500
	Average Loss: 0.001393 	Learning rate: 0.000500
	Average Loss: 0.001363 	Learning rate: 0.000500
	Average Loss: 0.001339 	Learning rate: 0.000500
	Average Loss: 0.001320 	Learning rate: 0.000500
	Average Loss: 0.001298 	Learning rate: 0.000500
	Average Loss: 0.001271 	Learning rate: 0.000500
	Average Loss: 0.001254 	Learning rate: 0.000500
	Average Loss: 0.001237 	Learning rate: 0.000500
	Average Loss: 0.001217 	Learning rate: 0.000500
	Average Loss: 0.001205 	Learning rate: 0.000250
	Average Loss: 0.001

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.002505 	Learning rate: 0.001000
	Average Loss: 0.002119 	Learning rate: 0.001000
	Average Loss: 0.001960 	Learning rate: 0.001000
	Average Loss: 0.001849 	Learning rate: 0.001000
	Average Loss: 0.001763 	Learning rate: 0.001000
	Average Loss: 0.001698 	Learning rate: 0.001000
	Average Loss: 0.001639 	Learning rate: 0.001000
	Average Loss: 0.001593 	Learning rate: 0.001000
	Average Loss: 0.001547 	Learning rate: 0.001000
	Average Loss: 0.001508 	Learning rate: 0.000500
	Average Loss: 0.001384 	Learning rate: 0.000500
	Average Loss: 0.001359 	Learning rate: 0.000500
	Average Loss: 0.001335 	Learning rate: 0.000500
	Average Loss: 0.001309 	Learning rate: 0.000500
	Average Loss: 0.001291 	Learning rate: 0.000500
	Average Loss: 0.001268 	Learning rate: 0.000500
	Average Loss: 0.001254 	Learning rate: 0.000500
	Average Loss: 0.001231 	Learning rate: 0.000500
	Average Loss: 0.001216 	Learning rate: 0.000500
	Average Loss: 0.001199 	Learning rate: 0.000250
	Average Loss: 0.001

Create single label Dataset:   0%|          | 0/44624 [00:01<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.002509 	Learning rate: 0.001000
	Average Loss: 0.002128 	Learning rate: 0.001000
	Average Loss: 0.001974 	Learning rate: 0.001000
	Average Loss: 0.001866 	Learning rate: 0.001000
	Average Loss: 0.001778 	Learning rate: 0.001000
	Average Loss: 0.001706 	Learning rate: 0.001000
	Average Loss: 0.001646 	Learning rate: 0.001000
	Average Loss: 0.001597 	Learning rate: 0.001000
	Average Loss: 0.001551 	Learning rate: 0.001000
	Average Loss: 0.001511 	Learning rate: 0.000500
	Average Loss: 0.001401 	Learning rate: 0.000500
	Average Loss: 0.001366 	Learning rate: 0.000500
	Average Loss: 0.001347 	Learning rate: 0.000500
	Average Loss: 0.001322 	Learning rate: 0.000500
	Average Loss: 0.001303 	Learning rate: 0.000500
	Average Loss: 0.001279 	Learning rate: 0.000500
	Average Loss: 0.001262 	Learning rate: 0.000500
	Average Loss: 0.001242 	Learning rate: 0.000500
	Average Loss: 0.001230 	Learning rate: 0.000500
	Average Loss: 0.001211 	Learning rate: 0.000250
	Average Loss: 0.001

In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_roc_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_roc_micro]) }")

Mean	: 0.9237663149833679
 SD 	: 0.0025326688773930073


In [None]:
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_pr_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_pr_micro]) }")

Mean	: 0.8487552404403687
 SD 	: 0.006737273186445236


In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x for x in all_auc_roc_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_roc_macro]) }")

Mean	: 0.8945469326094578
 SD 	: 0.0015447874055492755


In [None]:
print(f"Mean\t: {np.mean([x for x in all_auc_pr_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_pr_macro]) }")

Mean	: 0.8549697016414843
 SD 	: 0.006542385950118931
