# Download Data

In [None]:
!gdown 1b5GsQwGFhvlb3NFgtJfGVtNQR8Cs7Znd -O astroswin.zip && unzip -qq astroswin.zip && rm -rf astroswin.zip # the most recent checkpoint of aswin
!gdown 15AC-BMDLuafKRs-b9CaC3jYmi4gIkLVY -O dataset.zip && unzip -qq dataset.zip && rm -f dataset.zip # manually mined linear data

# Training Pipeline

In [None]:
import cv2
import numpy as np
import os
import random
import torch
import torch.nn.functional as F

from torch.utils.data import Dataset
from torchvision.transforms.v2 import RandomCrop, GaussianBlur, Compose, RandomRotation
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor, Trainer, TrainingArguments

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dtype2factor = {
    'float32': float(2**32 - 1),
    'float16': float(2**16 - 1),
    'uint8': int(2**8 - 1)
}

In [None]:
def random_downsample(image: np.ndarray):
    base_dim = 256
    downscale_range = min(image.shape[:2]) // base_dim
    downscale = random.randint(1, min(2, downscale_range)) if downscale_range > 1 else 1
    return cv2.resize(image, dsize=None, fx=1/downscale, fy=1/downscale, interpolation=cv2.INTER_LANCZOS4)

def init_process(image: np.ndarray, processor: Swin2SRImageProcessor) -> np.ndarray:
    if len(image.shape) == 2:
        image = np.tile(np.expand_dims(image, -1), (1, 1, 3))
    image = image[:, :, ::-1]
    if image.min() >= 0 and image.max() <= 1:
        return processor(image, do_rescale=False, return_tensors='np')['pixel_values']
    return processor(image, rescale_factor=1/dtype2factor[str(image.dtype)], return_tensors='np')['pixel_values']

def to_tensor(image: np.ndarray) -> torch.Tensor:
    return torch.Tensor(image.transpose(2, 1, 0))

rotator = RandomRotation(90, expand=True, fill=0.6)
crop = RandomCrop((256, 256))

In [None]:
def sharp_loss(pred: torch.Tensor, target: torch.Tensor, alpha: float = 0.7):
    content_loss = F.l1_loss(pred, target)
    laplacian_kernel = torch.tensor([
        [0, 1, 0],
        [1, -4, 1],
        [0, 1, 0]
    ], dtype=torch.float32).view(1, 1, 3, 3).repeat(1, 3, 1, 1).to(pred.device)
    pred_edges = F.conv2d(pred, laplacian_kernel, padding=1)
    target_edges = F.conv2d(target, laplacian_kernel, padding=1)
    edge_loss = F.l1_loss(pred_edges, target_edges)

    return alpha * content_loss + (1 - alpha) * edge_loss

def gradient_loss(pred: torch.Tensor, target: torch.Tensor, temperature: float = 1):
    pred_grad_x = pred[:, :, :, :-1] - pred[:, :, :, 1:]
    pred_grad_y = pred[:, :, :-1, :] - pred[:, :, 1:, :]
    target_grad_x = target[:, :, :, :-1] - target[:, :, :, 1:]
    target_grad_y = target[:, :, :-1, :] - target[:, :, 1:, :]

    mask_x = F.sigmoid((target_grad_x - 0.5) / temperature)
    mask_y = F.sigmoid((target_grad_y - 0.5) / temperature)

    loss_x = mask_x * F.l1_loss(pred_grad_x, target_grad_x)
    loss_y = mask_y * F.l1_loss(pred_grad_y, target_grad_y)
    return (loss_x.mean() + loss_y.mean()) / 2

def differentiable_histogram(
    x: torch.Tensor, bins: int = 256, bandwidth: float = 0.01
) -> torch.Tensor:
    batch_size, channels = x.shape[0], x.shape[1]
    bin_centers = torch.linspace(0, 1, bins, device=x.device)  # [bins]
    x_flat = x.reshape(batch_size, channels, -1, 1)  # [B, C, H*W, 1]
    distances = torch.abs(x_flat - bin_centers)    # [B, C, H*W, bins]
    weights = torch.clamp(1 - distances / bandwidth, 0, 1)
    hist = torch.sum(weights, dim=2)  # [B, C, bins]
    return hist / (x.shape[2] * x.shape[3])  # normalize

def hist_loss(pred: torch.Tensor, target: torch.Tensor, bins: int = 256, bandwidth: float = 0.1) -> torch.Tensor:
    hist_pred = differentiable_histogram(pred, bins=bins, bandwidth=bandwidth)
    hist_target = differentiable_histogram(target, bins=bins, bandwidth=bandwidth)

    return F.l1_loss(hist_pred, hist_target)

In [None]:
def motion_blur(image: torch.Tensor, length: int = 15, angle: float = 0.0) -> torch.Tensor:
    def _get_motion_kernel(length: int, angle_deg: float) -> torch.Tensor:
        import math
        angle = math.radians(angle_deg)
        kernel = torch.zeros((length, length))
        center = length // 2

        for i in range(length):
            dx = i - center
            dy = round(math.tan(angle) * dx)
            y = center + dy
            if 0 <= y < length:
                kernel[y, i] = 1.0
        return kernel / kernel.sum()

    length |= 1  # make odd to have an odd-sized kernel
    kernel = _get_motion_kernel(length, angle)
    return F.conv2d(
        image,
        kernel.view(1, 1, length, length).repeat(3, 1, 1, 1),
        padding=length // 2,
        groups=image.shape[1],
    )

def bokeh_blur(image: torch.Tensor, radius: int = 3) -> torch.Tensor:
    kernel_size = 2 * radius + 1
    kernel = torch.zeros((kernel_size, kernel_size))
    y, x = torch.meshgrid(
        torch.linspace(-radius, radius, kernel_size),
        torch.linspace(-radius, radius, kernel_size),
        indexing='ij',
    )
    mask = (x**2 + y**2) <= radius**2
    kernel[mask] = 1.0
    kernel /= kernel.sum()
    return F.conv2d(
        image,
        kernel.view(1, 1, kernel_size, kernel_size).repeat(3, 1, 1, 1),
        padding=kernel_size//2,
        groups=image.shape[1]
    )

def anisotropic_gaussian_blur(
    image: torch.Tensor,
    sigma_x: float = 1.0,
    sigma_y: float = 1.0,
    angle: float = 0.0
) -> torch.Tensor:
    def _get_rotated_gaussian_kernel(sigma_x, sigma_y, angle_deg, kernel_size):
        import math
        angle = math.radians(angle_deg)
        x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
        y = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
        x, y = torch.meshgrid(x, y, indexing='ij')

        x_rot = x * math.cos(angle) + y * math.sin(angle)
        y_rot = -x * math.sin(angle) + y * math.cos(angle)

        kernel = torch.exp(-(x_rot**2 / (2 * sigma_x**2) + y_rot**2 / (2 * sigma_y**2)))
        return kernel / kernel.sum()

    kernel_size = int(2 * 3 * max(sigma_x, sigma_y) + 1) | 1  # make odd to have an odd-sized kernel
    kernel = _get_rotated_gaussian_kernel(sigma_x, sigma_y, angle, kernel_size)

    return F.conv2d(
        image,
        kernel.view(1, 1, kernel_size, kernel_size).repeat(3, 1, 1, 1),
        padding=kernel_size//2,
        groups=image.shape[1]
    )

In [None]:
gaussian_blur = GaussianBlur(kernel_size=3, sigma=1.0)

blur_functions = [
    lambda x: gaussian_blur(x),
    lambda x: motion_blur(x, length=random.randint(1, 3), angle=random.randint(0, 360)),
    lambda x: anisotropic_gaussian_blur(
        x,
        sigma_x=1,
        sigma_y=1,
        angle=random.randint(0, 360),
    ),
    lambda x: bokeh_blur(x,radius=random.randint(1, 3)),
]

In [None]:
def filter_paths(paths: list, min_size: int = 256) -> list:
    from PIL import Image
    Image.MAX_IMAGE_PIXELS = None

    res = []
    for path in paths:
        if not path.endswith(('jpg', 'png', 'tiff', 'tif', 'jpeg')):
            continue
        if path.endswith(('tiff', 'tif')):
            if min(cv2.imread(path, cv2.IMREAD_UNCHANGED).shape[:2]) >= min_size:
                res.append(path)
        else:
            if min(Image.open(path).size) >= min_size:
                res.append(path)
    return res

In [None]:
T = -0.51082562376  # ln(0.6), where 0.6 is a desired expected value of autostretched image

def autostretch_torch(image: torch.Tensor, eps: float = 1e-2):
    im_min, im_max = image.min(), image.max()
    min_max_scaled = (image - im_min) / (im_max - im_min)
    mean = min_max_scaled.mean()
    mean_scaled = ((1+eps)/(mean+eps)) * mean
    best_gamma =  T / torch.log(mean_scaled)
    scale = (1 + eps) / (min_max_scaled + eps)
    scaled_image = (min_max_scaled * scale) ** best_gamma
    return {
        'labels': scaled_image,
        'scale': scale,
        'gamma': best_gamma,
        'min': im_min,
        'max': im_max,
    }

In [None]:
class AstroDataset(Dataset):
    def __init__(
        self,
        root_dir: str,
        processor: Swin2SRImageProcessor,
        transform: Compose,
        is_linear: bool = False,
        corrupt_dropout: float = 0.05,
        noise_base_amount: float = 0.2,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.root_dir = root_dir
        self.image_paths = filter_paths([os.path.join(root_dir, f) for f in os.listdir(root_dir)])
        self.corrupt_dropout = corrupt_dropout
        self.processor = processor
        self.transform = transform
        self.is_linear = is_linear
        self.noise_base_amount = noise_base_amount

    def __len__(self):
        return len(self.image_paths)

    def _corrupt(self, image: torch.Tensor):
        if random.random() >= self.corrupt_dropout:
            blur_composition = random.choices(blur_functions, k=random.randint(2, len(blur_functions)))
            for blur in blur_composition:
                image = blur(image)
        return image

    def _get_noise_pattern(self, image: torch.Tensor):
        return torch.randn(image.shape) * random.random() * self.noise_base_amount

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
        image = self.transform(image)
        stretch_output = autostretch_torch(image) if self.is_linear else {'labels': image}
        corrupted_tensor = self._corrupt(stretch_output['labels'].unsqueeze(0).clone().detach())[0]
        item = {'pixel_values': corrupted_tensor, **stretch_output}
        noise_pattern = self._get_noise_pattern(image)
        item['pixel_values'] += noise_pattern
        item['labels'] += noise_pattern
        return item

class AstroSwin2SR(Swin2SRForImageSuperResolution):
    def __init__(self, config):
        super().__init__(config)
        del self.upsample
        self.resample = torch.nn.Conv2d(60, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self, pixel_values: torch.Tensor, labels: torch.Tensor = None):
        output = self.swin2sr(pixel_values=pixel_values)
        output = self.resample(output.last_hidden_state)
        return {'outputs': output}

In [None]:
class TrainerWithCustomLoss(Trainer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.process_state = []

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        def _calculate_loss(outputs: torch.Tensor, labels: torch.Tensor):
            h_loss = hist_loss(outputs, labels, bins=256, bandwidth=5e-2)
            g_loss = gradient_loss(outputs, labels, temperature=0.5)
            d_loss = sharp_loss(outputs, labels, alpha=0.65)
            return d_loss * 2.0 + g_loss * 2.0 + h_loss * 1.0

        labels = inputs.pop('labels') if 'labels' in inputs else None
        scale = inputs.pop('scale') if 'scale' in inputs else None
        gamma = inputs.pop('gamma') if 'gamma' in inputs else None
        image_min = inputs.pop('min') if 'min' in inputs else None
        image_max = inputs.pop('max') if 'max' in inputs else None

        result = model(**inputs)
        outputs = result['outputs']

        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = _calculate_loss(outputs, labels)
        else:
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            loss = loss.mean()
        return (loss, outputs) if return_outputs else loss

In [None]:
aswin = AstroSwin2SR.from_pretrained('aswin-1.2-checkpoint-900').to(device)
processor = Swin2SRImageProcessor.from_pretrained('aswin-1.2-checkpoint-900')

In [None]:
transform = Compose([
    lambda img: init_process(img, processor),
    lambda img: img[0].transpose(2, 1, 0),
    lambda img: random_downsample(img),
    lambda img: to_tensor(img),
    #lambda img: rotator(img),  -- excluded from linear data train pipeline, because random rotation affects autostretch in a very bad way
    lambda img: crop(img),
])

In [None]:
train_dataset = AstroDataset('linear_data/train', processor, transform, is_linear=True, corrupt_dropout=0.05, noise_base_amount=0.0)
eval_dataset = AstroDataset('linear_data/test', processor, transform, is_linear=True, corrupt_dropout=0.05, noise_base_amount=0.0)

In [None]:
args = TrainingArguments(
    output_dir='astro_model',
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=2e-4,
    gradient_checkpointing=False,
    gradient_accumulation_steps=4,
    num_train_epochs=12,
    logging_steps=30,
    logging_strategy='steps',
    eval_steps=30,
    eval_strategy='steps',
    save_strategy='steps',
    save_steps=60,
    fp16=True,
    report_to='none',
    remove_unused_columns=False,
)

trainer = TrainerWithCustomLoss(
    model=aswin,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

In [None]:
trainer.train()

In [None]:
aswin.save_pretrained('aswin-1.3-checkpoint-300')
processor.save_pretrained('aswin-1.3-checkpoint-300')

In [None]:
!zip -r astroswin_v1_3_checkpoint300.zip aswin-1.3-checkpoint-300/

# Inference Example

In [None]:
def tensor_to_pil(tensor: torch.Tensor):
    np_values = tensor.numpy()[0]
    np_values = np.clip(np_values, 0, 1)
    np_values = np.moveaxis(np_values, source=0, destination=-1)
    return np_values.astype(np.float32)

@torch.no_grad()
def terminate_blur(
    image: np.ndarray, model: AstroSwin2SR, processor: Swin2SRImageProcessor, window: int = 256
):
    from tqdm import tqdm
    from gc import collect

    def create_weight_mask(size, overlap):
        mask = torch.ones(1, 1, size + 2 * overlap, size + 2 * overlap)
        fade = torch.linspace(0, 1, overlap)

        # vertical borders
        mask[..., :overlap, :] *= fade.view(1, 1, -1, 1)
        mask[..., -overlap:, :] *= fade.flip(0).view(1, 1, -1, 1)

        # horizontal borders
        mask[..., :, :overlap] *= fade.view(1, 1, 1, -1)
        mask[..., :, -overlap:] *= fade.flip(0).view(1, 1, 1, -1)
        return mask.to(device)

    pad = 32
    pad_based_width, pad_based_height = (image.shape[1] // window + 1) * window, (image.shape[0] // window + 1) * window

    img_tensor = processor(image, do_rescale=False, return_tensors='pt').pixel_values.to(device)
    pad_based_img = torch.zeros(1, 3, pad_based_height, pad_based_width).to(device)
    pad_based_img[:, :, :img_tensor.shape[-2], :img_tensor.shape[-1]] += img_tensor
    target = torch.zeros_like(pad_based_img)
    weight_sum = torch.zeros_like(target)
    weight_mask = create_weight_mask(window, pad)

    for x in tqdm(range(0, pad_based_width, window)):
        for y in range(0, pad_based_height, window):
            # calculate coordinates
            x_from = max(0, x - pad)
            y_from = max(0, y - pad)
            x_to = x + min(window + pad, pad_based_width - x)
            y_to = y + min(window + pad, pad_based_height - y)
            # pass patch through model
            patch_tensor = pad_based_img[:, :, y_from:y_to, x_from:x_to]
            outputs = model(pixel_values=patch_tensor)['outputs'].detach()
            # apply mask
            mask_x_from = 0 if x_from - pad >= 0 else pad # 0 если маска и патч влезают полностью; pad, если нужно кропнуть
            mask_y_from = 0 if y_from - pad >= 0 else pad
            mask_x_to = window + 2 * pad if x_to + pad <= pad_based_width else window + pad # win+2pad если маска влезает справа полностью; win+pad, если нужно кропнуть
            mask_y_to = window + 2 * pad if y_to + pad <= pad_based_height else window + pad
            # add patch
            cropped_mask = weight_mask[:, :, mask_y_from:mask_y_to, mask_x_from:mask_x_to]
            target[:, :, y_from:y_to, x_from:x_to] += outputs * cropped_mask
            weight_sum[:, :, y_from:y_to, x_from:x_to] += cropped_mask
            # remove tmp tensor
            del outputs
        # forced memory clean up
        collect()
        torch.cuda.empty_cache()
    target /= weight_sum.clamp(min=1e-6)
    return target.detach().cpu()[:, :, :image.shape[0], :image.shape[1]]

In [None]:
img = cv2.imread('test.tif', cv2.IMREAD_UNCHANGED)

stretch_res = autostretch_torch(torch.Tensor(img))
processed_tensor = terminate_blur(stretch_res['labels'].numpy(), aswin, processor)

scale, gamma = stretch_res['scale'], stretch_res['gamma']
im_min, im_max = stretch_res['min'], stretch_res['max']
linear_tensor = (processed_tensor[0].permute(1,2,0) ** (1/gamma) / scale) * (im_max - im_min) + im_min
res = np.clip(linear_tensor.numpy(), 0, 1)

cv2.imwrite('test_processed.tif', res.astype(np.float32), [cv2.IMWRITE_TIFF_COMPRESSION, 0])