In [235]:
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as tvf
import torchvision.transforms as tvtfms
import operator as op
from PIL import Image
from torch import nn
from timm import create_model

# For type hinting later on
import collections
import typing

In [2]:
net = create_model("vit_tiny_patch16_224", pretrained=False, num_classes=0, in_chans=3)

In [3]:
head = nn.Sequential(
    nn.BatchNorm1d(192),
    nn.Dropout(0.25),
    nn.Linear(192, 512, bias=False),
    nn.ReLU(inplace=True),
    nn.BatchNorm1d(512),
    nn.Dropout(0.5),
    nn.Linear(512, 37, bias=False)
)

In [4]:
model = nn.Sequential(net, head)

In [5]:
state = torch.load("models/MyModel.pth")

In [None]:
model.load_state_dict(state);

In [7]:
list(model.state_dict().keys())[:5]

In [8]:
list(state.keys())[:5]

In [9]:
def copy_weight(name, parameter, state_dict):
    """
    Takes in a layer `name`, model `parameter`, and `state_dict`
    and loads the weights from `state_dict` into `parameter`
    if it exists.
    """
   
    if name[0] == "0":
        name = name[:2] + "model." + name[2:]
    if name in state_dict.keys():
        input_parameter = state_dict[name]
        if input_parameter.shape == parameter.shape:
            parameter.copy_(input_parameter)
        else:
            print(f'Shape mismatch at layer: {name}, skipping')
    else:
        print(f'{name} is not in the state_dict, skipping.')

In [10]:
def apply_weights(input_model:nn.Module, input_weights:collections.OrderedDict, application_function:callable):
    """
    Takes an input state_dict and applies those weights to the `input_model`, potentially 
    with a modifier function.
    
    Args:
        input_model (`nn.Module`):
            The model that weights should be applied to
        input_weights (`collections.OrderedDict`):
            A dictionary of weights, the trained model's `state_dict()`
        application_function (`callable`):
            A function that takes in one parameter and layer name from `input_model`
            and the `input_weights`. Should apply the weights from the state dict into `input_model`.
    """
    model_dict = input_model.state_dict()
    for name, parameter in model_dict.items():
        application_function(name, parameter, input_weights)
    input_model.load_state_dict(model_dict)

In [11]:
apply_weights(model, state, copy_weight)

In [121]:
from fastai.vision.data import PILImage
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files
import fastai.vision.augment as fastai_aug

import numpy as np

In [13]:
path = untar_data(URLs.PETS)/'images'
fname = get_image_files(path)[0]
fname

In [14]:
im_pil = Image.open(fname)
im_fastai = PILImage.create(fname)

In [15]:
assert (np.array(im_pil) == np.array(im_fastai)).all()

In [16]:
crop_fastai = fastai_aug.CropPad((460,460))
crop_torch = tvtfms.CenterCrop((460,460))

In [21]:
assert (np.array(crop_fastai(im_fastai)) == np.array(crop_torch(im_pil))).all()

In [142]:
def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:
    """
    Takes a `PIL.Image` and crops it `size` unless one 
    dimension is larger than the actual image. Padding 
    must be performed afterwards if so.
    
    Args:
        image (`PIL.Image`):
            An image to perform cropping on
        size (`tuple` of integers):
            A size to crop to, should be in the form
            of (width, height)
            
    Returns:
        An augmented `PIL.Image`
    """
    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2
    
    top = max(top, 0)
    left = max(left, 0)
    
    height = min(top + size[0], image.shape[-1])
    width = min(left + size[1], image.shape[-2])
    return image.crop((top, left, height, width))

In [236]:
def pad(image, size:typing.Tuple[int,int]) -> Image:
    """
    Takes a `PIL.Image` and pads it to `size` with
    zeros.
    
    Args:
        image (`PIL.Image`):
            An image to perform padding on
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)
            
    Returns:
        An augmented `PIL.Image`
    """
    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2
    
    pad_top = max(-top, 0)
    pad_left = max(-left, 0)
    
    height, width = (
        max(size[1] - image.shape[-1] + top, 0), 
        max(size[0] - image.shape[-2] + left, 0)
    )
    return tvf.pad(
        image, 
        [pad_top, pad_left, height, width], 
        padding_mode="constant"
    )

In [237]:
size = (460,460)
tfmd_img = pad(crop(im_pil, size),size)

In [238]:
(np.array(tfmd_img) == crop_fastai(im_fastai)).all()

In [240]:
def gpu_crop(
    batch:torch.tensor, 
    size:typing.Tuple[int,int]
):
    """
    Crops each image in `batch` to a particular `size`.
    
    Args:
        batch (array of `torch.Tensor`):
            A batch of images, should be of shape `NxCxWxH`
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)
            
    Returns:
        A batch of cropped images
    """
   
    affine_matrix = torch.eye(3, device=batch.device).float()
    affine_matrix = affine_matrix.unsqueeze(0)
    affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
    affine_matrix = affine_matrix.contiguous()[:,:2]
    
    coords = F.affine_grid(
        affine_matrix, batch.shape[:2] + size, align_corners=True
    )
    
    top_range, bottom_range = coords.min(), coords.max()
    zoom = 1/(bottom_range - top_range).item()*2
    
    resizing_limit = min(
        batch.shape[-2]/coords.shape[-2],
        batch.shape[-1]/coords.shape[-1]
    )/2
    
    if resizing_limit > 1 and resizing_limit > zoom:
        batch = F.interpolate(
            batch, 
            scale_factor=1/resizing_limit, 
            mode='area', 
            recompute_scale_factor=True
        )
    return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)

In [245]:
# fastai augmentations
tt_fastai = fastai_aug.ToTensor()
i2f_fastai = fastai_aug.IntToFloatTensor()
rrc_fastai = fastai_aug.RandomResizedCropGPU((224,224))

# torchvision augmentations
tt_torch = tvtfms.ToTensor()

# apply fastai augmentations
base_im_fastai = crop_fastai(im_fastai)
result_im_fastai = rrc_fastai(
    i2f_fastai(
        tt_fastai(base_im_fastai).unsqueeze(0)
    ), split_idx=1
)

# apply torchvision augmentations
result_im_tv = gpu_crop(tt_torch(tfmd_img).unsqueeze(0), (224,224))

In [246]:
torch.allclose(result_im_fastai, result_im_tv)

In [247]:
norm_torch = tvtfms.Normalize([0.485, 0.456, 0.406], [0.229,0.224,0.225])

In [249]:
# fastai augmentations
norm_fastai = fastai_aug.Normalize.from_stats(*fastai_aug.imagenet_stats, cuda=False)
# apply fastai augmentations
base_im_fastai = crop_fastai(im_fastai)
result_im_fastai = norm_fastai(
    rrc_fastai(
        i2f_fastai(
            tt_fastai(base_im_fastai).unsqueeze(0)
        ), split_idx=1
    )
)

# apply torchvision augmentations
result_im_tv = norm_torch(gpu_crop(tt_torch(tfmd_img).unsqueeze(0), (224,224)))

In [250]:
torch.allclose(result_im_fastai, result_im_tv)

In [None]:
import typing
from PIL import Image
import torchvision.transforms.functional as tvf

In [None]:
def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:
    """
    Takes a `PIL.Image` and crops it `size` unless one 
    dimension is larger than the actual image. Padding 
    must be performed afterwards if so.
    
    Args:
        image (`PIL.Image`):
            An image to perform cropping on
        size (`tuple` of integers):
            A size to crop to, should be in the form
            of (width, height)
            
    Returns:
        An augmented `PIL.Image`
    """
    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2
    
    top = max(top, 0)
    left = max(left, 0)
    
    height = min(top + size[0], image.shape[-1])
    width = min(left + size[1], image.shape[-2])
    return image.crop((top, left, height, width))

In [None]:
def pad(image, size:typing.Tuple[int,int]) -> Image:
    """
    Takes a `PIL.Image` and pads it to `size` with
    zeros.
    
    Args:
        image (`PIL.Image`):
            An image to perform padding on
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)
            
    Returns:
        An augmented `PIL.Image`
    """
    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2
    
    pad_top = max(-top, 0)
    pad_left = max(-left, 0)
    
    height, width = (
        max(size[1] - image.shape[-1] + top, 0), 
        max(size[0] - image.shape[-2] + left, 0)
    )
    return tvf.pad(
        image, 
        [pad_top, pad_left, height, width], 
        padding_mode="constant"
    )

In [None]:
def gpu_crop(
    batch:torch.tensor, 
    size:typing.Tuple[int,int]
):
    """
    Crops each image in `batch` to a particular `size`.
    
    Args:
        batch (array of `torch.Tensor`):
            A batch of images, should be of shape `NxCxWxH`
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)
            
    Returns:
        A batch of cropped images
    """
   
    affine_matrix = torch.eye(3, device=batch.device).float()
    affine_matrix = affine_matrix.unsqueeze(0)
    affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
    affine_matrix = affine_matrix.contiguous()[:,:2]
    
    coords = F.affine_grid(
        affine_matrix, batch.shape[:2] + size, align_corners=True
    )
    
    top_range, bottom_range = coords.min(), coords.max()
    zoom = 1/(bottom_range - top_range).item()*2
    
    resizing_limit = min(
        batch.shape[-2]/coords.shape[-2],
        batch.shape[-1]/coords.shape[-1]
    )/2
    
    if resizing_limit > 1 and resizing_limit > zoom:
        batch = F.interpolate(
            batch, 
            scale_factor=1/resizing_limit, 
            mode='area', 
            recompute_scale_factor=True
        )
    return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)