# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7f079dffefb0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)
umg = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
    bin_size=1,
    sites=["UMG"],
    years=[2020,2021],
    antibiotics=driams.selected_antibiotics,
)

driams.loading_type = "memory"
umg.loading_type = "memory"


driams



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/73745 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
from chemprop.models import MPNN
from chemprop.nn import BondMessagePassing, NormAggregation, RegressionFFN
from multimodal_amr.models.modules import ResMLP
from torch import nn


class Residual_AMR_Classifier(nn.Module):
    """Overall model definition."""

    def __init__(self, config, n_input_spectrum, n_input_drug):
        super().__init__()
        self.config = config

        # Sample embedding
        if config["species_embedding_dim"]==0 and config["conv_out_size"]==config["sample_embedding_dim"]:
            self.sample_emb = nn.Identity()
        elif config["species_embedding_dim"]==0:
            self.sample_emb = nn.Linear(
                config["conv_out_size"],
                config["sample_embedding_dim"],
            )
        else:
            self.sample_emb = nn.Linear(
                config["conv_out_size"] + config["species_embedding_dim"],
                config["sample_embedding_dim"],
            )

        # Maldi-tof spectrum embedding
        # self.spectrum_emb = Conv1d_Block(output_dim=config["conv_out_size"])
        
        self.spectrum_emb = nn.Linear(n_input_spectrum ,config["conv_out_size"])

        # Drugs layers
        # if config["drug_emb_type"] == "vae_embedding" or config["drug_emb_type"] == "gnn_embedding":
        #     self.drug_emb = nn.Identity()
        # elif config["drug_emb_type"] == "fingerprint":
        mp = BondMessagePassing()
        agg = NormAggregation()
        ffn = RegressionFFN()
        self.mpnn = MPNN(mp, agg, ffn)
        
        
        self.drug_emb = nn.Linear(
            n_input_drug, config["drug_embedding_dim"]
        )

        # Output network
        self.net = ResMLP(
            config["n_hidden_layers"],
            config["sample_embedding_dim"] + 300,
            1,
            p_dropout=0.2,
        )

    def forward(self, spectrum, drug):
        spectrum_embedding = self.spectrum_emb(spectrum)
        
        dr_emb = self.mpnn.encoding(drug.bmg, i=0)
        
        #combined_emb = torch.cat(spectrum_embedding, dr_emb)
        combined_emb = torch.cat((spectrum_embedding, dr_emb), dim= 1)

        return self.net(combined_emb)




conf = {
    "conv_out_size": 512,
    "species_embedding_dim":0,
    "sample_embedding_dim":512,
    "drug_embedding_dim":512,
    "n_hidden_layers": 5,
}
        

In [None]:
from maldi2resistance.model.dualBranch import DualBranchOneHot
import copy
from maldi2resistance.model.singleBranchMlp import SingleBranchMLP

model = Residual_AMR_Classifier(config= conf,n_input_spectrum = driams.n_bins, n_input_drug= 1024)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torch.utils.data import DataLoader

In [None]:
from maldi2resistance.data.chemprop import collate, CachedChempropCollate
from torchmetrics.utilities.compute import auc
from torchmetrics.classification import BinaryAUROC, BinaryPrecisionRecallCurve
from maldi2resistance.data.driams import DriamsSingleAntibiotic
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

gen = torch.Generator()

all_auc_roc_micro = []
all_auc_pr_micro = []

all_auc_roc_macro = []
all_auc_pr_macro = []
cached_collate = CachedChempropCollate(driams = driams)

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_data_umg = torch.utils.data.ConcatDataset([train_data, umg])
    
    train_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= train_data_umg)
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data)
    
    train_loader = DataLoader(train_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, generator= gen.manual_seed(SEED), collate_fn=cached_collate.collate)
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, generator= gen.manual_seed(SEED), collate_fn=cached_collate.collate)
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30), leave= False, position=1):
        overall_loss = 0
        
        for batch_idx, data in tqdm(enumerate(train_loader), total=len(train_loader), leave=False, position=1):
            (x,y), pos = data
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            pos.bmg.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x, pos)
        
            loss = F.binary_cross_entropy_with_logits(torch.squeeze(output), y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    model = model.to(DEVICE)
    output = []
    test_labels = []
    
    for ((x, label), drug) in tqdm(test_loader, leave=False):
        x = x.to(DEVICE)
        drug.bmg.to(DEVICE)
        label = label.to(DEVICE)
        
        result = model(x, drug).detach().cpu()
        
        test_labels.append(label)
        output.append(result)

    output = torch.cat(output).squeeze()
    test_labels = torch.cat(test_labels).int()

    auRoc = BinaryAUROC()
    auc_roc = auRoc(torch.sigmoid(output), test_labels)
    all_auc_roc_micro.append(auc_roc)
    
    metric = BinaryPrecisionRecallCurve()
    metric.update(torch.sigmoid(output), test_labels)
    precision, recall, thresholds = metric.compute()
    aucPC = auc(recall, precision)
    
    all_auc_pr_micro.append(aucPC)
    
    macro_aucroc = 0
    macro_aucpr = 0
    
    n_not_empty = 0
    
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data)
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True)
    _, _, test_pos = next(iter(test_loader))
    
    for pos, antibiotic in enumerate(driams.selected_antibiotics):
        out_part = output[test_pos == pos]
        label_part = test_labels[test_pos == pos]
        
        occurrences = torch.bincount(label_part.int())

        try:
            if label_part.numel() == 0 or occurrences[1].item() == 0 or occurrences[0].item()== 0:
                continue
        except IndexError:
            # no positives
            continue
        
        au_roc = auRoc(out_part, label_part.int())
        
        metric.update(out_part, label_part.int())
        precision, recall, thresholds = metric.compute()
        aucPC = auc(recall, precision)
        
        n_not_empty +=1
        macro_aucroc += au_roc.item()
        macro_aucpr += aucPC.item()
        
        
    macro_aucroc = macro_aucroc / n_not_empty
    macro_aucpr = macro_aucpr / n_not_empty
    
    all_auc_roc_macro.append(macro_aucroc)
    all_auc_pr_macro.append(macro_aucpr)
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)
    

Start training ...


Create single label Dataset:   0%|          | 0/118369 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.002670 	Learning rate: 0.001000


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.002084 	Learning rate: 0.001000


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001888 	Learning rate: 0.001000


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001762 	Learning rate: 0.001000


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001661 	Learning rate: 0.001000


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001581 	Learning rate: 0.001000


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001517 	Learning rate: 0.001000


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001460 	Learning rate: 0.001000


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001400 	Learning rate: 0.001000


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001349 	Learning rate: 0.000500


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001216 	Learning rate: 0.000500


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001174 	Learning rate: 0.000500


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001138 	Learning rate: 0.000500


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001105 	Learning rate: 0.000500


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001078 	Learning rate: 0.000500


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001051 	Learning rate: 0.000500


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001025 	Learning rate: 0.000500


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.001000 	Learning rate: 0.000500


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000983 	Learning rate: 0.000500


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000924 	Learning rate: 0.000250


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000852 	Learning rate: 0.000250


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000833 	Learning rate: 0.000250


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000819 	Learning rate: 0.000250


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000805 	Learning rate: 0.000250


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000788 	Learning rate: 0.000250


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000777 	Learning rate: 0.000250


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000767 	Learning rate: 0.000250


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000753 	Learning rate: 0.000250


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000742 	Learning rate: 0.000250


  0%|          | 0/9335 [00:00<?, ?it/s]

	Average Loss: 0.000728 	Learning rate: 0.000125
Finished Fold 0


  0%|          | 0/1098 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/118369 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.002664 	Learning rate: 0.001000


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.002078 	Learning rate: 0.001000


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001883 	Learning rate: 0.001000


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001755 	Learning rate: 0.001000


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001657 	Learning rate: 0.001000


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001578 	Learning rate: 0.001000


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001507 	Learning rate: 0.001000


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001449 	Learning rate: 0.001000


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001392 	Learning rate: 0.001000


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001342 	Learning rate: 0.000500


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001201 	Learning rate: 0.000500


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001159 	Learning rate: 0.000500


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001120 	Learning rate: 0.000500


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001088 	Learning rate: 0.000500


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001056 	Learning rate: 0.000500


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.001024 	Learning rate: 0.000500


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000997 	Learning rate: 0.000500


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000971 	Learning rate: 0.000500


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000946 	Learning rate: 0.000500


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000921 	Learning rate: 0.000250


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000824 	Learning rate: 0.000250


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000799 	Learning rate: 0.000250


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000781 	Learning rate: 0.000250


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000763 	Learning rate: 0.000250


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000748 	Learning rate: 0.000250


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000733 	Learning rate: 0.000250


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000718 	Learning rate: 0.000250


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000704 	Learning rate: 0.000250


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000690 	Learning rate: 0.000250


  0%|          | 0/9325 [00:00<?, ?it/s]

	Average Loss: 0.000676 	Learning rate: 0.000125
Finished Fold 1


  0%|          | 0/1108 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/118369 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.002690 	Learning rate: 0.001000


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.002087 	Learning rate: 0.001000


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001892 	Learning rate: 0.001000


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001763 	Learning rate: 0.001000


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001668 	Learning rate: 0.001000


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001587 	Learning rate: 0.001000


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001521 	Learning rate: 0.001000


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001462 	Learning rate: 0.001000


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001409 	Learning rate: 0.001000


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001358 	Learning rate: 0.000500


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001222 	Learning rate: 0.000500


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001177 	Learning rate: 0.000500


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001143 	Learning rate: 0.000500


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001111 	Learning rate: 0.000500


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001080 	Learning rate: 0.000500


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001050 	Learning rate: 0.000500


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.001022 	Learning rate: 0.000500


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000996 	Learning rate: 0.000500


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000973 	Learning rate: 0.000500


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000946 	Learning rate: 0.000250


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000851 	Learning rate: 0.000250


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000827 	Learning rate: 0.000250


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000810 	Learning rate: 0.000250


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000792 	Learning rate: 0.000250


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000776 	Learning rate: 0.000250


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000762 	Learning rate: 0.000250


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000746 	Learning rate: 0.000250


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000732 	Learning rate: 0.000250


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000718 	Learning rate: 0.000250


  0%|          | 0/9326 [00:00<?, ?it/s]

	Average Loss: 0.000707 	Learning rate: 0.000125
Finished Fold 2


  0%|          | 0/1107 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/118369 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.002674 	Learning rate: 0.001000


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.002092 	Learning rate: 0.001000


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001896 	Learning rate: 0.001000


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001765 	Learning rate: 0.001000


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001671 	Learning rate: 0.001000


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001593 	Learning rate: 0.001000


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001524 	Learning rate: 0.001000


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001464 	Learning rate: 0.001000


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001412 	Learning rate: 0.001000


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001364 	Learning rate: 0.000500


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001228 	Learning rate: 0.000500


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001185 	Learning rate: 0.000500


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001151 	Learning rate: 0.000500


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001117 	Learning rate: 0.000500


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001086 	Learning rate: 0.000500


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001057 	Learning rate: 0.000500


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001028 	Learning rate: 0.000500


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.001000 	Learning rate: 0.000500


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000976 	Learning rate: 0.000500


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000949 	Learning rate: 0.000250


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000856 	Learning rate: 0.000250


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000829 	Learning rate: 0.000250


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000810 	Learning rate: 0.000250


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000794 	Learning rate: 0.000250


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000778 	Learning rate: 0.000250


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000764 	Learning rate: 0.000250


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000748 	Learning rate: 0.000250


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000736 	Learning rate: 0.000250


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000719 	Learning rate: 0.000250


  0%|          | 0/9330 [00:00<?, ?it/s]

	Average Loss: 0.000709 	Learning rate: 0.000125
Finished Fold 3


  0%|          | 0/1103 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/118369 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.002671 	Learning rate: 0.001000


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.002092 	Learning rate: 0.001000


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001896 	Learning rate: 0.001000


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001767 	Learning rate: 0.001000


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001669 	Learning rate: 0.001000


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001589 	Learning rate: 0.001000


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001525 	Learning rate: 0.001000


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001462 	Learning rate: 0.001000


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001410 	Learning rate: 0.001000


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001356 	Learning rate: 0.000500


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001223 	Learning rate: 0.000500


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001181 	Learning rate: 0.000500


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001147 	Learning rate: 0.000500


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001112 	Learning rate: 0.000500


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001083 	Learning rate: 0.000500


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001055 	Learning rate: 0.000500


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001027 	Learning rate: 0.000500


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.001002 	Learning rate: 0.000500


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000977 	Learning rate: 0.000500


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000954 	Learning rate: 0.000250


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000862 	Learning rate: 0.000250


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000837 	Learning rate: 0.000250


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000820 	Learning rate: 0.000250


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000803 	Learning rate: 0.000250


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000788 	Learning rate: 0.000250


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000772 	Learning rate: 0.000250


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000757 	Learning rate: 0.000250


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000745 	Learning rate: 0.000250


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000730 	Learning rate: 0.000250


  0%|          | 0/9340 [00:00<?, ?it/s]

	Average Loss: 0.000718 	Learning rate: 0.000125
Finished Fold 4


  0%|          | 0/1093 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_roc_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_roc_micro]) }")

Mean	: 0.9247110486030579
 SD 	: 0.002133152214810252


In [None]:
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_pr_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_pr_micro]) }")

Mean	: 0.8556210398674011
 SD 	: 0.006246367935091257


In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x for x in all_auc_roc_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_roc_macro]) }")

Mean	: 0.9248690570655622
 SD 	: 0.0026863199029906037


In [None]:
print(f"Mean\t: {np.mean([x for x in all_auc_pr_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_pr_macro]) }")

Mean	: 0.8556290149688721
 SD 	: 0.006075328960571071
