In [None]:

from torchvision import models
import torch
import torch.nn as nn

: 

In [None]:


def return_efficientnet(size='small', dev='cpu', in_channels=1, out_channels=3, use_pretrained=True):
    if size == 'small':
        model_generator = models.efficientnet_v2_s
        weights = models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
    elif size == 'medium':
        model_generator = models.efficientnet_v2_m
        weights = models.EfficientNet_V2_M_Weights.IMAGENET1K_V1
    elif size == 'large':
        model_generator = models.efficientnet_v2_l
        weights = models.EfficientNet_V2_L_Weights.IMAGENET1K_V1
    else:
        assert (1 == 2)
    if use_pretrained:
        axial_tile_model = model_generator(weights=weights).to(dev)
    else:
        axial_tile_model = model_generator().to(dev)

    if size == 'large':
        axial_tile_model.features[0][0] = nn.Conv2d(in_channels, 32, kernel_size=(
            3, 3), stride=(2, 2), padding=(1, 1), bias=False).to(dev)
    else:
        axial_tile_model.features[0][0] = nn.Conv2d(in_channels, 24, kernel_size=(
            3, 3), stride=(2, 2), padding=(1, 1), bias=False).to(dev)

    axial_tile_model.classifier[-1] = nn.Linear(
        1280, out_channels, bias=True).to(dev)

    return axial_tile_model.to(dev)


def return_swin(size='small', dev='cpu', in_channels=1, out_channels=3, use_pretrained=True):
    if size == 'small':
        model_generator = models.swin_v2_t
        weights = models.Swin_V2_T_Weights.IMAGENET1K_V1
    elif size == 'medium':
        model_generator = models.swin_v2_s
        weights = models.Swin_V2_S_Weights.IMAGENET1K_V1
    elif size == 'large':
        model_generator = models.swin_v2_b
        weights = models.Swin_V2_B_Weights.IMAGENET1K_V1
    else:
        assert (1 == 2)
    if use_pretrained:
        axial_tile_model = model_generator(weights=weights).to(dev)
    else:
        axial_tile_model = model_generator().to(dev)

    if size == 'large':
        axial_tile_model.features[0][0] = nn.Conv2d(
            in_channels, 128, kernel_size=(4, 4), stride=(4, 4)).to(dev)
        axial_tile_model.head = nn.Linear(
            1024, out_channels, bias=True).to(dev)
    else:
        axial_tile_model.features[0][0] = nn.Conv2d(
            in_channels, 96, kernel_size=(4, 4), stride=(4, 4)).to(dev)
        axial_tile_model.head = nn.Linear(768, out_channels, bias=True).to(dev)

    return axial_tile_model.to(dev)


def return_resnet(size='small', dev='cpu', in_channels=1, out_channels=3, use_pretrained=True):
    if size == 'small':
        model_generator = models.resnet50
        weights = models.ResNet50_Weights.IMAGENET1K_V1
    elif size == 'medium':
        model_generator = models.resnet101
        weights = models.ResNet101_Weights.IMAGENET1K_V1
    elif size == 'large':
        model_generator = models.resnet152
        weights = models.ResNet152_Weights.IMAGENET1K_V1
    else:
        assert (1 == 2)

    if use_pretrained:
        axial_tile_model = model_generator(weights=weights).to(dev)
    else:
        axial_tile_model = model_generator().to(dev)

    axial_tile_model.conv1 = nn.Conv2d(
        in_channels, 64, kernel_size=(4, 4), stride=(4, 4)).to(dev)
    axial_tile_model.fc = nn.Linear(2048, out_channels, bias=True).to(dev)

    return axial_tile_model.to(dev)


def return_resnext(size='small', dev='cpu', in_channels=1, out_channels=3, use_pretrained=True):
    if size == 'small':
        model_generator = models.resnext50_32x4d
        weights = models.ResNeXt50_32X4D_Weights.IMAGENET1K_V1
    elif size == 'medium':
        model_generator = models.resnext101_64x4d
        weights = models.ResNeXt101_64X4D_Weights.IMAGENET1K_V1
    elif size == 'large':
        model_generator = models.resnext101_32x8d
        weights = models.ResNeXt101_32X8D_Weights.IMAGENET1K_V1
    else:
        assert (1 == 2)

    if use_pretrained:
        axial_tile_model = model_generator(weights=weights).to(dev)
    else:
        axial_tile_model = model_generator().to(dev)

    axial_tile_model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(
        7, 7), stride=(2, 2), padding=(3, 3), bias=False).to(dev)
    axial_tile_model.fc = nn.Linear(2048, out_channels, bias=True).to(dev)

    return axial_tile_model.to(dev)


def return_convnext(size='small', dev='cpu', in_channels=1, out_channels=3, use_pretrained=True):
    if size == 'small':
        model_generator = models.convnext_tiny
        weights = models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
    elif size == 'medium':
        model_generator = models.convnext_small
        weights = models.ConvNeXt_Small_Weights.IMAGENET1K_V1
    elif size == 'large':
        model_generator = models.convnext_base
        weights = models.ConvNeXt_Base_Weights.IMAGENET1K_V1
    else:
        assert (1 == 2)

    if use_pretrained:
        axial_tile_model = model_generator(weights=weights).to(dev)
    else:
        axial_tile_model = model_generator().to(dev)

    if size == 'large':
        axial_tile_model.features[0][0] = nn.Conv2d(
            in_channels, 128, kernel_size=(4, 4), stride=(4, 4)).to(dev)
        axial_tile_model.classifier[-1] = nn.Linear(
            1024, out_channels, bias=True).to(dev)
    elif size == 'medium':
        axial_tile_model.features[0][0] = nn.Conv2d(
            in_channels, 96, kernel_size=(4, 4), stride=(4, 4)).to(dev)
        axial_tile_model.classifier[-1] = nn.Linear(
            768, out_channels, bias=True).to(dev)
    else:
        axial_tile_model.features[0][0] = nn.Conv2d(
            in_channels, 64, kernel_size=(4, 4), stride=(4, 4)).to(dev)
        axial_tile_model.classifier[-1] = nn.Linear(
            512, out_channels, bias=True).to(dev)

    return axial_tile_model.to(dev)