In [64]:
import torch
import torch.nn as nn
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
from importlib import import_module
from torch.utils.data import Dataset,DataLoader
from tqdm import tqdm
from transform import get_tta_transform
from PIL import Image
from utils import *

## valid dataset

In [86]:
class CustomDataset(Dataset):
    def __init__(self):
        super(CustomDataset, self).__init__()
        self.data_dir = "/opt/ml/input/data/train/crop_images/"
        self.info_path = "/opt/ml/input/data/train/train.csv"
        self.k = 5
        self.seed = 1997
        self.k_index= 4
        self.train = False

        self.folders = None
        self.set_k_fold(self.info_path, k_index=self.k_index, k=self.k, seed=self.seed)

        self.input_files = []
        self.images = []
        self.masks = []
        self.genders = []
        self.ages = []
        self.labels = []
        

        self.num_classes = None

        ### prepare images and labels
        if self.train:
            print("Train Data Loading...")
            age_count = [0,0,0]
        else:
            print("Test Data Loading...")

        for directory in tqdm(self.folders):
            image_dir = os.path.join(self.data_dir, directory)
            ID, GENDER, RACE, real_AGE = directory.split('_')

            if GENDER == "male":
                GENDER = 0
            elif GENDER =="female":
                GENDER = 1

            # fix gender label error
            if self.train:
                if GENDER == "male":
                    GENDER = 0.95
                elif GENDER =="female":
                    GENDER = 0.05

                if ID in ['006359', '006360', '006361', '006362', '006363', '006364', '001498-1', '004432']:
                    GENDER = 0.05 if GENDER==0.95 else 0.95
                elif ID in ['003724', '003421', '003399', '001200', '005223', '001270', '006226', '000664']:
                    GENDER = 0.5
            
            else:
                if GENDER == "male":
                    GENDER = 1
                elif GENDER =="female":
                    GENDER = 0
                if ID in ['006359', '006360', '006361', '006362', '006363', '006364', '001498-1', '004432']:
                    GENDER = 0 if GENDER==1 else 1

            if self.train:
                if int(real_AGE) < 30:
                    AGE = 0
                    age_count[0] += 1
                elif int(real_AGE) < 60:
                    AGE = 1
                    age_count[1] += 1
                else:
                    AGE = 2
                    age_count[2] += 1
            else:
                if int(real_AGE) < 30:
                    AGE = 0
                elif int(real_AGE) < 60:
                    AGE = 1
                else:
                    AGE = 2
        
        
            file_list = [f for f in os.listdir(image_dir) if f[0] != '.']
            for file in file_list:
                self.input_files.append(os.path.join(image_dir, file))
                
                if file[0:4] == "mask":
                    MASK = 0
                elif file[0:9] == "incorrect":
                    MASK = 1
                elif file[0:6] == "normal":
                    MASK = 2
                else:
                    raise

                self.images.append(np.array(Image.open(os.path.join(image_dir, file))))    
                self.masks.append(MASK)
                self.genders.append(GENDER)
                self.ages.append(AGE)
                self.labels.append(MASK*6 + GENDER*3 + AGE)

        if self.train:
            self.age_weight = [1/i for i in age_count]
            self.class_weight = [self.age_weight[a] for a in self.ages]
                    
    def __len__(self):
        return len(self.labels)
    
    
    def __getitem__(self, idx):
        ### load image
        image = self.images[idx]
        if self.transform:
            image = self.transform(image=image)['image']
        image.type(torch.float32)
        
        ### load label
        label = torch.tensor(self.labels[idx], dtype=torch.long)

    
        
        return image, label

    def set_k_fold(self, info_path, k_index=None, k=5, seed=1997):
        """
            output: train_folder, valid_folder
        """

        if not k_index in range(k):
            raise Exception('n_splits에 맞는 index를 입력해주세요')

        train_info = pd.read_csv(info_path)

        ### age/gender 동일 비율로 K Fold진행
        new_age = np.array(train_info['age'])
        new_gender = np.array(train_info['gender'])
        str_for_split = new_age.astype(str)+new_gender

        SFK = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
        for idx, (train_index, valid_index) in enumerate(SFK.split(train_info, str_for_split)):
            if idx == k_index:
                self.folders = train_info['path'][train_index] if self.train else train_info['path'][valid_index]
                break

    def set_transform(self, transform):
        self.transform = transform

In [87]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_module = getattr(import_module("model"), "CustomModel")  # default: BaseModel

model = model_module().to(device)
model.load_state_dict(torch.load("./model/ensemble4/best_acc.pth", map_location=device))
model = torch.nn.DataParallel(model)

In [88]:
dataset_module = getattr(import_module("dataset"), 'CustomDataset')  # default: BaseAugmentation
val_set = CustomDataset()

transform_module = getattr(import_module("transform"), 'Augmentation_384')
val_transform = transform_module(
        train=False
    )
val_set.set_transform(val_transform)


  1%|▏         | 7/540 [00:00<00:08, 63.86it/s]

Test Data Loading...


100%|██████████| 540/540 [00:10<00:00, 53.90it/s]


In [94]:
val_loader = DataLoader(
        val_set,
        batch_size=2,
        num_workers=2,
        shuffle=False,
        drop_last=False,
    )


In [118]:
def tta(tta_transforms, model, inputs):
    m_out_list = [[],[],[]]
    g_out_list = []
    a_out_list = [[],[],[]]
    for transformer in tta_transforms: # custom transforms or e.g. tta.aliases.d4_transform() 
                    # augment image
        augmented_image = transformer.augment_image(inputs)
        m, g, a = model(augmented_image)
                    
        m_tensor = nn.functional.softmax(m).cpu()
        g_tensor = torch.sigmoid(g).cpu()
        a_tensor = nn.functional.softmax(a).cpu()
                    # save results
        m_out_list[0].append(m_tensor[:,0])
        m_out_list[1].append(m_tensor[:,1])
        m_out_list[2].append(m_tensor[:,2])
        g_out_list.append(g_tensor)
        a_out_list[0].append(a_tensor[:,0])
        a_out_list[1].append(a_tensor[:,1])
        a_out_list[2].append(a_tensor[:,2])
                    
            # reduce results as you want, e.g mean/max/min
    m_out_list[0] = torch.tensor(m_out_list)
    print(a)
    raise
    m1_result = torch.mean(torch.tensor(m_out_list[0]),1)
    m2_result = torch.mean(torch.tensor(m_out_list[1]),1)
    m3_result = torch.mean(torch.tensor(m_out_list[2]),1)
    a1_result = torch.mean(torch.tensor(a_out_list[0]),1)
    a2_result = torch.mean(torch.tensor(a_out_list[1]),1)
    a3_result = torch.mean(torch.tensor(a_out_list[2]),1)
    g_result = torch.mean(torch.tensor(g_out_list),1)

    m_outs = torch.tensor([m1_result, m2_result, m3_result])
    a_outs = torch.tensor([a1_result, a2_result, a3_result])

    return m_outs, torch.tensor(g_result).cpu(), a_outs

In [119]:
ans = []
target = []
with torch.no_grad():
    for idx, (images, labels) in enumerate(tqdm(val_loader)):
        images = images.to(device)


        m_outs, g_outs, a_outs = tta(tta_transforms, model, images)

        tta_m_preds = torch.unsqueeze(torch.argmax(m_outs, dim=-1),0).cpu()
        tta_a_preds = torch.unsqueeze(torch.argmax(a_outs, dim=-1),0).cpu()
        if g_outs >= 0.5 : tta_g_preds = 1
        else: tta_g_preds= 0
        tta_g_preds = torch.unsqueeze(torch.tensor(tta_g_preds),0).cpu()
        tta_preds = label_encoder(tta_m_preds, tta_g_preds, tta_a_preds)
        ans.append(tta_preds[0].item())
        
        target.append(labels)
    target = torch.cat(target)
    ans = torch.tensor(ans)

  m_tensor = nn.functional.softmax(m).cpu()
  a_tensor = nn.functional.softmax(a).cpu()
  0%|          | 0/1890 [00:01<?, ?it/s]


ValueError: only one element tensors can be converted to Python scalars

In [95]:
f1_score(target,ans,average='macro')

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


nan