### Imports

In [None]:
from dataclasses import dataclass, field
from functools import lru_cache
from math import pi
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dist
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm.auto import trange
from tqdm import tqdm

plt.rcParams["image.cmap"] = "inferno"
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

### Simulator Setup

In [None]:
def get_meshgrid(resolution, nx, ny, device=None):
    """
    Constructs meshgrids.
    """
    dx = resolution
    dy = resolution

    # Coordinates at pixel centers
    x = torch.linspace(-1, 1, int(nx), device=device) * (nx - 1) * dx / 2
    y = torch.linspace(-1, 1, int(ny), device=device) * (ny - 1) * dy / 2

    # Note difference to numpy (!)
    Y, X = torch.meshgrid((y, x), indexing='ij')

    return X, Y


@dataclass
class Simulator:
    extent: float = 15.0
    """Observation size [deg]"""

    n_pix: int = 64
    """Number of pixels"""

    d_gc: float = 8.3
    """Distance to galactic center [kpc]"""

    sigma_pi0: float = 0.25
    """Length scale associated with spatial variations in pi0 emission [deg]."""

    pi0_disk_height: float = 5.0
    """Disk height for pi0 emission"""

    pi0_disk_radius: float = 20.0
    """Disk radius for pi0 emission"""

    disk_height_ic: float = 5.0
    """Disk height for IC emission [deg]."""

    bubble_smoothing_scale: float = 0.6
    """Smoothing scale for Fermi bubble template [deg]."""

    ps_disk_height: float = 0.3
    """Disk scale height for disk-correlated point sources [kpc]."""

    ps_disk_radius: float = 5.0
    """Disk scale radius for disk-correlated point sources [kpc]."""

    dm_dist_concentration: float = 0.5
    """Steepness of DM emission."""

    dm_dist_scale: float = 5.0
    """Spatial scale of DM emission [deg]."""

    containment_radius: float = 0.8 / 3
    """Very approximate 68% containment radius for PSF [deg]."""

    def __post_init__(self):
        # Image grid
        self.resolution = 2 * self.extent / self.n_pix
        self.X, self.Y = get_meshgrid(self.resolution, self.n_pix, self.n_pix)

        # Gaussian kernel for point spread function used to blur the observation
        X_k, Y_k = get_meshgrid(self.resolution, 5, 5)
        self.psf_kernel = torch.exp(-(X_k**2 + Y_k**2) / (2 * self.containment_radius**2))

        # Kernel for Gaussian process used to model pi0 emission
        pts = torch.stack([self.X.flatten(), self.Y.flatten()])
        d2s = ((pts[:, :, None] - pts[:, None, :]) ** 2).sum(0)
        self.kernel_pi0 = (-d2s / (2 * self.sigma_pi0**2)).exp()

    def sample_pi0_template(self):
        """
        Samples pi0 emission
        """
        # Manually rescale to make it look more realistic
        emission = (self.kernel_pi0 @ torch.randn(self.n_pix**2)).reshape(
            self.n_pix, self.n_pix
        )
        emission = 50 * torch.exp(emission / 8)
        return (
            emission
            * torch.exp(-((self.X / self.pi0_disk_radius) ** 2))
            * torch.exp(-((self.Y / self.pi0_disk_height) ** 2))
        )

    def template_ic(self):
        """
        Gets IC emission, which is smooth and fixed
        """
        return 25 * torch.exp(-self.Y.abs() / self.disk_height_ic)

    def template_bubbles(self):
        """
        Gets Fermi bubbles emission, which is smooth and fixed
        """
        Y_norths = 10.5 * (torch.cosh((-self.X - 1) / 10.5) - 1) + 1
        Y_souths = -8.7 * (torch.cosh((-self.X + 1.7) / 8.7) - 1) - 1
        # Apply some hacky smoothing to the edges of the bubbles
        emission = torch.zeros([self.n_pix, self.n_pix])
        emission[self.Y > 0] = torch.sigmoid(
            (self.Y[self.Y > 0] - Y_norths[self.Y > 0]) / self.bubble_smoothing_scale
        )
        emission[self.Y < 0] = torch.sigmoid(
            (Y_souths[self.Y < 0] - self.Y[self.Y < 0]) / self.bubble_smoothing_scale
        )
        return 3 * emission

    def template_dm(self):
        """
        Computes smooth emission from DM annihilation
        """
        dm_dist = dist.Gamma(self.dm_dist_concentration, 1 / self.dm_dist_scale)
        # Distance to galactic center
        rs = torch.sqrt(self.X**2 + self.Y**2)
        return 100 * dm_dist.log_prob(rs).exp()

    def _sample_ps_gc_xy(self, n: int):
        # Use same distribution as for DM emission
        ps_gc_dist = dist.Gamma(self.dm_dist_concentration + 1, 1 / self.dm_dist_scale)
        rs = ps_gc_dist.sample((n,))
        angles = torch.rand(n) * 2 * pi
        xs = rs * torch.cos(angles)
        ys = rs * torch.sin(angles)
        return xs, ys

    def _sample_ps_disk_xy(self, n: int):
        # Convert to degrees
        scale_radius = self.ps_disk_radius / self.d_gc * 180 / pi
        scale_height = self.ps_disk_height / self.d_gc * 180 / pi

        rs = dist.Exponential(1 / scale_radius).sample((n,))
        ys = dist.Exponential(1 / scale_height).sample((n,))

        # Randomly put pulsars above or below the galactic plane
        ys *= 2 * ((torch.rand(n) > 0.5).float() - 0.5)

        # Project radial coordinate
        angles = 2 * pi * torch.rand(n)
        xs = rs * torch.cos(angles)

        return xs, ys

    def _sample_pss_in_image(self, pos_sampler, n: int):
        # Sample positions until n lie inside the image
        xs, ys = pos_sampler(n)
        while True:
            idx_oob = (xs.abs() > self.extent) | (ys.abs() > self.extent)
            n_oob = idx_oob.sum()
            if n_oob == 0:
                break
            else:
                xs[idx_oob], ys[idx_oob] = pos_sampler(n_oob)

        return xs, ys

    def _pixelate_pss(self, xs, ys, fluxes):
        """
        Creates pixelated map by histogramming point source positions and fluxes.
        """
        # Map onto pixel grid
        bins = torch.linspace(-self.extent, self.extent, self.n_pix + 1)
        return torch.histogramdd(
            torch.stack([ys, xs], 1), bins=(bins, bins), weight=fluxes
        ).hist

    def sample(self) -> dict:
        """
        Simulate all the emission components
        """
        trace = {}

        # Template normalizations
        trace["A_ic"] = torch.rand(1)
        trace["A_pi0"] = torch.rand(1)
        trace["A_bubbles"] = torch.rand(1)
        trace["A_dm"] = torch.rand(1)

        # Sample pi0 template
        trace["template_pi0"] = self.sample_pi0_template()

        # Number of point sources
        n_ps_gc = trace["n_ps_gc"] = torch.randint(low=200, high=800, size=(1,))
        n_ps_disk = trace["n_ps_disk"] = torch.randint(low=1000, high=4000, size=(1,))

        # Sample fluxes
        ps_log_flux_mean = 1.0
        ps_log_flux_scale = 1.2
        dist_ps_fluxes = dist.LogNormal(ps_log_flux_mean, ps_log_flux_scale)
        ps_gc_fluxes = dist_ps_fluxes.sample((n_ps_gc,))
        ps_disk_fluxes = dist_ps_fluxes.sample((n_ps_disk,))
        # Sample positions
        xs_ps_gc, ys_ps_gc = self._sample_pss_in_image(self._sample_ps_gc_xy, n_ps_gc)
        xs_ps_disk, ys_ps_disk = self._sample_pss_in_image(self._sample_ps_disk_xy, n_ps_disk)
        # Pixelate
        trace["flux_ps_gc"] = self._pixelate_pss(xs_ps_gc, ys_ps_gc, ps_gc_fluxes)
        trace["flux_ps_disk"] = self._pixelate_pss(xs_ps_disk, ys_ps_disk, ps_disk_fluxes)

        trace["mu"] = (
            trace["A_ic"] * self.template_ic()
            + trace["A_pi0"] * trace["template_pi0"]
            + trace["A_dm"] * self.template_dm()
            + trace["A_bubbles"] * self.template_bubbles()
            + trace["flux_ps_gc"]
            + trace["flux_ps_disk"]
        )

        # Apply PSF and sample noise
        mu_blurred = torch.nn.functional.conv2d(
            trace["mu"][None, None, :, :], self.psf_kernel[None, None, :, :], padding=2,
        )[0, 0, :, :]

        trace["img"] = dist.Poisson(mu_blurred).sample()


        return trace

    def testsample(self, A_ic, A_pi0, A_bubbles, A_dm) -> dict:
        """
        Simulate all the emission components
        """
        trace = {}

        # Template normalizations
        trace["A_ic"] = A_ic
        trace["A_pi0"] = A_pi0
        trace["A_bubbles"] = A_bubbles
        trace["A_dm"] = A_dm

        # Sample pi0 template
        trace["template_pi0"] = self.sample_pi0_template()

        # Number of point sources
        n_ps_gc = trace["n_ps_gc"] = torch.randint(low=200, high=800, size=(1,))
        n_ps_disk = trace["n_ps_disk"] = torch.randint(low=1000, high=4000, size=(1,))

        # Sample fluxes
        ps_log_flux_mean = 1.0
        ps_log_flux_scale = 1.2
        dist_ps_fluxes = dist.LogNormal(ps_log_flux_mean, ps_log_flux_scale)
        ps_gc_fluxes = dist_ps_fluxes.sample((n_ps_gc,))
        ps_disk_fluxes = dist_ps_fluxes.sample((n_ps_disk,))
        # Sample positions
        xs_ps_gc, ys_ps_gc = self._sample_pss_in_image(self._sample_ps_gc_xy, n_ps_gc)
        xs_ps_disk, ys_ps_disk = self._sample_pss_in_image(self._sample_ps_disk_xy, n_ps_disk)
        # Pixelate
        trace["flux_ps_gc"] = self._pixelate_pss(xs_ps_gc, ys_ps_gc, ps_gc_fluxes)
        trace["flux_ps_disk"] = self._pixelate_pss(xs_ps_disk, ys_ps_disk, ps_disk_fluxes)

        trace["mu"] = (
            trace["A_ic"] * self.template_ic()
            + trace["A_pi0"] * trace["template_pi0"]
            + trace["A_dm"] * self.template_dm()
            + trace["A_bubbles"] * self.template_bubbles()
            + trace["flux_ps_gc"]
            + trace["flux_ps_disk"]
        )

        # Apply PSF and sample noise
        mu_blurred = torch.nn.functional.conv2d(
            trace["mu"][None, None, :, :], self.psf_kernel[None, None, :, :], padding=2,
        )[0, 0, :, :]

        trace["img"] = dist.Poisson(mu_blurred).sample()


        return trace

    def sample_batch(self, n) -> dict:
        samples = [self.sample() for _ in range(n)]
        return torch.utils.data.default_collate(samples)

### Inference
Example observation with each emission component plotted separately. The numbers in the subplots' titles show the mean flux in each pixel.

(The color bars are all the same, which makes some components hard to see. You can change `vmax` to a lower value to make them more visible, or set `vmax=None` to set it automatically.)

In [None]:
sim = Simulator()
sample = sim.sample()

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

ax = axes[0, 0]
im = ax.pcolormesh(sim.X, sim.Y, sample["flux_ps_gc"], vmin=0, vmax=60)
ax.set_title(f"Galactic Center point sources: {sample['flux_ps_gc'].mean().item():.2f}")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

ax = axes[0, 1]
im = ax.pcolormesh(sim.X, sim.Y, sample["flux_ps_disk"], vmin=0, vmax=60)
ax.set_title(f"Disk point sources: {sample['flux_ps_gc'].mean().item():.2f}")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

ax = axes[0, 2]
flux_pi0 = sample["A_pi0"] * sample["template_pi0"]
im = ax.pcolormesh(sim.X, sim.Y, flux_pi0, vmin=0, vmax=60)
ax.set_title(r"$\pi^0$/brem. diffuse: " + f"{flux_pi0.mean().item()}")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

ax = axes[0, 3]
flux_ic = sample["A_ic"] * sim.template_ic()
im = ax.pcolormesh(sim.X, sim.Y, flux_ic, vmin=0, vmax=60)
ax.set_title(f"Inverse Compton diffuse: {flux_ic.mean().item():.2f}")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

ax = axes[1, 0]
flux_bubbles = sample["A_bubbles"] * sim.template_bubbles()
im = ax.pcolormesh(sim.X, sim.Y, flux_bubbles, vmin=0, vmax=60)
ax.set_title(f"Fermi bubbles: {flux_bubbles.mean().item():.2f}")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

ax = axes[1, 1]
flux_dm = sample["A_dm"] * sim.template_dm()
im = ax.pcolormesh(sim.X, sim.Y, flux_dm, vmin=0, vmax=60)
ax.set_title(f"Dark matter: {flux_dm.mean().item():.2f}")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

ax = axes[1, 2]
im = ax.pcolormesh(sim.X, sim.Y, sample["mu"], vmin=0, vmax=200)
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
ax.set_title(f"Total: {sample['mu'].mean().item():.2f}")

ax = axes[1, 3]
im = ax.pcolormesh(sim.X, sim.Y, sample["img"], vmin=0, vmax=200)
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
ax.set_title(f"Observation: {sample['img'].mean().item():.2f}")

for ax in axes.flatten():
    ax.set_aspect("equal")

fig.tight_layout()

### Dataset

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, simulator, size: int = 1000):
        self.size = size  # number of simulations in an epoch
        self.simulator = simulator

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        sample = self.simulator.sample()  # generate a sample from the simulator
        # sample["img"] is generally what is observed by network
        # sample["A_dm"] is what we want to infer
        return sample

### U-Net

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        # encoder
        self.enc1 = self.conv_block(in_channels, 32)
        self.enc2 = self.conv_block(32, 64)
        self.enc3 = self.conv_block(64, 128)
        self.enc4 = self.conv_block(128, 128)

        # maxpooling
        self.pool = nn.MaxPool2d(2)

        # decoder
        self.upconv3 = self.upconv_block(128, 128) # same as the bottleneck channels
        self.dec3 = self.conv_block(256, 64) # concatenation of channels from encoder and upconv3
        self.upconv2 = self.upconv_block(64, 32)
        self.dec2 = self.conv_block(96, 32) # concatenation of channels from enc2 and upconv2
        self.upconv1 = self.upconv_block(32, 16)
        self.dec1 = self.conv_block(48, 16) # concatenation of channels from enc1 and upconv1
        self.out = nn.Conv2d(16, out_channels, 1)

    def forward(self, x):
        # encoder
        e1 = self.enc1(x)
        p1 = self.pool(e1)

        e2 = self.enc2(p1)
        p2 = self.pool(e2)

        e3 = self.enc3(p2)
        p3 = self.pool(e3)

        e4 = self.enc4(p3)

        # decoder
        up3 = self.upconv3(e4)
        merge3 = torch.cat([e3, up3], dim=1)
        d3 = self.dec3(merge3)

        up2 = self.upconv2(d3)
        merge2 = torch.cat([e2, up2], dim=1)
        d2 = self.dec2(merge2)

        up1 = self.upconv1(d2)
        merge1 = torch.cat([e1, up1], dim=1)
        d1 = self.dec1(merge1)

        out = self.out(d1)
        out=torch.sigmoid(out)

        return out

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

# create instance of the UNet model
model = UNet(1,2)

# define optimizer (Adam)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# create dataset and dataloader
train_dataset = Dataset(simulator, size=1000)
train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=True)
loss_list = []
# training loop
num_epochs = 10

for epoch in range(num_epochs):
    print(epoch+1)
    running_loss = 0.0
    for batch_idx, sample in enumerate(train_dataloader):
        print(f"Batch {batch_idx+1}/{len(train_dataloader)}")
        inputs = sample["img"].unsqueeze(1)

        targets_gc = torch.where(sample["flux_ps_gc"] > 0.8, torch.tensor(1.0), torch.tensor(0.0)).float()
        targets_disk = torch.where(sample["flux_ps_disk"] > 0.8, torch.tensor(1.0), torch.tensor(0.0)).float()
        targets = torch.stack([targets_gc, targets_disk], dim=1)

        optimizer.zero_grad()

        # forward pass
        outputs = model(inputs)

        # compute binary cross-entropy loss
        criterion = nn.BCELoss()
        loss = criterion(outputs, targets)
        # backpropagation and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loss_list.append(running_loss)

    # print average loss for the epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_dataloader)}")

print("Finished Training")

In [None]:
validation_sample = simulator.sample()

validation_img = validation_sample["img"][0]
validation_img = validation_img.unsqueeze(0).unsqueeze(0)

image = sample['img'][0].unsqueeze(0).unsqueeze(1)
ground_truth_gc = sample['flux_ps_gc'][0]
ground_truth_disk = sample['flux_ps_disk'][0]
model_predictions = model(image).squeeze()

threshold = 0.95
binary_predictions = (model_predictions > threshold).float()

pred = binary_predictions.unsqueeze(0).unsqueeze(1)[0]
gt = (ground_truth_disk).unsqueeze(0).unsqueeze(1)

observation = validation_img[0,0]
prediction = pred[0,0].detach().cpu().numpy()
ground_truth = gt[0,0]

# create subplots
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))

# observation
axes[0].imshow(observation)
axes[0].set_title('Observation')
axes[0].axis('off')  # Hide axes for better visualization

# prediction
axes[1].imshow(prediction)
axes[1].set_title('Prediction')
axes[1].axis('off')

# ground truth
axes[2].imshow(ground_truth)
axes[2].set_title('Ground Truth')
axes[2].axis('off')

plt.tight_layout()
plt.show()