In [1]:
import torch
from PIL import Image
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from torchvision.models.resnet import resnet18
from torchvision.models import resnet101
from efficientnet_pytorch import EfficientNet
from torchvision import datasets, transforms

In [2]:
t = transforms.Compose([transforms.Resize((450, 600)),
                transforms.Pad((0, 0, 0, 150), fill = 0, padding_mode = 'constant'),
                transforms.Resize((512, 512)),
                transforms.ToTensor(),
                transforms.Normalize(mean=0.1354949, std=0.18222201)])

In [3]:
path = '/data/kpusteln/fetal/fetal_extracted/40_3_206.png'
ps = 0.0344111999999999
im = Image.open(path)
tensor = t(im)
tensor = tensor.unsqueeze(0)

In [4]:
class EffNet(torch.nn.Module):
    def __init__(self, out_features = 7, use_pretrained = False, extract = True, freeze = False, unfreeze_last_layers = False):
        super(EffNet, self).__init__()
        self.out_features = out_features
        self.extract = extract
        self.sigmoid = torch.nn.Sigmoid()
        self.backbone = EfficientNet.from_pretrained('efficientnet-b6', in_channels = 1, num_classes=self.out_features).float()
        self.fc = torch.nn.Linear(in_features=2305, out_features=out_features, bias=True).float()
        if use_pretrained:
            model = torch.load('/data/kpusteln/Fetal-RL/swin-transformer/output/effnet_cls/default/ckpt_epoch_4.pth')['model']
            for key in list(model.keys()):
                if 'backbone' in key:
                    model[key.replace('backbone.', '')] = model.pop(key) # remove prefix backbone.
            self.backbone.load_state_dict(model)
        if self.extract:    ## extract features for the transformer, ignore last layer
            self.backbone._fc = torch.nn.Identity()
        if freeze:
            for param in self.backbone.parameters():
                    param.requires_grad = False
                
        if unfreeze_last_layers:
            for param in self.backbone._blocks[44:].parameters():
                    param.requires_grad = True
                
    def count_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(self, x, ps):
        x = self.backbone(x)
        ps = ps.reshape(-1, 1).float()
        x = torch.cat((x, ps), dim = 1)
        if self.out_features == 1:
            x = self.fc(x)
            x = self.sigmoid(x)
        return x
    
model = EffNet(out_features = 1, use_pretrained = False, extract = True, freeze = True, unfreeze_last_layers = False)
checkpoint = torch.load('/data/kpusteln/Fetal-RL/swin-transformer/output/effnet_reg_v2/default/ckpt_epoch_89.pth', map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)
model.eval()
output = model(tensor, torch.tensor([ps]) )
output[0]


Loaded pretrained weights for efficientnet-b6


tensor([0.4479], grad_fn=<SelectBackward0>)

In [6]:
def denormalize(value, min, max):
    value = value * (max - min) + min
    return value

In [7]:
denormalize(output[0], 1.93, 7.48)

tensor([4.4161], grad_fn=<AddBackward0>)