In [2]:
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from torch import optim
from torch import nn

from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchvision import transforms

import torchvision.models as models

import random
from glob import glob
import pandas as pd
import numpy as np
from PIL import Image

def seed_everything(seed): # seed 고정
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    
def extract_day(images):
    day = int(images.split('.')[-2][-2:])
    return day

def make_day_array(images):
    day_array = np.array([extract_day(x) for x in images])
    return day_array

def make_combination(length, species, data_frame, direct_name):
    before_file_path = []
    after_file_path = []
    time_delta = []

    for i in range(length):
        
        direct = random.randrange(0,len(direct_name))
        temp = data_frame[data_frame['version'] == direct_name[direct]]
    
        sample = temp[temp['species'] == species].sample(2)
        after = sample[sample['day'] == max(sample['day'])].reset_index(drop=True)
        before = sample[sample['day'] == min(sample['day'])].reset_index(drop=True)

        before_file_path.append(before.iloc[0]['file_name'])
        after_file_path.append(after.iloc[0]['file_name'])
        delta = int(after.iloc[0]['day'] - before.iloc[0]['day'])
        time_delta.append(delta)

    combination_df = pd.DataFrame({
        'before_file_path': before_file_path,
        'after_file_path': after_file_path,
        'time_delta': time_delta,
    })

    combination_df['species'] = species

    return combination_df



## 학습 데이터 만들기

베이스라인 코드 덧글에서 'qwopqwop' 님과 '네네넹' 님이 논의 하신 것과 같이 기존의 코드를 돌리면 **병목 현상**이 발생되어서 학습에 많은 어려움이 있었습니다. 이러한 부분을 없애고자 사전에 모든 이미지에 대한 기본적인 전처리를 수행했고 이를 .np 형으로 저장을 해서 가져오는 방식을 택했습니다.

또한 기존의 Baseline 코드와 다른점이라고 한다면 랜덤한 경우의 수를 가져오는 것이 아닌   
**동일 품종 동일 식물(같은 폴더)**


를 가져온다는 것 입니다.

In [3]:
seed_everything(2048)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device

'cuda:0'

In [4]:
root_path = './drive/MyDrive/Colab Notebooks/224size_train'

# BC 폴더와 LT 폴더에 있는 하위 폴더를 저장한다.
bc_direct = glob(root_path + '/BC/*')
bc_direct_name = [x[-5:] for x in bc_direct]
lt_direct = glob(root_path + '/LT/*')
lt_direct_name = [x[-5:] for x in lt_direct]

# 하위 폴더에 있는 이미지들을 하위 폴더 이름과 매칭시켜서 저장한다.
bc_images = {key : glob(name + '/*.png') for key,name in zip(bc_direct_name, bc_direct)}
lt_images = {key : glob(name + '/*.png') for key,name in zip(lt_direct_name, lt_direct)}

# 하위 폴더에 있는 이미지들에서 날짜 정보만 따로 저장한다.
bc_dayes = {key : make_day_array(bc_images[key]) for key in bc_direct_name}
lt_dayes = {key : make_day_array(lt_images[key]) for key in lt_direct_name}

bc_dfs = []
for i in bc_direct_name:
    bc_df = pd.DataFrame({
        'file_name':bc_images[i],
        'day':bc_dayes[i],
        'species':'bc',
        'version':i
    })
    bc_dfs.append(bc_df)
    
lt_dfs = []
for i in lt_direct_name:
    lt_df = pd.DataFrame({
        'file_name':lt_images[i],
        'day':lt_dayes[i],
        'species':'lt',
        'version':i
    })
    lt_dfs.append(lt_df)

bc_dataframe = pd.concat(bc_dfs).reset_index(drop=True)
lt_dataframe = pd.concat(lt_dfs).reset_index(drop=True)
total_dataframe = pd.concat([bc_dataframe, lt_dataframe]).reset_index(drop=True)


bc_combination = make_combination(5000, 'bc', total_dataframe, bc_direct_name)
lt_combination = make_combination(5000, 'lt', total_dataframe, lt_direct_name)

bc_train = bc_combination.iloc[:4500]
bc_valid = bc_combination.iloc[4500:]

lt_train = lt_combination.iloc[:4500]
lt_valid = lt_combination.iloc[4500:]

train_set = pd.concat([bc_train, lt_train])
valid_set = pd.concat([bc_valid, lt_valid])

In [6]:
!pip install timm

Collecting timm
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[?25l[K     |▉                               | 10 kB 28.1 MB/s eta 0:00:01[K     |█▊                              | 20 kB 22.0 MB/s eta 0:00:01[K     |██▋                             | 30 kB 16.5 MB/s eta 0:00:01[K     |███▌                            | 40 kB 14.2 MB/s eta 0:00:01[K     |████▍                           | 51 kB 8.3 MB/s eta 0:00:01[K     |█████▏                          | 61 kB 8.2 MB/s eta 0:00:01[K     |██████                          | 71 kB 8.8 MB/s eta 0:00:01[K     |███████                         | 81 kB 9.8 MB/s eta 0:00:01[K     |███████▉                        | 92 kB 9.7 MB/s eta 0:00:01[K     |████████▊                       | 102 kB 7.9 MB/s eta 0:00:01[K     |█████████▋                      | 112 kB 7.9 MB/s eta 0:00:01[K     |██████████▍                     | 122 kB 7.9 MB/s eta 0:00:01[K     |███████████▎                    | 133 kB 7.9 MB/s eta 0:00:01[K    

In [7]:
import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)

['adv_inception_v3',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'convit_base',
 'convit_small',
 'convit_tiny',
 'cspdarknet53',
 'cspresnet50',
 'cspresnext50',
 'deit_base_distilled_patch16_224',
 'deit_base_distilled_patch16_384',
 'deit_base_patch16_224',
 'deit_base_patch16_384',
 'deit_small_distilled_patch16_224',
 'deit_small_patch16_224',
 'deit_tiny_distilled_patch16_224',
 'deit_tiny_patch16_224',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenetblur121d',
 'dla34',
 'dla46_c',
 'dla46x_c',
 'dla60',
 'dla60_res2net',
 'dla60_res2next',
 'dla60x',
 'dla60x_c',
 'dla102',
 'dla102x',
 'dla102x2',
 'dla169',
 'dm_nfnet_f0',
 'dm_nfnet_f1',
 'dm_nfnet_f2',
 'dm_nfnet_f3',
 'dm_nfnet_f4',
 'dm_nfnet_f5',
 'dm_nfnet_f6',
 'dpn68',
 'dpn

train_img_file_names = zip(train_set['before_file_path'], train_set['after_file_path'])
val_img_file_names = zip(valid_set['before_file_path'], valid_set['after_file_path'])

train_before = []
train_after = []
val_before = []
val_after = []

transform = transforms.Compose([
    transforms.ToTensor()
])

for before, after in train_img_file_names:
    before_image = Image.open(before)
    after_image = Image.open(after)

    before_image = transform(before_image)
    after_image = transform(after_image)

    train_before.append(before_image)
    train_after.append(after_image)


for before, after in val_img_file_names:
    before_image = Image.open(before)
    after_image = Image.open(after)

    # validation transform x
    val_before.append(before_image)
    val_after.append(after_image)


torch_train_before = np.zeros((4500,3,224, 224))
torch_train_after = np.zeros((4500,3,224, 224))


for i in range(4500):
    torch_train_before[i] = train_before[i].numpy()
    torch_train_after[i] = train_after[i].numpy()

np.save("train_before.npy", torch_train_before)
np.save("train_after.npy", torch_train_after)
np.save("train_time_delta.npy", np.array(train_set['time_delta']))

In [8]:
class KistDataset(Dataset):
    def __init__(self, combination_df, is_valid= None, is_test= None):

        self.combination_df = combination_df
        self.is_valid = is_valid
        self.is_test = is_test
        if is_valid == None and is_test == None:
            self.transform = transforms.Compose([                                    
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomAffine((-20, 20)),
            transforms.RandomRotation(90),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transforms.Compose([
            transforms.ToTensor(),                                     
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

    def __getitem__(self, idx):
        before_image = Image.open(self.combination_df.iloc[idx]['before_file_path'])
        after_image = Image.open(self.combination_df.iloc[idx]['after_file_path'])

        before_image = self.transform(before_image)
        after_image = self.transform(after_image)
        if self.is_test:
            return before_image, after_image
        time_delta = self.combination_df.iloc[idx]['time_delta']
        return before_image, after_image, time_delta

    def __len__(self):
        return len(self.combination_df)

In [12]:
class CompareCNN(nn.Module):

    def __init__(self, model_name):
        super(CompareCNN, self).__init__()
        self.regnet = model = timm.create_model(model_name, pretrained=True, num_classes=1)
        
    def forward(self, input):
        x = self.regnet(input)
        return x


class CompareNet(nn.Module):

    def __init__(self, model_name):
        super(CompareNet, self).__init__()
        self.before_net = CompareCNN(model_name)
        self.after_net = CompareCNN(model_name)

    def forward(self, before_input, after_input):
        before = self.before_net(before_input)
        after = self.after_net(after_input)
        delta = after - before
        return delta

In [14]:
import numpy as np
import json
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import random
import gc

            
lr = 1e-5
epochs = 10
batch_size = 64
valid_batch_size = 50

train_dataset = KistDataset(train_set)
valid_dataset = KistDataset(valid_set)

optimizer = optim.Adam(model.parameters(), lr=lr)

train_data_loader = DataLoader(train_dataset,
                               batch_size=batch_size,
                               shuffle=True)

valid_data_loader = DataLoader(valid_dataset,
                               batch_size=valid_batch_size)

In [15]:
# 사용할 모델의 이름 
model_names = ['regnetx_004','regnetx_004', 'regnetx_004', 'regnetx_004', 'regnetx_004']

# 사용할 시드 값
seeds = [428, 124, 333, 777, 1205]

# 모델 번호
model_number = 0

for i, name in enumerate(model_names):
    seed_everything(seeds[i])
    model = CompareNet(name).to(device)    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    model_number += 1
    
    # 모델 학습에 있어서 초기 428 seed 값은 사용하지 않았음
    if i == 0:
        train_data_loader = DataLoader(train_dataset,
                               batch_size=batch_size,
                               shuffle=True)
        continue
        
    # 총 10 에폭을 수행함
    for epoch in tqdm(range(0,epochs)):
        for step, (before_image, after_image, time_delta) in tqdm(enumerate(train_data_loader)):
            torch.cuda.empty_cache()
            before_image = before_image.to(device)
            after_image = after_image.to(device)
            time_delta = time_delta.to(device)

            optimizer.zero_grad()
            logit = model(before_image, after_image)
            train_loss = (torch.sum(torch.abs(logit.squeeze(1).float() - time_delta.float())) /
                          torch.LongTensor([batch_size]).squeeze(0).to(device))
            train_loss.backward()
            optimizer.step()

            if step % 15 == 0:
                print('\n=====================loss=======================')
                print(f'\n=====================EPOCH: {epoch}=======================')
                print(f'\n=====================step: {step}=======================')
                print('MAE_loss : ', train_loss.detach().cpu().numpy())

        if epoch % 5 == 0 and epoch != 0:
            torch.save(model.state_dict(), './drive/MyDrive/Colab Notebooks/ensemble/{}_{}_model_{}epoch.pt'.format(model_number,name,epoch))
            
    gc.enable()
    del model
    gc.collect()

  0%|          | 0/10 [00:00<?, ?it/s]

0it [00:00, ?it/s]




MAE_loss :  14.433638



MAE_loss :  13.120515



MAE_loss :  14.870891



MAE_loss :  13.526396



MAE_loss :  15.668597



MAE_loss :  12.745159



MAE_loss :  12.088287



MAE_loss :  11.453802



MAE_loss :  12.3417845



MAE_loss :  11.216797


0it [00:00, ?it/s]




MAE_loss :  11.106665



MAE_loss :  11.0048065



MAE_loss :  12.50445



MAE_loss :  11.663544



MAE_loss :  11.888184



MAE_loss :  9.208008



MAE_loss :  10.859445



MAE_loss :  11.247517



MAE_loss :  9.591719



MAE_loss :  9.149471


0it [00:00, ?it/s]




MAE_loss :  11.37799



MAE_loss :  10.156313



MAE_loss :  10.633503



MAE_loss :  10.57412



MAE_loss :  11.872185



MAE_loss :  9.13434



MAE_loss :  8.527725



MAE_loss :  6.8313723



MAE_loss :  10.277488



MAE_loss :  7.162163


0it [00:00, ?it/s]




MAE_loss :  7.8508186



MAE_loss :  8.748753



MAE_loss :  11.0478735



MAE_loss :  8.376371



MAE_loss :  8.36662



MAE_loss :  6.8518453



MAE_loss :  6.574192



MAE_loss :  6.6496706



MAE_loss :  7.306038



MAE_loss :  9.049722


0it [00:00, ?it/s]




MAE_loss :  6.3521233



MAE_loss :  7.0977



MAE_loss :  8.240974



MAE_loss :  5.5914903



MAE_loss :  6.9566035



MAE_loss :  6.053133



MAE_loss :  5.595493



MAE_loss :  5.961575



MAE_loss :  5.599123



MAE_loss :  5.9261093


0it [00:00, ?it/s]




MAE_loss :  4.013786



MAE_loss :  6.8665743



MAE_loss :  6.505702



MAE_loss :  5.012963



MAE_loss :  4.8927755



MAE_loss :  4.70012



MAE_loss :  4.8499594



MAE_loss :  5.558633



MAE_loss :  4.441949



MAE_loss :  5.328311


0it [00:00, ?it/s]




MAE_loss :  5.278494



MAE_loss :  5.064928



MAE_loss :  5.530699



MAE_loss :  5.4052625



MAE_loss :  4.3423176



MAE_loss :  4.2630615



MAE_loss :  4.0031843



MAE_loss :  4.464344



MAE_loss :  4.7609634



MAE_loss :  3.1897297


0it [00:00, ?it/s]




MAE_loss :  4.2800736



MAE_loss :  3.3440862



MAE_loss :  4.5942807



MAE_loss :  3.4666395



MAE_loss :  3.7197416



MAE_loss :  3.7320437



MAE_loss :  3.262002



MAE_loss :  3.6637146



MAE_loss :  4.0380125



MAE_loss :  3.8490992


0it [00:00, ?it/s]




MAE_loss :  2.5957956



MAE_loss :  3.6182966



MAE_loss :  2.9541974



MAE_loss :  3.521336



MAE_loss :  2.7864616



MAE_loss :  3.215442



MAE_loss :  3.6293917



MAE_loss :  2.8623838



MAE_loss :  2.6551902



MAE_loss :  2.6881528


0it [00:00, ?it/s]




MAE_loss :  2.5057635



MAE_loss :  2.3401518



MAE_loss :  3.4274375



MAE_loss :  2.1699889



MAE_loss :  2.6928716



MAE_loss :  2.707445



MAE_loss :  2.1589413



MAE_loss :  2.4724922



MAE_loss :  2.1272392



MAE_loss :  2.4249334


  0%|          | 0/10 [00:00<?, ?it/s]

0it [00:00, ?it/s]




MAE_loss :  14.880985



MAE_loss :  12.886756



MAE_loss :  11.872726



MAE_loss :  12.757771



MAE_loss :  11.855996



MAE_loss :  10.857054



MAE_loss :  13.335977



MAE_loss :  13.201436



MAE_loss :  11.653377



MAE_loss :  10.480963


0it [00:00, ?it/s]




MAE_loss :  12.77522



MAE_loss :  11.716919



MAE_loss :  11.30046



MAE_loss :  9.4601145



MAE_loss :  11.999257



MAE_loss :  10.018851



MAE_loss :  11.9074135



MAE_loss :  11.642612



MAE_loss :  10.553629



MAE_loss :  9.429338


0it [00:00, ?it/s]




MAE_loss :  10.174096



MAE_loss :  9.100248



MAE_loss :  8.605492



MAE_loss :  8.505547



MAE_loss :  10.243315



MAE_loss :  10.836063



MAE_loss :  8.485964



MAE_loss :  8.4610615



MAE_loss :  6.588314



MAE_loss :  6.4645014


0it [00:00, ?it/s]




MAE_loss :  7.9916973



MAE_loss :  7.970025



MAE_loss :  8.468186



MAE_loss :  8.487115



MAE_loss :  7.2999277



MAE_loss :  8.320757



MAE_loss :  6.65921



MAE_loss :  7.5802774



MAE_loss :  6.556383



MAE_loss :  6.8811336


0it [00:00, ?it/s]




MAE_loss :  6.8124967



MAE_loss :  7.380749



MAE_loss :  6.1222754



MAE_loss :  6.991432



MAE_loss :  5.060362



MAE_loss :  5.7776794



MAE_loss :  6.869132



MAE_loss :  7.1110315



MAE_loss :  4.9964323



MAE_loss :  6.9156218


0it [00:00, ?it/s]




MAE_loss :  5.1511393



MAE_loss :  6.5473866



MAE_loss :  4.692341



MAE_loss :  5.5152183



MAE_loss :  6.05606



MAE_loss :  5.026151



MAE_loss :  4.242159



MAE_loss :  4.615722



MAE_loss :  4.3612194



MAE_loss :  5.6517315


0it [00:00, ?it/s]




MAE_loss :  5.2899785



MAE_loss :  5.517019



MAE_loss :  3.9716277



MAE_loss :  5.0402603



MAE_loss :  4.039096



MAE_loss :  4.645382



MAE_loss :  3.9369223



MAE_loss :  4.684808



MAE_loss :  3.7120757



MAE_loss :  3.4806914


0it [00:00, ?it/s]




MAE_loss :  3.6723566



MAE_loss :  4.6731405



MAE_loss :  3.3471675



MAE_loss :  3.0058272



MAE_loss :  4.00508



MAE_loss :  3.2695878



MAE_loss :  2.9629889



MAE_loss :  3.955708



MAE_loss :  3.3659961



MAE_loss :  3.0509892


0it [00:00, ?it/s]




MAE_loss :  3.0454154



MAE_loss :  2.7169027



MAE_loss :  2.8640645



MAE_loss :  3.1166737



MAE_loss :  2.789978



MAE_loss :  3.5034018



MAE_loss :  2.8169856



MAE_loss :  2.5807266



MAE_loss :  3.061562



MAE_loss :  2.9603803


0it [00:00, ?it/s]




MAE_loss :  3.3527718



MAE_loss :  2.2503521



MAE_loss :  2.617072



MAE_loss :  2.978019



MAE_loss :  2.1763444



MAE_loss :  2.3977036



MAE_loss :  2.6832252



MAE_loss :  2.3395765



MAE_loss :  2.6013894



MAE_loss :  1.8242068


  0%|          | 0/10 [00:00<?, ?it/s]

0it [00:00, ?it/s]




MAE_loss :  14.832726



MAE_loss :  12.288546



MAE_loss :  13.22726



MAE_loss :  14.101533



MAE_loss :  12.156921



MAE_loss :  13.154152



MAE_loss :  12.14418



MAE_loss :  11.547744



MAE_loss :  11.689422



MAE_loss :  13.264391


0it [00:00, ?it/s]




MAE_loss :  12.322777



MAE_loss :  11.357498



MAE_loss :  11.301292



MAE_loss :  11.462343



MAE_loss :  11.806598



MAE_loss :  10.0129595



MAE_loss :  11.755214



MAE_loss :  10.266693



MAE_loss :  9.503956



MAE_loss :  10.365322


0it [00:00, ?it/s]




MAE_loss :  10.804485



MAE_loss :  8.254049



MAE_loss :  8.9447155



MAE_loss :  7.9270077



MAE_loss :  8.199396



MAE_loss :  9.448723



MAE_loss :  8.3685055



MAE_loss :  7.776609



MAE_loss :  8.920429



MAE_loss :  7.836729


0it [00:00, ?it/s]




MAE_loss :  9.773331



MAE_loss :  9.333189



MAE_loss :  7.6092057



MAE_loss :  7.084316



MAE_loss :  9.45157



MAE_loss :  7.4157953



MAE_loss :  6.2691116



MAE_loss :  7.130585



MAE_loss :  6.669841



MAE_loss :  6.450243


0it [00:00, ?it/s]




MAE_loss :  5.6387615



MAE_loss :  7.8741293



MAE_loss :  6.521439



MAE_loss :  7.0564404



MAE_loss :  6.023898



MAE_loss :  6.72042



MAE_loss :  6.134516



MAE_loss :  6.62178



MAE_loss :  5.92107



MAE_loss :  7.116184


0it [00:00, ?it/s]




MAE_loss :  5.627391



MAE_loss :  6.921339



MAE_loss :  6.231287



MAE_loss :  6.0680866



MAE_loss :  4.105653



MAE_loss :  5.905496



MAE_loss :  4.673892



MAE_loss :  4.39719



MAE_loss :  4.7483964



MAE_loss :  5.331362


0it [00:00, ?it/s]




MAE_loss :  4.2431707



MAE_loss :  5.0132017



MAE_loss :  4.4036493



MAE_loss :  3.2982159



MAE_loss :  4.727189



MAE_loss :  3.3128285



MAE_loss :  3.7041266



MAE_loss :  3.474833



MAE_loss :  4.4773026



MAE_loss :  3.6186957


0it [00:00, ?it/s]




MAE_loss :  3.57367



MAE_loss :  3.8809228



MAE_loss :  3.2941186



MAE_loss :  5.0306826



MAE_loss :  3.4850893



MAE_loss :  2.6264148



MAE_loss :  3.1521995



MAE_loss :  2.7860093



MAE_loss :  2.1868637



MAE_loss :  3.2704782


0it [00:00, ?it/s]




MAE_loss :  3.7163408



MAE_loss :  4.127738



MAE_loss :  3.0400155



MAE_loss :  3.5544713



MAE_loss :  2.6238275



MAE_loss :  2.9740443



MAE_loss :  2.3425007



MAE_loss :  3.507806



MAE_loss :  2.6265457



MAE_loss :  2.9784386


0it [00:00, ?it/s]




MAE_loss :  2.519244



MAE_loss :  2.9886627



MAE_loss :  2.3685975



MAE_loss :  2.4731348



MAE_loss :  2.833786



MAE_loss :  2.9570131



MAE_loss :  2.942935



MAE_loss :  2.9677963



MAE_loss :  1.7322154



MAE_loss :  2.14314


  0%|          | 0/10 [00:00<?, ?it/s]

0it [00:00, ?it/s]




MAE_loss :  13.837757



MAE_loss :  12.295903



MAE_loss :  13.732836



MAE_loss :  11.662688



MAE_loss :  10.281417



MAE_loss :  13.179346



MAE_loss :  11.149647



MAE_loss :  12.711059



MAE_loss :  13.4007845



MAE_loss :  12.40498


0it [00:00, ?it/s]




MAE_loss :  10.904346



MAE_loss :  11.213653



MAE_loss :  13.94555



MAE_loss :  11.056



MAE_loss :  11.906194



MAE_loss :  10.885515



MAE_loss :  11.349301



MAE_loss :  9.813658



MAE_loss :  11.028507



MAE_loss :  7.4079356


0it [00:00, ?it/s]




MAE_loss :  10.85848



MAE_loss :  8.569674



MAE_loss :  8.238718



MAE_loss :  8.532713



MAE_loss :  10.798027



MAE_loss :  7.962713



MAE_loss :  9.382872



MAE_loss :  9.600601



MAE_loss :  9.1477585



MAE_loss :  7.7331123


0it [00:00, ?it/s]




MAE_loss :  10.342746



MAE_loss :  10.28463



MAE_loss :  9.761845



MAE_loss :  9.757914



MAE_loss :  7.368633



MAE_loss :  7.732397



MAE_loss :  6.5820303



MAE_loss :  8.49048



MAE_loss :  6.258092



MAE_loss :  7.0821924


0it [00:00, ?it/s]




MAE_loss :  5.2934175



MAE_loss :  8.5330925



MAE_loss :  6.065634



MAE_loss :  6.083915



MAE_loss :  6.109201



MAE_loss :  6.4716454



MAE_loss :  5.927745



MAE_loss :  5.9750423



MAE_loss :  5.3997464



MAE_loss :  5.211914


0it [00:00, ?it/s]




MAE_loss :  5.5424094



MAE_loss :  5.909378



MAE_loss :  5.2039156



MAE_loss :  4.670663



MAE_loss :  4.6078625



MAE_loss :  7.165651



MAE_loss :  4.923547



MAE_loss :  6.7181835



MAE_loss :  4.5856657



MAE_loss :  4.308134


0it [00:00, ?it/s]




MAE_loss :  4.909341



MAE_loss :  5.715341



MAE_loss :  4.5401845



MAE_loss :  4.4663196



MAE_loss :  4.558964



MAE_loss :  3.8399615



MAE_loss :  4.3507767



MAE_loss :  5.013825



MAE_loss :  4.5566006



MAE_loss :  3.834906


0it [00:00, ?it/s]




MAE_loss :  4.3172784



MAE_loss :  2.978692



MAE_loss :  3.5547073



MAE_loss :  3.964697



MAE_loss :  3.9233847



MAE_loss :  3.7992754



MAE_loss :  4.548097



MAE_loss :  3.2544656



MAE_loss :  3.4973025



MAE_loss :  2.6144097


0it [00:00, ?it/s]




MAE_loss :  4.0517173



MAE_loss :  3.0463202



MAE_loss :  4.0285864



MAE_loss :  3.2985852



MAE_loss :  4.008483



MAE_loss :  3.8033407



MAE_loss :  3.1173582



MAE_loss :  2.9517503



MAE_loss :  3.5547175



MAE_loss :  2.7980402


0it [00:00, ?it/s]




MAE_loss :  3.2071419



MAE_loss :  2.7069733



MAE_loss :  2.3080606



MAE_loss :  3.0451503



MAE_loss :  2.7530727



MAE_loss :  2.3927174



MAE_loss :  2.5795074



MAE_loss :  2.8166702



MAE_loss :  3.0468626



MAE_loss :  3.2785957


In [16]:
test_set = pd.read_csv('./drive/MyDrive/Colab Notebooks/224size_test/test_data.csv')
test_set['l_root'] = test_set['before_file_path'].map(lambda x: './drive/MyDrive/Colab Notebooks/224size_test/' + x.split('_')[1] + '/' + x.split('_')[2])
test_set['r_root'] = test_set['after_file_path'].map(lambda x: './drive/MyDrive/Colab Notebooks/224size_test/' + x.split('_')[1] + '/' + x.split('_')[2])
test_set['before_file_path'] = test_set['l_root'] + '/' + test_set['before_file_path'] + '.png'
test_set['after_file_path'] = test_set['r_root'] + '/' + test_set['after_file_path'] + '.png'


test_dataset = KistDataset(test_set, is_test=True)
test_data_loader = DataLoader(test_dataset,
                               batch_size=64)

test_set

Unnamed: 0,idx,before_file_path,after_file_path,l_root,r_root
0,0,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...
1,1,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...
2,2,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...
3,3,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...
4,4,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...
...,...,...,...,...,...
3955,3955,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...
3956,3956,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...,./drive/MyDrive/Colab Notebooks/224size_test/L...
3957,3957,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...
3958,3958,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...,./drive/MyDrive/Colab Notebooks/224size_test/B...


In [17]:
model_names = ['regnetx_004','regnetx_004', 'regnetx_004', 'regnetx_004', 'regnetx_004', 'regnetx_004', 'regnetx_004', 'regnetx_004', 'regnetx_004']

count = 0

# 앙상블 모델들은 ensemble 폴더 안에 위치함
# 10Epoch의 결과들을 먼저 ensemble 폴더에 넣어서 예측 값을 구함
# 10Epoch에 대한 예측을 구하고 나면 40Epoch에 대한 예측 값을 구함
for pt in glob("./drive/MyDrive/Colab Notebooks/ensemble/*.pt"):
    # 모델을 불러온다.
    model = CompareNet(model_names[count]).to(device)
    count+=1
    # 학습된 가중치 값을 모델에 입력
    model.load_state_dict(torch.load(pt))
    # 모델을 평가 모드 상태로 전환
    model.eval()
    
    test_value = []
    with torch.no_grad():
        for test_before, test_after in tqdm(test_data_loader):
            test_before = test_before.to(device)
            test_after = test_after.to(device)
            logit = model(test_before, test_after)
            value = logit.squeeze(1).detach().cpu().float()

            test_value.extend(value)
    
    # submission 형식을 불러온다.
    submission = pd.read_csv('./drive/MyDrive/Colab Notebooks/sample_submission.csv')
    
    # 예측한 값들은 텐서 형태로 변환 시켜준다.
    predict = torch.FloatTensor(test_value)

    # 음수의 값을 갖는 모든 값들을 1 Day 차이가 발생하도록 바꿔줌
    temp_predict = predict.numpy()
    temp_predict[np.where(temp_predict<1)] = 1

    # 모델별로 예측 값을 predict 폴더 안에 저장함
    submission['time_delta'] = temp_predict
    submission.to_csv('./drive/MyDrive/Colab Notebooks/predict/{}_{}.csv'.format(count, model_names[count-1]), index=False)

  0%|          | 0/62 [00:00<?, ?it/s]

  0%|          | 0/62 [00:00<?, ?it/s]

  0%|          | 0/62 [00:00<?, ?it/s]

  0%|          | 0/62 [00:00<?, ?it/s]

In [18]:
# 예측 csv 이름을 갖고옴
predict = glob("./drive/MyDrive/Colab Notebooks/predict/*.csv")

# 첫 번째에 해당하는 csv 파일을 불러옴
p1 = pd.read_csv(predict[0])

# 나머지 8개의 csv 파일에 대해서 결과 값들을 모두 더해줌
for p in predict[1:]:
    temp = pd.read_csv(p)
    
    p1['time_delta'] += temp['time_delta']

# 9개의 모델들에 대한 평균 값을 구하고 이를 최종 예측 값으로 저장함
p1['time_delta'] = p1['time_delta'] / 9
p1.to_csv('./drive/MyDrive/Colab Notebooks/predict/result.csv', index=False)

In [19]:
# submission 형식을 불러온다.
submission = pd.read_csv('./drive/MyDrive/Colab Notebooks/sample_submission.csv')

# 예측한 값들은 텐서 형태로 변환 시켜준다.
predict = torch.FloatTensor(test_value)

# 음수의 값을 갖는 모든 값들을 1 Day 차이가 발생하도록 바꿔줌
temp_predict = predict.numpy()
temp_predict[np.where(temp_predict<1)] = 1

In [20]:
# 모델의 예측 값을 저장함
submission['time_delta'] = temp_predict
submission.to_csv('./drive/MyDrive/Colab Notebooks/regnetx_004_ensemble.csv', index=False)