# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7f87a3506fb0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)

driams.loading_type = "memory"

driams



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
from maldi2resistance.model.dualBranch import DualBranchOneHot
import copy
from maldi2resistance.model.singleBranchMlp import SingleBranchMLP

model = DualBranchOneHot( input_dim_spectrum= 18000, input_dim_drug= len(driams.selected_antibiotics))
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torch.utils.data import DataLoader

In [None]:
from maldi2resistance.loss.asymmetricLoss import AsymmetricLoss
from torchmetrics.classification import BinaryAUROC, BinaryAveragePrecision
from maldi2resistance.data.driams import DriamsSingleAntibiotic
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

gen = torch.Generator()

all_auc_roc_micro = []
all_auc_pr_micro = []

all_auc_roc_macro = []
all_auc_pr_macro = []

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= train_data)
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data)
    
    train_loader = DataLoader(train_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
    metric = AsymmetricLoss()

    for epoch in tqdm(range(30), leave= False, position=1):
        overall_loss = 0
        
        for batch_idx, (x, y, pos) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            pos = pos.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x, pos)
        
            loss = metric(output, y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels, test_pos = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    test_pos = test_pos.to(DEVICE)
    model = model.to(DEVICE)
    
    output = model(test_features, test_pos)
    
    auRoc = BinaryAUROC()
    auc_roc = auRoc(output, test_labels.int())
    all_auc_roc_micro.append(auc_roc)
    
    auPR = BinaryAveragePrecision()
    aucPC = auPR(output, test_labels.int())
    all_auc_pr_micro.append(aucPC)
    
    macro_aucroc = 0
    macro_aucpr = 0
    
    n_not_empty = 0
    
    for pos, antibiotic in enumerate(driams.selected_antibiotics):
        out_part = output[test_pos == pos]
        label_part = test_labels[test_pos == pos]
        
        occurrences = torch.bincount(label_part.int())

        try:
            if label_part.numel() == 0 or occurrences[1].item() == 0 or occurrences[0].item()== 0:
                continue
        except IndexError:
            # no positives
            continue
        
        auRoc = BinaryAUROC()
        au_roc = auRoc(out_part, label_part.int())
        
        auPR = BinaryAveragePrecision()
        aucPC = auPR(out_part, label_part.int())
        
        n_not_empty +=1
        macro_aucroc += au_roc.item()
        macro_aucpr += aucPC.item()
        
        
    macro_aucroc = macro_aucroc / n_not_empty
    macro_aucpr = macro_aucpr / n_not_empty
    
    all_auc_roc_macro.append(macro_aucroc)
    all_auc_pr_macro.append(macro_aucpr)
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)
    

Start training ...


Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 148.447481 	Learning rate: 0.001000
	Average Loss: 93.102000 	Learning rate: 0.001000
	Average Loss: 1887.965994 	Learning rate: 0.001000
	Average Loss: 3259.843283 	Learning rate: 0.001000
	Average Loss: 485.612477 	Learning rate: 0.001000
	Average Loss: 4.699589 	Learning rate: 0.001000
	Average Loss: 218.451229 	Learning rate: 0.001000
	Average Loss: 139.064096 	Learning rate: 0.001000
	Average Loss: 428.892707 	Learning rate: 0.001000
	Average Loss: 654.423848 	Learning rate: 0.000500
	Average Loss: 19.724870 	Learning rate: 0.000500
	Average Loss: 2.812002 	Learning rate: 0.000500
	Average Loss: 2.424202 	Learning rate: 0.000500
	Average Loss: 2.249009 	Learning rate: 0.000500
	Average Loss: 2.581475 	Learning rate: 0.000500
	Average Loss: 2.622135 	Learning rate: 0.000500
	Average Loss: 2.454060 	Learning rate: 0.000500
	Average Loss: 2.387503 	Learning rate: 0.000500
	Average Loss: 2.304877 	Learning rate: 0.000500
	Average Loss: 2.244080 	Learning rate: 0.000250


Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 92.156180 	Learning rate: 0.001000
	Average Loss: 104.153304 	Learning rate: 0.001000
	Average Loss: 22.747265 	Learning rate: 0.001000
	Average Loss: 29.306089 	Learning rate: 0.001000
	Average Loss: 42.432945 	Learning rate: 0.001000
	Average Loss: 14.276404 	Learning rate: 0.001000
	Average Loss: 286.515749 	Learning rate: 0.001000
	Average Loss: 712.583538 	Learning rate: 0.001000
	Average Loss: 536.568741 	Learning rate: 0.001000
	Average Loss: 343.894724 	Learning rate: 0.000500
	Average Loss: 166.852997 	Learning rate: 0.000500
	Average Loss: 111.928305 	Learning rate: 0.000500
	Average Loss: 100.045232 	Learning rate: 0.000500
	Average Loss: 92.427112 	Learning rate: 0.000500
	Average Loss: 80.819885 	Learning rate: 0.000500
	Average Loss: 73.442439 	Learning rate: 0.000500
	Average Loss: 66.647311 	Learning rate: 0.000500
	Average Loss: 58.900851 	Learning rate: 0.000500
	Average Loss: 47.129283 	Learning rate: 0.000500
	Average Loss: 42.112546 	Learning rate: 0

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 486.110061 	Learning rate: 0.001000
	Average Loss: 398.972569 	Learning rate: 0.001000
	Average Loss: 10.376561 	Learning rate: 0.001000
	Average Loss: 0.211402 	Learning rate: 0.001000
	Average Loss: 0.229970 	Learning rate: 0.001000
	Average Loss: 0.181952 	Learning rate: 0.001000
	Average Loss: 0.167455 	Learning rate: 0.001000
	Average Loss: 0.160007 	Learning rate: 0.001000
	Average Loss: 0.155773 	Learning rate: 0.001000
	Average Loss: 0.173615 	Learning rate: 0.000500
	Average Loss: 0.156610 	Learning rate: 0.000500
	Average Loss: 0.156802 	Learning rate: 0.000500
	Average Loss: 0.153296 	Learning rate: 0.000500
	Average Loss: 0.210436 	Learning rate: 0.000500
	Average Loss: 0.165139 	Learning rate: 0.000500
	Average Loss: 0.152894 	Learning rate: 0.000500
	Average Loss: 0.151129 	Learning rate: 0.000500
	Average Loss: 0.152522 	Learning rate: 0.000500
	Average Loss: 0.150396 	Learning rate: 0.000500
	Average Loss: 0.152221 	Learning rate: 0.000250
	Average Loss: 

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 74.177166 	Learning rate: 0.001000
	Average Loss: 58.231340 	Learning rate: 0.001000
	Average Loss: 34.672745 	Learning rate: 0.001000
	Average Loss: 26.688590 	Learning rate: 0.001000
	Average Loss: 19.504710 	Learning rate: 0.001000
	Average Loss: 13.190866 	Learning rate: 0.001000
	Average Loss: 1.097837 	Learning rate: 0.001000
	Average Loss: 0.775897 	Learning rate: 0.001000
	Average Loss: 0.311124 	Learning rate: 0.001000
	Average Loss: 0.300673 	Learning rate: 0.000500
	Average Loss: 0.298963 	Learning rate: 0.000500
	Average Loss: 0.297848 	Learning rate: 0.000500
	Average Loss: 0.297654 	Learning rate: 0.000500
	Average Loss: 0.296647 	Learning rate: 0.000500
	Average Loss: 0.296620 	Learning rate: 0.000500
	Average Loss: 0.295640 	Learning rate: 0.000500
	Average Loss: 0.295195 	Learning rate: 0.000500
	Average Loss: 0.294429 	Learning rate: 0.000500
	Average Loss: 0.293281 	Learning rate: 0.000500
	Average Loss: 0.293738 	Learning rate: 0.000250
	Average Loss:

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 77.575466 	Learning rate: 0.001000
	Average Loss: 107.021661 	Learning rate: 0.001000
	Average Loss: 25.351369 	Learning rate: 0.001000
	Average Loss: 46.052911 	Learning rate: 0.001000
	Average Loss: 16.382259 	Learning rate: 0.001000
	Average Loss: 51.996259 	Learning rate: 0.001000
	Average Loss: 46.735512 	Learning rate: 0.001000
	Average Loss: 44.475905 	Learning rate: 0.001000
	Average Loss: 40.595642 	Learning rate: 0.001000
	Average Loss: 113.588576 	Learning rate: 0.000500
	Average Loss: 171.120546 	Learning rate: 0.000500
	Average Loss: 120.182841 	Learning rate: 0.000500
	Average Loss: 74.216461 	Learning rate: 0.000500
	Average Loss: 67.340764 	Learning rate: 0.000500
	Average Loss: 63.405377 	Learning rate: 0.000500
	Average Loss: 58.711430 	Learning rate: 0.000500
	Average Loss: 32.671251 	Learning rate: 0.000500
	Average Loss: 21.736968 	Learning rate: 0.000500
	Average Loss: 17.166390 	Learning rate: 0.000500
	Average Loss: 16.434923 	Learning rate: 0.000

In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_roc_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_roc_micro]) }")

Mean	: 0.6829784512519836
 SD 	: 0.09357023239135742


In [None]:
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_pr_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_pr_micro]) }")

Mean	: 0.4856565594673157
 SD 	: 0.12486594915390015


In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x for x in all_auc_roc_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_roc_macro]) }")

Mean	: 0.5908076456698932
 SD 	: 0.05749029500469532


In [None]:
print(f"Mean\t: {np.mean([x for x in all_auc_pr_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_pr_macro]) }")

Mean	: 0.36621746913971087
 SD 	: 0.048557828012267525
