In [9]:
import copy
import numpy as np
import pandas as pd
import sys
sys.path.insert(0,'../../')
import torch
import torchmetrics.classification
from torch import cuda
from maldi2resistance.model.MLP import Decoder as MLP
from maldi2resistance.data.ms_data import MS_Data
assert cuda.is_available()
assert cuda.device_count() > 0
print(cuda.get_device_name(cuda.current_device()))

DEVICE = torch.device("cuda")
SEED = 42
torch.manual_seed(SEED)

save_folder = "results_MLP/assymetric_loss_MLP_5cv_DRIAMS-ABCD"

NVIDIA RTX 2000 Ada Generation Laptop GPU


### Load the Dataset

In [10]:
ms_data = MS_Data(
    root_dir="/home/youngjunpark/Data/MS_data",
    #sites=["DRIAMS-D"],
    #years=[2018],
    bin_size=1,
)
ms_data.loading_type = "memory"
ms_data



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]

Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


In [11]:
len(ms_data.label_stats.columns)

38

In [12]:
model = MLP(18000, hidden_dim=1024, output_dim=len(ms_data.selected_antibiotics) )
model = model.to(DEVICE)

model_state = copy.deepcopy(model.state_dict()) 

In [13]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                   Param #
Decoder                                  --
├─Linear: 1-1                            18,433,024
├─Linear: 1-2                            1,049,600
├─Linear: 1-3                            38,950
├─LeakyReLU: 1-4                         --
Total params: 19,521,574
Trainable params: 19,521,574
Non-trainable params: 0


In [14]:
from torch.utils.data import DataLoader
from maldi2resistance.loss.asymmetricLoss import AsymmetricLoss
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from pathlib import Path
from maldi2resistance.metric.ROC import MultiLabelRocNan
from maldi2resistance.loss.maskedLoss import MaskedBCE
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")

model.train()

batch_size = 64
fig_path = Path(f"./{save_folder}/figures")
fig_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(f"./{save_folder}/csv")
csv_path.mkdir(parents=True, exist_ok=True)

loss_per_batch = []

class_weights_negative = torch.tensor((1 - (ms_data.label_stats.loc["negative"] / ms_data.label_stats.loc["n_sum"])).values, device=DEVICE)
class_weights_positive = torch.tensor((1 - (ms_data.label_stats.loc["positive"] / ms_data.label_stats.loc["n_sum"])).values, device=DEVICE)

criterion = AsymmetricLoss()

gen = torch.Generator()

for fold, (train_data, test_data) in enumerate(ms_data.getK_fold(n_splits=5, shuffle=True, random_state= SEED)):
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=True, drop_last=True, generator= gen.manual_seed(SEED))
    
    model.load_state_dict(model_state)
    model.to(DEVICE)
    model.train()
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in tqdm(range(30)):
        overall_loss = 0
        
        for batch_idx, (x, y) in enumerate(train_loader):
    
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            
            optimizer.zero_grad()
    
            output = model(x)
        
            loss = criterion(output, y)
            current_loss_value = loss.item()
            loss_per_batch.append(current_loss_value)
            
            overall_loss += current_loss_value
            
            loss.backward()
            optimizer.step()
    
        scheduler.step()
        with tqdm.external_write_mode():
            print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")
    
    
    print(f"Finished Fold {fold}")
    
    model.eval()
    
    DEVICE = torch.device("cpu")
    test_features, test_labels = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    model = model.to(DEVICE)
    
    ml_roc = MultiLabelRocNan()
    output = model(test_features)
    
    
    ml_roc.compute(output,test_labels,ms_data.selected_antibiotics, create_csv=f"./{csv_path}/fold-{fold}_ROC.csv")
    fig_, ax_ = ml_roc()
    
    plt.savefig(fig_path / f"fold-{fold}_ROC.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    ml_pr = MultiLabelPRNan()
    ml_pr.compute(output,test_labels,ms_data.selected_antibiotics, create_csv=f"./{csv_path}/fold-{fold}_PrecisionRecall.csv")
    
    fig_, ax_ = ml_pr()
    
    plt.savefig(fig_path / f"fold-{fold}_PrecisionRecall.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    DEVICE = torch.device("cuda")
    model = model.to(DEVICE)

Start training ...


  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.889504 	Learning rate: 0.001000
	Average Loss: 0.765793 	Learning rate: 0.001000
	Average Loss: 0.717202 	Learning rate: 0.001000
	Average Loss: 0.683673 	Learning rate: 0.001000
	Average Loss: 0.654544 	Learning rate: 0.001000
	Average Loss: 0.627636 	Learning rate: 0.001000
	Average Loss: 0.602571 	Learning rate: 0.001000
	Average Loss: 0.580782 	Learning rate: 0.001000
	Average Loss: 0.561444 	Learning rate: 0.001000
	Average Loss: 0.533415 	Learning rate: 0.000500
	Average Loss: 0.483339 	Learning rate: 0.000500
	Average Loss: 0.464711 	Learning rate: 0.000500
	Average Loss: 0.443739 	Learning rate: 0.000500
	Average Loss: 0.430573 	Learning rate: 0.000500
	Average Loss: 0.410089 	Learning rate: 0.000500
	Average Loss: 0.394672 	Learning rate: 0.000500
	Average Loss: 0.377592 	Learning rate: 0.000500
	Average Loss: 0.362157 	Learning rate: 0.000500
	Average Loss: 0.344930 	Learning rate: 0.000500
	Average Loss: 0.334285 	Learning rate: 0.000250
	Average Loss: 0.293

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.894467 	Learning rate: 0.001000
	Average Loss: 0.768950 	Learning rate: 0.001000
	Average Loss: 0.720964 	Learning rate: 0.001000
	Average Loss: 0.679301 	Learning rate: 0.001000
	Average Loss: 0.657145 	Learning rate: 0.001000
	Average Loss: 0.625699 	Learning rate: 0.001000
	Average Loss: 0.601231 	Learning rate: 0.001000
	Average Loss: 0.579257 	Learning rate: 0.001000
	Average Loss: 0.553255 	Learning rate: 0.001000
	Average Loss: 0.536238 	Learning rate: 0.000500
	Average Loss: 0.476647 	Learning rate: 0.000500
	Average Loss: 0.456739 	Learning rate: 0.000500
	Average Loss: 0.441950 	Learning rate: 0.000500
	Average Loss: 0.427028 	Learning rate: 0.000500
	Average Loss: 0.405723 	Learning rate: 0.000500
	Average Loss: 0.389735 	Learning rate: 0.000500
	Average Loss: 0.373962 	Learning rate: 0.000500
	Average Loss: 0.357061 	Learning rate: 0.000500
	Average Loss: 0.340813 	Learning rate: 0.000500
	Average Loss: 0.325650 	Learning rate: 0.000250
	Average Loss: 0.287

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.897021 	Learning rate: 0.001000
	Average Loss: 0.769318 	Learning rate: 0.001000
	Average Loss: 0.724153 	Learning rate: 0.001000
	Average Loss: 0.686166 	Learning rate: 0.001000
	Average Loss: 0.656465 	Learning rate: 0.001000
	Average Loss: 0.628621 	Learning rate: 0.001000
	Average Loss: 0.605821 	Learning rate: 0.001000
	Average Loss: 0.586133 	Learning rate: 0.001000
	Average Loss: 0.560795 	Learning rate: 0.001000
	Average Loss: 0.539781 	Learning rate: 0.000500
	Average Loss: 0.484526 	Learning rate: 0.000500
	Average Loss: 0.465524 	Learning rate: 0.000500
	Average Loss: 0.446657 	Learning rate: 0.000500
	Average Loss: 0.430920 	Learning rate: 0.000500
	Average Loss: 0.413989 	Learning rate: 0.000500
	Average Loss: 0.396594 	Learning rate: 0.000500
	Average Loss: 0.382017 	Learning rate: 0.000500
	Average Loss: 0.364561 	Learning rate: 0.000500
	Average Loss: 0.348552 	Learning rate: 0.000500
	Average Loss: 0.334917 	Learning rate: 0.000250
	Average Loss: 0.296

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.896658 	Learning rate: 0.001000
	Average Loss: 0.768835 	Learning rate: 0.001000
	Average Loss: 0.719620 	Learning rate: 0.001000
	Average Loss: 0.680890 	Learning rate: 0.001000
	Average Loss: 0.654141 	Learning rate: 0.001000
	Average Loss: 0.625113 	Learning rate: 0.001000
	Average Loss: 0.603287 	Learning rate: 0.001000
	Average Loss: 0.576801 	Learning rate: 0.001000
	Average Loss: 0.553860 	Learning rate: 0.001000
	Average Loss: 0.532257 	Learning rate: 0.000500
	Average Loss: 0.478788 	Learning rate: 0.000500
	Average Loss: 0.459658 	Learning rate: 0.000500
	Average Loss: 0.438265 	Learning rate: 0.000500
	Average Loss: 0.424461 	Learning rate: 0.000500
	Average Loss: 0.405432 	Learning rate: 0.000500
	Average Loss: 0.392009 	Learning rate: 0.000500
	Average Loss: 0.373114 	Learning rate: 0.000500
	Average Loss: 0.355697 	Learning rate: 0.000500
	Average Loss: 0.343779 	Learning rate: 0.000500
	Average Loss: 0.326753 	Learning rate: 0.000250
	Average Loss: 0.288

  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.895499 	Learning rate: 0.001000
	Average Loss: 0.764957 	Learning rate: 0.001000
	Average Loss: 0.716091 	Learning rate: 0.001000
	Average Loss: 0.683787 	Learning rate: 0.001000
	Average Loss: 0.655700 	Learning rate: 0.001000
	Average Loss: 0.627951 	Learning rate: 0.001000
	Average Loss: 0.606437 	Learning rate: 0.001000
	Average Loss: 0.581104 	Learning rate: 0.001000
	Average Loss: 0.559845 	Learning rate: 0.001000
	Average Loss: 0.537050 	Learning rate: 0.000500
	Average Loss: 0.483390 	Learning rate: 0.000500
	Average Loss: 0.463781 	Learning rate: 0.000500
	Average Loss: 0.445990 	Learning rate: 0.000500
	Average Loss: 0.428080 	Learning rate: 0.000500
	Average Loss: 0.410339 	Learning rate: 0.000500
	Average Loss: 0.394929 	Learning rate: 0.000500
	Average Loss: 0.377987 	Learning rate: 0.000500
	Average Loss: 0.361718 	Learning rate: 0.000500
	Average Loss: 0.345411 	Learning rate: 0.000500
	Average Loss: 0.331669 	Learning rate: 0.000250
	Average Loss: 0.293

In [15]:
micro = []
macro = []

for fold in range(0,5):
    csv = pd.read_csv(f"./{csv_path}/fold-{fold}_ROC.csv")
    micro.append(csv[csv["class"] == "micro"]["ROCAUC"])
    macro.append(csv[csv["class"] == "macro"]["ROCAUC"])

print(f"micro-Mean\t: {np.mean(micro)}")
print(f"micro-SD \t: {np.std(micro) }")
print(f"macro-Mean\t: {np.mean(macro)}")
print(f"macro-SD \t: {np.std(macro) }")

micro-Mean	: 0.935346829891205
micro-SD 	: 0.0018252347835044026
macro-Mean	: 0.9064308100625087
macro-SD 	: 0.0020087933534215694


In [16]:
micro = []
macro = []

for fold in range(0,5):
    csv = pd.read_csv(f"./{csv_path}/fold-{fold}_PrecisionRecall.csv")
    micro.append(csv[csv["class"] == "micro"]["PrecisionRecallAUC"])
    macro.append(csv[csv["class"] == "macro"]["PrecisionRecallAUC"])

print(f"Mean\t: {np.mean(micro)}")
print(f" SD \t: {np.std(micro) }")
print(f"Mean\t: {np.mean(macro)}")
print(f" SD \t: {np.std(macro) }")

Mean	: 0.8726518988609314
 SD 	: 0.0019970157278060706
Mean	: 0.7863881920513354
 SD 	: 0.003981206063662021
