In [None]:
import torch

DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7f23e51fa330>

In [None]:
import gc
from maldi2resistance.metric.PrecisionRecall import MultiLabelPRNan
from matplotlib import pyplot as plt
from maldi2resistance.metric.ROC import MultiLabelRocNan
from pathlib import Path
from tqdm.auto import tqdm
from maldi2resistance.loss.maskedLoss import MaskedBCE
from torch.optim.lr_scheduler import StepLR
from torch.optim import Adam
from torch.utils.data import DataLoader
from maldi2resistance.model.MLP import AeBasedMLP
from maldi2resistance.data.driams import Driams

gen = torch.Generator()
batch_size = 128
retrain = False

for i in tqdm(range(1, 18001), desc="Bin size", position=0, leave=True):
    
    # check that for bin start and stops
    if 18000%i != 0:
        continue
        
    path = Path("./results")
    path.mkdir(exist_ok=True)
    path = path / f"binSize@{i}"
    path.mkdir(exist_ok=True)
    
    driams = Driams(
        root_dir="/home/jan/Uni/master/data/Driams",
        bin_size=i,
    )
    try:
        driams.loading_type = "memory"
    except FileNotFoundError:
        driams.preprocess()
        driams.loading_type = "memory"
    
    model = AeBasedMLP(input_dim=driams.n_bins, output_dim=len(driams.selected_antibiotics), hidden_dim=4096, latent_dim=2048)
    model.to(DEVICE)
    
    train_size = int(0.8 * len(driams))
    test_size = len(driams) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(driams, [train_size, test_size], generator=gen.manual_seed(SEED))
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory= True)
    
    optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
    
    class_weights_negative = torch.tensor((1 - (driams.label_stats.loc["negative"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)
    class_weights_positive = torch.tensor((1 - (driams.label_stats.loc["positive"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)
    
    criterion = MaskedBCE(class_weights_positive= class_weights_positive, class_weights_negative= class_weights_negative)
    
    if retrain:
        for epoch in tqdm(range(30), desc="Trainings epoch", leave=False, position=1):
        
            for batch_idx, (x, y) in enumerate(train_loader):
        
                x = x.to(DEVICE,non_blocking=True)
                y = y.to(DEVICE,non_blocking=True)
                
                optimizer.zero_grad()
                output = model(x)
                
                loss = criterion(output, y)        
                
                loss.backward()
                optimizer.step()
        
            scheduler.step()        
    
        model.eval()
        
        torch.save({
                'model_state_dict': model.state_dict(),
                'selected_antibiotics': driams.selected_antibiotics
                }, path / 'model.pt')
        
        del x, y, loss
    else:
        
        checkpoint = torch.load(path / 'model.pt')
        model.load_state_dict(checkpoint['model_state_dict'])

    
    del train_loader, train_dataset, class_weights_negative, class_weights_positive
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()
    
    DEVICE = torch.device("cpu")
    
    test_loader = DataLoader(test_dataset, batch_size=test_size, shuffle=True)
    test_features, test_labels = next(iter(test_loader))
    test_features = test_features.to(DEVICE)
    test_labels = test_labels.to(DEVICE)
    model = model.to(DEVICE)
    
    output = model(test_features)
    
    ml_roc = MultiLabelRocNan()
    ml_roc.compute(output,test_labels,driams.selected_antibiotics, create_csv=path / "ROC_results.csv")
    fig_, ax_ = ml_roc()
    
    plt.savefig(path / "ROC_results.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    ml_pr = MultiLabelPRNan()
    ml_pr.compute(output,test_labels,driams.selected_antibiotics, create_csv=path /"PR_results.csv")
    
    fig_, ax_ = ml_pr()
    
    plt.savefig(path /"PR_results.png", transparent=True, format= "png", bbox_inches = "tight")
    plt.close()
    
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()
    DEVICE = torch.device("cuda")
    

Bin size:   0%|          | 0/18000 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



  0%|          | 0/55780 [00:00<?, ?it/s]

Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]



Loading Spectra into Memory:   0%|          | 0/55780 [00:00<?, ?it/s]