In [1]:
from utils.metric_dataloader import MetricDataPreprocessor
from utils.tshae_utils import load_tshae_model
from models.ddpm_models import ContextUnet, DDPM
from omegaconf import OmegaConf
import torch
import numpy as np
from collections import defaultdict

In [2]:
class LatentEncoder:
    def __init__(self, checkpoint_path: str = 'best_models/FD003/tshae/', device: str ='cuda'):
        config_path = checkpoint_path + ".hydra/config.yaml"
        model_path = checkpoint_path + "tshae_best_model.pt"
        self.device = device if torch.cuda.is_available() else "cpu"
        self.config = self._get_congig(config_path=config_path)
        print(model_path)
        self.model = load_tshae_model(model_path=model_path).to(self.device)
        train_dataset, test_dataset, val_dataset = self._get_datasets()
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.val_dataset = val_dataset
    
    def _get_congig(self, config_path: str):
        config = OmegaConf.load(config_path)
        return config

    def _get_datasets(self):
        preproc = MetricDataPreprocessor(**self.config.data_preprocessor)
        train_dataset, test_dataset, val_dataset = preproc.get_datasets()
        return train_dataset, test_dataset, val_dataset
    
    def get_run_ids(self):
        run_ids = {
            'train_ids': self.train_dataset.ids,
            'test_ids': self.test_dataset.ids,
            'val_ids': self.val_dataset.ids
            }
        return run_ids

    def encode_run_id(self, run_id: int) -> np.ndarray:
        if run_id in self.train_dataset.ids:
            x, true_rul = self.train_dataset.get_run(run_id)
        elif run_id in self.test_dataset.ids:
            x, true_rul = self.test_dataset.get_run(run_id)
        elif run_id in self.val_dataset.ids:
            x, true_rul = self.val_dataset.get_run(run_id)
        else:
            raise KeyError('No such run_id in datasets!')
        rul_hat, z, *_ = self.model(x.to(self.device))
        return z.detach().cpu().numpy()

le = LatentEncoder()
print(le.get_run_ids())


best_models/FD003/tshae/tshae_best_model.pt
{'train_ids': array([ 1,  2,  3,  5,  6,  7,  8,  9, 11, 13, 14, 15, 16, 17, 19, 20, 21,
       23, 24, 25, 26, 27, 28, 29, 32, 34, 35, 36, 37, 38, 40, 41, 42, 43,
       46, 47, 48, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
       64, 65, 66, 67, 68, 69, 71, 72, 74, 75, 78, 79, 81, 82, 84, 85, 86,
       87, 88, 89, 91, 92, 93, 94, 95, 96, 97, 98, 99]), 'test_ids': array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
        92,  93,  94,  95,  96,  97,  98,  99, 100]), 'val_ids': arra

In [3]:
z = le.encode_run_id(1)

In [4]:
class DiffusionGen:
    def __init__(self, checkpoint_path = 'best_models/FD003/ddpm/', tshae_model_path='best_models/FD003/tshae/tshae_best_model.pt', device='cuda'):
        config_path = checkpoint_path + ".hydra/config.yaml"
        model_path = checkpoint_path + "ddpm_best_model.pt"
        self.device = device if torch.cuda.is_available() else "cpu"
        self.config = self._get_congig(config_path=config_path)
        self.ddpm_model = self._get_model().to(self.device)
        self.tshae_model = load_tshae_model(model_path=tshae_model_path).to(self.device)
    
    def _get_congig(self, config_path):
        config = OmegaConf.load(config_path)
        return config
    
    def _get_model(self):
        n_T = self.config.diffusion.ddpm_train.n_T 
        z_dim   = self.config.diffusion.ddpm_train.z_dim
        n_feat = self.config.diffusion.ddpm_train.n_feat
        drop_prob = self.config.diffusion.ddpm_model.drop_prob
        
        ddpm = DDPM(
            nn_model=ContextUnet(
            in_channels=1, 
            n_feat=n_feat, 
            z_dim=z_dim), 
            betas=(1e-4, 0.02), 
            n_T= n_T, 
            device=self.device, 
            drop_prob=drop_prob)
        ddpm.load_state_dict(torch.load(self.config.diffusion.checkpoint_ddpm.path))
        ddpm.eval()
        return ddpm
    
    def generate_from_latent(self, z_space, num_samples=4, w=0.5, quantile=0.25, mode='best'):

        history = defaultdict(dict)

        with torch.no_grad():
            z = torch.FloatTensor(z_space).to(self.device)
            x_tshae_samples = self.tshae_model.decoder(z)
            x_hat_diffusion, _ = self.ddpm_model.sample_cmapss(n_sample=num_samples, size=(1,32,32), device=z.device, z_space_contexts=z, guide_w=w)

            # Calculate distances between the passed latent trajectory and generated trajectories

            num_z = z.shape[0]
            x_hat_diffusion = x_hat_diffusion.squeeze(1)[:,:,:21]  # Adjusting shape to match the desired sensor reconstruction


            #================================================
            z = z.unsqueeze(1)

            x_hat_diffusion = x_hat_diffusion.squeeze(1)[:,:,:21]
            with torch.no_grad(): 
                rul_hat_diff, z_diff, *_  = self.tshae_model(x_hat_diffusion)
            z_diff = z_diff.reshape(num_samples, num_z, 2).permute(1,0,2)
 
            rul_hat_diff = rul_hat_diff.reshape(num_samples, num_z,  1).permute(1,0,2)


            distances = torch.norm(z - z_diff, dim=-1)

            if mode == "quantile":
                limits = torch.quantile(distances.squeeze(1), quantile, interpolation='linear', dim=1, keepdim=True)
                choices = []
                for i in range(distances.shape[0]):
                    choices.append(
                        np.random.choice(
                            np.flatnonzero(
                                np.where(
                                    distances.squeeze(1).detach().cpu().numpy()[i] < limits.detach().cpu().numpy()[i], 1, 0
                                )
                            ), 
                        1)
                    )
                best_samples = torch.from_numpy(np.array(choices).squeeze(1))
            else:
                best_samples = torch.argmin(distances.squeeze(1), dim=-1)
            

            x_diff_samples = x_hat_diffusion.reshape(num_samples, num_z, 32, 21).permute(1,0,2,3) #[num_z, num_samples, 32, 21]

            x_diff_samples = x_diff_samples[range(x_diff_samples.shape[0]), best_samples]

            z_diff = z_diff[range(z_diff.shape[0]), best_samples]
            rul_hat_diff = rul_hat_diff[range(rul_hat_diff.shape[0]), best_samples]
            
            x_diff_samples = x_diff_samples.to("cpu")
            x_tshae_samples = x_tshae_samples.to("cpu")
            
            sensors_diff_reconstructed = []
            sensors_tshae_reconstructed = []

            for ind in range(num_z):
                if ind == 0:
                    sensors_diff_reconstructed.append(x_diff_samples[ind])
                    sensors_tshae_reconstructed.append(x_tshae_samples[ind])
                else:
                    sensors_diff_reconstructed.append(np.expand_dims(x_diff_samples[ind, -1, :], axis=0))
                    sensors_tshae_reconstructed.append(np.expand_dims(x_tshae_samples[ind, -1, :], axis=0))

            sensors_diff_reconstructed = np.concatenate(sensors_diff_reconstructed, axis=0)
            sensors_tshae_reconstructed = np.concatenate(sensors_tshae_reconstructed, axis=0)


            history['z'] = z.squeeze().detach().cpu().numpy()
            history["z_diff"] = z_diff.detach().cpu().numpy()
            
            history["x_diff_samples"] = x_diff_samples.detach().cpu().numpy()
            
            history["rul_hat_diff"] = rul_hat_diff.detach().cpu().numpy()
            
            history["sensors_diff_reconstructed"] =  sensors_diff_reconstructed
            history["sensors_tshae_reconstructed"] =  sensors_tshae_reconstructed
        return history

dg = DiffusionGen()
h = dg.generate_from_latent(z)


sampling timestep 100

In [5]:
print(z.shape)
print(h['sensors_diff_reconstructed'].shape)
print(h['sensors_tshae_reconstructed'].shape)

(228, 2)
(259, 21)
(259, 21)
