# Assymetric loss on ResMLP-multilabel

In [14]:
import copy
import numpy as np
import pandas as pd
import sys
sys.path.insert(0,'../../')
import torch
import torchmetrics.classification
from torch import cuda
from maldi2resistance.model.MultilabelResMLP import MultilabelResMLP
from maldi2resistance.data.ms_data import MS_Data
assert cuda.is_available()
assert cuda.device_count() > 0
print(cuda.get_device_name(cuda.current_device()))

DEVICE = torch.device("cuda")
SEED = 42
torch.manual_seed(SEED)

save_folder = "results_ResMLP-multilabel/assymetric_loss_ResMLP-multilabel_5cv_DRIAMS-ABCD"

NVIDIA RTX 2000 Ada Generation Laptop GPU


### Load the Dataset

In [15]:
ms_data = MS_Data(
    root_dir="/home/youngjunpark/Data/MS_data",
    #sites=["DRIAMS-A"],
    #years=[2015,2016,2017,2018],
    bin_size=1,
)
ms_data.loading_type = "memory"
ms_data



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [16]:
len(ms_data.label_stats.columns)

38

In [17]:
model = MultilabelResMLP(input_dim= 18000, output_dim= len(ms_data.selected_antibiotics), hidden_dim=256)
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [18]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                        Param #
MultilabelResMLP                              --
├─Linear: 1-1                                 4,608,256
├─ResMLP: 1-2                                 --
│    └─Sequential: 2-1                        --
│    │    └─ResBlock: 3-1                     66,304
│    │    └─ResBlock: 3-2                     66,304
│    │    └─ResBlock: 3-3                     66,304
│    │    └─ResBlock: 3-4                     66,304
│    │    └─ResBlock: 3-5                     66,304
│    │    └─Linear: 3-6                       9,766
Total params: 4,949,542
Trainable params: 4,949,542
Non-trainable params: 0


In [19]:
from torch.utils.data import DataLoader
from maldi2resistance.loss.asymmetricLoss import AsymmetricLoss
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")

model.train()

batch_size = 64
fig_path = Path(f"./{save_folder}/figures")
fig_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(f"./{save_folder}/csv")
csv_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

class_weights_negative = torch.tensor((1 - (ms_data.label_stats.loc["negative"] / ms_data.label_stats.loc["n_sum"])).values, device=DEVICE)
class_weights_positive = torch.tensor((1 - (ms_data.label_stats.loc["positive"] / ms_data.label_stats.loc["n_sum"])).values, device=DEVICE)

criterion = AsymmetricLoss()

gen = torch.Generator()

for fold, (train_data, test_data) in enumerate(ms_data.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30)):
        overall_loss = 0
        
        for batch_idx, (x, y) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x)
        
            loss = criterion(output, y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    model = model.to(DEVICE)
    
    ml_roc = MultiLabelRocNan()
    output = model(test_features)
    
    
    ml_roc.compute(output,test_labels,ms_data.selected_antibiotics, create_csv=f"./{csv_path}/fold-{fold}_ROC.csv")
    fig_, ax_ = ml_roc()
    
    plt.savefig(fig_path / f"fold-{fold}_ROC.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    ml_pr = MultiLabelPRNan()
    ml_pr.compute(output,test_labels,ms_data.selected_antibiotics, create_csv=f"./{csv_path}/fold-{fold}_PrecisionRecall.csv")
    
    fig_, ax_ = ml_pr()
    
    plt.savefig(fig_path / f"fold-{fold}_PrecisionRecall.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)

Start training ...


  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 1.064477 	Learning rate: 0.001000
	Average Loss: 0.807175 	Learning rate: 0.001000
	Average Loss: 0.765778 	Learning rate: 0.001000
	Average Loss: 0.725346 	Learning rate: 0.001000
	Average Loss: 0.693184 	Learning rate: 0.001000
	Average Loss: 0.667477 	Learning rate: 0.001000
	Average Loss: 0.645144 	Learning rate: 0.001000
	Average Loss: 0.622297 	Learning rate: 0.001000
	Average Loss: 0.601262 	Learning rate: 0.001000
	Average Loss: 0.575184 	Learning rate: 0.000500
	Average Loss: 0.522150 	Learning rate: 0.000500
	Average Loss: 0.503210 	Learning rate: 0.000500
	Average Loss: 0.491922 	Learning rate: 0.000500
	Average Loss: 0.477062 	Learning rate: 0.000500
	Average Loss: 0.464822 	Learning rate: 0.000500
	Average Loss: 0.450823 	Learning rate: 0.000500
	Average Loss: 0.445680 	Learning rate: 0.000500
	Average Loss: 0.435135 	Learning rate: 0.000500
	Average Loss: 0.422208 	Learning rate: 0.000500
	Average Loss: 0.412306 	Learning rate: 0.000250
	Average Loss: 0.380

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 1.066409 	Learning rate: 0.001000
	Average Loss: 0.804638 	Learning rate: 0.001000
	Average Loss: 0.756193 	Learning rate: 0.001000
	Average Loss: 0.722141 	Learning rate: 0.001000
	Average Loss: 0.692981 	Learning rate: 0.001000
	Average Loss: 0.659148 	Learning rate: 0.001000
	Average Loss: 0.639730 	Learning rate: 0.001000
	Average Loss: 0.615998 	Learning rate: 0.001000
	Average Loss: 0.592792 	Learning rate: 0.001000
	Average Loss: 0.573009 	Learning rate: 0.000500
	Average Loss: 0.515923 	Learning rate: 0.000500
	Average Loss: 0.495814 	Learning rate: 0.000500
	Average Loss: 0.479922 	Learning rate: 0.000500
	Average Loss: 0.467187 	Learning rate: 0.000500
	Average Loss: 0.454824 	Learning rate: 0.000500
	Average Loss: 0.442509 	Learning rate: 0.000500
	Average Loss: 0.428345 	Learning rate: 0.000500
	Average Loss: 0.418942 	Learning rate: 0.000500
	Average Loss: 0.410996 	Learning rate: 0.000500
	Average Loss: 0.408427 	Learning rate: 0.000250
	Average Loss: 0.366

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 1.064455 	Learning rate: 0.001000
	Average Loss: 0.811657 	Learning rate: 0.001000
	Average Loss: 0.760288 	Learning rate: 0.001000
	Average Loss: 0.720791 	Learning rate: 0.001000
	Average Loss: 0.686975 	Learning rate: 0.001000
	Average Loss: 0.663214 	Learning rate: 0.001000
	Average Loss: 0.638107 	Learning rate: 0.001000
	Average Loss: 0.613518 	Learning rate: 0.001000
	Average Loss: 0.595168 	Learning rate: 0.001000
	Average Loss: 0.573737 	Learning rate: 0.000500
	Average Loss: 0.513218 	Learning rate: 0.000500
	Average Loss: 0.497736 	Learning rate: 0.000500
	Average Loss: 0.484675 	Learning rate: 0.000500
	Average Loss: 0.469383 	Learning rate: 0.000500
	Average Loss: 0.456914 	Learning rate: 0.000500
	Average Loss: 0.444045 	Learning rate: 0.000500
	Average Loss: 0.435231 	Learning rate: 0.000500
	Average Loss: 0.421657 	Learning rate: 0.000500
	Average Loss: 0.412053 	Learning rate: 0.000500
	Average Loss: 0.401912 	Learning rate: 0.000250
	Average Loss: 0.371

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 1.054797 	Learning rate: 0.001000
	Average Loss: 0.805970 	Learning rate: 0.001000
	Average Loss: 0.757969 	Learning rate: 0.001000
	Average Loss: 0.722591 	Learning rate: 0.001000
	Average Loss: 0.695748 	Learning rate: 0.001000
	Average Loss: 0.669738 	Learning rate: 0.001000
	Average Loss: 0.649645 	Learning rate: 0.001000
	Average Loss: 0.627665 	Learning rate: 0.001000
	Average Loss: 0.601393 	Learning rate: 0.001000
	Average Loss: 0.582563 	Learning rate: 0.000500
	Average Loss: 0.522204 	Learning rate: 0.000500
	Average Loss: 0.503983 	Learning rate: 0.000500
	Average Loss: 0.488661 	Learning rate: 0.000500
	Average Loss: 0.472055 	Learning rate: 0.000500
	Average Loss: 0.462211 	Learning rate: 0.000500
	Average Loss: 0.448155 	Learning rate: 0.000500
	Average Loss: 0.437263 	Learning rate: 0.000500
	Average Loss: 0.423703 	Learning rate: 0.000500
	Average Loss: 0.420721 	Learning rate: 0.000500
	Average Loss: 0.403763 	Learning rate: 0.000250
	Average Loss: 0.367

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 1.067515 	Learning rate: 0.001000
	Average Loss: 0.802435 	Learning rate: 0.001000
	Average Loss: 0.756631 	Learning rate: 0.001000
	Average Loss: 0.718792 	Learning rate: 0.001000
	Average Loss: 0.691911 	Learning rate: 0.001000
	Average Loss: 0.662396 	Learning rate: 0.001000
	Average Loss: 0.640493 	Learning rate: 0.001000
	Average Loss: 0.617187 	Learning rate: 0.001000
	Average Loss: 0.599806 	Learning rate: 0.001000
	Average Loss: 0.578983 	Learning rate: 0.000500
	Average Loss: 0.518782 	Learning rate: 0.000500
	Average Loss: 0.500323 	Learning rate: 0.000500
	Average Loss: 0.486121 	Learning rate: 0.000500
	Average Loss: 0.470882 	Learning rate: 0.000500
	Average Loss: 0.459951 	Learning rate: 0.000500
	Average Loss: 0.445953 	Learning rate: 0.000500
	Average Loss: 0.435010 	Learning rate: 0.000500
	Average Loss: 0.424728 	Learning rate: 0.000500
	Average Loss: 0.416488 	Learning rate: 0.000500
	Average Loss: 0.404256 	Learning rate: 0.000250
	Average Loss: 0.368

In [22]:
micro = []
macro = []

for fold in range(0,5):
    csv = pd.read_csv(f"./{csv_path}/fold-{fold}_ROC.csv")
    micro.append(csv[csv["class"] == "micro"]["ROCAUC"])
    macro.append(csv[csv["class"] == "macro"]["ROCAUC"])

print(f"micro-Mean\t: {np.mean(micro)}")
print(f"micro-SD \t: {np.std(micro) }")
print(f"macro-Mean\t: {np.mean(macro)}")
print(f"macro-SD \t: {np.std(macro) }")

micro-Mean	: 0.9269583940505981
micro-SD 	: 0.002265571656398741
macro-Mean	: 0.8971553805627321
macro-SD 	: 0.0025078155352352086


In [23]:
micro = []
macro = []

for fold in range(0,5):
    csv = pd.read_csv(f"./{csv_path}/fold-{fold}_PrecisionRecall.csv")
    micro.append(csv[csv["class"] == "micro"]["PrecisionRecallAUC"])
    macro.append(csv[csv["class"] == "macro"]["PrecisionRecallAUC"])

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.8549234509468079
 SD 	: 0.003929648149179095
Mean	: 0.7637221794379385
 SD 	: 0.00431929700551127
