In [1]:
import pandas as pd
import numpy as np
# import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# from transformers import ViTFeatureExtractor, ViTModel, ViTConfig, AutoConfig

from PIL import Image


In [2]:
# Need to redefine the local class to load the weights
class model_final(nn.Module):
    def __init__(self, model_trans_top, trans_layer_norm, model_Res, dp_rate = 0.3):
        super().__init__()
        # All the trans model layers
        self.model_trans_top = model_trans_top
        self.trans_layer_norm = trans_layer_norm
        self.trans_flatten = nn.Flatten()
        self.trans_linear = nn.Linear(150528, 2048)

        # All the ResNet model
        self.model_Res = model_Res

        # Merge the result and pass the
        self.dropout = nn.Dropout(dp_rate)
        self.linear1 = nn.Linear(4096, 500)
        self.linear2 = nn.Linear(500,1)

    def forward(self, trans_b, res_b):
        # Get intermediate outputs using hidden layer
        result_trans = self.model_trans_top(trans_b)
        patch_state = result_trans.last_hidden_state[:,1:,:] # Remove the classification token and get the last hidden state of all patchs
        result_trans = self.trans_layer_norm(patch_state)
        result_trans = self.trans_flatten(patch_state)
        result_trans = self.dropout(result_trans)
        result_trans = self.trans_linear(result_trans)

        result_res = self.model_Res(res_b)
        # result_res = result_res.squeeze() # Batch size cannot be 1
        result_res = torch.reshape(result_res, (result_res.shape[0], result_res.shape[1]))

        result_merge = torch.cat((result_trans, result_res),1)
        result_merge = self.dropout(result_merge)
        result_merge = self.linear1(result_merge)
        result_merge = self.dropout(result_merge)
        result_merge = self.linear2(result_merge)

        return result_merge

device = torch.device('cpu')
model = torch.load('model_1228', map_location=device)

In [5]:
df = pd.read_csv("test.csv")
df["file_path"] = df["Id"].apply(lambda Id: "test/" + Id + ".jpg") # Create image path

In [6]:
class petDataset_pred(Dataset):
    def __init__(self, dataframe, trans_transform=None, res_transform=None):
        self.images = dataframe["file_path"]
        self.trans_transform = trans_transform
        self.res_transform = res_transform

    def __len__ (self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path)

        image_trans = self.trans_transform(image)

        image_res = self.res_transform(image)
        return image_trans, image_res

trans_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
res_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


test_ds = petDataset_pred(df, trans_transform=trans_transform, res_transform=res_transform)
test_dl = DataLoader(test_ds, batch_size=2, shuffle=False)


In [8]:
output = np.array([])

for x_trans, x_res in test_dl:
    result = model(x_trans, x_res)
    output = np.append(output, result.cpu().detach().numpy())

df_submission = pd.read_csv('sample_submission.csv')
df_submission["Pawpularity"] = output
df_submission.to_csv("submission.csv", index=False)