In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pylab as plt
import numpy as np
import torch
import swyft
import swyft.lightning as sl
import lensing_model
from pytorch_lightning import loggers as pl_loggers
from pyrofit.lensing.sources import SersicSource

import sys
sys.path.append('../16-swyft_unet/scripts/')
from plot import *

## Problem-specific analysis components

In [3]:
KNN = 3
SIGMA = 0.02
NPIX_SRC = NPIX_IMG = 40

class Model(sl.SwyftModel):
    def slow(self, pars):
        torch.cuda.set_device(0)
        torch.set_default_tensor_type(torch.cuda.FloatTensor)    
        x, y, phi, q, r_ein, slope = pars['z_lens']
        slope = 2.0
        x_src, y_src, phi_src, q_src, index, r_e, I_e = pars['z_src']
        img, coords = lensing_model.image_generator_sersic(x, y, phi, q, r_ein, slope, x_src, y_src, phi_src, q_src, index, r_e, I_e)
        X, Y, Xsrc, Ysrc = coords
        kNN_idx = lensing_model.get_kNN_idx(X/5, -Y/5, Xsrc, Ysrc, k = KNN)  # TODO: Need to sort out strange 1/5 and -1/5 factors
        torch.set_default_tensor_type(torch.FloatTensor)
        return sl.SampleStore(mu = img.cpu(), kNN_idx = kNN_idx.cpu(), X = X.cpu(), Y = Y.cpu(), Xsrc = Xsrc.cpu(), Ysrc = Ysrc.cpu())
    
    def fast(self, d):
        img = d['mu'] + torch.randn_like(d['mu'])*SIGMA
        return sl.SampleStore(img=img)
    
    def prior(self, N, bounds = None):
        src_samples = self.prior_sersic(N, bounds = bounds)
        lens_samples = self.prior_lens(N, bounds = bounds)
        return sl.SampleStore(**src_samples, **lens_samples)
    
    # Draw from source prior
#     def prior_src(self, N, bounds = None):
#         if bounds is None or 'z_src' not in bounds:
#             R = lensing_model.RandomSource()
#             z_src = torch.stack([R().cpu() for _ in range(N)])
#         else:
#             n = 3
#             l, h = bounds['z_src'].low, bounds['z_src'].high
#             R = lensing_model.RandomSource()
#             z_src = []
#             for _ in range(N):
#                 rnd = sum([R().cpu()-R().cpu() for _ in range(n)])
#                 rnd -= rnd.min()
#                 rnd /= rnd.max()
#                 z_src.append(l+rnd*h)
#             z_src = torch.stack(z_src)
#         return sl.SampleStore(z_src=z_src)

    def prior_sersic(self, N, bounds = None):
        if bounds is not None:
            low = bounds['z_src'].low
            high = bounds['z_src'].high
        else:
            low =  np.array([-0.1, -0.1, 0, 0., 0.5, 0.1, 0.])
            high = np.array([0.1, 0.1, 1.5, 1., 4.0, 2.5, 4.])
        draw = np.array([np.random.uniform(low=low, high=high) for _ in range(N)])
        return sl.SampleStore(z_src = torch.tensor(draw).float())

    def prior_lens(self, N, bounds = None):
        if bounds is not None:
            low = bounds['z_lens'].low
            high = bounds['z_lens'].high
        else:
            low =  np.array([-0.2, -0.2, 0, 0.2, 1.0, 1.5])
            high = np.array([0.2, 0.2, 1.5, 0.9, 2.0, 2.5])
        draw = np.array([np.random.uniform(low=low, high=high) for _ in range(N)])
        return sl.SampleStore(z_lens = torch.tensor(draw).float())
    
m = Model()
m.sample(3);

  return (indices.unsqueeze(-1) // strides) % shape
100%|██████████| 3/3 [00:03<00:00,  1.04s/it]
100%|██████████| 3/3 [00:00<00:00, 2137.77it/s]


## Definition of target image

In [4]:
s_targets = m.sample(10)
# torch.save(s_targets, "test_targets_sersic.pt")
# s_targets = torch.load("test_targets_sersic.pt")

100%|██████████| 10/10 [00:00<00:00, 36.89it/s]
100%|██████████| 10/10 [00:00<00:00, 14650.03it/s]


In [5]:
# for k, v in s_targets.items():
#     print(k, v.shape)

In [6]:
# for v in ['img', 'mu', 'X', 'Y' ,  'Xsrc', 'Ysrc']:
#     print(v)
#     plt_imshow(s_targets[v])

In [7]:
s0 = s_targets[6]
# print(s0['z_lens'])
# print(s0['z_src'])
# plt.figure(figsize=(15, 5))
# plt.subplot(1, 2, 1)
# plt.imshow(s0['img'].cpu())
# plt.colorbar()
# plt.subplot(1, 2, 2)
sersic = SersicSource()
# src = sersic(X=s0['X'].cpu(), Y=s0['Y'].cpu(), x=s0['z_src'][0].cpu(), y=s0['z_src'][1].cpu(), phi=s0['z_src'][2].cpu(), q=s0['z_src'][3].cpu(), index=s0['z_src'][4].cpu(), r_e=s0['z_src'][5].cpu(), I_e=s0['z_src'][6].cpu())
# plt.imshow(src)
# plt.colorbar()

In [8]:
# rec = lensing_model.deproject_idx(s0['img'].unsqueeze(0), s0['kNN_idx'].unsqueeze(0)).mean(axis=1).squeeze(0)
# plt.imshow(rec)
# plt.colorbar()

In [9]:
class LensNetwork(sl.SwyftModule):
    def __init__(self):
        super().__init__()
        self.online_z_score = swyft.networks.OnlineDictStandardizingLayer(dict(img = (NPIX_IMG, NPIX_IMG)))
        self.CNN = torch.nn.Sequential(
            torch.nn.Conv2d(1, 10, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(10, 20, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(20, 40, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(128),
            torch.nn.ReLU(),
            torch.nn.LazyLinear(256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 16),
        )
        self.c = sl.RatioEstimatorMLP1d(16, 6, hidden_features = 256)
        
    def forward(self, x, z):
        # Digesting x
        x = dict(img = x['img'])
        x = self.online_z_score(x)['img']
#         print('going through LenseNetwork')
        x = self.CNN(x.unsqueeze(1)).squeeze(1)
        
#         print('x', x.shape)
#         print('z', z['z_lens'].shape)
        
        out = self.c(x, z['z_lens'])
#         print('out', out)
        return dict(z_lens = out)

In [10]:
# class SourceNetwork(sl.SwyftModule):
#     def __init__(self):
#         super().__init__()
#         self.l = torch.nn.Linear(10, 10)
#         self.reg1d = sl.RatioEstimatorGaussian1d(momentum = 0.1)
#         self.L = torch.nn.Linear(NPIX_SRC**2, NPIX_SRC**2)
        
#     def get_img_rec(self, x):
#         x_img = x['img']
#         x_kNN_idx = x['kNN_idx']
#         x_src_rec = lensing_model.deproject_idx(x_img, x_kNN_idx)[:,:,:,:].mean(dim=1)
#         x_src_rec = self.L(x_src_rec.view(-1, NPIX_SRC*NPIX_SRC)).view(-1, NPIX_SRC, NPIX_SRC)*0 + x_src_rec
#         return x_src_rec
    
#     def forward(self, x, z):
#         x_img_rec = self.get_img_rec(x)
#         z_src = z['z_src']
#         x_img_rec, z_src = sl.equalize_tensors(x_img_rec, z_src)
#         w = self.reg1d(x_img_rec, z_src)
#         return dict(z_src = w)

class SourceNetwork(sl.SwyftModule):
    def __init__(self):
        super().__init__()
        self.online_z_score = swyft.networks.OnlineDictStandardizingLayer(dict(img = (NPIX_IMG, NPIX_IMG)))
        self.CNN = torch.nn.Sequential(
            torch.nn.Conv2d(1, 10, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(10, 20, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(20, 40, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(128),
            torch.nn.ReLU(),
            torch.nn.LazyLinear(256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 16),
        )
        self.c = sl.RatioEstimatorMLP1d(16, 7, hidden_features = 256)
        
    def forward(self, x, z):
        # Digesting x
        x = dict(img = x['img'])
        x = self.online_z_score(x)['img']
#         print('going through LensNetwork')
        x = self.CNN(x.unsqueeze(1)).squeeze(1)
        
        out = self.c(x, z['z_src'])
        return dict(z_src = out)

# Workflow

In [11]:
s0_img = dict(img = s0['img'])

In [12]:
Ntrain1, R1, ME = 5000, 2, 3 # Number of training simulations, number of training rounds (?), max epochs 
Ntrain1, R1, ME = 50, 1, 2 # Number of training simulations, number of training rounds (?), max epochs 
TARGET = 3
tag = 'VSersic01'
INFER_SOURCE = True

In [16]:
bounds = None
results = []
s0 = s_targets[TARGET]
for i in range(R1):
    tbl = pl_loggers.TensorBoardLogger("lightning_logs", name = 'lensing_%s'%tag)#, default_hp_metric=True)
    # s1: img, lens, src ~ p(img|lens, src)p(lens)p(src)
    s1 = sl.file_cache(lambda: m.sample(Ntrain1, bounds = bounds), './train_data_%s_%i_%i_%i.pt'%(tag, TARGET, Ntrain1, i))
    
    # r1: p(z_lens|img)/p(z_lens)
    r1 = LensNetwork()
    
    # d1: split img vs z_lens
    # TODO: Specify x_keys = ['img'], z_keys=['z_lens']
    d1 = sl.SwyftDataModule(store = s1, model = m, batch_size = 128)
    
    # Train r1 with d1
    t1 = sl.SwyftTrainer(accelerator = 'gpu', gpus=1, max_epochs = ME, logger = tbl)
    t1.fit(r1, d1)
    t1.test(r1, d1)
    
    # p1: z_lens ~ p(z_lens|img_obs)  --  these are weighted samples
    p1 = t1.infer(r1, d1, condition_x = s0)

    if INFER_SOURCE:
        # r2: p(src|z_lens, img)/p(src)
        r2 = SourceNetwork()

        # d2: split (img, kNN_idx) vs src
        # TODO: Specify x_keys = ['img', 'kNN_idx'], z_keys=['src']
        d2 = sl.SwyftDataModule(store = s1, model = m, batch_size = 16)

        # Train r2 with d2
        t2 = sl.SwyftTrainer(accelerator = 'gpu', gpus=1, max_epochs = 2, logger = tbl)
        t2.fit(r2, d2)

        # d3: img, lens, src ~ p(img|lens, src)p(lens|img_obs)
        s3 = m.sample(100, bounds = bounds, effective_prior = {'z_lens': p1})

        # d3: split (img, kNN) vs (z_lens, src)
        # TODO: Specify x_keys = ['img', 'kNN_idx'], z_keys=['z_lens', 'src']
        d3 = sl.SwyftDataModule(store = s3, model = m, batch_size = 16)

        # ws2: src ~ p(src|img_obs) = \int dlens p(src|lens, img_obs)*p(lens|img_obs)  --  weighted samples
        p2 = t2.infer(r2, d3, condition_x = s0_img)

        # Rectangle Bounds
        all_inference = dict(**p1, **p2)
    else:
        all_inference = p1
        p2 = None
        
    bounds = sl.get_1d_rect_bounds(all_inference, th = 1e-6)  # for p(z_lens)
    #results.append(dict(bounds = bounds, t1=t1, t2=t2, d1=d1, ws1=ws1, ws2=ws2))
    results.append(dict(p1=p1, p2=p2, bounds = bounds))
    
    # Making nice plots
    
    z = p1.sample(10000)['z_lens'].numpy()
    zr = p1.sample(10000000, replacement = False)['z_lens'].numpy()
    for k in range(6):
        fig = plt.figure(dpi = 100)
        plt.hist(zr[:,k], density = True, bins = 20, color = 'r')
        plt.hist(z[:,k], density = True, bins = 20, color = 'b')
        plt.axvline(s0['z_lens'][k], color='r')
        tbl.experiment.add_figure("posterior/%i"%k, fig)
        
    fig = plt.figure()
    key = 'z_lens'
    for k in range(6):
        l, h = bounds[key].low[k], bounds[key].high[k]
        plt.plot([k, k], [l, h], 'k')
        plt.scatter(k, s0[key][k], marker='o', color='r')
    tbl.experiment.add_figure("bounds", fig)

    for k in range(8):
        fig = plt.figure()
        img = s1[k]['img']
        plt.imshow(img)
        tbl.experiment.add_figure("train_data/%i"%k, fig)
        
    for k in range(8):
        fig = plt.figure()
        img = sersic(X=s1[k]['X'].cpu(), Y=s1[k]['Y'].cpu(), x=s1[k]['z_src'][0].cpu(), y=s1[k]['z_src'][1].cpu(), phi=s1[k]['z_src'][2].cpu(), q=s1[k]['z_src'][3].cpu(), index=s1[k]['z_src'][4].cpu(), r_e=s1[k]['z_src'][5].cpu(), I_e=s1[k]['z_src'][6].cpu())
        plt.imshow(img)
        tbl.experiment.add_figure("train_data_src/%i"%k, fig)
        
    fig = plt.figure()
    plt.imshow(s0['img'])
    tbl.experiment.add_figure("target/image", fig)
    fig = plt.figure()
    src = sersic(X=s0['X'].cpu(), Y=s0['Y'].cpu(), x=s0['z_src'][0].cpu(), y=s0['z_src'][1].cpu(), phi=s0['z_src'][2].cpu(), q=s0['z_src'][3].cpu(), index=s0['z_src'][4].cpu(), r_e=s0['z_src'][5].cpu(), I_e=s0['z_src'][6].cpu())
    plt.imshow(src)
    tbl.experiment.add_figure("target/source", fig)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

  | Name           | Type                         | Params
----------------------------------------------------------------
0 | online_z_score | OnlineDictStandardizingLayer | 0     
1 | CNN            | Sequential                   | 13.3 K
2 | c              | RatioEstimatorMLP1d          | 1.6 M 
----------------------------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.535     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hp/JS-div': 8.318975448608398, 'hp/KL-div': 0.056465357542037964}
--------------------------------------------------------------------------------
printing self! <swyft.lightning.components.SwyftTrainer object at 0x148cf093d490> <bound method Trainer.predict of <swyft.lightning.components.SwyftTrainer object at 0x148cf093d490>>


Predicting: 1it [00:00, ?it/s]

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

  | Name           | Type                         | Params
----------------------------------------------------------------
0 | online_z_score | OnlineDictStandardizingLayer | 0     
1 | CNN            | Sequential                   | 13.3 K
2 | c              | RatioEstimatorMLP1d          | 1.9 M 
----------------------------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.615     Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

100%|██████████| 100/100 [00:02<00:00, 37.33it/s]
100%|██████████| 100/100 [00:00<00:00, 20971.52it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


printing self! <swyft.lightning.components.SwyftTrainer object at 0x148cf06e6a00> <bound method Trainer.predict of <swyft.lightning.components.SwyftTrainer object at 0x148cf06e6a00>>


Predicting: 3it [00:00, ?it/s]