In [None]:
!gdown 10ZR4X57PoqXunGMFig4cWo9Ef7X1uz3m -O astroswin.zip && unzip -qq astroswin.zip && rm -rf astroswin.zip
!gdown 1v1HvfuoQPMprDa5FgRvjRMlwIwt5mp84 -O dataset.zip && unzip -qq dataset.zip && rm -f dataset.zip

In [None]:
import os
import random
import torch
import torch.nn.functional as F

from gc import collect
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms.v2 import RandomCrop, GaussianBlur, Compose, RandomRotation
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor, Trainer, TrainingArguments

In [None]:
device = torch.device('cuda')
Image.MAX_IMAGE_PIXELS = None

In [None]:
def random_downsample(image):
    base_dim = 256
    min_dim = min(image.width, image.height)
    downscale_range = min_dim // base_dim
    downscale_mul = random.randint(1, downscale_range) if downscale_range > 1 else 1
    return image.resize(
        (image.width // downscale_mul, image.height // downscale_mul),
        Image.Resampling.BICUBIC
    )

rotator = RandomRotation(90)
crop = RandomCrop((256, 256))

transform = Compose([
    lambda img: img.convert('RGB'),
    lambda img: random_downsample(img),
    lambda img: rotator(img),
    lambda img: crop(img),
])

In [None]:
def sharp_loss(pred: torch.Tensor, target: torch.Tensor, alpha: float = 0.7):
    content_loss = F.l1_loss(pred, target)
    laplacian_kernel = torch.tensor([
        [0, 1, 0],
        [1, -4, 1],
        [0, 1, 0]
    ], dtype=torch.float32).view(1, 1, 3, 3).repeat(1, 3, 1, 1).to(pred.device)
    pred_edges = F.conv2d(pred, laplacian_kernel, padding=1)
    target_edges = F.conv2d(target, laplacian_kernel, padding=1)
    edge_loss = F.l1_loss(pred_edges, target_edges)

    return alpha * content_loss + (1 - alpha) * edge_loss

def gradient_loss(pred: torch.Tensor, target: torch.Tensor, temperature: float = 1):
    pred_grad_x = pred[:, :, :, :-1] - pred[:, :, :, 1:]
    pred_grad_y = pred[:, :, :-1, :] - pred[:, :, 1:, :]
    target_grad_x = target[:, :, :, :-1] - target[:, :, :, 1:]
    target_grad_y = target[:, :, :-1, :] - target[:, :, 1:, :]

    mask_x = F.sigmoid((target_grad_x - 0.5) / temperature)
    mask_y = F.sigmoid((target_grad_y - 0.5) / temperature)

    loss_x = mask_x * F.l1_loss(pred_grad_x, target_grad_x)
    loss_y = mask_y * F.l1_loss(pred_grad_y, target_grad_y)
    return (loss_x.mean() + loss_y.mean()) / 2

def hist_loss(pred: torch.Tensor, target: torch.Tensor, bandwidth: float = 0.1) -> torch.Tensor:
    def differentiable_histogram(
        x: torch.Tensor, bins: int = 256, bandwidth: float = 0.01
    ) -> torch.Tensor:
        batch_size, channels = x.shape[0], x.shape[1]
        bin_centers = torch.linspace(0, 1, bins, device=x.device)  # [bins]
        x_flat = x.view(batch_size, channels, -1, 1)  # [B, C, H*W, 1]
        distances = torch.abs(x_flat - bin_centers)    # [B, C, H*W, bins]
        weights = torch.clamp(1 - distances / bandwidth, 0, 1)
        hist = torch.sum(weights, dim=2)  # [B, C, bins]
        return hist / (x.shape[2] * x.shape[3])  # normalize

    hist_pred = differentiable_histogram(pred, bandwidth=bandwidth)
    hist_target = differentiable_histogram(target, bandwidth=bandwidth)

    return F.l1_loss(hist_pred, hist_target)

In [None]:
def motion_blur(image: torch.Tensor, length: int = 15, angle: float = 0.0) -> torch.Tensor:
    def _get_motion_kernel(length: int, angle_deg: float) -> torch.Tensor:
        import math
        angle = math.radians(angle_deg)
        length |= 1  # make odd to have odd-sized kernel
        kernel = torch.zeros((length, length))
        center = length // 2

        for i in range(length):
            dx = i - center
            dy = round(math.tan(angle) * dx)
            y = center + dy
            if 0 <= y < length:
                kernel[y, i] = 1.0

        kernel /= kernel.sum()
        return kernel.view(1, 1, length, length).repeat(3, 1, 1, 1)

    kernel = _get_motion_kernel(length, angle)
    return F.conv2d(
        image,
        kernel,
        padding=length // 2,
        groups=image.shape[1],
    )

def bokeh_blur(image: torch.Tensor, radius: int = 3) -> torch.Tensor:
    kernel_size = 2 * radius + 1
    kernel = torch.zeros((kernel_size, kernel_size))
    y, x = torch.meshgrid(
        torch.linspace(-radius, radius, kernel_size),
        torch.linspace(-radius, radius, kernel_size),
        indexing='ij',
    )
    mask = (x**2 + y**2) <= radius**2
    kernel[mask] = 1.0
    kernel /= kernel.sum()
    return F.conv2d(
        image,
        kernel.view(1, 1, kernel_size, kernel_size).repeat(3, 1, 1, 1),
        padding=kernel_size//2,
        groups=image.shape[1]
    )

def anisotropic_gaussian_blur(
    image: torch.Tensor,
    sigma_x: float = 1.0,
    sigma_y: float = 1.0,
    angle: float = 0.0
) -> torch.Tensor:
    def _get_rotated_gaussian_kernel(sigma_x, sigma_y, angle_deg, kernel_size):
        import math
        angle = math.radians(angle_deg)
        x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
        y = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
        x, y = torch.meshgrid(x, y, indexing='ij')

        x_rot = x * math.cos(angle) + y * math.sin(angle)
        y_rot = -x * math.sin(angle) + y * math.cos(angle)

        kernel = torch.exp(-(x_rot**2 / (2 * sigma_x**2) + y_rot**2 / (2 * sigma_y**2)))
        kernel = kernel / kernel.sum()
        return kernel.view(1, 1, kernel_size, kernel_size).repeat(3, 1, 1, 1)  # [C, 1, H, W]

    kernel_size = int(2 * 3 * max(sigma_x, sigma_y) + 1) | 1  # make odd to have odd sized kernel
    kernel = _get_rotated_gaussian_kernel(sigma_x, sigma_y, angle, kernel_size)

    return F.conv2d(
        image,
        kernel,
        padding=kernel_size//2,
        groups=image.shape[1]
    )

In [None]:
gaussian_blur = GaussianBlur(kernel_size=3, sigma=2.0)

blur_functions = [
    lambda x: gaussian_blur(x),
    lambda x: motion_blur(x, length=random.randint(2, 5), angle=random.randint(0, 360)),
    lambda x: anisotropic_gaussian_blur(
        x,
        sigma_x=random.randint(1, 3),
        sigma_y=random.randint(1, 3),
        angle=random.randint(0, 360),
    ),
    lambda x: bokeh_blur(x,radius=random.randint(3, 6)),
]

In [None]:
class AstroDataset(Dataset):
    def __init__(
        self,
        root_dir,
        processor,
        transform,
        corrupt_dropout=0.05,
        noise_base_amount=0.2,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.root_dir = root_dir
        self.image_paths = [
            os.path.join(root_dir, f)
            for f in os.listdir(root_dir)
            if f.endswith(('jpg', 'png', 'tiff', 'jpeg')) and min(Image.open(os.path.join(root_dir, f)).size) >= 256
        ]
        self.corrupt_dropout = corrupt_dropout
        self.noise_base_amount = noise_base_amount
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def _corrupt(self, image: torch.Tensor):
        if random.random() < self.corrupt_dropout:
            return image
        blur_composition = random.choices(blur_functions, k=random.randint(2, len(blur_functions)))
        for blur in blur_composition:
            image = blur(image)
        return image + torch.randn(image.shape) * random.random() * self.noise_base_amount

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)
        image_tensor = self.processor(image, return_tensors='pt')['pixel_values']
        corrupted_tensor = self._corrupt(image_tensor.clone().detach())
        return {
            'pixel_values': corrupted_tensor[0],
            'labels': image_tensor[0],
        }


class AstroSwin2SR(Swin2SRForImageSuperResolution):
    def __init__(self, config):
        super().__init__(config)
        del self.upsample
        self.resample = torch.nn.Conv2d(60, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self, pixel_values: torch.Tensor, labels: torch.Tensor = None):
        output = self.swin2sr(pixel_values=pixel_values)
        output = self.resample(output.last_hidden_state)
        return {'outputs': output}

In [None]:
class TrainerWithCustomLoss(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if 'labels' in inputs:
            labels = inputs.pop('labels')
        else:
            labels = None
        outputs = model(**inputs)['outputs']
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            h_loss = hist_loss(outputs, labels, bandwidth=5e-2)
            g_loss = gradient_loss(outputs, labels, temperature=0.5)
            d_loss = sharp_loss(outputs, labels)
            loss = d_loss * 2.0 + g_loss * 1.0 + h_loss * 1.0
        else:
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            loss = loss.mean()
        return (loss, outputs) if return_outputs else loss

In [None]:
aswin = AstroSwin2SR.from_pretrained('astroswin_v7').to(device)
processor = Swin2SRImageProcessor.from_pretrained('astroswin_v7')

In [None]:
collect()
torch.cuda.empty_cache()

In [None]:
train_dataset = AstroDataset('hq_images/train', processor, transform, corrupt_dropout=0.1, noise_base_amount=0.1)
eval_dataset = AstroDataset('hq_images/test', processor, transform, corrupt_dropout=0.1, noise_base_amount=0.1)

In [None]:
args = TrainingArguments(
    output_dir='astro_model',
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=1e-5,
    gradient_checkpointing=False,
    gradient_accumulation_steps=4,
    num_train_epochs=4,
    logging_steps=15,
    logging_strategy='steps',
    eval_steps=100,
    eval_strategy='steps',
    fp16=True,
    report_to='none',
)

trainer = TrainerWithCustomLoss(
    model=aswin,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

In [None]:
trainer.train()

In [None]:
aswin.save_pretrained('astroswin_v1')
processor.save_pretrained('astroswin_v1')

In [None]:
!zip -r astroswin_v1.zip astroswin_v1/