# Model Design

In [None]:
import torch
from torch import nn
import sys
sys.path.append(".")
from src.models.model_utils import Conv_block, BLSTM, Up_conv
from src.utils.utils import get_config
from configs.seed import *


class Generator(nn.Module):
    def __init__(self, cfgs):
        super(Generator, self).__init__()
        self.cfgs = cfgs
        self.depth = self.cfgs['model']['baseline']['depth']
        self.use_se_block = self.cfgs['model']['baseline']['use_se_block']
        self.initial_features = self.cfgs['model']['baseline']['initial_features']
        self.features = [self.initial_features * (2**i) for i in range(self.depth)]    
        self.img_ch = self.cfgs['model']['baseline']['chin']
        self.output_ch = self.cfgs['model']['baseline']['chout']
        self.use_lstm = self.cfgs['model']['baseline']['use_lstm']
        self.use_gated_block = self.cfgs['model']['baseline']['use_gated_block']
        self.gn_in = self.cfgs['data']['gn_input']
        
        if self.gn_in: 
            self.first_gn = nn.GroupNorm(1, 1)
            self.film_gn = nn.GroupNorm(1, 1)
            self.last_gn = nn.GroupNorm(1, 1)

        # Encoding path
        self.encoders = nn.ModuleList()
        self.encoder_films = nn.ModuleList()

        self.encoders.append(Conv_block(self.img_ch, self.features[0], use_gated_block=self.use_gated_block, use_se_block=self.use_se_block))
        for i in range(1, self.depth):
            self.encoders.append(Conv_block(self.features[i-1], self.features[i], use_gated_block=self.use_gated_block, use_se_block=self.use_se_block))

        # lstm layer
        if self.use_lstm:
            self.lstm = BLSTM(self.features[-1], bi = True)

        # Decoding path
        self.up_convs = nn.ModuleList()
        self.decoders = nn.ModuleList()

        for i in range(self.depth-1, -1, -1):
            self.up_convs.append(Up_conv(self.features[i], self.features[i-1] if i != 0 else self.features[0], use_gated_block=self.use_gated_block))
            in_channels = self.features[i] + (self.features[i-1] if i != 0 else self.features[0])
            out_channels = self.features[i-1] if i != 0 else self.initial_features  
            self.decoders.append(Conv_block(in_channels, out_channels, use_gated_block=self.use_gated_block, use_se_block=self.use_se_block))

        self.Maxpool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.Dropout = nn.Dropout(p=0.2)
        self.Conv_1x1 = nn.Conv1d(self.features[0], self.output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x, seg_film = None, mask_tensor=None):
        skip_connections = []
        if self.gn_in:
            x = self.first_gn(x)
        
        # encoding path
        for i in range(self.depth):
            x = self.encoders[i](x)
            skip_connections.append(x)
            x = self.Maxpool(x)
            x = self.Dropout(x)
            
        # lstm layer
        if self.use_lstm:
            x = x.permute(2, 0, 1)
            x, _ = self.lstm(x)
            x = x.permute(1, 2, 0)

        # decoding path
        for i in range(self.depth):
            x = self.up_convs[i](x)  
            x = torch.cat((skip_connections[-(i+1)], x), dim=1)
            x = self.Dropout(x)
            x = self.decoders[i](x)  

        x = self.Conv_1x1(x)
        if self.gn_in:
            x = self.last_gn(x)
            
        return x


class Discriminator(nn.Module):
    def __init__(self, cfgs):
        super(Discriminator, self).__init__()
        self.cfgs = cfgs
        self.depth = self.cfgs['model']['baseline']['depth']
        self.use_se_block = self.cfgs['model']['baseline']['use_se_block']
        self.initial_features = self.cfgs['model']['baseline']['initial_features']
        self.features = [self.initial_features * (2**i) for i in range(self.depth)]    
        self.img_ch = self.cfgs['model']['baseline']['chin']
        self.output_ch = self.cfgs['model']['baseline']['chout']
        self.use_lstm = self.cfgs['model']['baseline']['use_lstm']
        self.use_gated_block = self.cfgs['model']['baseline']['use_gated_block']
        self.gn_input = self.cfgs['data']['gn_input']
        self.use_hinge = self.cfgs['train']['hinge']

        self.encoders = nn.ModuleList()

        if self.gn_input:
            self.first_gn = nn.GroupNorm(1, 1)

        self.encoders.append(Conv_block(self.img_ch, self.features[0], use_gated_block=self.use_gated_block, use_se_block=self.use_se_block))
        
        for i in range(1, self.depth):
            self.encoders.append(Conv_block(self.features[i-1], self.features[i], use_gated_block=self.use_gated_block, use_se_block=self.use_se_block))

        self.Maxpool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.Dropout = nn.Dropout(p=0.2)

        if self.use_hinge:
            self.adv_layer = nn.Sequential(nn.Linear(128 * 128, 1))
        else:
            self.adv_layer = nn.Sequential(nn.Linear(128 * 128, 1), nn.Sigmoid())

    def forward(self, x):
        if self.gn_input:
            x = self.first_gn(x)
            
        for i in range(self.depth):
            x = self.encoders[i](x)
            x = self.Maxpool(x)
            x = self.Dropout(x)

        out = x.view(x.size(0), -1)
        out = self.adv_layer(out)
        
        return out


if __name__ == "__main__":
    cfgs = get_config()
    model = Generator(cfgs)
    x = torch.randn(1, 1, 128 * 8)
    out = model(x, x)
    print(out.shape)
    
    model = Discriminator(cfgs)
    x = torch.randn(1, 1, 128 * 8)
    y = model(x)
    print(y.shape)

torch.Size([1, 1, 1024])
torch.Size([1, 1])


# Model Training

In [None]:
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
import warnings
warnings.simplefilter("ignore", UserWarning)
import sys
sys.path.append(".")
from src.metrics.metrics import CustomLoss, DisLoss, GenLoss
from src.dataloader.dataset import get_loader
from src.models.cpppg import Generator, Discriminator
# from src.models.model_utils import DiscriminatorV2
from src.utils.utils import plot_result, depadding
from src.utils.postprocess import moving_average_batch
import random
from torch.autograd import Variable
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
from configs.seed import *


class AdversarialTrainer:
    def __init__(self, tracking, cfgs):
        self.tracking = tracking
        self.cfgs = cfgs
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.clip_value = self.cfgs['train']['clip_value']
        self.epoch_n = self.cfgs['train']['epoch_n']
        self.BEST_LOSS = np.inf
        self.ckpt = self.cfgs['train']['ckpt']

        self.model = Generator(cfgs).to(self.device)   
        self.criterion = CustomLoss(self.cfgs).to(self.device)
        self.optimizer = torch.optim.AdamW(self.model.parameters(), betas=(0.9, 0.999), weight_decay=0.005, lr=self.cfgs['train']['lr'])
        self.discriminator = Discriminator(cfgs).to(self.device)
        if self.cfgs['train']['hinge']:
            self.gen_loss = GenLoss().to(self.device)
            self.discriminator_loss = DisLoss().to(self.device)
        else:
            self.discriminator_loss = torch.nn.BCELoss().to(self.device) 

        self.discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), betas=(0.9, 0.999), weight_decay=0.005, lr=self.cfgs['train']['lr']/3)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.9, patience=8, verbose=True)
        self.d_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.discriminator_optimizer, factor=0.9, patience=8, verbose=True)

        self.train_loader, self.val_loader, self.test_loader = get_loader(self.cfgs)
        # self.load_weights()

    def load_weights(self):
        try:
            self.model.load_state_dict(torch.load(self.ckpt, map_location=self.device))
            # self.discriminator.load_state_dict(torch.load("checkpoints/discriminator.pth", map_location=self.device))
            print("SUCCESSFULLY LOAD TRAINED MODELS !")
        except:
            print('FIRST TRAINING')

    def train_epoch(self):
        self.model.train()
        self.discriminator.train()
        train_loss_epoch = 0
        train_cosin_loss_epoch = 0
        train_peak_to_peak_loss = 0
        train_mse_loss = 0
        theta = 0.1

        for src_signal, ref_signal, seg_film, mask in tqdm(self.train_loader):
            
            src_signal = src_signal.to(self.device)
            seg_film = seg_film.to(self.device)
            out = self.model(src_signal, seg_film, mask.to(self.device))   

            self.optimizer.zero_grad()
            ref_signal = ref_signal.to(self.device)
            total_loss, cosin_loss, peak_to_peak_loss, mse_loss = self.criterion(out, ref_signal)
            
            if self.cfgs['train']['hinge']:
                g_loss = self.gen_loss(self.discriminator(out))
            else:
                g_loss = self.discriminator_loss(self.discriminator(out), Variable(Tensor(src_signal.shape[0], 1).fill_(1.0), requires_grad=False))
            
            total_loss += theta * g_loss 
            train_loss_epoch += total_loss.item()
            train_cosin_loss_epoch += cosin_loss.item()
            train_peak_to_peak_loss += peak_to_peak_loss.item()
            train_mse_loss += mse_loss.item()
            total_loss.backward()
            self.optimizer.step()
            
            valid = Variable(Tensor(src_signal.shape[0], 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(src_signal.shape[0], 1).fill_(0.0), requires_grad=False)

            self.discriminator_optimizer.zero_grad()
            if self.cfgs['train']['hinge']:
                d_loss = self.discriminator_loss(self.discriminator(out.detach()), self.discriminator(ref_signal))

            else:
                real_loss = self.discriminator_loss(self.discriminator(ref_signal), valid)
                fake_loss = self.discriminator_loss(self.discriminator(out.detach()), fake)
                d_loss = (real_loss + fake_loss) / 2

            d_loss.backward(retain_graph=True)
            self.discriminator_optimizer.step()   

        return train_loss_epoch/ len(self.train_loader), train_cosin_loss_epoch / len(self.train_loader), train_peak_to_peak_loss/ len(self.train_loader), train_mse_loss / len(self.train_loader)

    def val_epoch(self):
        self.model.eval()
        val_loss_epoch = 0
        val_cosin_loss_epoch = 0
        val_peak_to_peak_loss = 0
        val_mse_loss = 0
        with torch.no_grad():
            for src_signal, ref_signal, seg_film, mask in tqdm(self.val_loader):
                src_signal = src_signal.to(self.device)
                seg_film = seg_film.to(self.device)
                out = self.model(src_signal, seg_film, mask.to(self.device))
                ref_signal = ref_signal.to(self.device)
                total_loss, cosin_loss, peak_to_peak_loss, mse_loss = self.criterion(out, ref_signal)
                val_loss_epoch += total_loss.item()
                self.val_loss_epoch = val_loss_epoch
                val_cosin_loss_epoch += cosin_loss.item()
                val_peak_to_peak_loss += peak_to_peak_loss.item()
                val_mse_loss += mse_loss.item()
            return val_loss_epoch / len(self.val_loader), val_cosin_loss_epoch / len(self.val_loader), val_peak_to_peak_loss / len(self.val_loader), val_mse_loss / len(self.val_loader)
    
    def test_epoch(self):
        self.model.eval()
        test_loss_epoch = 0
        test_cosin_loss_epoch = 0 
        test_peak_to_peak_loss = 0
        test_mse_loss = 0
        with torch.no_grad():
            for src_signal, ref_signal, seg_film, mask in tqdm(self.test_loader):
                src_signal = src_signal.to(self.device)
                seg_film = seg_film.to(self.device)
                out = self.model(src_signal, seg_film, mask.to(self.device))
                ref_signal = ref_signal.to(self.device)
                total_loss, cosin_loss, peak_to_peak_loss, mse_loss = self.criterion(out, ref_signal)
                test_loss_epoch += total_loss.item()
                test_cosin_loss_epoch += cosin_loss.item()
                test_peak_to_peak_loss += peak_to_peak_loss.item()
                test_mse_loss += mse_loss.item()
            return test_loss_epoch/ len(self.test_loader), test_cosin_loss_epoch / len(self.test_loader), test_peak_to_peak_loss/len(self.test_loader), test_mse_loss /  len(self.test_loader)

    def show_result(self, epoch):
        random_idx = random.choice(range(60))
        src_signal, ref_signal, seg_film, mask = self.test_loader.dataset[random_idx][0], self.test_loader.dataset[random_idx][1], self.test_loader.dataset[random_idx][2], self.test_loader.dataset[random_idx][3]
        src_signal = src_signal.unsqueeze(dim=0)
        mask = mask.unsqueeze(dim=0)
        out = self.model(src_signal.to(self.device), seg_film.unsqueeze(dim=0).to(self.device), mask.to(self.device)) 
        real_sr_signal = src_signal.clone()
        mask = mask.reshape(-1).cpu().detach().numpy()
        ref_signal = ref_signal.reshape(-1).cpu().detach().numpy()
        plot_result(real_sr_signal.reshape(-1), out.reshape(-1).cpu().detach().numpy(), ref_signal.reshape(-1))     
        self.tracking.log_image(image_data="src/experiments/results/result.jpg", name=f"Result at epoch: {epoch}")
        
    def save_checkpoint(self):
        if self.val_loss_epoch < self.BEST_LOSS:
            self.BEST_LOSS = self.val_loss_epoch
            torch.save(self.model.state_dict(), self.ckpt)
            # torch.save(self.discriminator.state_dict(), "checkpoints/discriminator.pth")
            self.tracking.log_model("model", self.ckpt)

    def training_experiment(self):
        print("BEGIN TRAINING ...")
        for epoch in range(1, self.epoch_n+1):
            with self.tracking.train():
                train_loss_epoch, train_cosin_loss_epoch, peak_to_peak_loss, train_mse_loss = self.train_epoch()
                self.tracking.log_metrics({
                    "total loss": train_loss_epoch,
                    "mse loss": train_mse_loss,
                    "cosin similarity": train_cosin_loss_epoch,
                    "peak-to-peak loss": peak_to_peak_loss
                }, epoch=epoch)

            with self.tracking.validate():
                val_loss_epoch, val_cosin_loss_epoch, peak_to_peak_loss, val_mse_loss = self.val_epoch()
                self.scheduler.step(val_loss_epoch)
                self.d_scheduler.step(val_loss_epoch)
                self.save_checkpoint()
                self.tracking.log_metrics({
                    "total loss": val_loss_epoch,
                    "mse loss": val_mse_loss,
                    "cosin similarity": val_cosin_loss_epoch,
                    "peak-to-peak loss": peak_to_peak_loss
                }, epoch=epoch)
            
            with self.tracking.test():
                test_loss_epoch, test_cosin_loss_epoch, test_peak_to_peak_loss, test_mse_loss = self.test_epoch()
                self.tracking.log_metrics({
                    "total loss": test_loss_epoch,
                    "mse loss": test_mse_loss,
                    "cosin similarity": test_cosin_loss_epoch,
                    "peak-to-peak loss": test_peak_to_peak_loss
                }, epoch=epoch)

            self.show_result(epoch)

            print("EPOCH: ", epoch, " - TRAIN_LOSS: ", train_loss_epoch, " || VAL_LOSS: ", val_loss_epoch, " || TEST_LOSS: ", test_loss_epoch)


In [None]:
from comet_ml import Experiment
import sys
sys.path.append(".")
from src.utils.utils import get_config
from configs.seed import *


if __name__=="__main__":
    cfgs = get_config()

    with open('configs/experiment_apikey.txt','r') as f:
        api_key = f.read()

    tracking = Experiment(
        api_key = api_key,
        project_name = "PPG Data v2 - Window Based",
        workspace = "maxph2211",
    )
    tracking.log_parameters(cfgs)
    trainer = AdversarialTrainer(tracking, cfgs)
        
    trainer.training_experiment()
    print("DONE!")