# Fully connected feedforward network implementing a loss mask

In [None]:
import torchmetrics.classification
from torch import cuda
assert cuda.is_available()
assert cuda.device_count() > 0

In [None]:
print(cuda.get_device_name(cuda.current_device()))


NVIDIA GeForce RTX 3060 Ti


In [None]:
import torch
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda")
SEED = 76436278

torch.manual_seed(SEED)

<torch._C.Generator at 0x7ff5acd5e3f0>

### Load the Dataset

In [None]:
from src.maldi2resistance.data.driams import Driams

driams = Driams(
    root_dir="/home/jan/Uni/master/data/Driams",
)

driams.loading_type = "memory"

driams

100%|██████████| 27446/27446 [00:06<00:00, 3951.60it/s]


Antibiotic:,Amikacin,Amoxicillin-Clavulanic acid,Ampicillin,Ampicillin-Amoxicillin,Benzylpenicillin,Cefazolin,Cefepime,Cefpodoxime,Ceftazidime,Ceftriaxone,Cefuroxime,Ciprofloxacin,Clarithromycin,Clindamycin,Colistin,Cotrimoxazole,Ertapenem,Erythromycin,Fosfomycin,Fosfomycin-Trometamol,Fusidic acid,Gentamicin,Imipenem,Levofloxacin,Meropenem,Mupirocin,Nitrofurantoin,Norfloxacin,Oxacillin,Penicillin,Piperacillin-Tazobactam,Polymyxin B,Tetracycline,Tobramycin
Number resistant:,267,6120,8578,5928,618,1032,2456,677,1649,3122,2412,3629,260,1625,837,3806,204,2047,2264,371,1072,1706,2592,1315,1291,549,954,563,2021,5042,2732,486,1140,385
Number susceptible:,9321,13875,4194,1440,457,1708,15856,1272,13796,14520,3977,20191,1061,4066,4516,8729,11251,4273,10161,1349,3300,16625,15481,5889,9629,2053,2422,3269,3255,1623,15571,2305,5538,4974
Number data points:,9588,19995,12772,7368,1075,2740,18312,1949,15445,17642,6389,23820,1321,5691,5353,12535,11455,6320,12425,1720,4372,18331,18073,7204,10920,2602,3376,3832,5276,6665,18303,2791,6678,5359


### Encoder

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.input = nn.Linear(input_dim, hidden_dim)
        self.layer_1  = nn.Linear (hidden_dim, hidden_dim)
        self.layer_2  = nn.Linear (hidden_dim, latent_dim)
        #self.FC_mean  = nn.Linear(hidden_dim, latent_dim)
        #self.FC_var   = nn.Linear (hidden_dim, latent_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
        self.training = True
        
    def forward(self, x):
        h_       = self.LeakyReLU(self.input(x))
        h_       = self.LeakyReLU(self.layer_1(h_))
        h_       = self.LeakyReLU(self.layer_2(h_))
        
        #mean     = self.FC_mean(h_)
        #log_var  = self.FC_var(h_)  

        return h_
        return mean, log_var

### Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.output_dim = output_dim
        self.input = nn.Linear(latent_dim, hidden_dim)
        self.layer_1 = nn.Linear(hidden_dim, hidden_dim)
        self.layer_2 = nn.Linear(hidden_dim, output_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)      
        z = mean + var*epsilon
        return z
        
    def forward(self, x):
        h_    = self.LeakyReLU(self.input(x))
        h_    = self.LeakyReLU(self.layer_1(h_))
        
        output = torch.sigmoid(self.layer_2(h_))
        return output

### Combine Encoder and Decoder

In [None]:
class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
        z = mean + var*epsilon                          # reparameterization trick
        return z
        
    def forward(self, x):

        latent = self.Encoder(x)
        output = self.Decoder(latent)

        return latent, output

        
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var))
        
        x_hat   = self.Decoder(z)
        
        return x_hat, mean, log_var

In [None]:
len(driams.label_stats.columns)

34

In [None]:
encoder = Encoder(input_dim=18000, hidden_dim=4096, latent_dim=2048)
decoder = Decoder(latent_dim=2048, hidden_dim = 4096, output_dim = len(driams.selected_antibiotics))

model = Model(Encoder=encoder, Decoder=decoder)
model.to(DEVICE)

Model(
  (Encoder): Encoder(
    (input): Linear(in_features=18000, out_features=4096, bias=True)
    (layer_1): Linear(in_features=4096, out_features=4096, bias=True)
    (layer_2): Linear(in_features=4096, out_features=2048, bias=True)
    (LeakyReLU): LeakyReLU(negative_slope=0.2)
  )
  (Decoder): Decoder(
    (input): Linear(in_features=2048, out_features=4096, bias=True)
    (layer_1): Linear(in_features=4096, out_features=4096, bias=True)
    (layer_2): Linear(in_features=4096, out_features=34, bias=True)
    (LeakyReLU): LeakyReLU(negative_slope=0.2)
  )
)

In [None]:
from torch.utils.data import DataLoader

In [None]:
gen = torch.Generator()
gen.manual_seed(SEED)


batch_size = 128

train_size = int(0.8 * len(driams))
test_size = len(driams) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(driams, [train_size, test_size], generator=gen)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
from tqdm.auto import tqdm
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

print("Start training ...")
model.train()

optimizer = Adam(model.parameters(), lr=1e-3, amsgrad = True)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
loss_per_batch = []

criterion = nn.BCELoss()

class_weights_negative = torch.tensor((1 - (driams.label_stats.loc["negative"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)
class_weights_positive = torch.tensor((1 - (driams.label_stats.loc["positive"] / driams.label_stats.loc["n_sum"])).values, device=DEVICE)

for epoch in tqdm(range(30)):
    overall_loss = 0
    
    for batch_idx, (x, y) in enumerate(train_loader):
        
        

        x = x.view(batch_size, 18000)
        x = x.to(DEVICE)
        
        split1,split2 = torch.chunk(x, 2)
        combined =torch.add(split1 , split2)
        combined_features = torch.div(combined, 2)
        
        x = torch.cat((x, combined_features), dim = 0)
        
        y = y.view(batch_size, len(driams.selected_antibiotics))
        y = y.to(DEVICE)
        
        split1,split2 = torch.chunk(y, 2)
        combined =torch.add(split1 , split2)
        combined_labels = torch.div(combined, 2)
        combined_labels[combined_labels == 0.5] =1
        
        y = torch.cat((y, combined_labels), dim = 0)
        
        positive_weight = torch.clone(y)
        negative_weight = torch.clone(y)
        negative_weight[negative_weight == 1] = -1
        negative_weight[negative_weight == 0] = 1
        negative_weight[negative_weight == -1] = 0
        negative_weight = class_weights_negative * negative_weight[:, None]
        positive_weight = class_weights_positive * positive_weight[:, None]
        
        weight = torch.add(positive_weight, negative_weight)
        weight = torch.nan_to_num(weight, 0)
        weight = weight[:,0, :]
        
        weight.to(DEVICE)
        y = torch.nan_to_num(y, 0)
        
        optimizer.zero_grad()

        # output, mean, log_var = model(x)
        latent, output = model(x)

        #loss = loss_function(y, output, mean, log_var)
        loss = F.binary_cross_entropy(output, y, weight=weight)
        loss = criterion(output, y)
        current_loss_value = loss.item()
        loss_per_batch.append(current_loss_value)
        
        overall_loss += current_loss_value
        
        loss.backward()
        optimizer.step()

    scheduler.step()
    with tqdm.external_write_mode():
        print(f"\tAverage Loss: {overall_loss / (batch_idx*batch_size):.6f} \tLearning rate: {scheduler.get_last_lr()[0]:.6f}")


print("Finish")

Start training ...


  0%|          | 0/30 [00:00<?, ?it/s]

	Average Loss: 0.001191 	Learning rate: 0.001000
	Average Loss: 0.001033 	Learning rate: 0.001000
	Average Loss: 0.000978 	Learning rate: 0.001000
	Average Loss: 0.000920 	Learning rate: 0.001000
	Average Loss: 0.000887 	Learning rate: 0.001000
	Average Loss: 0.000861 	Learning rate: 0.001000
	Average Loss: 0.000843 	Learning rate: 0.001000
	Average Loss: 0.000821 	Learning rate: 0.001000
	Average Loss: 0.000786 	Learning rate: 0.001000
	Average Loss: 0.000767 	Learning rate: 0.000500
	Average Loss: 0.000685 	Learning rate: 0.000500
	Average Loss: 0.000663 	Learning rate: 0.000500
	Average Loss: 0.000646 	Learning rate: 0.000500
	Average Loss: 0.000638 	Learning rate: 0.000500
	Average Loss: 0.000632 	Learning rate: 0.000500
	Average Loss: 0.000619 	Learning rate: 0.000500
	Average Loss: 0.000600 	Learning rate: 0.000500
	Average Loss: 0.000579 	Learning rate: 0.000500
	Average Loss: 0.000564 	Learning rate: 0.000500
	Average Loss: 0.000558 	Learning rate: 0.000250
	Average Loss: 0.000

In [None]:
model.eval()

Model(
  (Encoder): Encoder(
    (input): Linear(in_features=18000, out_features=4096, bias=True)
    (layer_1): Linear(in_features=4096, out_features=4096, bias=True)
    (layer_2): Linear(in_features=4096, out_features=2048, bias=True)
    (LeakyReLU): LeakyReLU(negative_slope=0.2)
  )
  (Decoder): Decoder(
    (input): Linear(in_features=2048, out_features=4096, bias=True)
    (layer_1): Linear(in_features=4096, out_features=4096, bias=True)
    (layer_2): Linear(in_features=4096, out_features=34, bias=True)
    (LeakyReLU): LeakyReLU(negative_slope=0.2)
  )
)

In [None]:
model_scripted = torch.jit.script(model)
model_scripted.save('./model.pt')