In [3]:
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
import matplotlib.pyplot as plt
import timm
import albumentations as A
import warnings
warnings.filterwarnings("ignore")

path_model = '../models/_ckpt.pt'

In [6]:
class BackboneNet(nn.Module):
    def __init__(self, timm_model_name, len_meta_x, out_dim, pretrained=True, freezepretrained=False):
        super().__init__()

        self.efficientmodel = timm.create_model(timm_model_name, pretrained=pretrained)
        n_features = self.efficientmodel.classifier.in_features
        self.efficientmodel.classifier = nn.Identity()
        
        self.metamodel = nn.Sequential(
            nn.Linear(len_meta_x, 512), 
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3), 
            nn.Linear(512, 128), 
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        n_features += 128
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(n_features, out_dim)

        # freeze weights of efficientnet
        if freezepretrained:
            for param in self.efficientmodel.parameters():
                param.requires_grad = False


    def forward(self, image, meta_x):
        # backbone 1: efficient net
        x1 = self.efficientmodel(image)
        # backbone 2: meta model
        x2 = self.metamodel(meta_x)
        # concatenate backbones
        x = torch.cat((x1, x2), dim=1)
        # head
        x = self.dropout(x)
        x = self.fc(x)
        # drop dimension of size 1 (same as .view(-1)) (if not, broadcasting problems with y ground truth)
        x = x.squeeze()     
        return x

MODEL_NAME='tf_efficientnet_b1_ns'
model = BackboneNet(timm_model_name=MODEL_NAME, len_meta_x=12, out_dim=1, pretrained=False, freezepretrained=False)

In [10]:
checkpoint = torch.load(path_model)
checkpoint.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss', 'val_loss', 'rmse', 'val_rmse', 'train_lr'])

In [18]:
checkpoint['val_rmse']

18.684137693578954

In [19]:
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [20]:
model

BackboneNet(
  (efficientmodel): EfficientNet(
    (conv_stem): Conv2dSame(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (act1): SiLU(inplace=True)
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): SiLU(inplace=True)
            (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, t