In [1]:
import torch,math,joblib, os
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

import sys
sys.path.append('codes/Python/mimic-iii/imputation/diffusion/LSTS/Github')

from utils import load_data,create_sequences,ICPDataset
from util import calc_diffusion_step_embedding
from torch.optim.lr_scheduler import ExponentialLR
from util import calc_diffusion_hyperparams
from torch.utils.data import DataLoader
from tqdm import tqdm
# from vae import VAE1D
from typing import Tuple

import numpy as np
import matplotlib.pyplot as plt
from deap import base, creator, tools, algorithms
from fastdtw import fastdtw
import random
import multiprocessing
import warnings

# Load data

In [2]:
train_data, val_data, test_data = load_data(time_minutes=250, look_back_minutes=25,full=1,remote=1)
train_dataset = ICPDataset(create_sequences(train_data, num_prev_chunks=2, num_missing_chunks=12, num_next_chunks=2))
val_dataset = ICPDataset(create_sequences(val_data, num_prev_chunks=2, num_missing_chunks=12, num_next_chunks=2))
test_dataset = ICPDataset(create_sequences(test_data, num_prev_chunks=2, num_missing_chunks=12, num_next_chunks=2))
# Create data loader
batch_size = 32  # Adjust based on your hardware capabilities
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
# print percentage of length of train_data, val_data, test_data
print(len(train_data)/len(train_data+val_data+test_data))
print(len(val_data)/len(train_data+val_data+test_data))
print(len(test_data)/len(train_data+val_data+test_data))

0.8970850070100894
0.03812972535932084
0.0647852676305898


# Define Models

In [3]:
def swish(x):
    return x * torch.sigmoid(x)

def frequency_loss(x_reconstructed, condition, weight=1.0):
    """
    Compute frequency-domain loss between reconstructed and conditional waves.
    Args:
        x_reconstructed: Reconstructed signal x_0, shape [B, 1, L].
        condition: Conditional wave (e.g., prev_context or next_context), shape [B, 1, L].
        weight: Weight for frequency loss.
    Returns:
        Loss value.
    """
    # Compute FFTs
    recon_freq = torch.fft.rfft(x_reconstructed, dim=-1)
    cond_freq = torch.fft.rfft(condition, dim=-1)

    # Compute magnitude of frequencies
    recon_mag = torch.abs(recon_freq)
    cond_mag = torch.abs(cond_freq)

    # L2 loss in frequency domain
    freq_loss = F.mse_loss(recon_mag, cond_mag)

    return weight * freq_loss
    
def reconstruct_signal(x_t, noise_pred, alpha_bar_t):
    """
    Reconstruct x_0 from x_t and predicted noise.
    Args:
        x_t: Noisy input at time step t, shape [B, 1, L].
        noise_pred: Predicted noise, shape [B, 1, L].
        alpha_bar_t: Cumulative product of noise scales, scalar or tensor.
    Returns:
        Reconstructed signal x_0, shape [B, 1, L].
    """
    return (x_t - (1 - alpha_bar_t).sqrt() * noise_pred) / alpha_bar_t.sqrt()

def weighted_mse_loss(pred, target, weight, weights=1.0):
    """
    Compute the weighted Mean Squared Error loss.

    Args:
        pred: Predicted segment [B, 1, L]
        target: Ground truth segment [B, 1, L]
        weight: Weight mask [B, 1, L]

    Returns:
        Weighted MSE loss.
    """
    # Compute element-wise squared error
    error = (pred - target) ** 2
    # Apply the weight mask
    weighted_error = error * weight
    # Compute the mean of the weighted errors
    loss = torch.mean(weighted_error)
    return weights * loss

def scale_signal(signal: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Scale the signal to have mean=0 and amplitude between -1 and +1.

    Args:
        signal (torch.Tensor): Input signal of shape (batch_size, 1, signal_length).

    Returns:
        scaled_signal (torch.Tensor): Scaled signal with mean=0 and amplitude [-1, 1].
        original_mean (torch.Tensor): Original mean of the signal for each batch.
        original_range (torch.Tensor): Original range (max - min) of the signal for each batch.
    """
    original_mean = signal.mean(dim=2, keepdim=True)  # Compute mean along the signal length
    centered_signal = signal - original_mean          # Subtract mean to center around 0
    signal_min, _ = centered_signal.min(dim=2, keepdim=True)
    signal_max, _ = centered_signal.max(dim=2, keepdim=True)
    original_range = signal_max - signal_min          # Compute the original range
    scaled_signal = centered_signal / (original_range / 2 + 1e-8)  # Scale to [-1, 1]

    return scaled_signal, original_mean, original_range

def unscale_signal(scaled_signal: torch.Tensor, original_mean: torch.Tensor, original_range: torch.Tensor) -> torch.Tensor:
    """
    Convert the scaled signal back to its original form.

    Args:
        scaled_signal (torch.Tensor): Scaled signal with mean=0 and amplitude [-1, 1].
        original_mean (torch.Tensor): Original mean of the signal for each batch.
        original_range (torch.Tensor): Original range (max - min) of the signal for each batch.

    Returns:
        original_signal (torch.Tensor): Signal scaled back to its original form.
    """
    centered_signal = scaled_signal * (original_range / 2)
    original_signal = centered_signal + original_mean
    return original_signal


def calc_batch_diffusion_step_embedding(diffusion_steps, embed_dims=(64, 64)):
    """
    Embed each column of diffusion_steps into a specific dimensional space and concatenate the results.

    Parameters:
    diffusion_steps (torch.tensor, shape=(batch_size, 3)): 
                                Input diffusion steps for the batch.
    embed_dims (tuple):         Embedding dimensions for each column (default=(25, 200, 25)).

    Returns:
    torch.tensor, shape=(batch_size, sum(embed_dims)): 
                    Concatenated embeddings for all columns.
    """
    batch_size, num_columns = diffusion_steps.shape
    assert len(embed_dims) == num_columns, "Embed dimensions must match the number of columns in diffusion_steps"

    embeddings = []
    for col_idx, embed_dim in enumerate(embed_dims):
        # Select the column and create an embedding matrix
        col_steps = diffusion_steps[:, col_idx].unsqueeze(-1)  # Shape: (batch_size, 1)
        half_dim = embed_dim // 2
        _embed = torch.exp(torch.arange(half_dim) * -(torch.log(torch.tensor(10000.0)) / (half_dim - 1)))  # Shape: (half_dim,)
        
        # Compute sinusoidal embeddings
        col_embed = col_steps * _embed  # Shape: (batch_size, half_dim)
        col_embed = torch.cat([torch.sin(col_embed), torch.cos(col_embed)], dim=-1)  # Shape: (batch_size, embed_dim)
        
        # Append the embedding for the current column
        embeddings.append(col_embed)

    # Concatenate embeddings along the last dimension
    output = torch.cat(embeddings, dim=-1)  # Shape: (batch_size, sum(embed_dims))
    
    return output

class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1):
        super(Conv, self).__init__()
        self.padding = dilation * (kernel_size - 1) // 2
        self.conv = nn.Conv1d(
            in_channels, 
            out_channels, 
            kernel_size, 
            dilation=dilation, 
            stride=stride,
            # padding=self.padding)
            padding='same')
        self.conv = nn.utils.weight_norm(self.conv)
        nn.init.kaiming_normal_(self.conv.weight)

    def forward(self, x):
        return self.conv(x)

class ZeroConv1d(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ZeroConv1d, self).__init__()
        self.conv = nn.Conv1d(
            in_channel, 
            out_channel, 
            kernel_size=1, 
            padding=0)
        nn.init.zeros_(self.conv.weight)
        nn.init.zeros_(self.conv.bias)

    def forward(self, x):
        return self.conv(x)

In [4]:
class Residual_block(nn.Module):
    def __init__(self, in_channels, res_channels, skip_channels, dilation, diffusion_embed_dim_out):
        super(Residual_block, self).__init__()
        
        self.res_channels = res_channels

        # Diffusion step embedding
        self.fc_t = nn.Linear(diffusion_embed_dim_out, res_channels)

        # Dilated convolution
        self.dilated_conv = Conv(res_channels, res_channels, kernel_size=3, dilation=dilation)

        # Embeddings
        self.chunk_number_embedding = nn.Embedding(12, 16)

        self.cond_conv = Conv(4, res_channels, kernel_size=1)

        # FiLM generator
        self.film_generator = nn.Sequential(
            nn.Conv1d(16, 2 * res_channels, kernel_size=1),
            nn.ReLU()
        )

        # Conditioner projections
        self.conditioner_projection_prev = nn.Conv1d(1, res_channels, kernel_size=1)
        self.conditioner_projection_next = nn.Conv1d(1, res_channels, kernel_size=1)

        # Output projections
        self.output_projection = nn.Conv1d(res_channels, 2 * res_channels, kernel_size=1)
        self.residual_projection = nn.Conv1d(res_channels, res_channels, kernel_size=1)
        self.skip_projection = nn.Conv1d(res_channels, skip_channels, kernel_size=1)

        
    def forward(self, x, cond_ref, cond_ctl, diffusion_step_embed):
        
        B, C, L = x.shape
        assert C == self.res_channels

        # Diffusion step embedding
        part_t = self.fc_t(diffusion_step_embed).view(B, self.res_channels, 1)
        h = x + part_t
        
        cond_ctl = cond_ctl.long()
        cond_patient = self.chunk_number_embedding(cond_ctl).unsqueeze(-1)  # [B, 16, 1]
        film_params = self.film_generator(cond_patient)  # [B, 2 * res_channels, 1
        gamma, beta = torch.chunk(film_params, 2, dim=1)
        h = gamma * h + beta  # Apply FiLM
        
        prev_context_pc_z1, prev_context_pc_z2, next_context_pc_z1, next_context_pc_z2 = cond_ref
        
        cond = torch.stack([
            prev_context_pc_z1,
            prev_context_pc_z2,
            next_context_pc_z1,
            next_context_pc_z2
        ], dim=1)

        cond = swish(self.cond_conv(cond))

        # Combine the processed conditional inputs
        h = h + cond # [B, res_channels, T]

        # Continue with dilated convolution and gating
        h = self.dilated_conv(h)
        h = swish(h)
        h = self.output_projection(h)
        gate, filter = torch.chunk(h, 2, dim=1)
        h = torch.sigmoid(gate) * torch.tanh(filter)
        
        # Residual and skip connections
        residual = self.residual_projection(h)
        skip = self.skip_projection(h)
        
        return (x + residual) * math.sqrt(0.5), skip
          

class Residual_group(nn.Module):
    def __init__(self, in_channels, res_channels, skip_channels, num_res_layers, dilation_cycle, diffusion_embed_dim_in, diffusion_embed_dim_mid, diffusion_embed_dim_out):
        super(Residual_group, self).__init__()

        self.num_res_layers = num_res_layers

        # Diffusion embedding layers
        self.fc_t1 = nn.Linear(diffusion_embed_dim_in, diffusion_embed_dim_mid)
        self.fc_t2 = nn.Linear(diffusion_embed_dim_mid, diffusion_embed_dim_out)

        # Residual blocks
        self.residual_blocks = nn.ModuleList([
            Residual_block(
                in_channels,
                res_channels,
                skip_channels,
                dilation=2 ** (n % dilation_cycle),
                diffusion_embed_dim_out=diffusion_embed_dim_out
            ) for n in range(num_res_layers)
        ])

    def forward(self, x, cond_ref, cond_ctl, diffusion_steps):
        # Ensure diffusion_steps is float
        diffusion_steps = diffusion_steps.float()

        # Compute diffusion step embeddings
        diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.fc_t1.in_features)
        diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
        diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))

        # Pass through residual blocks
        skip_connections = 0
        for res_block in self.residual_blocks:
            x, skip = res_block(x, cond_ref, cond_ctl, diffusion_step_embed)
            skip_connections += skip

        return skip_connections * math.sqrt(1.0 / self.num_res_layers)

class DiffWaveImputer(nn.Module):
    def __init__(self, in_channels, res_channels, skip_channels, out_channels, num_res_layers, dilation_cycle, diffusion_embed_dim_in, diffusion_embed_dim_mid, diffusion_embed_dim_out):
        super(DiffWaveImputer, self).__init__()

        self.init_conv = nn.Sequential(
            Conv(in_channels, res_channels, kernel_size=1),
            nn.ReLU()
        )

        self.residual_layer = Residual_group(
            in_channels=in_channels,
            res_channels=res_channels,
            skip_channels=skip_channels,
            num_res_layers=num_res_layers,
            dilation_cycle=dilation_cycle,
            diffusion_embed_dim_in=diffusion_embed_dim_in,
            diffusion_embed_dim_mid=diffusion_embed_dim_mid,
            diffusion_embed_dim_out=diffusion_embed_dim_out
        )

        self.final_conv = nn.Sequential(
            Conv(skip_channels, skip_channels, kernel_size=1),
            nn.ReLU(),
            ZeroConv1d(skip_channels, out_channels)
        )

    def forward(self, noise, cond_ref, cond_ctl, diffusion_steps):
        x = self.init_conv(noise)
        x = self.residual_layer(x, cond_ref, cond_ctl, diffusion_steps)
        return self.final_conv(x)

# Define Hyperparameters

In [None]:
diffwave = DiffWaveImputer(
    in_channels = 1,
    out_channels = 1,
    num_res_layers = 8,
    res_channels = 64,
    skip_channels = 64,
    dilation_cycle = 7,
    diffusion_embed_dim_in = 64,
    diffusion_embed_dim_mid = 128,
    diffusion_embed_dim_out = 128
)
diffwave = diffwave.cuda()

if torch.cuda.device_count() > 1:
        print(f"Let's use {torch.cuda.device_count()} GPUs!")
        diffwave = nn.DataParallel(diffwave)
optimizer = torch.optim.Adam(diffwave.parameters(), lr=1e-3)#, weight_decay=1e-4)
scheduler = ExponentialLR(optimizer, gamma=0.95)
epochs = 50

diffusion_config = {
        "T": 200,
        "beta_0": 0.0001,
        "beta_T": 0.02
    }  # basic hyperparameters
diffusion_hyperparams = calc_diffusion_hyperparams(**diffusion_config)  # di8 ctionary of all diffusion hyperparameters
diffusion_hyperparams = {k: torch.tensor(v).cuda() for k, v in diffusion_hyperparams.items()}  # convert to tensors

# Training process

In [None]:
best_loss = 1000
loss_fn = nn.MSELoss()
overlap_length = 25

for i in range(epochs):

    diffwave.train()
    train_loss = 0
    
    for prev_context_pc, next_context_pc, target_pc, patient_idx, chunk_number, start_pos, enc_pos in tqdm(train_dataloader):
        
        T, Alpha_bar = diffusion_hyperparams["T"], diffusion_hyperparams["Alpha_bar"]

        chunk_number = chunk_number[0].to('cuda')
        
        target_pc_scaled_signal, target_pc_original_mean, target_pc_original_range = scale_signal(target_pc.unsqueeze(-1).permute(0,2,1).to('cuda'))     
        prev_context_pc_1_scaled_signal, prev_context_pc_1_original_mean, prev_context_pc_1_original_range = scale_signal(prev_context_pc[:,0:250].unsqueeze(-1).permute(0,2,1).to('cuda'))
        prev_context_pc_2_scaled_signal, prev_context_pc_2_original_mean, prev_context_pc_2_original_range = scale_signal(prev_context_pc[:,250:500].unsqueeze(-1).permute(0,2,1).to('cuda')) 
        next_context_pc_1_scaled_signal, next_context_pc_1_original_mean, next_context_pc_1_original_range = scale_signal(next_context_pc[:,0:250].unsqueeze(-1).permute(0,2,1).to('cuda')) 
        next_context_pc_2_scaled_signal, next_context_pc_2_original_mean, next_context_pc_2_original_range = scale_signal(next_context_pc[:,250:500].unsqueeze(-1).permute(0,2,1).to('cuda'))
    
        # target_pc_index =  torch.stack([
        #                         226 * (chunk_number+2),
        #                         226 * (chunk_number+2) + 226
        #                     ], dim=1)

        audio = target_pc_scaled_signal
        cond_data = (prev_context_pc_1_scaled_signal.squeeze(1), prev_context_pc_2_scaled_signal.squeeze(1), next_context_pc_1_scaled_signal.squeeze(1), next_context_pc_2_scaled_signal.squeeze(1))

        B, C, L = audio.shape  # B is batchsize, C=1, L is audio length
        diffusion_steps = torch.randint(T, size=(target_pc.shape[0], 1, 1)).cuda()  # randomly sample diffusion steps from 1~T

        z = torch.normal(0, 1, size=audio.shape).cuda()
        transformed_X = torch.sqrt(Alpha_bar[diffusion_steps]) * audio + torch.sqrt(1 - Alpha_bar[diffusion_steps]) * z  # compute x_t from q(x_t|x_0)
        
        optimizer.zero_grad()
        epsilon_theta = diffwave(transformed_X, cond_data, chunk_number, diffusion_steps.view(B, 1)) # predict \epsilon according to \epsilon_\theta

        # 1. noise loss
        noise_loss = loss_fn(epsilon_theta,z)   
        
        total_loss = noise_loss # + mse_loss + freq_loss_value #+ 0.2*diff_loss
        total_loss.backward()
        train_loss += total_loss.item()
        optimizer.step()
        # print(train_loss)
    scheduler.step()
    diffwave.eval()
    val_loss = 0
    
    for prev_context_pc, next_context_pc, target_pc, patient_idx, chunk_number, start_pos, enc_pos in tqdm(val_dataloader):

        T, Alpha, Alpha_bar, Sigma = diffusion_hyperparams["T"], diffusion_hyperparams["Alpha"], diffusion_hyperparams["Alpha_bar"], diffusion_hyperparams["Sigma"]

        chunk_number = chunk_number[0].to('cuda')
        
        target_pc_scaled_signal, target_pc_original_mean, target_pc_original_range = scale_signal(target_pc.unsqueeze(-1).permute(0,2,1).to('cuda'))
        prev_context_pc_1_scaled_signal, prev_context_pc_1_original_mean, prev_context_pc_1_original_range = scale_signal(prev_context_pc[:,0:250].unsqueeze(-1).permute(0,2,1).to('cuda'))
        prev_context_pc_2_scaled_signal, prev_context_pc_2_original_mean, prev_context_pc_2_original_range = scale_signal(prev_context_pc[:,250:500].unsqueeze(-1).permute(0,2,1).to('cuda')) 
        next_context_pc_1_scaled_signal, next_context_pc_1_original_mean, next_context_pc_1_original_range = scale_signal(next_context_pc[:,0:250].unsqueeze(-1).permute(0,2,1).to('cuda'))
        next_context_pc_2_scaled_signal, next_context_pc_2_original_mean, next_context_pc_2_original_range = scale_signal(next_context_pc[:,250:500].unsqueeze(-1).permute(0,2,1).to('cuda'))

        target_pc_index =  torch.stack([
                                226 * (chunk_number+2),
                                226 * (chunk_number+2) + 226
                            ], dim=1)

        # audio = target_pc_scaled_signal.unsqueeze(-1).permute(0,2,1)
        audio = target_pc_scaled_signal
        cond_data = (prev_context_pc_1_scaled_signal.squeeze(1), prev_context_pc_2_scaled_signal.squeeze(1), next_context_pc_1_scaled_signal.squeeze(1), next_context_pc_2_scaled_signal.squeeze(1))
        
        size = (audio.shape[0],1,audio.shape[2])
        x = torch.normal(0, 1, size=size).cuda()
        
        B, C, L = audio.shape  # B is batchsize, C=1, L is audio length
        
        eta_initial = 1.0  # Full noise initially
        eta_final = 0.1    # Minimal noise at the last step
        T = diffusion_hyperparams["T"]
    
        with torch.no_grad():
            for t in (range(T - 1, -1, -1)):
                diffusion_steps = (t * torch.ones((size[0], 1))).cuda()
                epsilon_theta = diffwave(x, cond_data, chunk_number, diffusion_steps.view(B, 1))
                x = (x - (1 - Alpha[t]) / torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha[t])
                if t > 0:
                    x = x + Sigma[t] * torch.normal(0, 1, size=size).cuda()  # add the variance term to x_{t-1}
        break
    
    # reconstructed = vae.decode(x)
    reconstructed_unscale = unscale_signal(x, target_pc_original_mean, target_pc_original_range)
    
    loss = loss_fn(reconstructed_unscale[0,0,0:250], target_pc.unsqueeze(-1).permute(0,2,1).to('cuda'))
#     loss = loss_fn(x,target_pc_z)
    val_loss += loss.item()

    print("iteration: {} \t train_loss: {:.4f} \t val_loss: {:.4f} ".format(i, train_loss/len(train_dataloader), val_loss))
    torch.save(diffwave.state_dict(), "stable_diffusion_autodl.pt")
    # wandb.log({"epoch": i, "train_loss": train_loss/len(train_dataloader), "val_loss":val_loss})

# Use PIDM to impute segments on Evaluation Dataset

In [None]:
data_ori = []
data_imp = []
data_pre = []
data_next = []
for prev_context_pc, next_context_pc, target_pc, patient_idx, chunk_number, start_pos, enc_pos in tqdm(val_dataloader):

    T, Alpha, Alpha_bar, Sigma = diffusion_hyperparams["T"], diffusion_hyperparams["Alpha"], diffusion_hyperparams["Alpha_bar"], diffusion_hyperparams["Sigma"]

    chunk_number = chunk_number[0].to('cuda')
        
    target_pc_scaled_signal, target_pc_original_mean, target_pc_original_range = scale_signal(target_pc.unsqueeze(-1).permute(0,2,1).to('cuda'))     
    prev_context_pc_1_scaled_signal, prev_context_pc_1_original_mean, prev_context_pc_1_original_range = scale_signal(prev_context_pc[:,0:250].unsqueeze(-1).permute(0,2,1).to('cuda'))
    prev_context_pc_2_scaled_signal, prev_context_pc_2_original_mean, prev_context_pc_2_original_range = scale_signal(prev_context_pc[:,250:500].unsqueeze(-1).permute(0,2,1).to('cuda')) 
    next_context_pc_1_scaled_signal, next_context_pc_1_original_mean, next_context_pc_1_original_range = scale_signal(next_context_pc[:,0:250].unsqueeze(-1).permute(0,2,1).to('cuda')) 
    next_context_pc_2_scaled_signal, next_context_pc_2_original_mean, next_context_pc_2_original_range = scale_signal(next_context_pc[:,250:500].unsqueeze(-1).permute(0,2,1).to('cuda'))
    
    # target_pc_index =  torch.stack([
    #                         226 * (chunk_number+2),
    #                         226 * (chunk_number+2) + 226
    #                     ], dim=1)
    
    audio = target_pc_scaled_signal
    cond_data = (prev_context_pc_1_scaled_signal.squeeze(1), prev_context_pc_2_scaled_signal.squeeze(1), next_context_pc_1_scaled_signal.squeeze(1), next_context_pc_2_scaled_signal.squeeze(1))
    
    B, C, L = audio.shape  # B is batchsize, C=1, L is audio length

    size = (audio.shape[0],1,audio.shape[2])
    x = torch.normal(0, 1, size=size).cuda()

    B, C, L = audio.shape  # B is batchsize, C=1, L is audio length

    eta_initial = 1.0  # Full noise initially
    eta_final = 0.1    # Minimal noise at the last step
    T = diffusion_hyperparams["T"]

    with torch.no_grad():
        for t in (range(T - 1, -1, -1)):
            diffusion_steps = (t * torch.ones((size[0], 1))).cuda()
            epsilon_theta = diffwave(x, cond_data, chunk_number, diffusion_steps.view(B, 1))
            x = (x - (1 - Alpha[t]) / torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha[t])
            if t > 0:
                x = x + Sigma[t] * torch.normal(0, 1, size=size).cuda()  # add the variance term to x_{t-1}
                
    # reconstructed = vae.decode(x)
    reconstructed_unscaled = unscale_signal(x,target_pc_original_mean, target_pc_original_range)

    data_ori.append(target_pc.unsqueeze(-1).permute(0,2,1).to('cuda'))
    data_imp.append(reconstructed_unscaled[:,:,:])
    data_pre.append(prev_context_pc)
    data_next.append(next_context_pc)

    # break

100%|██████████| 5633/5633 [1:36:54<00:00,  1.03s/it]


In [None]:
data_ori_t = data_ori
data_imp_t = data_imp
data_pre_t = data_pre
data_next_t = data_next
data_ori_ = torch.cat(data_ori_t, dim=0)
data_imp_ = torch.cat(data_imp_t, dim=0)
data_pre_ = torch.cat(data_pre_t, dim=0)
data_next_ = torch.cat(data_next_t, dim=0)

# Use DPA to reconstruct the sequences

In [None]:
# Retrieve the original sequence length
# data_ori_tensor = torch.cat(data_ori[0][0:252,:],dim=0).cpu().detach()
# data_imp_tensor = torch.cat(data_imp[0][0:252,:],dim=0).cpu().detach()
data_ori_tensor = data_ori_[:,0,:].cpu().detach()
data_imp_tensor = data_imp_[:,0,:].cpu().detach()
data_pre_tensor = data_pre_[:,:].cpu().detach()
data_next_tensor = data_next_[:,:].cpu().detach()

num_segments = 12
segment_length = 250
overlap = 25
num_sequences = data_ori_tensor.size(0) // num_segments
original_length = segment_length * (num_segments + 4) - overlap * (num_segments + 4 - 1)  # 2675

# Reshape and reconstruct original sequence
data_ori_reshaped = data_ori_tensor.view(num_sequences, num_segments, -1)
data_pre_reshaped = data_pre_tensor.view(num_sequences, num_segments, -1)
data_next_reshaped = data_next_tensor.view(num_sequences, num_segments, -1)

# data_ori_reshaped is with shape (batch_size, num_segments, segment_length)
# data_pre_reshaped is with shape (batch_size, num_segments, 2*segment_length)
# data_next_reshaped is with shape (batch_size, num_segments, 2*segment_length)
data_pre_reshape = data_pre_reshaped[:,0,:]
# reshape the data_pre_reshaped to (batch_size,2, segment_length) 
data_pre_reshape = data_pre_reshape.view(num_sequences,2,-1)
data_next_reshaped = data_next_reshaped[:,0,:]
# reshape the data_next_reshaped to (batch_size,2, segment_length)
data_next_reshape = data_next_reshaped.view(num_sequences,2,-1)

# add data_ori_reshaped to the data_pre_reshape
data_ori_reshape_cat = torch.cat((data_pre_reshape,data_ori_reshaped),dim=1)
# add data_next_reshape to the data_ori_reshape_cat
data_ori_reshape_cat = torch.cat((data_ori_reshape_cat,data_next_reshape),dim=1)

reconstructed_ori = torch.zeros((num_sequences, original_length))
# Reconstruct original sequences by considering overlaps

for i in range(num_segments+4):
    start = i * (segment_length - overlap)
    end = start + segment_length
    reconstructed_ori[:, start:end] += data_ori_reshape_cat[:, i, :].squeeze(-1)
# Normalize overlap areas to avoid summing overlap contributions
weights = torch.zeros(original_length)
for i in range(num_segments+4):
    start = i * (segment_length - overlap)
    end = start + segment_length
    weights[start:end] += 1
weights = weights.unsqueeze(0)  # Broadcast across all sequences
reconstructed_ori /= weights

num_segments = 12
segment_length = 250
retain_length = 225  # Points to retain from each segment
compare_length = 25  # Points to compare for the transition
overlap = segment_length - retain_length

# Reshape tensors for processing
data_imp_reshaped = data_imp_tensor.view(-1, num_segments, segment_length)
num_sequences = data_imp_reshaped.size(0)
reconstructed_imp = torch.zeros((num_sequences, retain_length * num_segments))
# Process `data_imp_tensor` to retain 226 points and optimize overlap transitions
for seq in tqdm(range(num_sequences)):
    last_point = None
    start = 0
    # print("seq"+str(seq))
    for i in range(num_segments):
        segment = data_imp_reshaped[seq, i, :]
        # print(i)
        if i == 0:
            last_point = data_pre_reshape[seq, 1, -1]
            last_second_point = data_pre_reshape[seq, 1, -2]
            last_derivatives = last_point-last_second_point
            
#             # Retain the first 226 points of the first segment
#             reconstructed_imp[seq, start:start + retain_length] = segment[:retain_length]
#         else:
        compare_region = segment[:compare_length]
        transition_abs_value = torch.abs(compare_region-last_point)
        derivatives = torch.tensor(compare_region[1:] - compare_region[:-1])
        if last_derivatives > 0:
            positive_indices = (derivatives > 0).nonzero(as_tuple=True)[0]
            if len(positive_indices > 0):
                best_transition_point = transition_abs_value[positive_indices].argmin().item()
                best_transition_point = positive_indices[best_transition_point].item()
            else:
                best_transition_point = transition_abs_value.argmin().item()
        else:
            negative_indices = (derivatives < 0).nonzero(as_tuple=True)[0]
            if len(negative_indices > 0): 
                best_transition_point = transition_abs_value[negative_indices].argmin().item()
                best_transition_point = negative_indices[best_transition_point].item()
            else:
                best_transition_point = transition_abs_value.argmin().item()

        reconstructed_imp[seq, start:start + retain_length] = segment[best_transition_point+1:best_transition_point+1 + retain_length]
#             if i == 2:
#                 break
    # Update the last point for the next segment's comparison
        last_point = reconstructed_imp[seq, start + retain_length - 1]
        last_second_point = reconstructed_imp[seq, start + retain_length - 2]
        last_derivatives = last_point-last_second_point
        start += retain_length
    # break
