# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7fcd00326fb0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)

driams.loading_type = "memory"

driams



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
import copy
from maldi2resistance.model.singleBranchMlp import SingleBranchMLP

model = SingleBranchMLP( input_dim= 18000, output_dim= len(driams.selected_antibiotics))
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torchinfo import summary

print(summary(model, input_data=torch.zeros((1,18000)).to(DEVICE)))

Layer (type:depth-idx)                   Output Shape              Param #
SingleBranchMLP                          [1, 38]                   --
├─Sequential: 1-1                        [1, 38]                   --
│    └─Linear: 2-1                       [1, 512]                  9,216,512
│    └─GELU: 2-2                         [1, 512]                  --
│    └─Dropout: 2-3                      [1, 512]                  --
│    └─LayerNorm: 2-4                    [1, 512]                  1,024
│    └─Linear: 2-5                       [1, 256]                  131,328
│    └─GELU: 2-6                         [1, 256]                  --
│    └─Dropout: 2-7                      [1, 256]                  --
│    └─LayerNorm: 2-8                    [1, 256]                  512
│    └─Linear: 2-9                       [1, 128]                  32,896
│    └─GELU: 2-10                        [1, 128]                  --
│    └─Dropout: 2-11                     [1, 128]                

In [None]:
from torch.utils.data import DataLoader

In [None]:
from maldi2resistance.loss.asymmetricLoss import AsymmetricLoss
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

class_weights_negative = torch.tensor((1 - (driams.label_stats.loc["negative"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)
class_weights_positive = torch.tensor((1 - (driams.label_stats.loc["positive"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)

criterion = AsymmetricLoss()

gen = torch.Generator()

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30)):
        overall_loss = 0
        
        for batch_idx, (x, y) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x)
        
            loss = criterion(output, y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    model = model.to(DEVICE)
    
    ml_roc = MultiLabelRocNan()
    output = model(test_features)
    
    
    ml_roc.compute(output,test_labels,driams.selected_antibiotics, create_csv=f"./kfold/csv/fold-{fold}_ROC.csv")
    fig_, ax_ = ml_roc()
    
    plt.savefig(fig_path / f"fold-{fold}_ROC.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    ml_pr = MultiLabelPRNan()
    ml_pr.compute(output,test_labels,driams.selected_antibiotics, create_csv=f"./kfold/csv/fold-{fold}_PrecisionRecall.csv")
    
    fig_, ax_ = ml_pr()
    
    plt.savefig(fig_path / f"fold-{fold}_PrecisionRecall.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)

Start training ...


  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.966885 	Learning rate: 0.001000
	Average Loss: 0.813091 	Learning rate: 0.001000
	Average Loss: 0.774205 	Learning rate: 0.001000
	Average Loss: 0.744872 	Learning rate: 0.001000
	Average Loss: 0.716957 	Learning rate: 0.001000
	Average Loss: 0.696193 	Learning rate: 0.001000
	Average Loss: 0.680086 	Learning rate: 0.001000
	Average Loss: 0.666550 	Learning rate: 0.001000
	Average Loss: 0.660815 	Learning rate: 0.001000
	Average Loss: 0.640449 	Learning rate: 0.000500
	Average Loss: 0.606576 	Learning rate: 0.000500
	Average Loss: 0.596344 	Learning rate: 0.000500
	Average Loss: 0.582541 	Learning rate: 0.000500
	Average Loss: 0.579889 	Learning rate: 0.000500
	Average Loss: 0.571993 	Learning rate: 0.000500
	Average Loss: 0.564334 	Learning rate: 0.000500
	Average Loss: 0.558652 	Learning rate: 0.000500
	Average Loss: 0.556793 	Learning rate: 0.000500
	Average Loss: 0.545536 	Learning rate: 0.000500
	Average Loss: 0.539553 	Learning rate: 0.000250
	Average Loss: 0.510

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.964183 	Learning rate: 0.001000
	Average Loss: 0.811396 	Learning rate: 0.001000
	Average Loss: 0.766032 	Learning rate: 0.001000
	Average Loss: 0.734811 	Learning rate: 0.001000
	Average Loss: 0.719692 	Learning rate: 0.001000
	Average Loss: 0.697253 	Learning rate: 0.001000
	Average Loss: 0.689385 	Learning rate: 0.001000
	Average Loss: 0.669050 	Learning rate: 0.001000
	Average Loss: 0.657856 	Learning rate: 0.001000
	Average Loss: 0.644109 	Learning rate: 0.000500
	Average Loss: 0.611141 	Learning rate: 0.000500
	Average Loss: 0.598690 	Learning rate: 0.000500
	Average Loss: 0.589562 	Learning rate: 0.000500
	Average Loss: 0.576896 	Learning rate: 0.000500
	Average Loss: 0.575818 	Learning rate: 0.000500
	Average Loss: 0.569853 	Learning rate: 0.000500
	Average Loss: 0.560876 	Learning rate: 0.000500
	Average Loss: 0.547750 	Learning rate: 0.000500
	Average Loss: 0.549703 	Learning rate: 0.000500
	Average Loss: 0.542982 	Learning rate: 0.000250
	Average Loss: 0.515

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.959463 	Learning rate: 0.001000
	Average Loss: 0.810404 	Learning rate: 0.001000
	Average Loss: 0.771605 	Learning rate: 0.001000
	Average Loss: 0.739027 	Learning rate: 0.001000
	Average Loss: 0.719869 	Learning rate: 0.001000
	Average Loss: 0.705216 	Learning rate: 0.001000
	Average Loss: 0.685003 	Learning rate: 0.001000
	Average Loss: 0.672421 	Learning rate: 0.001000
	Average Loss: 0.659272 	Learning rate: 0.001000
	Average Loss: 0.647694 	Learning rate: 0.000500
	Average Loss: 0.611573 	Learning rate: 0.000500
	Average Loss: 0.596812 	Learning rate: 0.000500
	Average Loss: 0.589870 	Learning rate: 0.000500
	Average Loss: 0.582205 	Learning rate: 0.000500
	Average Loss: 0.572366 	Learning rate: 0.000500
	Average Loss: 0.569149 	Learning rate: 0.000500
	Average Loss: 0.559153 	Learning rate: 0.000500
	Average Loss: 0.552231 	Learning rate: 0.000500
	Average Loss: 0.546557 	Learning rate: 0.000500
	Average Loss: 0.544498 	Learning rate: 0.000250
	Average Loss: 0.512

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.962802 	Learning rate: 0.001000
	Average Loss: 0.808297 	Learning rate: 0.001000
	Average Loss: 0.770110 	Learning rate: 0.001000
	Average Loss: 0.737073 	Learning rate: 0.001000
	Average Loss: 0.715456 	Learning rate: 0.001000
	Average Loss: 0.695935 	Learning rate: 0.001000
	Average Loss: 0.676981 	Learning rate: 0.001000
	Average Loss: 0.670249 	Learning rate: 0.001000
	Average Loss: 0.653473 	Learning rate: 0.001000
	Average Loss: 0.650115 	Learning rate: 0.000500
	Average Loss: 0.607648 	Learning rate: 0.000500
	Average Loss: 0.594189 	Learning rate: 0.000500
	Average Loss: 0.589106 	Learning rate: 0.000500
	Average Loss: 0.582350 	Learning rate: 0.000500
	Average Loss: 0.571027 	Learning rate: 0.000500
	Average Loss: 0.571032 	Learning rate: 0.000500
	Average Loss: 0.562945 	Learning rate: 0.000500
	Average Loss: 0.552683 	Learning rate: 0.000500
	Average Loss: 0.546921 	Learning rate: 0.000500
	Average Loss: 0.537419 	Learning rate: 0.000250
	Average Loss: 0.516

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.963313 	Learning rate: 0.001000
	Average Loss: 0.815284 	Learning rate: 0.001000
	Average Loss: 0.770452 	Learning rate: 0.001000
	Average Loss: 0.738039 	Learning rate: 0.001000
	Average Loss: 0.719154 	Learning rate: 0.001000
	Average Loss: 0.700497 	Learning rate: 0.001000
	Average Loss: 0.689994 	Learning rate: 0.001000
	Average Loss: 0.668365 	Learning rate: 0.001000
	Average Loss: 0.663919 	Learning rate: 0.001000
	Average Loss: 0.648544 	Learning rate: 0.000500
	Average Loss: 0.610932 	Learning rate: 0.000500
	Average Loss: 0.598699 	Learning rate: 0.000500
	Average Loss: 0.588973 	Learning rate: 0.000500
	Average Loss: 0.578308 	Learning rate: 0.000500
	Average Loss: 0.576338 	Learning rate: 0.000500
	Average Loss: 0.570048 	Learning rate: 0.000500
	Average Loss: 0.559760 	Learning rate: 0.000500
	Average Loss: 0.556271 	Learning rate: 0.000500
	Average Loss: 0.549930 	Learning rate: 0.000500
	Average Loss: 0.541579 	Learning rate: 0.000250
	Average Loss: 0.513

In [None]:
import pandas
micro = []
macro = []

for fold in range(0,5):
    csv = pandas.read_csv(f"./kfold/csv/fold-{fold}_ROC.csv")
    micro.append(csv[csv["class"] == "micro"]["ROCAUC"])
    macro.append(csv[csv["class"] == "macro"]["ROCAUC"])

In [None]:
import numpy as np

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")

Mean	: 0.9326603293418885
 SD 	: 0.0022185096511297725


In [None]:
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.8972347617149353
 SD 	: 0.0022192076993228895


In [None]:
import pandas
micro = []
macro = []

for fold in range(0,5):
    csv = pandas.read_csv(f"./kfold/csv/fold-{fold}_PrecisionRecall.csv")
    micro.append(csv[csv["class"] == "micro"]["PrecisionRecallAUC"])
    macro.append(csv[csv["class"] == "macro"]["PrecisionRecallAUC"])

In [None]:
import numpy as np

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")

Mean	: 0.865790331363678
 SD 	: 0.00572370106873237


In [None]:
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.7588865059771035
 SD 	: 0.005404624405759706
