In [15]:
import torch
from torchvision import models, datasets,transforms
import pandas as pd
import os 
from PIL import Image
from tqdm import tqdm 
import torch.nn as nn
from importlib import import_module

In [16]:
# Use three pre-trained model to ensemble
class MyEnsemble(nn.Module):
    def __init__(self, mask_model, gender_model, age_model):
        super(MyEnsemble, self).__init__()

        self.mask_model   = mask_model
        self.gender_model = gender_model
        self.age_model    = age_model
        
            
    def encode_multi_class(self,mask_label, gender_label, age_label) -> int:
        return mask_label * 6 + gender_label * 3 + age_label

    def forward(self, x):

        self.Ax = self.mask_model(x.clone())
        self.Bx = self.gender_model(x.clone())
        self.Cx = self.age_model(x.clone())
        
        _, mask_label   = torch.max(self.Ax, 1)
        _, gender_label = torch.max(self.Bx, 1)
        _, age_label    = torch.max(self.Cx, 1)
        print(mask_label,gender_label,age_label)
        return self.encode_multi_class(mask_label, gender_label, age_label)

In [17]:

# Mask Model
mask_model = models.resnet50(pretrained=False)

num_ftrs = mask_model.fc.in_features
mask_model.fc = nn.Linear(num_ftrs,3)
mask_model.load_state_dict(torch.load("models/mask_only.pth"))

# Gender Model
gender_model = models.resnet50(pretrained=False)

num_ftrs = gender_model.fc.in_features
gender_model.fc = nn.Linear(num_ftrs,2)
gender_model.load_state_dict(torch.load("models/gender_only.pth"))

# Age Model
age_model = models.resnet50(pretrained=False)


num_ftrs = age_model.fc.in_features
age_model.fc = nn.Linear(num_ftrs,3)
age_model.load_state_dict(torch.load("models/age_only.pth"))

<All keys matched successfully>

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_ft = MyEnsemble(mask_model, gender_model, age_model)
model_ft = model_ft.to(device);

In [42]:

# transform = transforms.Compose([
#                     transforms.Resize((224,224)),
#                     transforms.ToTensor(),
#                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# dataset_module = getattr(import_module("dataset"), "MaskBaseDatasetWithOnly")  # default: BaseAugmentation
# image_datasets = dataset_module(
#                             data_dir="../input/data/train/images",
#                             transform=transform,
#                             face_crop=True
#                             )

# inputs, labels = next(iter(image_datasets))
# model(inputs.unsqueeze(0).to(device))

In [19]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, info ,transforms=None):
        self.root_dir = root_dir
        self.df = pd.read_csv(info)
        self.transforms = transforms
        self.filename = self.df['ImageID'].values
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir,self.filename[idx])
        img = Image.open(img_path)
        
        if self.transforms:
            img = self.transforms(img)
        
        return self.filename[idx], img

    def __len__(self):
        return len(self.filename)

In [20]:
transform = transforms.Compose([
                    transforms.Resize((224,224)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

image_datasets =  CustomDataset('../input/data/eval/images','../input/data/eval/info.csv',transform)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=1,
                                             shuffle=False, num_workers=1)
dataset_sizes = len(image_datasets)
# class_names = image_datasets.classes

In [21]:
res = {'ImageID' : [],'ans' : []}

for (file_name , inputs) in tqdm(dataloaders) :
    inputs = inputs.to(device)
    preds = model_ft(inputs)
    
    res['ImageID'].append(file_name[0])
    res['ans'].append(preds.item())

  0%|          | 4/12600 [00:00<26:24,  7.95it/s]

tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')


  0%|          | 10/12600 [00:00<17:25, 12.04it/s]

tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')


  0%|          | 13/12600 [00:00<14:55, 14.06it/s]

tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')


  0%|          | 19/12600 [00:00<11:51, 17.67it/s]

tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')


  0%|          | 25/12600 [00:01<10:18, 20.33it/s]

tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')


  0%|          | 28/12600 [00:01<09:52, 21.23it/s]

tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')


  0%|          | 34/12600 [00:01<09:19, 22.45it/s]

tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')


  0%|          | 40/12600 [00:01<09:04, 23.07it/s]

tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')


  0%|          | 43/12600 [00:01<09:00, 23.24it/s]

tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')


  0%|          | 49/12600 [00:02<08:55, 23.46it/s]

tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')


  0%|          | 54/12600 [00:02<09:28, 22.05it/s]

tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')
tensor([0], device='cuda:0') tensor([1], device='cuda:0') tensor([1], device='cuda:0')





KeyboardInterrupt: 

In [12]:
ans = pd.DataFrame(data=res)

In [14]:
ans['ans'].unique()

array([4, 3])