# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7ff6e03aefb0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)
umg = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
    bin_size=1,
    sites=["UMG"],
    years=[2020,2021],
    antibiotics=driams.selected_antibiotics,
)

driams.loading_type = "memory"
umg.loading_type = "memory"

driams



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/73745 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
import copy
from maldi2resistance.model.MultilabelResMLP import MultilabelResMLP

model = MultilabelResMLP( input_dim= 18000, output_dim= len(driams.selected_antibiotics), hidden_dim=256)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                        Param #
MultilabelResMLP                              --
├─Linear: 1-1                                 4,608,256
├─ResMLP: 1-2                                 --
│    └─Sequential: 2-1                        --
│    │    └─ResBlock: 3-1                     66,304
│    │    └─ResBlock: 3-2                     66,304
│    │    └─ResBlock: 3-3                     66,304
│    │    └─ResBlock: 3-4                     66,304
│    │    └─ResBlock: 3-5                     66,304
│    │    └─Linear: 3-6                       9,766
Total params: 4,949,542
Trainable params: 4,949,542
Non-trainable params: 0


In [None]:
from torch.utils.data import DataLoader

In [None]:
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

class_weights_negative = torch.tensor((1 - (driams.label_stats.loc["negative"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)
class_weights_positive = torch.tensor((1 - (driams.label_stats.loc["positive"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)

criterion = MaskedBCE()

gen = torch.Generator()

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_data_umg = torch.utils.data.ConcatDataset([train_data, umg])
    train_loader = DataLoader(train_data_umg, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30)):
        overall_loss = 0
        
        for batch_idx, (x, y) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x)
        
            loss = criterion(output, y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    model = model.to(DEVICE)
    
    ml_roc = MultiLabelRocNan()
    output = model(test_features)
    
    
    ml_roc.compute(output,test_labels,driams.selected_antibiotics, create_csv=f"./kfold/csv/fold-{fold}_ROC.csv")
    fig_, ax_ = ml_roc()
    
    plt.savefig(fig_path / f"fold-{fold}_ROC.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    ml_pr = MultiLabelPRNan()
    ml_pr.compute(output,test_labels,driams.selected_antibiotics, create_csv=f"./kfold/csv/fold-{fold}_PrecisionRecall.csv")
    
    fig_, ax_ = ml_pr()
    
    plt.savefig(fig_path / f"fold-{fold}_PrecisionRecall.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)

Start training ...


  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000703 	Learning rate: 0.001000
	Average Loss: 0.000579 	Learning rate: 0.001000
	Average Loss: 0.000543 	Learning rate: 0.001000
	Average Loss: 0.000522 	Learning rate: 0.001000
	Average Loss: 0.000503 	Learning rate: 0.001000
	Average Loss: 0.000489 	Learning rate: 0.001000
	Average Loss: 0.000474 	Learning rate: 0.001000
	Average Loss: 0.000461 	Learning rate: 0.001000
	Average Loss: 0.000451 	Learning rate: 0.001000
	Average Loss: 0.000439 	Learning rate: 0.000500
	Average Loss: 0.000407 	Learning rate: 0.000500
	Average Loss: 0.000398 	Learning rate: 0.000500
	Average Loss: 0.000389 	Learning rate: 0.000500
	Average Loss: 0.000383 	Learning rate: 0.000500
	Average Loss: 0.000375 	Learning rate: 0.000500
	Average Loss: 0.000371 	Learning rate: 0.000500
	Average Loss: 0.000364 	Learning rate: 0.000500
	Average Loss: 0.000358 	Learning rate: 0.000500
	Average Loss: 0.000352 	Learning rate: 0.000500
	Average Loss: 0.000349 	Learning rate: 0.000250
	Average Loss: 0.000

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000703 	Learning rate: 0.001000
	Average Loss: 0.000579 	Learning rate: 0.001000
	Average Loss: 0.000545 	Learning rate: 0.001000
	Average Loss: 0.000523 	Learning rate: 0.001000
	Average Loss: 0.000504 	Learning rate: 0.001000
	Average Loss: 0.000489 	Learning rate: 0.001000
	Average Loss: 0.000476 	Learning rate: 0.001000
	Average Loss: 0.000464 	Learning rate: 0.001000
	Average Loss: 0.000453 	Learning rate: 0.001000
	Average Loss: 0.000442 	Learning rate: 0.000500
	Average Loss: 0.000410 	Learning rate: 0.000500
	Average Loss: 0.000400 	Learning rate: 0.000500
	Average Loss: 0.000392 	Learning rate: 0.000500
	Average Loss: 0.000385 	Learning rate: 0.000500
	Average Loss: 0.000377 	Learning rate: 0.000500
	Average Loss: 0.000371 	Learning rate: 0.000500
	Average Loss: 0.000365 	Learning rate: 0.000500
	Average Loss: 0.000359 	Learning rate: 0.000500
	Average Loss: 0.000353 	Learning rate: 0.000500
	Average Loss: 0.000349 	Learning rate: 0.000250
	Average Loss: 0.000

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000703 	Learning rate: 0.001000
	Average Loss: 0.000578 	Learning rate: 0.001000
	Average Loss: 0.000545 	Learning rate: 0.001000
	Average Loss: 0.000523 	Learning rate: 0.001000
	Average Loss: 0.000504 	Learning rate: 0.001000
	Average Loss: 0.000489 	Learning rate: 0.001000
	Average Loss: 0.000476 	Learning rate: 0.001000
	Average Loss: 0.000464 	Learning rate: 0.001000
	Average Loss: 0.000451 	Learning rate: 0.001000
	Average Loss: 0.000442 	Learning rate: 0.000500
	Average Loss: 0.000407 	Learning rate: 0.000500
	Average Loss: 0.000399 	Learning rate: 0.000500
	Average Loss: 0.000391 	Learning rate: 0.000500
	Average Loss: 0.000382 	Learning rate: 0.000500
	Average Loss: 0.000375 	Learning rate: 0.000500
	Average Loss: 0.000370 	Learning rate: 0.000500
	Average Loss: 0.000364 	Learning rate: 0.000500
	Average Loss: 0.000356 	Learning rate: 0.000500
	Average Loss: 0.000350 	Learning rate: 0.000500
	Average Loss: 0.000347 	Learning rate: 0.000250
	Average Loss: 0.000

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000701 	Learning rate: 0.001000
	Average Loss: 0.000580 	Learning rate: 0.001000
	Average Loss: 0.000546 	Learning rate: 0.001000
	Average Loss: 0.000523 	Learning rate: 0.001000
	Average Loss: 0.000504 	Learning rate: 0.001000
	Average Loss: 0.000490 	Learning rate: 0.001000
	Average Loss: 0.000478 	Learning rate: 0.001000
	Average Loss: 0.000463 	Learning rate: 0.001000
	Average Loss: 0.000452 	Learning rate: 0.001000
	Average Loss: 0.000443 	Learning rate: 0.000500
	Average Loss: 0.000410 	Learning rate: 0.000500
	Average Loss: 0.000400 	Learning rate: 0.000500
	Average Loss: 0.000392 	Learning rate: 0.000500
	Average Loss: 0.000385 	Learning rate: 0.000500
	Average Loss: 0.000377 	Learning rate: 0.000500
	Average Loss: 0.000371 	Learning rate: 0.000500
	Average Loss: 0.000364 	Learning rate: 0.000500
	Average Loss: 0.000359 	Learning rate: 0.000500
	Average Loss: 0.000355 	Learning rate: 0.000500
	Average Loss: 0.000348 	Learning rate: 0.000250
	Average Loss: 0.000

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000707 	Learning rate: 0.001000
	Average Loss: 0.000580 	Learning rate: 0.001000
	Average Loss: 0.000547 	Learning rate: 0.001000
	Average Loss: 0.000525 	Learning rate: 0.001000
	Average Loss: 0.000507 	Learning rate: 0.001000
	Average Loss: 0.000491 	Learning rate: 0.001000
	Average Loss: 0.000478 	Learning rate: 0.001000
	Average Loss: 0.000466 	Learning rate: 0.001000
	Average Loss: 0.000453 	Learning rate: 0.001000
	Average Loss: 0.000444 	Learning rate: 0.000500
	Average Loss: 0.000411 	Learning rate: 0.000500
	Average Loss: 0.000401 	Learning rate: 0.000500
	Average Loss: 0.000394 	Learning rate: 0.000500
	Average Loss: 0.000386 	Learning rate: 0.000500
	Average Loss: 0.000379 	Learning rate: 0.000500
	Average Loss: 0.000373 	Learning rate: 0.000500
	Average Loss: 0.000368 	Learning rate: 0.000500
	Average Loss: 0.000362 	Learning rate: 0.000500
	Average Loss: 0.000356 	Learning rate: 0.000500
	Average Loss: 0.000351 	Learning rate: 0.000250
	Average Loss: 0.000

In [None]:
import pandas
micro = []
macro = []

for fold in range(0,5):
    csv = pandas.read_csv(f"./kfold/csv/fold-{fold}_ROC.csv")
    micro.append(csv[csv["class"] == "micro"]["ROCAUC"])
    macro.append(csv[csv["class"] == "macro"]["ROCAUC"])

In [None]:
import numpy as np

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")

Mean	: 0.9277553915977478
 SD 	: 0.0016598027579010678


In [None]:
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.8965143812330145
 SD 	: 0.0022145633294949176


In [None]:
import pandas
micro = []
macro = []

for fold in range(0,5):
    csv = pandas.read_csv(f"./kfold/csv/fold-{fold}_PrecisionRecall.csv")
    micro.append(csv[csv["class"] == "micro"]["PrecisionRecallAUC"])
    macro.append(csv[csv["class"] == "macro"]["PrecisionRecallAUC"])

In [None]:
import numpy as np

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")

Mean	: 0.8562406182289124
 SD 	: 0.004599530274893811


In [None]:
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.7589248250973852
 SD 	: 0.005488268265847772


In [None]:
umg.label_stats

Unnamed: 0,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,...,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
positive,3765,4,27294,0,2,0,0,5287,4,11877,...,0,5937,62,13379,0,529,88,3577,3768,1933
negative,2199,27,21207,0,3,0,0,60,37,26028,...,0,10743,393,24634,0,16094,154,14045,2197,31047
n_sum,5964,31,48501,0,5,0,0,5347,41,37905,...,0,16680,455,38013,0,16623,242,17622,5965,32980
