In [3]:
import torch
from torch.utils.data import DataLoader
import os

In [4]:
import concurrent.futures
import itertools

In [8]:
class Fitting():
    def __init__(
        self,
        datatype,
        default_shift=1,
    ):
        self.device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
        print(f'Using device: {self.device}')
        self.datatype = datatype
        if self.datatype=='SHO':
            from data_sho import damped_sho_np as func
            from data_sho import DataGenerator
        elif self.datatype=='SineGaussian':
            from data_sinegaussian import sine_gaussian_np as func
            from data_sinegaussian import DataGenerator
        elif self.datatype=='LIGO':
            pass # TODO: Implement LIGO data handling
        else:
            raise ValueError(f'Unknown datatype: {self.datatype}')
        self.func = func
        self.datadir = f'/ceph/submit/data/user/k/kyoon/KYoonStudy/ssm_regression/{self.datatype}'
        self.savedir = f'/ceph/submit/data/user/k/kyoon/KYoonStudy/ssm_regression/fitresults'
        self.modeldir = os.path.join(self.datadir, 'models')
        self.test_dict = torch.load(os.path.join(self.datadir, 'test.pt'))
        self.test_data = DataGenerator(self.test_dict)
        self.test_dataloader = DataLoader(
            self.test_data,
            batch_size=1,
            shuffle=False
        )
        self.num_points=200
        self.n_repeats=10
        self.sigma = 0.4
        self.t_vals = torch.linspace(start=-1, end=10, steps=self.num_points).to(dtype=torch.float32)

    def lmfit(self, max_events=None, max_workers=4):
        from lmfit import Model, Parameters
        model = Model(self.func)
        params = Parameters()
        params.add('shift', value=1.0, vary=False) # Default shift value
        if self.datatype == 'SHO':
            params.add('omega_0', min=0.1, max=1.9)
            params.add('beta', min=0., max=0.5)
        elif self.datatype == 'SineGaussian':
            params.add('f_0', min=0.1, max=1.9)
            params.add('tau', min=1., max=4.)
        elif self.datatype == 'LIGO':
            pass # TODO: Implement LIGO model parameter hints
        else:
            raise ValueError(f'Unknown datatype: {self.datatype}')

        def fit_one(args):
            idx, (theta_u, theta_s, data_u, data_s) = args
            y = data_u[0][0].to(device='cpu')
            t_vals_np = self.t_vals.numpy()
            y_np = y.numpy()
            result = model.fit(y_np, params, t=t_vals_np)
            return idx, result.fit_report()

        # Prepare data iterator (limit events if max_events is set)
        data_iter = enumerate(self.test_dataloader)
        if max_events is not None:
            data_iter = itertools.islice(data_iter, max_events)

        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(fit_one, args) for args in data_iter]
            for future in concurrent.futures.as_completed(futures):
                idx, report = future.result()
                print(f'Fit result for event {idx}: {report}')

    def bilby(self, nlive=1000, sampler='dynesty', max_events=None, max_workers=4):
        import bilby
        from bilby.core.prior import Uniform
        from bilby.core.likelihood import GaussianLikelihood
        priors = {}
        if self.datatype=='SHO':
            priors['omega_0'] = Uniform(0.1, 1.9, name='omega_0', latex_label='$\omega_0$')
            priors['beta'] = Uniform(0, 0.5, name='beta', latex_label='$\beta$')
            injection_parameters = dict(omega_0=1., beta=0.3, shift=1)
        elif self.datatype=='SineGaussian':
            priors['f_0'] = Uniform(0.1, 1.9, name='f_0', latex_label='$f_0$')
            priors['tau'] = Uniform(1., 4., name='tau', latex_label='$\tau$')
            injection_parameters = dict(f_0=1., tau=2.5, shift=1)
        def bilby_one(args):
            idx, (theta_u, theta_s, data_u, data_s) = args
            y = data_u[0][0].to(device='cpu')
            t_vals_np = self.t_vals.numpy()
            y_np = y.numpy()
            log_l = GaussianLikelihood(t_vals_np, y, self.func, sigma=self.sigma)
            result = bilby.run_sampler(
                likelihood=log_l, priors=priors, sampler=sampler,
                nlive=nlive, npool=4, save=True, clean=True,
                injection_parameters=injection_parameters,
                output_dir=self.savedir,
                label=self.datatype
            )
            return idx, result.fit_report()
    
        data_iter = enumerate(self.test_dataloader)
        if max_events is not None:
            data_iter = itertools.islice(data_iter, max_events)

        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(bilby_one, args) for args in data_iter]
            for future in concurrent.futures.as_completed(futures):
                idx, report = future.result()
                print(f'Fit result for event {idx}: {report}')

In [None]:
fitting = Fitting(datatype='SHO')
fitting.bilby()

  self.test_dict = torch.load(os.path.join(self.datadir, 'test.pt'))
14:25 bilby INFO    : Running for label 'SHO', output will be saved to 'outdir'
14:25 bilby INFO    : Running for label 'SHO', output will be saved to 'outdir'
14:25 bilby INFO    : Running for label 'SHO', output will be saved to 'outdir'
14:25 bilby INFO    : Running for label 'SHO', output will be saved to 'outdir'


Using device: cuda:3


14:25 bilby INFO    : Analysis priors:
14:25 bilby INFO    : Analysis priors:
14:25 bilby INFO    : Analysis priors:
14:25 bilby INFO    : Analysis priors:
14:25 bilby INFO    : omega_0=Uniform(minimum=0.1, maximum=1.9, name='omega_0', latex_label='$\\omega_0$', unit=None, boundary=None)
14:25 bilby INFO    : omega_0=Uniform(minimum=0.1, maximum=1.9, name='omega_0', latex_label='$\\omega_0$', unit=None, boundary=None)
14:25 bilby INFO    : omega_0=Uniform(minimum=0.1, maximum=1.9, name='omega_0', latex_label='$\\omega_0$', unit=None, boundary=None)
14:25 bilby INFO    : omega_0=Uniform(minimum=0.1, maximum=1.9, name='omega_0', latex_label='$\\omega_0$', unit=None, boundary=None)
14:25 bilby INFO    : beta=Uniform(minimum=0, maximum=0.5, name='beta', latex_label='$\x08eta$', unit=None, boundary=None)
14:25 bilby INFO    : beta=Uniform(minimum=0, maximum=0.5, name='beta', latex_label='$\x08eta$', unit=None, boundary=None)
14:25 bilby INFO    : beta=Uniform(minimum=0, maximum=0.5, name='b