In [11]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torchvision.models import efficientnet_v2_l
from torchvision import datasets, transforms
from PIL import Image
    

In [35]:


class EffnetV2_L(torch.nn.Module):
    def __init__(self, out_features = 7, in_channels = 1):
        super(EffnetV2_L, self).__init__()
        
        
        self.out_features = out_features
        self.in_channels = in_channels
        self.model = efficientnet_v2_l(weights = 'EfficientNet_V2_L_Weights.IMAGENET1K_V1')
        self.model.features[0] = torch.nn.Conv2d(self.in_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        #self.model.classifier = torch.nn.Identity()
        self.model.classifier = torch.nn.Sequential(nn.Dropout(0.4), nn.Linear(1280, self.out_features))
        
        
    def count_params(self):
        
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
        
    def forward(self, x):
        features = self.model.features(x)
        features = self.model.avgpool(features)
        return self.model(x), features
        

        
model = EffnetV2_L(out_features = 7, in_channels = 1)
#test_tensor = torch.rand(16, 1, 448, 448).cuda()

# model = EffnetV2_L(out_features = 7, in_channels = 1).cuda()

# print(model.count_params())

# print(model(test_tensor).shape)


In [80]:
checkpoint = torch.load('/data/kpusteln/Fetal-RL/swin-transformer/output/effnetv2_cls/default/ckpt_epoch_34.pth', map_location='cpu')

In [81]:
msg = model.load_state_dict(checkpoint['model'], strict=False)

In [82]:
msg

<All keys matched successfully>

In [83]:
transform = transforms.Compose([transforms.Resize((450, 600)),
                transforms.Pad((0, 0, 0, 150), fill = 0, padding_mode = 'constant'),
                transforms.Resize((512, 512)),
                transforms.ToTensor(),
                transforms.Normalize(mean=0.1354949, std=0.18222201)])

In [84]:
img = Image.open('/data/kpusteln/fetal/fetal_extracted/68_2_86.png')

In [85]:
img = transform(img)

In [86]:
model.eval()
output, features = model(img.unsqueeze(0))

In [87]:
features = features.squeeze()

In [88]:
output.argmax()

tensor(4)

In [89]:
torch.nn.functional.softmax(output)

  """Entry point for launching an IPython kernel.


tensor([[7.9751e-06, 6.1487e-05, 2.8229e-04, 2.4531e-01, 7.5207e-01, 4.1704e-04,
         1.8470e-03]], grad_fn=<SoftmaxBackward0>)