This tutorial takes a hands-on, open-ended approach—think of it as your chance to get more comfortable with PyTorch by building your own flow matching code from scratch. We provide boilerplate code for data loading and point you to useful velocity models in existing PyTorch libraries, but the implementation details are yours to explore!

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/florpi/heidelberg_generative_lectures/blob/main/4_galaxy_images.ipynb)

Let's first write our data loading code. We will use galaxy image data from AstroClip (https://arxiv.org/abs/2310.03024)

In [None]:

import torch
import matplotlib.pyplot as plt
from typing import List, Optional
from lightning import LightningDataModule
from torch.utils.data import DataLoader
from datasets import load_dataset

from torch.utils.data.dataloader import default_collate
from torchvision.transforms import CenterCrop
import numpy as np


class ToRGB:
    """
    Transformation from raw image data (nanomaggies) to the rgb values displayed
    at the legacy viewer https://www.legacysurvey.org/viewer

    Code copied from
    https://github.com/legacysurvey/imagine/blob/master/map/views.py
    """

    def __init__(self, scales=None, m=0.03, Q=20, bands=["g", "r", "z"]):
        rgb_scales = {
            "u": (2, 1.5),
            "g": (2, 6.0),
            "r": (1, 3.4),
            "i": (0, 1.0),
            "z": (0, 2.2),
        }
        if scales is not None:
            rgb_scales.update(scales)

        self.rgb_scales = rgb_scales
        self.m = m
        self.Q = Q
        self.bands = bands
        self.axes, self.scales = zip(*[rgb_scales[bands[i]] for i in range(len(bands))])

        # rearange scales to correspond to image channels after swapping
        self.scales = [self.scales[i] for i in self.axes]

    def __call__(self, imgs):
        # Check image shape and set to C x H x W
        if imgs.shape[0] != len(self.bands):
            imgs = np.transpose(imgs, (2, 0, 1))

        I = 0
        for img, band in zip(imgs, self.bands):
            plane, scale = self.rgb_scales[band]
            img = np.maximum(0, img * scale + self.m)
            I = I + img
        I /= len(self.bands)

        Q = 20
        fI = np.arcsinh(Q * I) / np.sqrt(Q)
        I += (I == 0.0) * 1e-6
        H, W = I.shape
        rgb = np.zeros((H, W, 3), np.float32)
        for img, band in zip(imgs, self.bands):
            plane, scale = self.rgb_scales[band]
            rgb[:, :, plane] = (img * scale + self.m) * fI / I

        rgb = np.clip(rgb, 0, 1)
        return rgb


class AstroClipCollator:
    def __init__(
        self,
        center_crop: int = 144,
        bands: List[str] = ["g", "r", "z"],
        m: float = 0.03,
        Q: int = 20,
    ):
        self.center_crop = CenterCrop(center_crop)
        self.to_rgb = ToRGB(bands=bands, m=m, Q=Q)

    def _process_images(self, images):
        # convert to rgb
        img_outs = []
        for img in images:
            rgb_img = torch.tensor(self.to_rgb(img)[None, :, :, :])
            img_outs.append(rgb_img)
        images = torch.concatenate(img_outs)

        images = self.center_crop(images.permute(0, 3, 2, 1))
        return images

    def __call__(self, samples):
        # collate and handle dimensions
        samples = default_collate(samples)
        # process images
        samples["image"] = self._process_images(samples["image"])
        return samples

In [None]:
cache_dir= # WRITE DOWN A DIRECTORY WHERE YOU'D LIKE THE DATA TO BE STORED 
train_data = load_dataset(
    'mhsotoudeh/astroclip', 
    split='train',
    cache_dir=cache_dir,
    #streaming=True,
)

train_data = train_data.with_format('torch')

In [None]:

batch_size = 16
train_loader = DataLoader(
            train_data,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=AstroClipCollator(
                center_crop=144,
                bands=["g", "r", "z"],
                m=0.03,
                Q=20,
            ),
        )

Let's now get a quick example batch from our data loader

In [None]:
example_batch = next(iter(train_loader))

Now let's focus on the image content and plot it

In [None]:
example_batch['image'].shape

In [None]:
# Note that the convention for images in pytorch is Channels x Height x Width,
# but for imshow is Height x Width x Channels.
plt.imshow(example_batch['image'][0].permute(1, 2, 0))

You can play around with two velcity models, a CNN based one (UNet) and a transformer based one (Vision Transformer from torchvision). Here is an example of a UNet that can be used for this problem:

In [None]:
from diffusers import UNet2DModel

velocity_model = UNet2DModel(
    sample_size=144,        
    in_channels=3, 
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(64, 64, 128, 256),  
    down_block_types= tuple(['DownBlock2D'] * 4),
    up_block_types= tuple(['UpBlock2D'] * 4),
)

Have fun :P