In [None]:
import torch

from torch.cuda import amp
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor

In [2]:
class AstroSwin2SR(Swin2SRForImageSuperResolution):
    def __init__(self, config):
        super().__init__(config)
        del self.upsample
        self.resample = torch.nn.Conv2d(60, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self, pixel_values: torch.Tensor, labels: torch.Tensor = None):
        output = self.swin2sr(pixel_values=pixel_values)
        output = self.resample(output.last_hidden_state)
        return {'outputs': output}

In [None]:
device = torch.device('cuda')
aswin = AstroSwin2SR.from_pretrained('models/astroswin_v1').eval().to(device)
processor = Swin2SRImageProcessor.from_pretrained('models/astroswin_v1')

In [None]:
mock_tensor = torch.rand((1, 3, 256, 256), dtype=torch.float16, device=device)

with torch.no_grad():
    with amp.autocast(enabled=True):
        torch.onnx.export(
            aswin,                         # model to export
            (mock_tensor,),                # inputs of the model,
            "aswin_v1_0.onnx",             # filename of the ONNX model
            export_params=True,
            input_names=["pixel_values"],  # Rename inputs for the ONNX model
            opset_version=17,
            do_constant_folding=True,
            dynamo=False                   # True or False to select the exporter to use
        )