# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7fde5d9eafb0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)

driams.loading_type = "memory"

driams



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
from maldi2resistance.model.MLP import AeBasedMLP
import copy

model = AeBasedMLP(input_dim=18000, output_dim=len(driams.selected_antibiotics), hidden_dim=4096, latent_dim=2048)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                   Param #
AeBasedMLP                               --
├─Encoder: 1-1                           --
│    └─Linear: 2-1                       73,732,096
│    └─Linear: 2-2                       16,781,312
│    └─Linear: 2-3                       8,390,656
│    └─LeakyReLU: 2-4                    --
├─Decoder: 1-2                           --
│    └─Linear: 2-5                       8,392,704
│    └─Linear: 2-6                       16,781,312
│    └─Linear: 2-7                       155,686
│    └─LeakyReLU: 2-8                    --
Total params: 124,233,766
Trainable params: 124,233,766
Non-trainable params: 0


In [None]:
from torch.utils.data import DataLoader

In [None]:
from maldi2resistance.loss.asymmetricLoss import AsymmetricLoss
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

criterion = AsymmetricLoss()

gen = torch.Generator()

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30)):
        overall_loss = 0
        
        for batch_idx, (x, y) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x)
        
            loss = criterion(output, y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    model = model.to(DEVICE)
    
    ml_roc = MultiLabelRocNan()
    output = model(test_features)
    
    
    ml_roc.compute(output,test_labels,driams.selected_antibiotics, create_csv=f"./kfold/csv/fold-{fold}_ROC.csv")
    fig_, ax_ = ml_roc()
    
    plt.savefig(fig_path / f"fold-{fold}_ROC.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    ml_pr = MultiLabelPRNan()
    ml_pr.compute(output,test_labels,driams.selected_antibiotics, create_csv=f"./kfold/csv/fold-{fold}_PrecisionRecall.csv")
    
    fig_, ax_ = ml_pr()
    
    plt.savefig(fig_path / f"fold-{fold}_PrecisionRecall.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)

Start training ...


  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.953779 	Learning rate: 0.001000
	Average Loss: 0.814933 	Learning rate: 0.001000
	Average Loss: 0.766495 	Learning rate: 0.001000
	Average Loss: 0.733904 	Learning rate: 0.001000
	Average Loss: 0.700965 	Learning rate: 0.001000
	Average Loss: 0.686782 	Learning rate: 0.001000
	Average Loss: 0.656412 	Learning rate: 0.001000
	Average Loss: 0.643268 	Learning rate: 0.001000
	Average Loss: 0.621132 	Learning rate: 0.001000
	Average Loss: 0.601237 	Learning rate: 0.000500
	Average Loss: 0.539066 	Learning rate: 0.000500
	Average Loss: 0.518192 	Learning rate: 0.000500
	Average Loss: 0.496925 	Learning rate: 0.000500
	Average Loss: 0.490261 	Learning rate: 0.000500
	Average Loss: 0.472453 	Learning rate: 0.000500
	Average Loss: 0.449708 	Learning rate: 0.000500
	Average Loss: 0.440909 	Learning rate: 0.000500
	Average Loss: 0.415383 	Learning rate: 0.000500
	Average Loss: 0.406101 	Learning rate: 0.000500
	Average Loss: 0.393210 	Learning rate: 0.000250
	Average Loss: 0.325

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.947678 	Learning rate: 0.001000
	Average Loss: 0.810949 	Learning rate: 0.001000
	Average Loss: 0.763670 	Learning rate: 0.001000
	Average Loss: 0.719293 	Learning rate: 0.001000
	Average Loss: 0.699873 	Learning rate: 0.001000
	Average Loss: 0.678231 	Learning rate: 0.001000
	Average Loss: 0.656942 	Learning rate: 0.001000
	Average Loss: 0.631371 	Learning rate: 0.001000
	Average Loss: 0.617470 	Learning rate: 0.001000
	Average Loss: 0.609772 	Learning rate: 0.000500
	Average Loss: 0.541890 	Learning rate: 0.000500
	Average Loss: 0.524694 	Learning rate: 0.000500
	Average Loss: 0.502937 	Learning rate: 0.000500
	Average Loss: 0.486336 	Learning rate: 0.000500
	Average Loss: 0.469775 	Learning rate: 0.000500
	Average Loss: 0.458391 	Learning rate: 0.000500
	Average Loss: 0.435914 	Learning rate: 0.000500
	Average Loss: 0.424075 	Learning rate: 0.000500
	Average Loss: 0.398682 	Learning rate: 0.000500
	Average Loss: 0.392577 	Learning rate: 0.000250
	Average Loss: 0.326

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.947124 	Learning rate: 0.001000
	Average Loss: 0.812953 	Learning rate: 0.001000
	Average Loss: 0.769211 	Learning rate: 0.001000
	Average Loss: 0.732486 	Learning rate: 0.001000
	Average Loss: 0.706379 	Learning rate: 0.001000
	Average Loss: 0.682778 	Learning rate: 0.001000
	Average Loss: 0.662349 	Learning rate: 0.001000
	Average Loss: 0.639401 	Learning rate: 0.001000
	Average Loss: 0.620141 	Learning rate: 0.001000
	Average Loss: 0.605238 	Learning rate: 0.000500
	Average Loss: 0.538841 	Learning rate: 0.000500
	Average Loss: 0.521176 	Learning rate: 0.000500
	Average Loss: 0.504375 	Learning rate: 0.000500
	Average Loss: 0.486979 	Learning rate: 0.000500
	Average Loss: 0.473004 	Learning rate: 0.000500
	Average Loss: 0.458764 	Learning rate: 0.000500
	Average Loss: 0.443586 	Learning rate: 0.000500
	Average Loss: 0.424785 	Learning rate: 0.000500
	Average Loss: 0.416305 	Learning rate: 0.000500
	Average Loss: 0.401919 	Learning rate: 0.000250
	Average Loss: 0.333

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.938971 	Learning rate: 0.001000
	Average Loss: 0.817253 	Learning rate: 0.001000
	Average Loss: 0.771441 	Learning rate: 0.001000
	Average Loss: 0.730599 	Learning rate: 0.001000
	Average Loss: 0.708160 	Learning rate: 0.001000
	Average Loss: 0.683862 	Learning rate: 0.001000
	Average Loss: 0.659497 	Learning rate: 0.001000
	Average Loss: 0.640780 	Learning rate: 0.001000
	Average Loss: 0.620271 	Learning rate: 0.001000
	Average Loss: 0.607355 	Learning rate: 0.000500
	Average Loss: 0.545694 	Learning rate: 0.000500
	Average Loss: 0.522304 	Learning rate: 0.000500
	Average Loss: 0.509019 	Learning rate: 0.000500
	Average Loss: 0.497623 	Learning rate: 0.000500
	Average Loss: 0.476326 	Learning rate: 0.000500
	Average Loss: 0.460789 	Learning rate: 0.000500
	Average Loss: 0.446439 	Learning rate: 0.000500
	Average Loss: 0.430214 	Learning rate: 0.000500
	Average Loss: 0.411332 	Learning rate: 0.000500
	Average Loss: 0.400351 	Learning rate: 0.000250
	Average Loss: 0.334

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.936848 	Learning rate: 0.001000
	Average Loss: 0.812015 	Learning rate: 0.001000
	Average Loss: 0.771022 	Learning rate: 0.001000
	Average Loss: 0.734882 	Learning rate: 0.001000
	Average Loss: 0.709874 	Learning rate: 0.001000
	Average Loss: 0.683270 	Learning rate: 0.001000
	Average Loss: 0.665590 	Learning rate: 0.001000
	Average Loss: 0.638524 	Learning rate: 0.001000
	Average Loss: 0.619011 	Learning rate: 0.001000
	Average Loss: 0.604071 	Learning rate: 0.000500
	Average Loss: 0.543324 	Learning rate: 0.000500
	Average Loss: 0.523247 	Learning rate: 0.000500
	Average Loss: 0.510604 	Learning rate: 0.000500
	Average Loss: 0.492698 	Learning rate: 0.000500
	Average Loss: 0.475976 	Learning rate: 0.000500
	Average Loss: 0.453908 	Learning rate: 0.000500
	Average Loss: 0.441506 	Learning rate: 0.000500
	Average Loss: 0.427852 	Learning rate: 0.000500
	Average Loss: 0.415252 	Learning rate: 0.000500
	Average Loss: 0.401485 	Learning rate: 0.000250
	Average Loss: 0.333

In [None]:
import pandas
micro = []
macro = []

for fold in range(0,5):
    csv = pandas.read_csv(f"./kfold/csv/fold-{fold}_ROC.csv")
    micro.append(csv[csv["class"] == "micro"]["ROCAUC"])
    macro.append(csv[csv["class"] == "macro"]["ROCAUC"])

In [None]:
import numpy as np

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")

Mean	: 0.9301816463470459
 SD 	: 0.0008667747575244399


In [None]:
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.8994799805314917
 SD 	: 0.0023500148577302944


In [None]:
import pandas
micro = []
macro = []

for fold in range(0,5):
    csv = pandas.read_csv(f"./kfold/csv/fold-{fold}_PrecisionRecall.csv")
    micro.append(csv[csv["class"] == "micro"]["PrecisionRecallAUC"])
    macro.append(csv[csv["class"] == "macro"]["PrecisionRecallAUC"])

In [None]:
import numpy as np

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")

Mean	: 0.8684738755226136
 SD 	: 0.004694932920371596


In [None]:
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.782263079913039
 SD 	: 0.005540890263325638
