# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7f016823efb0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)

driams.loading_type = "memory"

driams



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [None]:
len(driams.label_stats.columns)

38

In [None]:
from chemprop.models import MPNN
from chemprop.nn import BondMessagePassing, NormAggregation, RegressionFFN
from multimodal_amr.models.modules import ResMLP
from torch import nn


class Residual_AMR_Classifier(nn.Module):
    """Overall model definition."""

    def __init__(self, config, n_input_spectrum, n_input_drug):
        super().__init__()
        self.config = config

        # Sample embedding
        if config["species_embedding_dim"]==0 and config["conv_out_size"]==config["sample_embedding_dim"]:
            self.sample_emb = nn.Identity()
        elif config["species_embedding_dim"]==0:
            self.sample_emb = nn.Linear(
                config["conv_out_size"],
                config["sample_embedding_dim"],
            )
        else:
            self.sample_emb = nn.Linear(
                config["conv_out_size"] + config["species_embedding_dim"],
                config["sample_embedding_dim"],
            )

        # Maldi-tof spectrum embedding
        # self.spectrum_emb = Conv1d_Block(output_dim=config["conv_out_size"])
        
        self.spectrum_emb = nn.Linear(n_input_spectrum ,config["conv_out_size"])

        # Drugs layers
        # if config["drug_emb_type"] == "vae_embedding" or config["drug_emb_type"] == "gnn_embedding":
        #     self.drug_emb = nn.Identity()
        # elif config["drug_emb_type"] == "fingerprint":
        mp = BondMessagePassing()
        agg = NormAggregation()
        ffn = RegressionFFN()
        self.mpnn = MPNN(mp, agg, ffn)
        
        
        self.drug_emb = nn.Linear(
            n_input_drug, config["drug_embedding_dim"]
        )

        # Output network
        self.net = ResMLP(
            config["n_hidden_layers"],
            config["sample_embedding_dim"] + 300,
            1,
            p_dropout=0.2,
        )

    def forward(self, spectrum, drug):
        spectrum_embedding = self.spectrum_emb(spectrum)
        
        dr_emb = self.mpnn.encoding(drug.bmg, i=0)
        
        #combined_emb = torch.cat(spectrum_embedding, dr_emb)
        combined_emb = torch.cat((spectrum_embedding, dr_emb), dim= 1)

        return self.net(combined_emb)




conf = {
    "conv_out_size": 512,
    "species_embedding_dim":0,
    "sample_embedding_dim":512,
    "drug_embedding_dim":512,
    "n_hidden_layers": 5,
}
        

In [None]:
from maldi2resistance.model.dualBranch import DualBranchOneHot
import copy
from maldi2resistance.model.singleBranchMlp import SingleBranchMLP

model = Residual_AMR_Classifier(config= conf,n_input_spectrum = driams.n_bins, n_input_drug= 1024)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [None]:
from torch.utils.data import DataLoader

In [None]:
from maldi2resistance.loss.asymmetricLoss import AsymmetricLoss
from maldi2resistance.data.chemprop import collate, CachedChempropCollate
from torchmetrics.utilities.compute import auc
from torchmetrics.classification import BinaryAUROC, BinaryPrecisionRecallCurve, BinaryAveragePrecision
from maldi2resistance.data.driams import DriamsSingleAntibiotic
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

batch_size = 128
fig_path = Path("./kfold/figures")
fig_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

gen = torch.Generator()

all_auc_roc_micro = []
all_auc_pr_micro = []

all_auc_roc_macro = []
all_auc_pr_macro = []

for fold, (train_data, test_data) in enumerate(driams.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= train_data)
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data)
    cached_collate = CachedChempropCollate(driams = driams)
    
    train_loader = DataLoader(train_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, generator= gen.manual_seed(SEED), collate_fn=cached_collate.collate)
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=batch_size, shuffle=True, generator= gen.manual_seed(SEED), collate_fn=cached_collate.collate)
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
    metric = AsymmetricLoss()

    for epoch in tqdm(range(30), leave= False, position=1):
        overall_loss = 0
        
        for batch_idx, data in tqdm(enumerate(train_loader), total=len(train_loader), leave=False, position=1):
            (x,y), pos = data
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            pos.bmg.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x, pos)
        
            loss = metric(torch.squeeze(output), y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    model = model.to(DEVICE)
    output = []
    test_labels = []
    
    for ((x, label), drug) in tqdm(test_loader, leave=False):
        x = x.to(DEVICE)
        drug.bmg.to(DEVICE)
        label = label.to(DEVICE)
        
        result = model(x, drug).detach().cpu()
        
        test_labels.append(label)
        output.append(result)

    output = torch.cat(output).squeeze()
    test_labels = torch.cat(test_labels).int()

    auRoc = BinaryAUROC()
    auc_roc = auRoc(torch.sigmoid(output), test_labels)
    all_auc_roc_micro.append(auc_roc)
    
    auPR = BinaryAveragePrecision()
    aucPC = auPR(output, test_labels.int())
    all_auc_pr_micro.append(aucPC)
    
    macro_aucroc = 0
    macro_aucpr = 0
    
    n_not_empty = 0
    
    test_dataset_single_antibiotic = DriamsSingleAntibiotic(driams= test_data)
    test_loader = DataLoader(test_dataset_single_antibiotic, batch_size=len(test_dataset_single_antibiotic), shuffle=True)
    _, _, test_pos = next(iter(test_loader))
    
    for pos, antibiotic in enumerate(driams.selected_antibiotics):
        out_part = output[test_pos == pos]
        label_part = test_labels[test_pos == pos]
        
        occurrences = torch.bincount(label_part.int())

        try:
            if label_part.numel() == 0 or occurrences[1].item() == 0 or occurrences[0].item()== 0:
                continue
        except IndexError:
            # no positives
            continue
        
        auRoc = BinaryAUROC()
        au_roc = auRoc(out_part, label_part.int())
        
        auPR = BinaryAveragePrecision()
        aucPC = auPR(out_part, label_part.int())
        
        n_not_empty +=1
        macro_aucroc += au_roc.item()
        macro_aucpr += aucPC.item()
        
        
    macro_aucroc = macro_aucroc / n_not_empty
    macro_aucpr = macro_aucpr / n_not_empty
    
    all_auc_roc_macro.append(macro_aucroc)
    all_auc_pr_macro.append(macro_aucpr)
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)
    

Start training ...


Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 139055.620875 	Learning rate: 0.001000


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 37030.436278 	Learning rate: 0.001000


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 9974638.215883 	Learning rate: 0.001000


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 165842.265382 	Learning rate: 0.001000


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.335697 	Learning rate: 0.001000


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.299280 	Learning rate: 0.001000


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.350263 	Learning rate: 0.001000


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.315879 	Learning rate: 0.001000


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.326092 	Learning rate: 0.001000


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.293555 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.330615 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.338355 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.344408 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.341407 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.357620 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.338004 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.354415 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.378784 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.347828 	Learning rate: 0.000500


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.352769 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.348274 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.370670 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.403286 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.348798 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.349672 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.353312 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.351771 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.317052 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.340115 	Learning rate: 0.000250


  0%|          | 0/4409 [00:00<?, ?it/s]

	Average Loss: 39.341271 	Learning rate: 0.000125
Finished Fold 0


  0%|          | 0/1098 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 835199.407442 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 1815914.486909 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 1720.086699 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 2039.481267 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 22.066131 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 137.280333 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 30.879323 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.293085 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.294814 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.294792 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.291171 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.318145 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.299330 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.294983 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.296515 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.299646 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.304286 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.336359 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.308363 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.309921 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.284756 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.310281 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.309264 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.309453 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.313827 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.312734 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.290862 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.296525 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.328054 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 29.308062 	Learning rate: 0.000125
Finished Fold 1


  0%|          | 0/1108 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 1081103.190091 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 676209.528058 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 97.365811 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 36.050487 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 31.774566 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 32.118225 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 32.089515 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 32.071850 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 32.020334 	Learning rate: 0.001000


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 31.990291 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 31.848686 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 31.824385 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 30.783178 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 62.385336 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 28035.508457 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 424.901902 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 2202.545352 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 131461.184291 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 51085.075098 	Learning rate: 0.000500


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 308.383921 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 87.826010 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 63.753110 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 38.220870 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 56.557073 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 79.628617 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 82.711352 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 48.182352 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 33.184071 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 33.173456 	Learning rate: 0.000250


  0%|          | 0/4400 [00:00<?, ?it/s]

	Average Loss: 33.215396 	Learning rate: 0.000125
Finished Fold 2


  0%|          | 0/1107 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 899089.202511 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.504189 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.474247 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.496059 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.494189 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.503069 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.506380 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.474161 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.479165 	Learning rate: 0.001000


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.492702 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.479742 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.457558 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.507253 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.504087 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.508447 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.496455 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.502593 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.506191 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.488712 	Learning rate: 0.000500


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.491210 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.476713 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.492424 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.480911 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.483702 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.478889 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.492607 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.501100 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.467047 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.502178 	Learning rate: 0.000250


  0%|          | 0/4405 [00:00<?, ?it/s]

	Average Loss: 34.482810 	Learning rate: 0.000125
Finished Fold 3


  0%|          | 0/1103 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/44624 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 16869256.350213 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 135364.158900 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 33.536365 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 27.206093 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 27.259001 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 27.368041 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 27.032347 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 27.827213 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 26.391996 	Learning rate: 0.001000


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 25.826381 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 22.748074 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 22.755450 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 23.446880 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 22.848280 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 22.758036 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 22.493826 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 22.366291 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 22.203178 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 21.866099 	Learning rate: 0.000500


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 2134.780559 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 91.913842 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 21.938095 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 17.830834 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 22.231026 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 18.702363 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 34.859804 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 349.876781 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 909.992202 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 110.204738 	Learning rate: 0.000250


  0%|          | 0/4415 [00:00<?, ?it/s]

	Average Loss: 82.413040 	Learning rate: 0.000125
Finished Fold 4


  0%|          | 0/1093 [00:00<?, ?it/s]

Create single label Dataset:   0%|          | 0/11156 [00:00<?, ?it/s]

In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_roc_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_roc_micro]) }")

Mean	: 0.5404237508773804
 SD 	: 0.13672615587711334


In [None]:
print(f"Mean\t: {np.mean([x.cpu() for x in all_auc_pr_micro])}")
print(f" SD \t: {np.std([x.cpu() for x in all_auc_pr_micro]) }")

Mean	: 0.3191229999065399
 SD 	: 0.13807132840156555


In [None]:
import numpy as np

print(f"Mean\t: {np.mean([x for x in all_auc_roc_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_roc_macro]) }")

Mean	: 0.5410671176094758
 SD 	: 0.13525482388377916


In [None]:
print(f"Mean\t: {np.mean([x for x in all_auc_pr_macro])}")
print(f" SD \t: {np.std([x for x in all_auc_pr_macro]) }")

Mean	: 0.31918808896290624
 SD 	: 0.13349533981816297
