# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7f96d556e470>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
    label_includes_species= True
)

driams.loading_type = "memory"

driams

  self.__meta[key] = pd.read_csv(root_dir / f"{site}/id/{year}/{year}_clean.csv")
100%|██████████| 55780/55780 [00:14<00:00, 3850.57it/s]


Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Aztreonam,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Rifampicin,Teicoplanin,Tetracycline,Tobramycin,Vancomycin
Number resistant:,1068,13366,8578,21966,628,618,4223,7383,2338,3470,8659,5855,9338,310,4381,2874,7405,427,5468,2303,1326,3620,3481,7560,4217,5194,570,1271,1205,5537,12431,7616,486,580,244,3534,1707,227
Number susceptible:,20941,24992,4194,4905,456,457,5813,31567,4382,24566,28464,8368,36822,1262,9841,15784,24590,21740,9044,10184,4803,8498,22662,31717,17989,27228,4656,3603,7031,7740,4286,31308,2305,14964,8486,10376,16809,20540
Number data points:,22009,38358,12772,26871,1084,1075,10036,38950,6720,28036,37123,14223,46160,1572,14222,18658,31995,22167,14512,12487,6129,12118,26143,39277,22206,32422,5226,4874,8236,13277,16717,38924,2791,15544,8730,13910,18516,20767


### Encoder

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.input = nn.Linear(input_dim, hidden_dim)
        self.layer_1  = nn.Linear (hidden_dim, hidden_dim)
        self.layer_2  = nn.Linear (hidden_dim, latent_dim)
        #self.FC_mean  = nn.Linear(hidden_dim, latent_dim)
        #self.FC_var   = nn.Linear (hidden_dim, latent_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
        self.training = True
        
    def forward(self, x):
        h_       = self.LeakyReLU(self.input(x))
        h_       = self.LeakyReLU(self.layer_1(h_))
        h_       = self.LeakyReLU(self.layer_2(h_))

        return h_

### Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim, species_out_dim):
        super(Decoder, self).__init__()
        self.output_dim = output_dim
        self.input = nn.Linear(latent_dim, hidden_dim)
        self.layer_1 = nn.Linear(hidden_dim, hidden_dim)
        self.layer_2 = nn.Linear(hidden_dim, output_dim)
        self.layer_3 = nn.Linear(hidden_dim, species_out_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        h_    = self.LeakyReLU(self.input(x))
        h_    = self.LeakyReLU(self.layer_1(h_))
        
        output = self.layer_2(h_)
        species = self.layer_3(h_)
        return output, species

### Combine Encoder and Decoder

In [None]:
class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def forward(self, x):

        latent = self.Encoder(x)
        output, species = self.Decoder(latent)

        return latent, output, species


In [None]:
encoder = Encoder(input_dim=18000, hidden_dim=4096, latent_dim=2048)
decoder = Decoder(latent_dim=2048, hidden_dim = 4096, output_dim = 18000, species_out_dim=len(driams.species))

model = Model(Encoder=encoder, Decoder=decoder)
model.to(DEVICE)

Model(
  (Encoder): Encoder(
    (input): Linear(in_features=18000, out_features=4096, bias=True)
    (layer_1): Linear(in_features=4096, out_features=4096, bias=True)
    (layer_2): Linear(in_features=4096, out_features=2048, bias=True)
    (LeakyReLU): LeakyReLU(negative_slope=0.2)
  )
  (Decoder): Decoder(
    (input): Linear(in_features=2048, out_features=4096, bias=True)
    (layer_1): Linear(in_features=4096, out_features=4096, bias=True)
    (layer_2): Linear(in_features=4096, out_features=18000, bias=True)
    (layer_3): Linear(in_features=4096, out_features=583, bias=True)
    (LeakyReLU): LeakyReLU(negative_slope=0.2)
  )
)

In [None]:
driams.species

['Abiotrophia defectiva',
 'Achromobacter insolitus',
 'Achromobacter spanius',
 'Achromobacter xylosoxidans',
 'Acidovorax temperans',
 'Acinetobacter baumannii',
 'Acinetobacter bereziniae',
 'Acinetobacter calcoaceticus',
 'Acinetobacter guillouiae',
 'Acinetobacter haemolyticus',
 'Acinetobacter johnsonii',
 'Acinetobacter junii',
 'Acinetobacter lwoffii',
 'Acinetobacter nosocomialis',
 'Acinetobacter parvus',
 'Acinetobacter pittii',
 'Acinetobacter radioresistens',
 'Acinetobacter sp',
 'Acinetobacter ursingii',
 'Actinobacillus ureae',
 'Actinomyces funkei',
 'Actinomyces meyeri',
 'Actinomyces neuii',
 'Actinomyces odontolyticus',
 'Actinomyces oris',
 'Actinomyces turicensis',
 'Actinotignum sanguinis',
 'Actinotignum schaalii',
 'Aerococcus christensenii',
 'Aerococcus sanguinicola',
 'Aerococcus urinae',
 'Aerococcus viridans',
 'Aeromonas caviae',
 'Aeromonas encheleia',
 'Aeromonas hydrophila',
 'Aeromonas ichthiosmia',
 'Aeromonas media',
 'Aeromonas salmonicida',
 'Aero

In [None]:
from sklearn.preprocessing import LabelBinarizer

lb = LabelBinarizer()
lb.fit(driams.species)

In [None]:
len(driams.species)

583

In [None]:
from torch.utils.data import DataLoader

In [None]:
gen = torch.Generator() 


batch_size = 128

train_size = int(0.8 * len(driams))
test_size = len(driams) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(driams, [train_size, test_size], generator=gen.manual_seed(SEED))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
from tqdm.auto import tqdm
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
loss_per_batch = []

spectra_criterion = nn.MSELoss()
species_criterion = nn.CrossEntropyLoss()


class_weights_negative = torch.tensor((1 - (driams.label_stats.loc["negative"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)
class_weights_positive = torch.tensor((1 - (driams.label_stats.loc["positive"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)

s_species = len(driams.species)

for epoch in tqdm(range(30)):
    overall_loss = 0
    
    for batch_idx, (x, y, species) in enumerate(train_loader):


        x = x.view(batch_size, 18000)
        x = x.to(DEVICE)
        
        # species = x.view(batch_size, s_species)
        species = species.float().to(DEVICE)
        optimizer.zero_grad()

        latent, output, out_species = model(x)

        loss_spectrum = spectra_criterion(output, x)
        
        
        loss_species = species_criterion(out_species, species)
        
        loss = loss_spectrum * 0.4 + loss_species * 0.6
        
        current_loss_value = loss.item()
        loss_per_batch.append(current_loss_value)
        
        overall_loss += current_loss_value
        
        loss.backward()
        optimizer.step()

    scheduler.step()
    with tqdm.external_write_mode():
        print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")


print("Finish")

Start training ...


  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.007388 	Learning rate: 0.001000
	Average Loss: 0.002951 	Learning rate: 0.001000
	Average Loss: 0.002121 	Learning rate: 0.001000
	Average Loss: 0.001667 	Learning rate: 0.001000
	Average Loss: 0.001458 	Learning rate: 0.001000
	Average Loss: 0.001240 	Learning rate: 0.001000
	Average Loss: 0.001061 	Learning rate: 0.001000
	Average Loss: 0.001139 	Learning rate: 0.001000
	Average Loss: 0.000955 	Learning rate: 0.001000
	Average Loss: 0.000806 	Learning rate: 0.000500
	Average Loss: 0.000483 	Learning rate: 0.000500
	Average Loss: 0.000381 	Learning rate: 0.000500
	Average Loss: 0.000358 	Learning rate: 0.000500
	Average Loss: 0.000311 	Learning rate: 0.000500
	Average Loss: 0.000292 	Learning rate: 0.000500
	Average Loss: 0.000259 	Learning rate: 0.000500
	Average Loss: 0.000254 	Learning rate: 0.000500
	Average Loss: 0.000230 	Learning rate: 0.000500
	Average Loss: 0.000222 	Learning rate: 0.000500
	Average Loss: 0.000212 	Learning rate: 0.000250
	Average Loss: 0.000

In [None]:
model.Decoder.layer_2 = nn.Linear(4096, len(driams.selected_antibiotics))

In [None]:
for param in model.Encoder.parameters():
    param.requires_grad = False
    
for param in model.Decoder.input.parameters():
    param.requires_grad = False

for param in model.Decoder.layer_1.parameters():
    param.requires_grad = False

for param in model.Decoder.layer_3.parameters():
    param.requires_grad = False

In [None]:
model.to(DEVICE)

Model(
  (Encoder): Encoder(
    (input): Linear(in_features=18000, out_features=4096, bias=True)
    (layer_1): Linear(in_features=4096, out_features=4096, bias=True)
    (layer_2): Linear(in_features=4096, out_features=2048, bias=True)
    (LeakyReLU): LeakyReLU(negative_slope=0.2)
  )
  (Decoder): Decoder(
    (input): Linear(in_features=2048, out_features=4096, bias=True)
    (layer_1): Linear(in_features=4096, out_features=4096, bias=True)
    (layer_2): Linear(in_features=4096, out_features=38, bias=True)
    (layer_3): Linear(in_features=4096, out_features=583, bias=True)
    (LeakyReLU): LeakyReLU(negative_slope=0.2)
  )
)

In [None]:
from tqdm.auto import tqdm
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
loss_per_batch = []

criterion = nn.BCELoss()

class_weights_negative = torch.tensor((1 - (driams.label_stats.loc["negative"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)
class_weights_positive = torch.tensor((1 - (driams.label_stats.loc["positive"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)

for epoch in tqdm(range(30)):
    overall_loss = 0
    
    for batch_idx, (x, y, species) in enumerate(train_loader):

        x = x.view(batch_size, 18000)
        x = x.to(DEVICE)
        
        
        y = y.view(batch_size, len(driams.selected_antibiotics))
        y = y.to(DEVICE)
        
        positive_weight = torch.clone(y)
        negative_weight = torch.clone(y)
        negative_weight[negative_weight == 1] = -1
        negative_weight[negative_weight == 0] = 1
        negative_weight[negative_weight == -1] = 0
        negative_weight = class_weights_negative * negative_weight[:, None]
        positive_weight = class_weights_positive * positive_weight[:, None]
        
        weight = torch.add(positive_weight, negative_weight)
        weight = torch.nan_to_num(weight, 0)
        weight = weight[:,0, :]
        
        weight.to(DEVICE)
        y = torch.nan_to_num(y, 0)
        
        optimizer.zero_grad()

        # output, mean, log_var = model(x)
        latent, output, species = model(x)

        #loss = loss_function(y, output, mean, log_var)
        loss = F.binary_cross_entropy_with_logits(output, y, weight=weight)
        current_loss_value = loss.item()
        loss_per_batch.append(current_loss_value)
        
        overall_loss += current_loss_value
        
        loss.backward()
        optimizer.step()

    scheduler.step()
    with tqdm.external_write_mode():
        print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")


print("Finish")

Start training ...


  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000402 	Learning rate: 0.001000
	Average Loss: 0.000395 	Learning rate: 0.001000
	Average Loss: 0.000392 	Learning rate: 0.001000
	Average Loss: 0.000389 	Learning rate: 0.001000
	Average Loss: 0.000389 	Learning rate: 0.001000
	Average Loss: 0.000384 	Learning rate: 0.001000
	Average Loss: 0.000385 	Learning rate: 0.001000
	Average Loss: 0.000383 	Learning rate: 0.001000
	Average Loss: 0.000381 	Learning rate: 0.001000
	Average Loss: 0.000381 	Learning rate: 0.000500
	Average Loss: 0.000363 	Learning rate: 0.000500
	Average Loss: 0.000362 	Learning rate: 0.000500
	Average Loss: 0.000360 	Learning rate: 0.000500
	Average Loss: 0.000360 	Learning rate: 0.000500
	Average Loss: 0.000360 	Learning rate: 0.000500
	Average Loss: 0.000359 	Learning rate: 0.000500
	Average Loss: 0.000359 	Learning rate: 0.000500
	Average Loss: 0.000359 	Learning rate: 0.000500
	Average Loss: 0.000359 	Learning rate: 0.000500
	Average Loss: 0.000358 	Learning rate: 0.000250
	Average Loss: 0.000

In [None]:
for param in model.Encoder.parameters():
    param.requires_grad = True
    
for param in model.Decoder.input.parameters():
    param.requires_grad = True

for param in model.Decoder.layer_1.parameters():
    param.requires_grad = True

for param in model.Decoder.layer_3.parameters():
    param.requires_grad = True

In [None]:
from tqdm.auto import tqdm
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

optimizer = Adam(model.parameters(), lr=5e-5, amsgrad = True)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
loss_per_batch = []

criterion = nn.BCELoss()

class_weights_negative = torch.tensor((1 - (driams.label_stats.loc["negative"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)
class_weights_positive = torch.tensor((1 - (driams.label_stats.loc["positive"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)

for epoch in tqdm(range(30)):
    overall_loss = 0
    
    for batch_idx, (x, y, species) in enumerate(train_loader):

        x = x.view(batch_size, 18000)
        x = x.to(DEVICE)
        
        
        y = y.view(batch_size, len(driams.selected_antibiotics))
        y = y.to(DEVICE)
        
        species = species.float().to(DEVICE)
        
        positive_weight = torch.clone(y)
        negative_weight = torch.clone(y)
        negative_weight[negative_weight == 1] = -1
        negative_weight[negative_weight == 0] = 1
        negative_weight[negative_weight == -1] = 0
        negative_weight = class_weights_negative * negative_weight[:, None]
        positive_weight = class_weights_positive * positive_weight[:, None]
        
        weight = torch.add(positive_weight, negative_weight)
        weight = torch.nan_to_num(weight, 0)
        weight = weight[:,0, :]
        
        weight.to(DEVICE)
        y = torch.nan_to_num(y, 0)
        
        optimizer.zero_grad()

        # output, mean, log_var = model(x)
        latent, output, out_species = model(x)

        #loss = loss_function(y, output, mean, log_var)
        amr_loss = F.binary_cross_entropy_with_logits(output, y, weight=weight)
        
        loss_species = species_criterion(out_species, species)
        loss = amr_loss + loss_species
        
        current_loss_value = loss.item()
        loss_per_batch.append(current_loss_value)
        
        overall_loss += current_loss_value
        
        loss.backward()
        optimizer.step()

    scheduler.step()
    with tqdm.external_write_mode():
        print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")


print("Finish")

Start training ...


  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.000382 	Learning rate: 0.000050
	Average Loss: 0.000366 	Learning rate: 0.000050
	Average Loss: 0.000361 	Learning rate: 0.000050
	Average Loss: 0.000356 	Learning rate: 0.000050
	Average Loss: 0.000346 	Learning rate: 0.000025
	Average Loss: 0.000329 	Learning rate: 0.000025
	Average Loss: 0.000322 	Learning rate: 0.000025
	Average Loss: 0.000318 	Learning rate: 0.000025
	Average Loss: 0.000316 	Learning rate: 0.000025
	Average Loss: 0.000313 	Learning rate: 0.000013
	Average Loss: 0.000306 	Learning rate: 0.000013
	Average Loss: 0.000303 	Learning rate: 0.000013
	Average Loss: 0.000300 	Learning rate: 0.000013
	Average Loss: 0.000299 	Learning rate: 0.000013
	Average Loss: 0.000297 	Learning rate: 0.000006
	Average Loss: 0.000293 	Learning rate: 0.000006
	Average Loss: 0.000292 	Learning rate: 0.000006
	Average Loss: 0.000290 	Learning rate: 0.000006
	Average Loss: 0.000289 	Learning rate: 0.000006
	Average Loss: 0.000288 	Learning rate: 0.000003
	Average Loss: 0.000

In [None]:
model_scripted = torch.jit.script(model)
model_scripted.save('./model.pt')