In [1]:
import os
import cv2
import torch
import torch.nn as nn
import numpy as np
import pandas as pd

from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader

!pip install timm
import timm

!pip install -q -U albumentations
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2



In [2]:
CFG = {
    "models" : ("tf_efficientnet_b3_ns", "tf_efficientnet_b4", "tf_efficientnet_b5_ns"),
    "crop_size" : (380, 380),
    "test_bs" : 32,
    "cross_valid" : True,
    "k_fold" : 5,
    "num_workers" : 8,
    "num_classes" : 18,
    "device" : torch.device("cuda")
}

In [3]:
transforms = {
    "test" : A.Compose([
#         A.Resize(CFG["img_size"][0], CFG["img_size"][1], p=1.0),
        A.CenterCrop(CFG["crop_size"][0], CFG["crop_size"][1], p=1.0),
        A.CLAHE(p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], p=1.0),
        ToTensorV2(p=1.0)
        ])
}

In [4]:
class MaskDataset(Dataset):
    def __init__(self, df, exist_label, transforms=None):
        self.df = df
        self.transforms = transforms
        self.exist_label = exist_label
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index: int):
        if self.exist_label:
            target = self.df.iloc[index]["label"]       
            path = self.df.iloc[index]["filepath"]
        else:
            path = os.path.join(test_dir, "images", self.df.iloc[index]["ImageID"])
            
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
        
        if self.exist_label:
            return img, target
        else:
            return img

In [5]:
test_dir = "/opt/ml/input/data/eval"
save_path = "/opt/ml/code/checkpoints"

test_csv = os.path.join(test_dir, "info.csv")
df = pd.read_csv(test_csv)

In [8]:
submission = pd.read_csv(os.path.join(test_dir, 'info.csv'))

test_dataset = MaskDataset(submission, exist_label=False, transforms=transforms["test"])
test_iter = DataLoader(test_dataset, batch_size=CFG["test_bs"], shuffle=False, num_workers=CFG["num_workers"])

all_predictions = []
for model_name in CFG["models"]:
    model = timm.create_model(model_name, num_classes=CFG["num_classes"])
    model.to(CFG["device"])
    model_predictions = []
    if CFG["cross_valid"]:
        for k in range(CFG["k_fold"]):
            model.load_state_dict(torch.load(os.path.join(save_path, f'{model_name}_[{k}].pth')))
            model.eval()
            temp_predictions = []
            for images in tqdm(test_iter):
                with torch.no_grad():
                    output = model(images.float().to(CFG["device"]))
                    temp_predictions.extend(output.cpu().numpy())
            model_predictions.append(temp_predictions)
    else:
        model.load_state_dict(torch.load(os.path.join(save_path, f'{model_name}.pth')))
        model.eval()
        temp_predictions = []
        for images in tqdm(test_iter):
            with torch.no_grad():
                output = model(images.float().to(CFG["device"]))
                temp_predictions.extend(output.cpu().numpy())   
        model_predictions.append(temp_predictions)
    
    model_predictions = np.array(model_predictions)
    model_predictions = model_predictions.sum(axis=0)
    model_predictions = torch.tensor(model_predictions)
    model_predictions = nn.Softmax(dim=1)(model_predictions)
    all_predictions.append(model_predictions.numpy())
    
all_predictions = np.array(all_predictions)
all_predictions = all_predictions.sum(axis=0)
all_predictions = all_predictions.argmax(axis=-1)
            
submission['ans'] = all_predictions
submission.to_csv(f'/opt/ml/code/submissions/ensemble_efficientNet.csv', index=False)

100%|██████████| 394/394 [00:53<00:00,  7.38it/s]
100%|██████████| 394/394 [00:53<00:00,  7.41it/s]
100%|██████████| 394/394 [00:54<00:00,  7.23it/s]
100%|██████████| 394/394 [00:52<00:00,  7.51it/s]
100%|██████████| 394/394 [00:53<00:00,  7.43it/s]
100%|██████████| 394/394 [00:58<00:00,  6.79it/s]
100%|██████████| 394/394 [00:58<00:00,  6.75it/s]
100%|██████████| 394/394 [00:57<00:00,  6.81it/s]
100%|██████████| 394/394 [00:57<00:00,  6.80it/s]
100%|██████████| 394/394 [00:57<00:00,  6.80it/s]
100%|██████████| 394/394 [01:09<00:00,  5.66it/s]
100%|██████████| 394/394 [01:09<00:00,  5.69it/s]
100%|██████████| 394/394 [01:09<00:00,  5.69it/s]
100%|██████████| 394/394 [01:09<00:00,  5.70it/s]
100%|██████████| 394/394 [01:08<00:00,  5.72it/s]
