# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7f7e614dafd0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)

driams.loading_type = "memory"

driams



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
import copy
from maldi2resistance.model.MultilabelResMLP import MultilabelResMLP

model = MultilabelResMLP( input_dim= 18000, output_dim= len(driams.selected_antibiotics), hidden_dim=256)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                        Param #
MultilabelResMLP                              --
├─Linear: 1-1                                 4,608,256
├─ResMLP: 1-2                                 --
│    └─Sequential: 2-1                        --
│    │    └─ResBlock: 3-1                     66,304
│    │    └─ResBlock: 3-2                     66,304
│    │    └─ResBlock: 3-3                     66,304
│    │    └─ResBlock: 3-4                     66,304
│    │    └─ResBlock: 3-5                     66,304
│    │    └─Linear: 3-6                       9,766
Total params: 4,949,542
Trainable params: 4,949,542
Non-trainable params: 0


In [None]:
from torch.utils.data import DataLoader

In [None]:
from maldi2resistance.loss.asymmetricLoss import AsymmetricLoss
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

class_weights_negative = torch.tensor((1 - (driams.label_stats.loc["negative"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)
class_weights_positive = torch.tensor((1 - (driams.label_stats.loc["positive"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)

criterion = AsymmetricLoss()

gen = torch.Generator()

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30)):
        overall_loss = 0
        
        for batch_idx, (x, y) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x)
        
            loss = criterion(output, y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    model = model.to(DEVICE)
    
    ml_roc = MultiLabelRocNan()
    output = model(test_features)
    
    
    ml_roc.compute(output,test_labels,driams.selected_antibiotics, create_csv=f"./kfold/csv/fold-{fold}_ROC.csv")
    fig_, ax_ = ml_roc()
    
    plt.savefig(fig_path / f"fold-{fold}_ROC.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    ml_pr = MultiLabelPRNan()
    ml_pr.compute(output,test_labels,driams.selected_antibiotics, create_csv=f"./kfold/csv/fold-{fold}_PrecisionRecall.csv")
    
    fig_, ax_ = ml_pr()
    
    plt.savefig(fig_path / f"fold-{fold}_PrecisionRecall.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)

Start training ...


  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 1.094219 	Learning rate: 0.001000
	Average Loss: 0.802585 	Learning rate: 0.001000
	Average Loss: 0.753606 	Learning rate: 0.001000
	Average Loss: 0.717094 	Learning rate: 0.001000
	Average Loss: 0.683592 	Learning rate: 0.001000
	Average Loss: 0.657317 	Learning rate: 0.001000
	Average Loss: 0.632067 	Learning rate: 0.001000
	Average Loss: 0.612966 	Learning rate: 0.001000
	Average Loss: 0.589904 	Learning rate: 0.001000
	Average Loss: 0.569795 	Learning rate: 0.000500
	Average Loss: 0.508794 	Learning rate: 0.000500
	Average Loss: 0.492212 	Learning rate: 0.000500
	Average Loss: 0.469805 	Learning rate: 0.000500
	Average Loss: 0.457860 	Learning rate: 0.000500
	Average Loss: 0.444729 	Learning rate: 0.000500
	Average Loss: 0.429302 	Learning rate: 0.000500
	Average Loss: 0.419763 	Learning rate: 0.000500
	Average Loss: 0.407271 	Learning rate: 0.000500
	Average Loss: 0.400664 	Learning rate: 0.000500
	Average Loss: 0.388390 	Learning rate: 0.000250
	Average Loss: 0.353

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 1.098228 	Learning rate: 0.001000
	Average Loss: 0.806557 	Learning rate: 0.001000
	Average Loss: 0.754883 	Learning rate: 0.001000
	Average Loss: 0.723656 	Learning rate: 0.001000
	Average Loss: 0.688899 	Learning rate: 0.001000
	Average Loss: 0.665098 	Learning rate: 0.001000
	Average Loss: 0.636128 	Learning rate: 0.001000
	Average Loss: 0.610860 	Learning rate: 0.001000
	Average Loss: 0.588265 	Learning rate: 0.001000
	Average Loss: 0.569548 	Learning rate: 0.000500
	Average Loss: 0.511250 	Learning rate: 0.000500
	Average Loss: 0.491688 	Learning rate: 0.000500
	Average Loss: 0.475398 	Learning rate: 0.000500
	Average Loss: 0.462153 	Learning rate: 0.000500
	Average Loss: 0.452408 	Learning rate: 0.000500
	Average Loss: 0.438925 	Learning rate: 0.000500
	Average Loss: 0.424690 	Learning rate: 0.000500
	Average Loss: 0.416789 	Learning rate: 0.000500
	Average Loss: 0.405170 	Learning rate: 0.000500
	Average Loss: 0.396150 	Learning rate: 0.000250
	Average Loss: 0.361

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 1.092171 	Learning rate: 0.001000
	Average Loss: 0.805152 	Learning rate: 0.001000
	Average Loss: 0.753743 	Learning rate: 0.001000
	Average Loss: 0.717670 	Learning rate: 0.001000
	Average Loss: 0.687805 	Learning rate: 0.001000
	Average Loss: 0.666397 	Learning rate: 0.001000
	Average Loss: 0.638857 	Learning rate: 0.001000
	Average Loss: 0.613550 	Learning rate: 0.001000
	Average Loss: 0.594861 	Learning rate: 0.001000
	Average Loss: 0.568728 	Learning rate: 0.000500
	Average Loss: 0.513556 	Learning rate: 0.000500
	Average Loss: 0.493872 	Learning rate: 0.000500
	Average Loss: 0.477132 	Learning rate: 0.000500
	Average Loss: 0.467412 	Learning rate: 0.000500
	Average Loss: 0.453913 	Learning rate: 0.000500
	Average Loss: 0.439765 	Learning rate: 0.000500
	Average Loss: 0.429556 	Learning rate: 0.000500
	Average Loss: 0.423044 	Learning rate: 0.000500
	Average Loss: 0.407619 	Learning rate: 0.000500
	Average Loss: 0.399938 	Learning rate: 0.000250
	Average Loss: 0.366

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 1.095144 	Learning rate: 0.001000
	Average Loss: 0.798912 	Learning rate: 0.001000
	Average Loss: 0.751827 	Learning rate: 0.001000
	Average Loss: 0.717684 	Learning rate: 0.001000
	Average Loss: 0.687108 	Learning rate: 0.001000
	Average Loss: 0.655264 	Learning rate: 0.001000
	Average Loss: 0.630236 	Learning rate: 0.001000
	Average Loss: 0.610595 	Learning rate: 0.001000
	Average Loss: 0.590243 	Learning rate: 0.001000
	Average Loss: 0.566199 	Learning rate: 0.000500
	Average Loss: 0.513150 	Learning rate: 0.000500
	Average Loss: 0.486880 	Learning rate: 0.000500
	Average Loss: 0.472080 	Learning rate: 0.000500
	Average Loss: 0.461586 	Learning rate: 0.000500
	Average Loss: 0.444641 	Learning rate: 0.000500
	Average Loss: 0.430437 	Learning rate: 0.000500
	Average Loss: 0.419644 	Learning rate: 0.000500
	Average Loss: 0.409540 	Learning rate: 0.000500
	Average Loss: 0.396813 	Learning rate: 0.000500
	Average Loss: 0.393307 	Learning rate: 0.000250
	Average Loss: 0.352

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 1.087677 	Learning rate: 0.001000
	Average Loss: 0.801259 	Learning rate: 0.001000
	Average Loss: 0.755974 	Learning rate: 0.001000
	Average Loss: 0.719177 	Learning rate: 0.001000
	Average Loss: 0.688713 	Learning rate: 0.001000
	Average Loss: 0.667648 	Learning rate: 0.001000
	Average Loss: 0.637471 	Learning rate: 0.001000
	Average Loss: 0.618134 	Learning rate: 0.001000
	Average Loss: 0.584078 	Learning rate: 0.001000
	Average Loss: 0.568627 	Learning rate: 0.000500
	Average Loss: 0.513555 	Learning rate: 0.000500
	Average Loss: 0.487214 	Learning rate: 0.000500
	Average Loss: 0.475739 	Learning rate: 0.000500
	Average Loss: 0.459793 	Learning rate: 0.000500
	Average Loss: 0.449749 	Learning rate: 0.000500
	Average Loss: 0.436355 	Learning rate: 0.000500
	Average Loss: 0.420212 	Learning rate: 0.000500
	Average Loss: 0.410846 	Learning rate: 0.000500
	Average Loss: 0.402833 	Learning rate: 0.000500
	Average Loss: 0.390166 	Learning rate: 0.000250
	Average Loss: 0.358

In [None]:
import pandas
micro = []
macro = []

for fold in range(0,5):
    csv = pandas.read_csv(f"./kfold/csv/fold-{fold}_ROC.csv")
    micro.append(csv[csv["class"] == "micro"]["ROCAUC"])
    macro.append(csv[csv["class"] == "macro"]["ROCAUC"])

In [None]:
import numpy as np

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")

Mean	: 0.9263352870941162
 SD 	: 0.001802019512499272


In [None]:
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.8968129201939232
 SD 	: 0.002148241864872247


In [None]:
import pandas
micro = []
macro = []

for fold in range(0,5):
    csv = pandas.read_csv(f"./kfold/csv/fold-{fold}_PrecisionRecall.csv")
    micro.append(csv[csv["class"] == "micro"]["PrecisionRecallAUC"])
    macro.append(csv[csv["class"] == "macro"]["PrecisionRecallAUC"])

In [None]:
import numpy as np

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")

Mean	: 0.8506577134132385
 SD 	: 0.005357928484170269


In [None]:
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.7626372038533813
 SD 	: 0.0058476071452286905
