# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7fdd6c0f6f70>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)

driams.loading_type = "memory"

driams

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
from chemprop.models import MPNN
from chemprop.nn import BondMessagePassing, NormAggregation, RegressionFFN
from multimodal_amr.models.modules import ResMLP
from torch import nn


class Residual_AMR_Classifier(nn.Module):
    """Overall model definition."""

    def __init__(self, config, n_input_spectrum, n_input_drug):
        super().__init__()
        self.config = config

        # Sample embedding
        if config["species_embedding_dim"]==0 and config["conv_out_size"]==config["sample_embedding_dim"]:
            self.sample_emb = nn.Identity()
        elif config["species_embedding_dim"]==0:
            self.sample_emb = nn.Linear(
                config["conv_out_size"],
                config["sample_embedding_dim"],
            )
        else:
            self.sample_emb = nn.Linear(
                config["conv_out_size"] + config["species_embedding_dim"],
                config["sample_embedding_dim"],
            )

        # Maldi-tof spectrum embedding
        # self.spectrum_emb = Conv1d_Block(output_dim=config["conv_out_size"])
        
        self.spectrum_emb = nn.Linear(n_input_spectrum ,config["conv_out_size"])

        # Drugs layers
        # if config["drug_emb_type"] == "vae_embedding" or config["drug_emb_type"] == "gnn_embedding":
        #     self.drug_emb = nn.Identity()
        # elif config["drug_emb_type"] == "fingerprint":
        mp = BondMessagePassing()
        agg = NormAggregation()
        ffn = RegressionFFN()
        self.mpnn = MPNN(mp, agg, ffn)
        
        
        self.drug_emb = nn.Linear(
            n_input_drug, config["drug_embedding_dim"]
        )

        # Output network
        self.net = ResMLP(
            config["n_hidden_layers"],
            config["sample_embedding_dim"] + 300,
            1,
            p_dropout=0.2,
        )

    def forward(self, spectrum, drug):
        spectrum_embedding = self.spectrum_emb(spectrum)
        
        dr_emb = self.mpnn.encoding(drug.bmg, i=0)
        
        #combined_emb = torch.cat(spectrum_embedding, dr_emb)
        combined_emb = torch.cat((spectrum_embedding, dr_emb), dim= 1)

        return self.net(combined_emb)




conf = {
    "conv_out_size": 512,
    "species_embedding_dim":0,
    "sample_embedding_dim":512,
    "drug_embedding_dim":512,
    "n_hidden_layers": 5,
}
        

In [None]:
from maldi2resistance.model.dualBranch import DualBranchOneHot
import copy
from maldi2resistance.model.singleBranchMlp import SingleBranchMLP

model = Residual_AMR_Classifier(config= conf,n_input_spectrum = driams.n_bins, n_input_drug= 1024)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torch.utils.data import DataLoader

In [None]:
from maldi2resistance.data.chemprop import collate
from torchmetrics.utilities.compute import auc
from torchmetrics.classification import BinaryAUROC, BinaryPrecisionRecallCurve
from maldi2resistance.data.driams import DriamsSingleAntibiotic
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

gen = torch.Generator()

all_auc_roc_micro = []
all_auc_pr_micro = []

all_auc_roc_macro = []
all_auc_pr_macro = []

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= train_data, prepeare4chemprop=True)
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data, prepeare4chemprop= True)
    
    train_loader = DataLoader(train_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, generator= gen.manual_seed(SEED), collate_fn=collate)
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, generator= gen.manual_seed(SEED), collate_fn=collate)
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30), leave= False, position=1):
        overall_loss = 0
        
        for batch_idx, data in tqdm(enumerate(train_loader), total=len(train_loader), leave=False, position=1):
            (x,y), pos = data
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            pos.bmg.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x, pos)
        
            loss = F.binary_cross_entropy_with_logits(torch.squeeze(output), y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    model = model.to(DEVICE)
    output = []
    test_labels = []
    
    for ((x, label), drug) in tqdm(test_loader, leave=False):
        x = x.to(DEVICE)
        drug.bmg.to(DEVICE)
        label = label.to(DEVICE)
        
        result = model(x, drug).detach().cpu()
        
        test_labels.append(label)
        output.append(result)

    output = torch.cat(output).squeeze()
    test_labels = torch.cat(test_labels).int()

    auRoc = BinaryAUROC()
    auc_roc = auRoc(torch.sigmoid(output), test_labels)
    all_auc_roc_micro.append(auc_roc)
    
    metric = BinaryPrecisionRecallCurve()
    metric.update(torch.sigmoid(output), test_labels)
    precision, recall, thresholds = metric.compute()
    aucPC = auc(recall, precision)
    
    all_auc_pr_micro.append(aucPC)
    
    macro_aucroc = 0
    macro_aucpr = 0
    
    n_not_empty = 0
    
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data)
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True)
    _, _, test_pos = next(iter(test_loader))
    
    for pos, antibiotic in enumerate(driams.selected_antibiotics):
        out_part = output[test_pos == pos]
        label_part = test_labels[test_pos == pos]
        
        occurrences = torch.bincount(label_part.int())

        try:
            if label_part.numel() == 0 or occurrences[1].item() == 0 or occurrences[0].item()== 0:
                continue
        except IndexError:
            # no positives
            continue
        
        au_roc = auRoc(out_part, label_part.int())
        
        metric.update(out_part, label_part.int())
        precision, recall, thresholds = metric.compute()
        aucPC = auc(recall, precision)
        
        n_not_empty +=1
        macro_aucroc += au_roc.item()
        macro_aucpr += aucPC.item()
        
        
    macro_aucroc = macro_aucroc / n_not_empty
    macro_aucpr = macro_aucpr / n_not_empty
    
    all_auc_roc_macro.append(macro_aucroc)
    all_auc_pr_macro.append(macro_aucpr)
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)
    

Start training ...


Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

	Average Loss: 0.001087 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.001039 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000997 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000961 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000923 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000889 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000858 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000824 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000799 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000763 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000665 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000638 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000616 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000598 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000580 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000561 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000545 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000529 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000516 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 0.000500 	Learning rate: 0.000125
Finished Fold 0


  0%|          | 0/1098 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.002730 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.002213 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001972 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001793 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001652 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001542 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001447 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001374 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001305 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001244 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001086 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001038 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000996 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000957 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000919 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000883 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000852 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000818 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000791 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000759 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000659 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000630 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000612 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000591 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000570 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000556 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000540 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000523 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000507 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000494 	Learning rate: 0.000125
Finished Fold 1


  0%|          | 0/1108 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.002748 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.002230 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001987 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001815 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001674 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001570 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001479 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001404 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001336 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001264 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001107 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001060 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.001017 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000973 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000938 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000906 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000868 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000834 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000805 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000769 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000669 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000637 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000619 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000598 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000583 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000562 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000546 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000529 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000514 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 0.000498 	Learning rate: 0.000125
Finished Fold 2


  0%|          | 0/1107 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.002725 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.002208 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.002010 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001838 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001712 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001615 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001531 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001465 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001399 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001381 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001210 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001222 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001149 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001099 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001072 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001042 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.001019 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000992 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000989 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000962 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000886 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000876 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000856 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000836 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000828 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000808 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000806 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000787 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000784 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 0.000762 	Learning rate: 0.000125
Finished Fold 3


  0%|          | 0/1103 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.002756 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.002246 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.001997 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.001812 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.001676 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.001566 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.001475 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.001400 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.001335 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.001265 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.001119 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.001067 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.001025 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000987 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000952 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000914 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000884 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000850 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000816 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000791 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000693 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000662 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000640 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000619 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000601 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000583 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000566 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000554 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000535 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 0.000522 	Learning rate: 0.000125
Finished Fold 4


  0%|          | 0/1093 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_roc_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_roc_micro]) }")

Mean	: 0.9148151278495789
 SD 	: 0.0027252810541540384


In [None]:
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_pr_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_pr_micro]) }")

Mean	: 0.8438540697097778
 SD 	: 0.00489542493596673


In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x for x in all_auc_roc_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_roc_macro]) }")

Mean	: 0.9147285473974127
 SD 	: 0.0030281215308687473


In [None]:
print(f"Mean\t: {np.mean([x for x in all_auc_pr_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_pr_macro]) }")

Mean	: 0.8438540577888489
 SD 	: 0.004895424853139714
