# OPENSOURCE BLURXTERMINATOR

This jupyter notebook provides all necessary functions to run RUSSIAN opensource astrophotography images sharpening model -- AstroSWIN. This model is a finetune of Image2Image-transformer `caidas/swin2SR-lightweight-x2-64` (original model card is located [here](https://huggingface.co/caidas/swin2SR-lightweight-x2-64/tree/main))

Model has been trained in 7 iterations using astrobin and esahubble archive data, all exclusive images rights are belong to their authors, this model cannot be used to reproduce their results or to generate look-alike images.

## Download the model

After training I got two primary model checkpoints, v0.6 performs a bit harder sharpening, while v0.7 does it more softly. They also differ at sharpening images background, as this was the hardest part of model's training process.

Models weights are saved at my personal google drive storage and publicly available for download via link:
1. V0.6: https://drive.google.com/file/d/1N6s8O9MESdfOz4uCzbrl4GsKGZhJw2iK/view?usp=sharing
2. V0.7: https://drive.google.com/file/d/10ZR4X57PoqXunGMFig4cWo9Ef7X1uz3m/view?usp=sharing

Anyone can freely modify this model's architechture and finetune it furthermore.

In [None]:
!gdown 1N6s8O9MESdfOz4uCzbrl4GsKGZhJw2iK -O astroswin.zip && unzip astroswin.zip  # v0.6
!gdown 10ZR4X57PoqXunGMFig4cWo9Ef7X1uz3m -O astroswin.zip && unzip astroswin.zip  # v0.7

## Initialization

Below we define required imports and auxilary functions to load and infer the model

In [None]:
import numpy as np
import torch

from PIL import Image
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor

In [None]:
class AstroSwin2SR(Swin2SRForImageSuperResolution):
    def __init__(self, config):
        super().__init__(config)
        del self.upsample
        self.resample = torch.nn.Conv2d(60, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self, pixel_values: torch.Tensor, labels: torch.Tensor = None):
        output = self.swin2sr(pixel_values=pixel_values)
        output = self.resample(output.last_hidden_state)
        return {'outputs': output}

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
version = 'v0_7'

aswin = AstroSwin2SR.from_pretrained(f'astroswin_{version}').eval().to(device)
processor = Swin2SRImageProcessor.from_pretrained(f'astroswin_{version}')

## Blur terminating functions

Below we define tensor-to-pil format converter and all-in-one image sharpening run function

In [None]:
def tensor_to_pil(tensor: torch.Tensor):
    np_values = tensor.numpy()[0]
    np_values = np.clip(np_values, 0, 1)
    np_values = np.moveaxis(np_values, source=0, destination=-1)
    return (np_values * 255.0).round().astype(np.uint8)

@torch.no_grad()
def terminate_blur(
    image: Image, model: AstroSwin2SR, processor: Swin2SRImageProcessor, window: int = 256
):
    from tqdm import tqdm
    from gc import collect

    # Создание весовой маски
    def create_weight_mask(size, overlap):
        mask = torch.ones(1, 1, size + 2 * overlap, size + 2 * overlap)
        fade = torch.linspace(0, 1, overlap)

        # Вертикальные границы
        mask[..., :overlap, :] *= fade.view(1, 1, -1, 1)
        mask[..., -overlap:, :] *= fade.flip(0).view(1, 1, -1, 1)

        # Горизонтальные границы
        mask[..., :, :overlap] *= fade.view(1, 1, 1, -1)
        mask[..., :, -overlap:] *= fade.flip(0).view(1, 1, 1, -1)
        return mask.to(device)

    pad = 32
    pad_based_width, pad_based_height = (image.width // window + 1) * window, (image.height // window + 1) * window

    img_tensor = processor(image, return_tensors='pt').pixel_values.to(device)
    pad_based_img = torch.zeros(1, 3, pad_based_height, pad_based_width).to(device)
    pad_based_img[:, :, :img_tensor.shape[-2], :img_tensor.shape[-1]] += img_tensor
    target = torch.zeros_like(pad_based_img)
    weight_sum = torch.zeros_like(target)
    weight_mask = create_weight_mask(window, pad)

    for x in tqdm(range(0, pad_based_width, window)):
        for y in range(0, pad_based_height, window):
            # calculate coordinates
            x_from = max(0, x - pad)
            y_from = max(0, y - pad)
            x_to = x + min(window + pad, pad_based_width - x)
            y_to = y + min(window + pad, pad_based_height - y)
            # pass patch through model
            patch_tensor = pad_based_img[:, :, y_from:y_to, x_from:x_to]
            outputs = model(pixel_values=patch_tensor)['outputs'].detach()
            # apply mask
            mask_x_from = 0 if x_from - pad >= 0 else pad # 0 если маска и патч влезают полностью; pad, если нужно кропнуть
            mask_y_from = 0 if y_from - pad >= 0 else pad
            mask_x_to = window + 2 * pad if x_to + pad <= pad_based_width else window + pad # win+2pad если маска влезает справа полностью; win+pad, если нужно кропнуть
            mask_y_to = window + 2 * pad if y_to + pad <= pad_based_height else window + pad
            # add patch
            cropped_mask = weight_mask[:, :, mask_y_from:mask_y_to, mask_x_from:mask_x_to]
            target[:, :, y_from:y_to, x_from:x_to] += outputs * cropped_mask
            weight_sum[:, :, y_from:y_to, x_from:x_to] += cropped_mask
            #
            del outputs
        collect()
        torch.cuda.empty_cache()
    target /= weight_sum.clamp(min=1e-6)
    return Image.fromarray(tensor_to_pil(target.cpu()[:, :, :image.height, :image.width]))

Here I give an example of image processing

In [None]:
image = Image.open('rgb_GraXpert_1.tiff').convert('RGB')

In [None]:
processed = terminate_blur(image, aswin, processor, window=256)

In [None]:
with open('processed_2.tiff', 'wb') as f:
    processed.save(f)