In [7]:
import sys
#sys.path.append('/home/npopkov/dll24')
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch
torch.set_float32_matmul_precision('medium')

from denoising_diffusion_pytorch import Unet, GaussianDiffusion

import lightning as lt

import tables
tables.file._open_files.close_all()

class LatentDataset(Dataset):
    def __init__(self, h5_file):
        self.shape = np.array(h5_file[list(h5_file.keys())[0]]).shape
        self.data = self.createData(h5_file)
        self.min = self.data.min()
        self.max = self.data.max()
        self.std = self.data.std()
        self.mean = self.data.mean()
        self.transform('normalize')


    def createData(self, h5_file):
        data = []
        for key in h5_file.keys():
        
            sample = np.array(h5_file[key])

            data.append(np.array(sample))
        
        return torch.tensor(np.array(data)).float()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
    def transform(self, type: str = 'normalize'):
        if type == 'standardize':
            self.data = (self.data - self.mean) / self.std
        elif type == 'normalize':
            self.data = (self.data - self.min) / (self.max - self.min)
        else:
            raise ValueError('Unknown transformation type')
        
    def inverse_transform(self, data):

        return data * (self.max - self.min) + self.min
    
    def unflatten(self, data):
        return data.reshape(self.shape)
    

hdf = h5py.File('256encodesamp.hdf5', 'r',)
dataset = LatentDataset(hdf)
hdf.close()



In [8]:

model = Unet(
    dim = 64,
    channels = 1,
    dim_mults = (1, 2, 4),
    flash_attn = False,
)



model.load_state_dict(torch.load('model.pt'))

diffusion = GaussianDiffusion(
    model,
    image_size = 256,
    timesteps = 1000,
        # number of steps
)

In [9]:
diffusion.to('cuda')
diffusion.device

device(type='cuda', index=0)

In [10]:
diffusion.eval()
sample = diffusion.sample(batch_size=6)

sampling loop time step: 100%|██████████| 1000/1000 [01:34<00:00, 10.63it/s]


In [11]:
torch.save(dataset.inverse_transform(sample.data), 'sample.pt')