# Importación de librerías

In [None]:
import os
from scipy.io import loadmat
import numpy as np
import torch as th
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import math
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable

# Carga de datos

In [None]:
DATA_PATH = 'shhs3'
dbpath=os.path.abspath(f'{DATA_PATH}')
listdir = os.listdir(dbpath)
listdir.sort()
listdir = listdir[1:]
lengths = loadmat(os.path.join(dbpath, 'sequenceLengths.mat'))
lengths = lengths['sequenceLengths'].squeeze()

subj = []
for i in range(len(listdir)):
    subj.append(loadmat(os.path.join(dbpath, listdir[i])))

feat_data = [] # SaO2
target_data = [] # TargetA0H4

for i in range(len(subj)):
        feat_data.extend(subj[i]['SaO2'].reshape(1, -1))
        target_data.append(subj[i]['targetA0H3'].flatten())

feat_data = np.concatenate(feat_data)
target_data = np.concatenate(target_data)

for i in range(len(target_data)):
    if target_data[i] > 1:
        target_data[i] = 1

def strided_app(a, L, S):  # Window len = L, Stride len/stepsize = S
    nrows = ((a.size - L) // S) + 1
    n = a.strides[0]
    return np.lib.stride_tricks.as_strided(
        a, shape=(nrows, L), strides=(S * n, n))

window = 200
stride = 1

feat_windowed_data = strided_app(feat_data, window, stride)
target_windowed_data = strided_app(target_data, window, stride)

dataset = TensorDataset(
        th.from_numpy(feat_windowed_data),
        th.from_numpy(target_windowed_data),
    )

generator = th.Generator().manual_seed(42)
traindata, valdata, testdata = th.utils.data.random_split(dataset, [0.7, 0.2, 0.1], generator=generator)

trainloader = DataLoader(
       traindata, batch_size=512, pin_memory=False, drop_last=False, shuffle=True
    )

valloader = DataLoader(
       valdata, batch_size=512, pin_memory=False, drop_last=False, shuffle=True
    )

testloader = DataLoader(
       testdata, batch_size=512, pin_memory=False, drop_last=False, shuffle=True
    )

# Definición del algoritmo

Autoencoder

In [None]:
class AutoencoderAp(nn.Module):
    def __init__(self, window=100, dropout=0.2):
        super().__init__()

        self.c1 = nn.Sequential(
            nn.Linear(200, 175),
            nn.BatchNorm1d(175),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(175, 125),
            nn.BatchNorm1d(125),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(125, 100),
            nn.BatchNorm1d(100),
            nn.Tanh(),
        )

        self.feature_extractor = nn.Sequential(
            nn.Linear(100, 125),
            nn.BatchNorm1d(125),
            nn.ReLU(),

            nn.Linear(125, 175),
            nn.BatchNorm1d(175),
            nn.ReLU(),

            nn.Linear(175, 200),
            nn.Sigmoid(),
        )

       # self._initialize_submodules()

    def _initialize_submodules(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # init.kaiming_normal(m.weight.data)
                n = m.weight.size(1)
                m.weight.data.normal_(0, math.sqrt(1.0 / n))
            elif isinstance(m, nn.Conv1d):
                # n = m.kernel_size[0] * m.out_channels
                n = m.kernel_size[0] * m.in_channels
                m.weight.data.normal_(0, math.sqrt(1.0 / n))

    def forward(self, x):
        y = self.c1(x)

        return y

Generador y discriminador

In [None]:
class AutoencoderSaO2(nn.Module):
    def __init__(self, window=100, dropout=0.5):
        super().__init__()


        self.c1 = nn.Sequential(
            nn.Linear(200, 175),
            nn.BatchNorm1d(175),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(175, 125),
            nn.BatchNorm1d(125),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(125, 100),
            nn.BatchNorm1d(100),
            nn.Tanh(),
        )

        self.feature_extractor = nn.Sequential(
            nn.Linear(100, 125),
            nn.BatchNorm1d(125),
            nn.ReLU(),

            nn.Linear(125, 175),
            nn.BatchNorm1d(175),
            nn.ReLU(),

            nn.Linear(175, 200),
            nn.Sigmoid(),
        )

        self._initialize_submodules()

    def _initialize_submodules(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
               # init.kaiming_normal(m.weight.data)
                n = m.weight.size(1)
                m.weight.data.normal_(0, math.sqrt(1.0 / n))
            elif isinstance(m, nn.Conv1d):
                # n = m.kernel_size[0] * m.out_channels
                n = m.kernel_size[0] * m.in_channels
                m.weight.data.normal_(0, math.sqrt(1.0 / n))

    def forward(self, x):
        # Feature extractor
        embeddings_ = self.c1(x)
        y = self.feature_extractor(embeddings_)

        return y, embeddings_

class Discriminador(nn.Module):
    def __init__(self, window=50, dropout=0.5):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Linear(100, 50),
            nn.BatchNorm1d(50),
            nn.ReLU(),
            nn.Linear(50, 25),
            nn.BatchNorm1d(25),
            nn.ReLU(),
            nn.Linear(25, 1),
            nn.BatchNorm1d(1),
            nn.Sigmoid(),
        )

        self._initialize_submodules()

    def _initialize_submodules(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # init.kaiming_normal(m.weight.data)
                n = m.weight.size(1)
                m.weight.data.normal_(0, math.sqrt(1.0 / n))
            elif isinstance(m, nn.Conv1d):
                # n = m.kernel_size[0] * m.out_channels
                n = m.kernel_size[0] * m.in_channels
                m.weight.data.normal_(0, math.sqrt(1.0 / n))

    def forward(self, x):
        y = self.feature_extractor(x)

        return y

Carga de modelos

In [None]:
modelSaO2 = AutoencoderSaO2()
modelAp = AutoencoderAp()
discriminador = Discriminador()

# CARGA AUTOENCODER YA ENTRENADO
modelAp.load_state_dict(th.load('models/modelAp.pth'))

# Optimizadores

In [None]:
optimizer_G = optim.Adam(list(modelSaO2.parameters())[:12], lr=2e-4)
optimizer_R = optim.Adam(modelSaO2.parameters(), lr=2e-4)
optimizer_D = optim.Adam(discriminador.parameters(), lr=2e-4)

scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, 10, 0.9)
scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, 10, 0.9)
scheduler_R = optim.lr_scheduler.StepLR(optimizer_D, 10, 0.5)

# Entrenamiento y validación

In [None]:
device = th.device("cuda" if th.cuda.is_available() else "cpu")

num_epochs = 50

modelAp.to(device)
modelSaO2.to(device)
discriminador.to(device)

for epoch in range(num_epochs):

    # ENTRENAMIENTO

    modelSaO2.train()
    discriminador.train()
    modelAp.eval()

    loss_train_epoch = 0.0
    D_loss_sum = 0
    D_real_prob = 0
    D_fake_prob = 0
    SaO2_loss_sum = 0
    loss_sum = 0
    reconstruction_loss = 0

    # Itero en los batch de trainloader
    for SaO2_data , Ap_data  in trainloader:

      SaO2_data , Ap_data = SaO2_data.to(device) , Ap_data.to(device)

      optimizer_SaO2.zero_grad()
      optimizer_Ap.zero_grad()
      optimizer_D.zero_grad()

      batch_size = SaO2_data.shape[0]
      ones_label = th.ones(batch_size, 1).to(device)
      zeros_label = th.zeros(batch_size, 1).to(device)

      # Optimización del discriminador
      true = modelAp(Ap_data.to(th.float32))
      _, embeddings_ = modelSaO2(SaO2_data.to(th.float32))
      D_real = discriminador(true)
      D_fake = discriminador(embeddings_)

      D_loss_real = F.binary_cross_entropy(D_real, ones_label)
      D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
      D_loss = D_loss_real + D_loss_fake

      D_real_prob += D_real.mean().item()
      D_fake_prob += D_fake.mean().item()

      D_loss.backward()
      optimizer_D.step()
      optimizer_SaO2.zero_grad()
      optimizer_D.zero_grad()
      optimizer_Ap.zero_grad()

      # Optimización del generador
      _, embeddings_ = modelSaO2(SaO2_data.to(th.float32))
      D_fake = discriminador(embeddings_)

      SaO2_loss = F.binary_cross_entropy(D_fake, ones_label)
      SaO2_loss.backward()
      optimizer_SaO2.step()
      optimizer_SaO2.zero_grad()
      optimizer_D.zero_grad()
      optimizer_Ap.zero_grad()

      reconstructions, _ = modelSaO2(SaO2_data.to(th.float32))

      loss_reconstruction = F.l1_loss(Ap_data,reconstructions)
      loss_reconstruction.backward()
      optimizer_Ap.step()
      optimizer_SaO2.zero_grad()
      optimizer_D.zero_grad()
      optimizer_Ap.zero_grad()

      net_loss = loss_reconstruction + SaO2_loss

      loss_sum += net_loss.item()
      reconstruction_loss += loss_reconstruction.item()

      D_loss_sum += D_loss
      SaO2_loss_sum += SaO2_loss

    scheduler_SaO2.step()
	  scheduler_Ap.step()
	  scheduler_D.step()

    path_save = 'models/'
	  th.save(modelSaO2.state_dict(), os.path.join(path_save,f'modelSaO2_epoch_{epoch}.pth'))
	  th.save(discriminador.state_dict(), os.path.join(path_save, f'discriminador_epoch_{epoch}.pth'))

    # VALIDACIÓN

    modelSaO2.eval()
    discriminador.eval()
    modelAp.eval()

    D_loss_sum = 0
    SaO2_loss_sum = 0
    loss_sum = 0
    D_real_prob = 0
    D_fake_prob = 0
    reconstruction_loss = 0

    with th.no_grad():

      # Itero en los batch de valloader
      for SaO2_data , Ap_data in valloader:

          SaO2_data, Ap_data = SaO2_data.to(device), Ap_data.to(device)
          SaO2_data.requires_grad_, Ap_data.requires_grad_ = False, False

          batch_size = SaO2_data.shape[0]
          ones_label = th.ones(batch_size, 1).to(device)
          zeros_label = th.zeros(batch_size, 1).to(device)

          # Test discriminador
          true = modelAp(Ap_data.to(th.float32))
          _, embeddings_ = modelSaO2(SaO2_data.to(th.float32))
          D_real = discriminador(true)
          D_fake = discriminador(embeddings_)

          D_loss_real = F.binary_cross_entropy(D_real, ones_label)
          D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
          D_loss = D_loss_real + D_loss_fake

          D_real_prob += D_real.mean().item()
          D_fake_prob += D_fake.mean().item()

          # Test generador
          reconstructions, embeddings_ = modelSaO2(SaO2_data.to(th.float32))
          D_fake = discriminador(embeddings_)

          loss_reconstruction = F.l1_loss(Ap_data, reconstructions)

          SaO2_loss = F.binary_cross_entropy(D_fake, ones_label)

          net_loss = loss_reconstruction + SaO2_loss

          loss_sum += net_loss.item()
          reconstruction_loss += loss_reconstruction.item()

          D_loss_sum += D_loss
          SaO2_loss_sum += SaO2_loss