# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7f2ee00f2fb0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)

driams.loading_type = "memory"

driams

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
import copy
from maldi2resistance.model.MultilabelResMLP import MultilabelResMLP

model = MultilabelResMLP( input_dim= 6000, output_dim= len(driams.selected_antibiotics), hidden_dim=256)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                        Param #
MultilabelResMLP                              --
├─Linear: 1-1                                 1,536,256
├─ResMLP: 1-2                                 --
│    └─Sequential: 2-1                        --
│    │    └─ResBlock: 3-1                     66,304
│    │    └─ResBlock: 3-2                     66,304
│    │    └─ResBlock: 3-3                     66,304
│    │    └─ResBlock: 3-4                     66,304
│    │    └─ResBlock: 3-5                     66,304
│    │    └─Linear: 3-6                       9,766
Total params: 1,877,542
Trainable params: 1,877,542
Non-trainable params: 0


In [None]:
from torch.utils.data import DataLoader

In [None]:
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

class_weights_negative = torch.tensor((1 - (driams.label_stats.loc["negative"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)
class_weights_positive = torch.tensor((1 - (driams.label_stats.loc["positive"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)

criterion = MaskedBCE(class_weights_positive= class_weights_negative, class_weights_negative= class_weights_negative)

gen = torch.Generator()

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30)):
        overall_loss = 0
        
        for batch_idx, (x, y) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x)
        
            loss = criterion(output, y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    model = model.to(DEVICE)
    
    ml_roc = MultiLabelRocNan()
    output = model(test_features)
    
    
    ml_roc.compute(output,test_labels,driams.selected_antibiotics, create_csv=f"./kfold/csv/fold-{fold}_ROC.csv")
    fig_, ax_ = ml_roc()
    
    plt.savefig(fig_path / f"fold-{fold}_ROC.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    ml_pr = MultiLabelPRNan()
    ml_pr.compute(output,test_labels,driams.selected_antibiotics, create_csv=f"./kfold/csv/fold-{fold}_PrecisionRecall.csv")
    
    fig_, ax_ = ml_pr()
    
    plt.savefig(fig_path / f"fold-{fold}_PrecisionRecall.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)

Start training ...


  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000261 	Learning rate: 0.001000
	Average Loss: 0.000217 	Learning rate: 0.001000
	Average Loss: 0.000203 	Learning rate: 0.001000
	Average Loss: 0.000193 	Learning rate: 0.001000
	Average Loss: 0.000184 	Learning rate: 0.001000
	Average Loss: 0.000176 	Learning rate: 0.001000
	Average Loss: 0.000167 	Learning rate: 0.001000
	Average Loss: 0.000161 	Learning rate: 0.001000
	Average Loss: 0.000154 	Learning rate: 0.001000
	Average Loss: 0.000148 	Learning rate: 0.000500
	Average Loss: 0.000131 	Learning rate: 0.000500
	Average Loss: 0.000126 	Learning rate: 0.000500
	Average Loss: 0.000120 	Learning rate: 0.000500
	Average Loss: 0.000116 	Learning rate: 0.000500
	Average Loss: 0.000113 	Learning rate: 0.000500
	Average Loss: 0.000109 	Learning rate: 0.000500
	Average Loss: 0.000105 	Learning rate: 0.000500
	Average Loss: 0.000103 	Learning rate: 0.000500
	Average Loss: 0.000099 	Learning rate: 0.000500
	Average Loss: 0.000097 	Learning rate: 0.000250
	Average Loss: 0.000

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000260 	Learning rate: 0.001000
	Average Loss: 0.000216 	Learning rate: 0.001000
	Average Loss: 0.000203 	Learning rate: 0.001000
	Average Loss: 0.000192 	Learning rate: 0.001000
	Average Loss: 0.000183 	Learning rate: 0.001000
	Average Loss: 0.000177 	Learning rate: 0.001000
	Average Loss: 0.000169 	Learning rate: 0.001000
	Average Loss: 0.000162 	Learning rate: 0.001000
	Average Loss: 0.000155 	Learning rate: 0.001000
	Average Loss: 0.000149 	Learning rate: 0.000500
	Average Loss: 0.000132 	Learning rate: 0.000500
	Average Loss: 0.000127 	Learning rate: 0.000500
	Average Loss: 0.000122 	Learning rate: 0.000500
	Average Loss: 0.000119 	Learning rate: 0.000500
	Average Loss: 0.000116 	Learning rate: 0.000500
	Average Loss: 0.000112 	Learning rate: 0.000500
	Average Loss: 0.000108 	Learning rate: 0.000500
	Average Loss: 0.000105 	Learning rate: 0.000500
	Average Loss: 0.000104 	Learning rate: 0.000500
	Average Loss: 0.000100 	Learning rate: 0.000250
	Average Loss: 0.000

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000261 	Learning rate: 0.001000
	Average Loss: 0.000216 	Learning rate: 0.001000
	Average Loss: 0.000203 	Learning rate: 0.001000
	Average Loss: 0.000192 	Learning rate: 0.001000
	Average Loss: 0.000183 	Learning rate: 0.001000
	Average Loss: 0.000178 	Learning rate: 0.001000
	Average Loss: 0.000169 	Learning rate: 0.001000
	Average Loss: 0.000162 	Learning rate: 0.001000
	Average Loss: 0.000156 	Learning rate: 0.001000
	Average Loss: 0.000150 	Learning rate: 0.000500
	Average Loss: 0.000133 	Learning rate: 0.000500
	Average Loss: 0.000128 	Learning rate: 0.000500
	Average Loss: 0.000122 	Learning rate: 0.000500
	Average Loss: 0.000119 	Learning rate: 0.000500
	Average Loss: 0.000114 	Learning rate: 0.000500
	Average Loss: 0.000111 	Learning rate: 0.000500
	Average Loss: 0.000108 	Learning rate: 0.000500
	Average Loss: 0.000105 	Learning rate: 0.000500
	Average Loss: 0.000101 	Learning rate: 0.000500
	Average Loss: 0.000099 	Learning rate: 0.000250
	Average Loss: 0.000

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000259 	Learning rate: 0.001000
	Average Loss: 0.000217 	Learning rate: 0.001000
	Average Loss: 0.000202 	Learning rate: 0.001000
	Average Loss: 0.000192 	Learning rate: 0.001000
	Average Loss: 0.000183 	Learning rate: 0.001000
	Average Loss: 0.000174 	Learning rate: 0.001000
	Average Loss: 0.000169 	Learning rate: 0.001000
	Average Loss: 0.000164 	Learning rate: 0.001000
	Average Loss: 0.000155 	Learning rate: 0.001000
	Average Loss: 0.000147 	Learning rate: 0.000500
	Average Loss: 0.000132 	Learning rate: 0.000500
	Average Loss: 0.000126 	Learning rate: 0.000500
	Average Loss: 0.000122 	Learning rate: 0.000500
	Average Loss: 0.000117 	Learning rate: 0.000500
	Average Loss: 0.000114 	Learning rate: 0.000500
	Average Loss: 0.000110 	Learning rate: 0.000500
	Average Loss: 0.000108 	Learning rate: 0.000500
	Average Loss: 0.000105 	Learning rate: 0.000500
	Average Loss: 0.000100 	Learning rate: 0.000500
	Average Loss: 0.000100 	Learning rate: 0.000250
	Average Loss: 0.000

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000262 	Learning rate: 0.001000
	Average Loss: 0.000217 	Learning rate: 0.001000
	Average Loss: 0.000204 	Learning rate: 0.001000
	Average Loss: 0.000194 	Learning rate: 0.001000
	Average Loss: 0.000184 	Learning rate: 0.001000
	Average Loss: 0.000176 	Learning rate: 0.001000
	Average Loss: 0.000169 	Learning rate: 0.001000
	Average Loss: 0.000163 	Learning rate: 0.001000
	Average Loss: 0.000155 	Learning rate: 0.001000
	Average Loss: 0.000148 	Learning rate: 0.000500
	Average Loss: 0.000133 	Learning rate: 0.000500
	Average Loss: 0.000126 	Learning rate: 0.000500
	Average Loss: 0.000123 	Learning rate: 0.000500
	Average Loss: 0.000118 	Learning rate: 0.000500
	Average Loss: 0.000115 	Learning rate: 0.000500
	Average Loss: 0.000113 	Learning rate: 0.000500
	Average Loss: 0.000108 	Learning rate: 0.000500
	Average Loss: 0.000105 	Learning rate: 0.000500
	Average Loss: 0.000103 	Learning rate: 0.000500
	Average Loss: 0.000100 	Learning rate: 0.000250
	Average Loss: 0.000

In [None]:
import pandas
micro = []
macro = []

for fold in range(0,5):
    csv = pandas.read_csv(f"./kfold/csv/fold-{fold}_ROC.csv")
    micro.append(csv[csv["class"] == "micro"]["ROCAUC"])
    macro.append(csv[csv["class"] == "macro"]["ROCAUC"])

In [None]:
import numpy as np

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")

Mean	: 0.915527069568634
 SD 	: 0.0034874926685825387


In [None]:
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.8909191806065409
 SD 	: 0.002691388177127918


In [None]:
import pandas
micro = []
macro = []

for fold in range(0,5):
    csv = pandas.read_csv(f"./kfold/csv/fold-{fold}_PrecisionRecall.csv")
    micro.append(csv[csv["class"] == "micro"]["PrecisionRecallAUC"])
    macro.append(csv[csv["class"] == "macro"]["PrecisionRecallAUC"])

In [None]:
import numpy as np

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")

Mean	: 0.8436816096305847
 SD 	: 0.005116129942464538


In [None]:
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.7321679060396395
 SD 	: 0.0054336295732402146
