In [None]:
import random
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms.v2 import GaussianBlur, Compose
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor, Trainer, TrainingArguments

In [None]:
device = torch.device('cuda')

In [None]:
transform = Compose([
    lambda img: img.convert('RGB'),
])

In [None]:
def adaptive_bg_homogenity_loss(
    pred: torch.Tensor, target: torch.Tensor, patch_size=32, eps=1e-3
) -> torch.Tensor:
        pool = torch.nn.AvgPool2d(patch_size, stride=patch_size)
        pred_var = torch.var(pool(pred), dim=[2,3])  # [B, C]
        target_var = torch.var(pool(target), dim=[2,3])
        mask = (target_var < eps).float()
        loss = (1 - mask) * (pred_var - target_var)**2
        return loss.mean()

def bright_mask_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    dark_loss = torch.nn.functional.sigmoid(1 - target) * ((pred - target) ** 2)
    bright_loss = torch.nn.functional.sigmoid(target) * torch.abs(pred - target)
    return dark_loss.mean() + bright_loss.mean()

def gradient_loss(pred: torch.Tensor, target: torch.Tensor, threshold=0.5):
    pred_grad_x = pred[:, :, :, :-1] - pred[:, :, :, 1:]  # [B,C,H,W-1]
    target_grad_x = target[:, :, :, :-1] - target[:, :, :, 1:]
    pred_grad_y = pred[:, :, :-1, :] - pred[:, :, 1:, :]  # [B,C,H-1,W]
    target_grad_y = target[:, :, :-1, :] - target[:, :, 1:, :]

    mask_x = (target_grad_x < threshold).float()
    mask_y = (target_grad_y < threshold).float()

    loss_x = mask_x * torch.abs(pred_grad_x - target_grad_x)
    loss_y = mask_y * torch.abs(pred_grad_y - target_grad_y)
    return loss_x.mean() + loss_y.mean()

def differentiable_histogram(
    x: torch.Tensor, bins: int = 256, bandwidth: float = 0.01
) -> torch.Tensor:
    batch_size, channels = x.shape[0], x.shape[1]
    bin_centers = torch.linspace(0, 1, bins, device=x.device)
    x_flat = x.view(batch_size, channels, -1, 1)  # [B, C, H*W, 1]
    distances = torch.abs(x_flat - bin_centers)    # [B, C, H*W, bins]
    weights = torch.clamp(1 - distances / bandwidth, 0, 1)
    hist = torch.sum(weights, dim=2)  # [B, C, bins]
    return hist / (x.shape[2] * x.shape[3])

def hist_loss(pred: torch.Tensor, target: torch.Tensor, bandwidth: float = 0.1) -> torch.Tensor:
    hist_pred = differentiable_histogram(pred, bandwidth=bandwidth)
    hist_target = differentiable_histogram(target, bandwidth=bandwidth)

    return torch.nn.functional.l1_loss(hist_pred, hist_target)

In [None]:
def motion_blur(image: torch.Tensor, length: int = 15, angle: float = 0.0) -> torch.Tensor:
    def _get_motion_kernel(length: int, angle_deg: float) -> torch.Tensor:
        import math
        angle = math.radians(angle_deg)
        length |= 1  # make odd
        kernel = torch.zeros((length, length))
        center = length // 2

        for i in range(length):
            dx = i - center
            dy = round(math.tan(angle) * dx)
            y = center + dy
            if 0 <= y < length:
                kernel[y, i] = 1.0
    
        kernel /= kernel.sum()
        return kernel.view(1, 1, length, length).repeat(3, 1, 1, 1)

    kernel = _get_motion_kernel(length, angle)
    return torch.nn.functional.conv2d(
        image, 
        kernel,
        padding=length // 2,
        groups=image.shape[1],
    )

def bokeh_blur(image: torch.Tensor, radius: int = 3) -> torch.Tensor:
    kernel_size = 2 * radius + 1
    kernel = torch.zeros((kernel_size, kernel_size))
    y, x = torch.meshgrid(
        torch.linspace(-radius, radius, kernel_size),
        torch.linspace(-radius, radius, kernel_size),
        indexing='ij',
    )
    mask = (x**2 + y**2) <= radius**2
    kernel[mask] = 1.0
    kernel /= kernel.sum()
    return torch.nn.functional.conv2d(
        image,
        kernel.view(1, 1, kernel_size, kernel_size).repeat(3, 1, 1, 1), 
        padding=kernel_size//2, 
        groups=image.shape[1]
    )

def anisotropic_gaussian_blur(
    image: torch.Tensor, 
    sigma_x: float = 1.0, 
    sigma_y: float = 1.0, 
    angle: float = 0.0
) -> torch.Tensor:
    def _get_rotated_gaussian_kernel(sigma_x, sigma_y, angle_deg, kernel_size):
        import math
        angle = math.radians(angle_deg)
        x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
        y = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
        x, y = torch.meshgrid(x, y, indexing='ij')

        x_rot = x * math.cos(angle) + y * math.sin(angle)
        y_rot = -x * math.sin(angle) + y * math.cos(angle)
        
        kernel = torch.exp(-(x_rot**2 / (2 * sigma_x**2) + y_rot**2 / (2 * sigma_y**2)))
        kernel = kernel / kernel.sum()
        return kernel.view(1, 1, kernel_size, kernel_size).repeat(3, 1, 1, 1)  # [C, 1, H, W]

    kernel_size = int(2 * 3 * max(sigma_x, sigma_y) + 1) | 1  # make odd
    kernel = _get_rotated_gaussian_kernel(sigma_x, sigma_y, angle, kernel_size)

    return torch.nn.functional.conv2d(
        image,
        kernel,
        padding=kernel_size//2,
        groups=image.shape[1]
    )

In [None]:
gaussian_blur = GaussianBlur(kernel_size=13, sigma=3.0)

blur_functions = [
    lambda x: gaussian_blur(x),
    lambda x: motion_blur(x, length=random.randint(4, 8), angle=random.randint(0, 360)),
    lambda x: anisotropic_gaussian_blur(
        x,
        sigma_x=random.randint(1, 5),
        sigma_y=random.randint(1, 5),
        angle=random.randint(0, 360),
    ),
    lambda x: bokeh_blur(x,radius=random.randint(3, 12)),
]

In [None]:
class AstroDataset(Dataset):
    def __init__(self, root_dir, processor, transform, **kwargs):
        super().__init__(**kwargs)
        import os
        self.root_dir = root_dir
        self.image_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir)]
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = self.transform(Image.open(img_path).convert('RGB'))
        image_tensor = self.processor(image, return_tensors='pt')['pixel_values']
        blur_func = random.choice(blur_functions)
        blurry_image = blur_func(image_tensor)

        return {
            'pixel_values': blurry_image[0],
            'labels': image_tensor[0]
        }


class AstroSwin2SR(Swin2SRForImageSuperResolution):
    def __init__(self, config):
        super().__init__(config)
        del self.upsample
        self.resample = torch.nn.Conv2d(60, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self, pixel_values: torch.Tensor, labels: torch.Tensor = None):
        output = self.swin2sr(pixel_values=pixel_values)
        output = self.resample(output.last_hidden_state)
        return {'outputs': output}

In [None]:
class TrainerWithCustomLoss(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        if 'labels' in inputs:
            labels = inputs.pop('labels')
        else:
            labels = None
        outputs = model(**inputs)['outputs']
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            h_loss = hist_loss(outputs, labels)
            g_loss = gradient_loss(outputs, labels)
            b_loss = bright_mask_loss(outputs, labels)
            abgh_loss = adaptive_bg_homogenity_loss(outputs, labels)
            loss = b_loss * 1.5 + g_loss * 0.75 + h_loss * 0.75 + abgh_loss * 1.5
        else:
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            loss = loss.mean()
        return (loss, outputs) if return_outputs else loss

In [None]:
model = AstroSwin2SR.from_pretrained('caidas/swin2SR-lightweight-x2-64').to(device)
processor = Swin2SRImageProcessor.from_pretrained('caidas/swin2SR-lightweight-x2-64')

In [None]:
train_dataset = AstroDataset('source_images/train', processor, transform)
eval_dataset = AstroDataset('source_images/test', processor, transform)

In [None]:
args = TrainingArguments(
    output_dir='astro_model',
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=4e-5,
    gradient_checkpointing=True,
    gradient_accumulation_steps=2,
    num_train_epochs=6,
    logging_steps=150,
    logging_strategy='steps',
    eval_steps=150,
    eval_strategy='steps',
    fp16=True,
)

trainer = TrainerWithCustomLoss(model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset)

In [None]:
trainer.train()

In [None]:
model.save_pretrained('astroswin')
processor.save_pretrained('astroswin')