In [25]:
import torch.utils.data as data
from PIL import Image
import torch
from datasets.one_hot_encode import one_hot_encode

import pandas as pd
import os

from torchvision import transforms

In [27]:
def soft_label(label):
    if 1801 <= label:
        soft = label * 0.002631578947 - 2.239473683

    elif 1600 < label and label < 1801:
        soft = label / 100. - 15.51
    
    elif label <= 1600:
        soft = label * 0.003105590062 - 4.472049689

    return soft

In [109]:
class MultiAtmaDataset_v2(data.Dataset):
    """画像とtargetを返すdataset
    """
    def __init__(self, data_dir, img_name_df, target_df, year_df, mate_df, tech_df, trans, target_scale=None):
        self.data_dir = data_dir

        self.img_name = list(img_name_df)
        self.label = list(target_df.values)
        self.year = list(year_df.values)
        self.mate = list(mate_df.values)
        self.tech = list(tech_df.values)

        self.trans = trans
        self.target_scale = target_scale

    def __len__(self):
        return len(self.img_name)

    def __getitem__(self, idx):
        img_path = os.path.join(self.data_dir, self.img_name[idx] + '.jpg')
        img = Image.open(img_path)
        img = self.trans(img)

        tar = self.label[idx]
        year = self.year[idx]
        mate = self.mate[idx]
        tech = self.tech[idx]

        print('year: ', year)

        if self.target_scale is not None:
            tar = tar * self.target_scale

        tar = torch.tensor(list(tar)).clone().detach()

        ce_tar = torch.tensor(tar, dtype=torch.long).clone().detach()
        soft_tar = torch.tensor(soft_label(year)).clone().detach()

        mate = torch.tensor(list(mate)).clone().detach()
        tech = torch.tensor(list(tech)).clone().detach()

        return img, tar, ce_tar, soft_tar, mate, tech

In [110]:
# pandas データフレーム 読み込み
DATA_DIR = '/home/junya/Documents/dataset_atmaCup11'

img_path = os.path.join(DATA_DIR, 'photos')

train_df = pd.read_csv(os.path.join(DATA_DIR, 'train.csv'))
mate_df = pd.read_csv(os.path.join(DATA_DIR, 'materials.csv'))
tech_df = pd.read_csv(os.path.join(DATA_DIR, 'techniques.csv'))

In [111]:
mate_df

Unnamed: 0,name,object_id
0,ink,002bff09b09998d0be65
1,paper,002bff09b09998d0be65
2,pencil,002bff09b09998d0be65
3,watercolor (paint),00309fb1ef05416f9c1f
4,paper,00309fb1ef05416f9c1f
...,...,...
9076,ink,ffe49bba69d06446de7e
9077,paper,ffe49bba69d06446de7e
9078,paper,ffe77db10be3400bed53
9079,ink,ffe77db10be3400bed53


In [112]:
def target_encoder(target, origin, num=2):
        """class: 10 -> 2 + other class 1 -> 3"""
        class_names = [class_name for class_name in origin.name.value_counts().index]
        use_class_names = class_names[:num]
        use_class_names.append('other')

        not_use_class_names = class_names[num:]

        # other making
        for name in not_use_class_names:
            target.loc[target[name] == 1, 'other'] = 1
            target = target.drop(name, axis=1)

        target = target.fillna(0)
        return target

In [113]:
encoded_mate_df = target_encoder(one_hot_encode(mate_df), mate_df, num=5)
encoded_tech_df = target_encoder(one_hot_encode(tech_df), tech_df, num=2)

In [114]:
encoded_mate_df

Unnamed: 0,object_id,chalk,ink,paper,pencil,watercolor (paint),other
0,002bff09b09998d0be65,0,1,1,1,0,0.0
1,00309fb1ef05416f9c1f,0,0,1,0,1,0.0
2,003a1562e97f79ba96dc,0,0,1,1,0,0.0
3,004890880e8e7431147b,1,0,1,0,0,0.0
4,00718c32602425f504c1,1,0,1,0,0,0.0
...,...,...,...,...,...,...,...
3931,ffa3259fff8e6f3818a1,0,1,1,1,1,0.0
3932,ffd4d361756587883e48,0,0,1,1,0,0.0
3933,ffd794b7b311b7b7fd92,0,1,1,1,1,0.0
3934,ffe49bba69d06446de7e,0,1,1,0,0,0.0


In [115]:

        
# train, materials, techniques を紐付ける
unit_mate_df = train_df.merge(encoded_mate_df, on='object_id', how='left')
unit_mate_df.loc[unit_mate_df.paper.isnull(), 'other'] = 1
unit_mate_df = unit_mate_df.fillna(0)

unit_tech_df = train_df.merge(encoded_tech_df, on='object_id', how='left')
unit_tech_df.loc[unit_tech_df.pen.isnull(), 'other'] = 1
unit_tech_df = unit_tech_df.fillna(0)

assert len(unit_tech_df) == len(unit_mate_df), "mate, tech is not same size"

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(256),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [116]:
train_dataset = MultiAtmaDataset_v2(
                data_dir = img_path,
                img_name_df = train_df.object_id,
                target_df = train_df.drop('object_id', axis=1).drop('sorting_date', axis=1).drop('art_series_id', axis=1),
                year_df= train_df.drop('object_id', axis=1).drop('target', axis=1).drop('art_series_id', axis=1),
                mate_df = unit_mate_df.drop('object_id', axis=1).drop('sorting_date', axis=1).drop('art_series_id', axis=1).drop('target', axis=1),
                tech_df = unit_tech_df.drop('object_id', axis=1).drop('sorting_date', axis=1).drop('art_series_id', axis=1).drop('target', axis=1),
                trans = train_transforms,
            )

In [118]:
for i in range(0, 10):
    print(train_dataset[i])

year:  [1631]
(tensor([[[1.9749, 1.9749, 1.9920,  ..., 1.9920, 1.9920, 1.9920],
         [1.9749, 1.9749, 1.9920,  ..., 1.9920, 1.9920, 1.9920],
         [1.9407, 1.9407, 1.9578,  ..., 2.0092, 1.9920, 1.9920],
         ...,
         [1.9407, 1.9407, 1.9407,  ..., 1.9920, 1.9920, 1.9920],
         [1.9407, 1.9407, 1.9407,  ..., 1.9920, 1.9920, 1.9920],
         [1.9407, 1.9407, 1.9407,  ..., 1.9920, 1.9920, 1.9920]],

        [[2.0434, 2.0434, 2.0609,  ..., 2.0609, 2.0609, 2.0609],
         [2.0434, 2.0434, 2.0609,  ..., 2.0609, 2.0609, 2.0609],
         [2.0084, 2.0084, 2.0084,  ..., 2.0609, 2.0609, 2.0609],
         ...,
         [2.0434, 2.0434, 2.0434,  ..., 2.0434, 2.0434, 2.0434],
         [2.0434, 2.0434, 2.0434,  ..., 2.0434, 2.0434, 2.0434],
         [2.0434, 2.0434, 2.0434,  ..., 2.0434, 2.0434, 2.0434]],

        [[1.9777, 1.9777, 1.9951,  ..., 2.0648, 2.0648, 2.0648],
         [1.9603, 1.9603, 1.9777,  ..., 2.0648, 2.0648, 2.0648],
         [1.9080, 1.9080, 1.9254,  ..., 2.0

