# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7f3a1c10afb0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)

driams.loading_type = "memory"

driams



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
from multimodal_amr.models.modules import ResMLP
from torch import nn


class Residual_AMR_Classifier(nn.Module):
    """Overall model definition."""

    def __init__(self, config, n_input_spectrum, n_input_drug):
        super().__init__()
        self.config = config

        # Sample embedding
        if config["species_embedding_dim"]==0 and config["conv_out_size"]==config["sample_embedding_dim"]:
            self.sample_emb = nn.Identity()
        elif config["species_embedding_dim"]==0:
            self.sample_emb = nn.Linear(
                config["conv_out_size"],
                config["sample_embedding_dim"],
            )
        else:
            self.sample_emb = nn.Linear(
                config["conv_out_size"] + config["species_embedding_dim"],
                config["sample_embedding_dim"],
            )

        self.spectrum_emb = nn.Linear(n_input_spectrum ,config["conv_out_size"])

        # Drugs layers
        # if config["drug_emb_type"] == "vae_embedding" or config["drug_emb_type"] == "gnn_embedding":
        #     self.drug_emb = nn.Identity()
        # elif config["drug_emb_type"] == "fingerprint":
        self.drug_emb = nn.Linear(
            n_input_drug, config["drug_embedding_dim"]
        )

        # Output network
        self.net = ResMLP(
            config["n_hidden_layers"],
            config["sample_embedding_dim"] + config["drug_embedding_dim"],
            1,
            p_dropout=0.2,
        )

    def forward(self, spectrum, fingerprint):
        spectrum_embedding = self.spectrum_emb(spectrum)
        dr_emb = self.drug_emb(fingerprint)

        return self.net(torch.cat([dr_emb, spectrum_embedding], dim=1))




conf = {
    "conv_out_size": 512,
    "species_embedding_dim":0,
    "sample_embedding_dim":512,
    "drug_embedding_dim":512,
    "n_hidden_layers": 5,
}
        

In [None]:
import copy

model = Residual_AMR_Classifier(config= conf,n_input_spectrum = driams.n_bins, n_input_drug= 1024)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                        Param #
Residual_AMR_Classifier                       --
├─Identity: 1-1                               --
├─Linear: 1-2                                 9,216,512
├─Linear: 1-3                                 524,800
├─ResMLP: 1-4                                 --
│    └─Sequential: 2-1                        --
│    │    └─ResBlock: 3-1                     1,051,648
│    │    └─ResBlock: 3-2                     1,051,648
│    │    └─ResBlock: 3-3                     1,051,648
│    │    └─ResBlock: 3-4                     1,051,648
│    │    └─ResBlock: 3-5                     1,051,648
│    │    └─Linear: 3-6                       1,025
Total params: 15,000,577
Trainable params: 15,000,577
Non-trainable params: 0


In [None]:
from torch.utils.data import DataLoader

In [None]:
from maldi2resistance.loss.asymmetricLoss import AsymmetricLoss
from torchmetrics.classification import BinaryAUROC, BinaryPrecisionRecallCurve, BinaryAveragePrecision
from maldi2resistance.data.driams import DriamsSingleAntibiotic
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

gen = torch.Generator()

all_auc_roc_micro = []
all_auc_pr_micro = []

all_auc_roc_macro = []
all_auc_pr_macro = []

criterion = AsymmetricLoss()

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= train_data, use_morganFingerprint4Drug=True)
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data, use_morganFingerprint4Drug=True)
    
    train_loader = DataLoader(train_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30), leave= False, position=1):
        overall_loss = 0
        
        for batch_idx, (x, y, pos) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            pos = pos.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x, pos)
        
            loss = criterion(output, y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
        
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels, test_pos = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    test_pos = test_pos.to(DEVICE)
    model = model.to(DEVICE)
    
    output = model(test_features, test_pos)
    output = torch.squeeze(output)
    
    auRoc = BinaryAUROC()
    auc_roc = auRoc(output, test_labels.int())
    all_auc_roc_micro.append(auc_roc)
    
    auPR = BinaryAveragePrecision()
    aucPC = auPR(output, test_labels.int())
    all_auc_pr_micro.append(aucPC)
    
    macro_aucroc = 0
    macro_aucpr = 0
    
    n_not_empty = 0
    
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data, use_morganFingerprint4Drug= False)
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True)
    _, _, test_pos = next(iter(test_loader))
    
    for pos, antibiotic in enumerate(driams.selected_antibiotics):
        out_part = output[test_pos == pos]
        label_part = test_labels[test_pos == pos]
        
        occurrences = torch.bincount(label_part.int())

        try:
            if label_part.numel() == 0 or occurrences[1].item() == 0 or occurrences[0].item()== 0:
                continue
        except IndexError:
            # no positives
            continue
        
        au_roc = auRoc(out_part, label_part.int())
        aucPC = auPR(out_part, label_part.int())
        
        n_not_empty +=1
        macro_aucroc += au_roc.item()
        macro_aucpr += aucPC.item()
        
        
    macro_aucroc = macro_aucroc / n_not_empty
    macro_aucpr = macro_aucpr / n_not_empty
    
    all_auc_roc_macro.append(macro_aucroc)
    all_auc_pr_macro.append(macro_aucpr)
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)
    

Start training ...


Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 2209494.393890 	Learning rate: 0.001000
	Average Loss: 8723.120919 	Learning rate: 0.001000
	Average Loss: 8723.276512 	Learning rate: 0.001000
	Average Loss: 8722.070379 	Learning rate: 0.001000
	Average Loss: 11763.130147 	Learning rate: 0.001000
	Average Loss: 8795.937832 	Learning rate: 0.001000
	Average Loss: 8793.880705 	Learning rate: 0.001000
	Average Loss: 8794.521214 	Learning rate: 0.001000
	Average Loss: 8799.687889 	Learning rate: 0.001000
	Average Loss: 8817.275475 	Learning rate: 0.000500
	Average Loss: 8817.141012 	Learning rate: 0.000500
	Average Loss: 8818.104612 	Learning rate: 0.000500
	Average Loss: 8818.447693 	Learning rate: 0.000500
	Average Loss: 8817.765620 	Learning rate: 0.000500
	Average Loss: 8818.705084 	Learning rate: 0.000500
	Average Loss: 8818.528231 	Learning rate: 0.000500
	Average Loss: 8818.677807 	Learning rate: 0.000500
	Average Loss: 8817.138518 	Learning rate: 0.000500
	Average Loss: 8819.512163 	Learning rate: 0.000500
	Average

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 49757602.040199 	Learning rate: 0.001000
	Average Loss: 5063.600838 	Learning rate: 0.001000
	Average Loss: 5072.111927 	Learning rate: 0.001000
	Average Loss: 5065.084525 	Learning rate: 0.001000
	Average Loss: 5061.284233 	Learning rate: 0.001000
	Average Loss: 5093.652338 	Learning rate: 0.001000
	Average Loss: 5098.596947 	Learning rate: 0.001000
	Average Loss: 5098.872983 	Learning rate: 0.001000
	Average Loss: 5097.765848 	Learning rate: 0.001000
	Average Loss: 5096.920339 	Learning rate: 0.000500
	Average Loss: 5097.489843 	Learning rate: 0.000500
	Average Loss: 5096.935558 	Learning rate: 0.000500
	Average Loss: 5095.083372 	Learning rate: 0.000500
	Average Loss: 5201.352692 	Learning rate: 0.000500
	Average Loss: 5355.644072 	Learning rate: 0.000500
	Average Loss: 5354.490986 	Learning rate: 0.000500
	Average Loss: 5353.756243 	Learning rate: 0.000500
	Average Loss: 5354.623191 	Learning rate: 0.000500
	Average Loss: 5354.684161 	Learning rate: 0.000500
	Average

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 11988414.514455 	Learning rate: 0.001000
	Average Loss: 6088.407414 	Learning rate: 0.001000
	Average Loss: 6088.449762 	Learning rate: 0.001000
	Average Loss: 6087.852195 	Learning rate: 0.001000
	Average Loss: 6088.009527 	Learning rate: 0.001000
	Average Loss: 6087.635738 	Learning rate: 0.001000
	Average Loss: 6087.210660 	Learning rate: 0.001000
	Average Loss: 6087.237097 	Learning rate: 0.001000
	Average Loss: 6086.463920 	Learning rate: 0.001000
	Average Loss: 6511.695877 	Learning rate: 0.000500
	Average Loss: 6654.751311 	Learning rate: 0.000500
	Average Loss: 6654.777966 	Learning rate: 0.000500
	Average Loss: 6655.269997 	Learning rate: 0.000500
	Average Loss: 6655.005836 	Learning rate: 0.000500
	Average Loss: 6654.304838 	Learning rate: 0.000500
	Average Loss: 6654.749492 	Learning rate: 0.000500
	Average Loss: 6654.793215 	Learning rate: 0.000500
	Average Loss: 6655.215184 	Learning rate: 0.000500
	Average Loss: 6655.308871 	Learning rate: 0.000500
	Average

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 80882519.163807 	Learning rate: 0.001000
	Average Loss: 9154.201510 	Learning rate: 0.001000
	Average Loss: 9155.652150 	Learning rate: 0.001000
	Average Loss: 9154.054740 	Learning rate: 0.001000
	Average Loss: 9154.477213 	Learning rate: 0.001000
	Average Loss: 9154.242537 	Learning rate: 0.001000
	Average Loss: 9154.892205 	Learning rate: 0.001000
	Average Loss: 9155.004305 	Learning rate: 0.001000
	Average Loss: 9155.775551 	Learning rate: 0.001000
	Average Loss: 9154.654671 	Learning rate: 0.000500
	Average Loss: 9154.846846 	Learning rate: 0.000500
	Average Loss: 9155.263035 	Learning rate: 0.000500
	Average Loss: 9153.513297 	Learning rate: 0.000500
	Average Loss: 9154.971265 	Learning rate: 0.000500
	Average Loss: 9154.276493 	Learning rate: 0.000500
	Average Loss: 9154.175221 	Learning rate: 0.000500
	Average Loss: 9154.424662 	Learning rate: 0.000500
	Average Loss: 9154.757800 	Learning rate: 0.000500
	Average Loss: 9153.975650 	Learning rate: 0.000500
	Average

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 108633753.838258 	Learning rate: 0.001000
	Average Loss: 5766.448417 	Learning rate: 0.001000
	Average Loss: 5766.790915 	Learning rate: 0.001000
	Average Loss: 5763.495870 	Learning rate: 0.001000
	Average Loss: 5753.896579 	Learning rate: 0.001000
	Average Loss: 38912.593132 	Learning rate: 0.001000
	Average Loss: 6843.751165 	Learning rate: 0.001000
	Average Loss: 6835.388288 	Learning rate: 0.001000
	Average Loss: 6840.538775 	Learning rate: 0.001000
	Average Loss: 7040.426174 	Learning rate: 0.000500
	Average Loss: 7074.200444 	Learning rate: 0.000500
	Average Loss: 7074.572746 	Learning rate: 0.000500
	Average Loss: 7074.091525 	Learning rate: 0.000500
	Average Loss: 7074.452573 	Learning rate: 0.000500
	Average Loss: 7074.303145 	Learning rate: 0.000500
	Average Loss: 7074.443724 	Learning rate: 0.000500
	Average Loss: 7074.755022 	Learning rate: 0.000500
	Average Loss: 7074.829990 	Learning rate: 0.000500
	Average Loss: 7074.177893 	Learning rate: 0.000500
	Avera

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_roc_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_roc_micro]) }")

Mean	: 0.5468364953994751
 SD 	: 0.09235317260026932


In [None]:
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_pr_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_pr_micro]) }")

Mean	: 0.27030566334724426
 SD 	: 0.051002245396375656


In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x for x in all_auc_roc_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_roc_macro]) }")

Mean	: 0.5460756709701136
 SD 	: 0.09287175332011391


In [None]:
print(f"Mean\t: {np.mean([x for x in all_auc_pr_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_pr_macro]) }")

Mean	: 0.27287272795250544
 SD 	: 0.05081944088174901
