# Hyper Parameters

In [1]:
notebookName = "BaseModel"
nepochs = 10
batch_size = 4
learning_rate = 0.001

SEED = 20180724
target_layer = ['road_block', 'walkway', 'road_divider', 'traffic_light']
class_names = ['None'] + target_layer
MAX_OBJECTS = 1
MAX_POINTS = 30
MAX_POINT_CLOUDS = 1200

In [2]:
LIDAR_PC_SHAPE = [4, MAX_POINT_CLOUDS] # [x,y,z,intensity] x num_of_points
MAP_OBJECT_SHAPE = [MAX_OBJECTS, MAX_POINTS, 2] # num_of_objects x num_of_points x [x, y]
MAP_LAYER_SHAPE = [len(class_names), MAX_OBJECTS] # num_of_class x num_of_objects
PATCH_SIZE = [-1, -1]

# Base Setting

In [3]:
import os
from pathlib import Path
import time

from nuscenes.nuscenes import NuScenes
from utils.custom_lidar_api import CustomLidarApi
from utils.custom_map_api_expansion import CustomNuScenesMap

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.nn import Module
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import TensorDataset, DataLoader

import warnings
warnings.filterwarnings(action='ignore')

%matplotlib inline 

In [4]:
locations = ['singapore-onenorth', 'singapore-hollandvillage', 'singapore-queenstown', 'boston-seaport']
version = 'v1.0-trainval'
dataroot = 'E:/datasets/nuscenes'

In [5]:
torch.manual_seed(SEED)
np.random.seed(SEED)

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [7]:

PATH = Path(f"./models/{notebookName}")
if os.path.isdir(PATH):
    dir_list = os.listdir(PATH)
    num_files = 0
    while True:
        if os.path.isfile(str(PATH / f"{num_files}")):
            print(num_files)
            num_files += 1
        else:
            break
else:
    os.mkdir(PATH)
    num_files = 0
num_files

0

In [8]:
class_dict = dict()

for i, name in enumerate(target_layer):
    class_dict[name] = i+1
class_array = np.eye(len(class_names))

# Load Nusc, Map Api, and Ldr Api

In [9]:
nusc = NuScenes(version=version, dataroot=dataroot, verbose=True)

Loading NuScenes tables for version v1.0-trainval...
Loading nuScenes-lidarseg...
32 category,
8 attribute,
4 visibility,
64386 instance,
12 sensor,
10200 calibrated_sensor,
2631083 ego_pose,
68 log,
850 scene,
34149 sample,
2631083 sample_data,
1166187 sample_annotation,
4 map,
34149 lidarseg,
Done loading in 61.213 seconds.
Reverse indexing ...
Done reverse indexing in 9.0 seconds.


In [10]:
map_api = dict([])
for location in locations:
    map_api[location] = CustomNuScenesMap(dataroot = dataroot, map_name= location, target_layer_names=target_layer, max_objs=MAX_OBJECTS, max_points=MAX_POINTS)

In [11]:
ldr_api = CustomLidarApi(nusc)

In [12]:
# get all sample token
sample_tokens = []
for scene in nusc.scene:
    token = scene['first_sample_token']
    while token != scene['last_sample_token']:
        sample_tokens.append(token)
        sample = nusc.get('sample', token)
        token = sample['next']
print(len(sample_tokens))

33299


In [13]:
def train_val_split(sample_tokens, ratio = 0.1, shuffle = True):
    index = np.array(range(len(sample_tokens)))
    index = np.random.choice(index.shape[0], index.shape[0], replace = False)
    
    valid_num = int(index.shape[0] * ratio)
    
    valid_tokens = sample_tokens[:valid_num]
    train_tokens = sample_tokens[valid_num:]
    
    return train_tokens, valid_tokens

train_tokens, valid_tokens = train_val_split(sample_tokens)

print(len(sample_tokens))
print(len(train_tokens))
print(len(valid_tokens))

33299
29970
3329


# Customized Dataset

In [14]:
# LIDAR_PC_SHAPE = [4, MAX_POINT_CLOUDS] # [x,y,z,intensity] x num_of_points
# MAP_OBJECT_SHAPE = [MAX_OBJECTS, MAX_POINTS, 2] # num_of_objects x num_of_points x [x, y]
# MAP_LAYER_SHAPE = [len(class_names), MAX_OBJECTS] # num_of_class x num_of_objects
class NusceneDataset(Dataset):
    def __init__(self, tokens, nusc, map_api, ldr_api, train = True):
        self.tokens = tokens
        self.nusc = nusc
        self.map_api = map_api
        self.ldr_api = ldr_api
        self.train = train
        
        self.length = len(tokens)
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        token = self.tokens[idx]
        
        sample = self.nusc.get('sample', token)
        scene = self.nusc.get('scene', sample['scene_token'])
        log_meta = self.nusc.get('log', scene['log_token'])
        
        location = log_meta['location']
        sample_data = self.nusc.get('sample_data', sample['data']['LIDAR_TOP'])
        
        pc = self.ldr_api.get_lidar_from_keyframe(token, max_points = LIDAR_PC_SHAPE[1], car_coord = True)
        ego = self.ldr_api.get_egopose_from_keyframe(token)
        structures = self.map_api[log_meta['location']].get_closest_structures(ego, \
                                                                               patch = PATCH_SIZE,\
                                                                               global_coord=False, \
                                                                               mode = 'intersect'\
                                                                              )
        if(len(structures) == 0):
            print(log_meta['location'])
            print(PATCH_SIZE)
            print(token)
            print(ego)
            print(len(structures))
            print(structures)
            print(self.map_api[log_meta['location']].structures[:3])
            print(self.map_api[log_meta['location']])
        
        X = torch.Tensor(pc.points)
        if self.train:
            classes, objects = self.get_label(structures)
        else:
            classes, objects = self.get_label(list())
        
        return X, classes, objects
    
    def get_label(self, structures):
        classes = list(map(lambda x: torch.Tensor(class_array[class_dict[x["layer"]], :]).reshape(MAP_LAYER_SHAPE[0], MAP_LAYER_SHAPE[1]), structures))
        classes = torch.cat(classes, axis = 1)
        
        objects = list(map(lambda x: torch.Tensor(x['nodes'].reshape(MAP_OBJECT_SHAPE[0], MAP_OBJECT_SHAPE[1], MAP_OBJECT_SHAPE[2])), structures))
        objects = torch.cat(objects, axis = 0)
        
        return classes, objects
        

train_dataset = NusceneDataset(tokens = train_tokens, nusc = nusc, map_api = map_api, ldr_api = ldr_api)
valid_dataset = NusceneDataset(tokens = valid_tokens, nusc = nusc, map_api = map_api, ldr_api = ldr_api)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
valid_loader = DataLoader(valid_dataset, batch_size = batch_size, shuffle = False)

In [15]:
# def get_label(structures):
#     classes = list(map(lambda x: torch.Tensor(class_array[class_dict[x["class"]], :]).reshape(MAP_LAYER_SHAPE[0], MAP_LAYER_SHAPE[1]), structures))
#     classes = torch.cat(classes, axis = 1)

#     objects = list(map(lambda x: torch.Tensor(x['nodes'].reshape(MAP_OBJECT_SHAPE[0], MAP_OBJECT_SHAPE[1], MAP_OBJECT_SHAPE[2])), structures))
#     objects = torch.cat(objects, axis = 0)
#     return classes, objects

# %load_ext line_profiler

# token = "54ea5283e6d643b4b54fe489bd373c16"

# sample = nusc.get('sample', token)
# scene = nusc.get('scene', sample['scene_token'])
# log_meta = nusc.get('log', scene['log_token'])
# location = log_meta['location']
# sample_data = nusc.get('sample_data', sample['data']['LIDAR_TOP'])
# pc = ldr_api.get_lidar_from_keyframe(token, max_points = LIDAR_PC_SHAPE[1], car_coord = True)
# ego = ldr_api.get_egopose_from_keyframe(token)
# mode = 'intersect'
# %lprun -f  map_api[location].get_closest_structures map_api[location].get_closest_structures(ego, patch = [200, 200],mode = mode,global_coord=False)


In [17]:
for X, c, y in tqdm(train_loader):
    pass

  0%|          | 0/7493 [00:00<?, ?it/s]

intersected 1386
calc dist 1386
sort dist 1386
slice 1
transform 1
intersected 1386
calc dist 1386
sort dist 1386
slice 1
transform 1
intersected 1954
calc dist 1954
sort dist 1954
slice 1
transform 1
intersected 1385
calc dist 1385
sort dist 1385
slice 1
transform 1
intersected 1953
calc dist 1953
sort dist 1953
slice 1
transform 1
intersected 1952
calc dist 1952
sort dist 1952
slice 1
transform 1
intersected 1951
calc dist 1951
sort dist 1951
slice 1
transform 1
intersected 1950
calc dist 1950
sort dist 1950
slice 1
transform 1
intersected 1111
calc dist 1111
sort dist 1111
slice 1
transform 1
intersected 1950
calc dist 1950
sort dist 1950
slice 1
transform 1
intersected 1950
calc dist 1950
sort dist 1950
slice 1
transform 1
intersected 1950
calc dist 1950
sort dist 1950
slice 1
transform 1
intersected 1950
calc dist 1950
sort dist 1950
slice 1
transform 1
intersected 1950
calc dist 1950
sort dist 1950
slice 1
transform 1
intersected 1385
calc dist 1385
sort dist 1385
slice 1
transfo

calc dist 1932
sort dist 1932
slice 1
transform 1
intersected 1931
calc dist 1931
sort dist 1931
slice 1
transform 1
intersected 1931
calc dist 1931
sort dist 1931
slice 1
transform 1
intersected 1756
calc dist 1756
sort dist 1756
slice 1
transform 1
intersected 1931
calc dist 1931
sort dist 1931
slice 1
transform 1
intersected 1107
calc dist 1107
sort dist 1107
slice 1
transform 1
intersected 1931
calc dist 1931
sort dist 1931
slice 1
transform 1
intersected 1380
calc dist 1380
sort dist 1380
slice 1
transform 1
intersected 1931
calc dist 1931
sort dist 1931
slice 1
transform 1
intersected 1930
calc dist 1930
sort dist 1930
slice 1
transform 1
intersected 1929
calc dist 1929
sort dist 1929
slice 1
transform 1
intersected 1928
calc dist 1928
sort dist 1928
slice 1
transform 1
intersected 1928
calc dist 1928
sort dist 1928
slice 1
transform 1
intersected 1107
calc dist 1107
sort dist 1107
slice 1
transform 1
intersected 1756
calc dist 1756
sort dist 1756
slice 1
transform 1
intersected 

intersected 1373
calc dist 1373
sort dist 1373
slice 1
transform 1
intersected 1372
calc dist 1372
sort dist 1372
slice 1
transform 1
intersected 1908
calc dist 1908
sort dist 1908
slice 1
transform 1
intersected 1371
calc dist 1371
sort dist 1371
slice 1
transform 1
intersected 1908
calc dist 1908
sort dist 1908
slice 1
transform 1
intersected 1908
calc dist 1908
sort dist 1908
slice 1
transform 1
intersected 1370
calc dist 1370
sort dist 1370
slice 1
transform 1
intersected 1098
calc dist 1098
sort dist 1098
slice 1
transform 1
intersected 1749
calc dist 1749
sort dist 1749
slice 1
transform 1
intersected 1907
calc dist 1907
sort dist 1907
slice 1
transform 1
intersected 1907
calc dist 1907
sort dist 1907
slice 1
transform 1
intersected 1907
calc dist 1907
sort dist 1907
slice 1
transform 1
intersected 1748
calc dist 1748
sort dist 1748
slice 1
transform 1
intersected 1907
calc dist 1907
sort dist 1907
slice 1
transform 1
intersected 1747
calc dist 1747
sort dist 1747
slice 1
transfo

intersected 1089
calc dist 1089
sort dist 1089
slice 1
transform 1
intersected 1367
calc dist 1367
sort dist 1367
slice 1
transform 1
intersected 1892
calc dist 1892
sort dist 1892
slice 1
transform 1
intersected 1366
calc dist 1366
sort dist 1366
slice 1
transform 1
intersected 1892
calc dist 1892
sort dist 1892
slice 1
transform 1
intersected 1891
calc dist 1891
sort dist 1891
slice 1
transform 1
intersected 1741
calc dist 1741
sort dist 1741
slice 1
transform 1
intersected 1890
calc dist 1890
sort dist 1890
slice 1
transform 1
intersected 1890
calc dist 1890
sort dist 1890
slice 1
transform 1
intersected 1889
calc dist 1889
sort dist 1889
slice 1
transform 1
intersected 1889
calc dist 1889
sort dist 1889
slice 1
transform 1
intersected 1889
calc dist 1889
sort dist 1889
slice 1
transform 1
intersected 1366
calc dist 1366
sort dist 1366
slice 1
transform 1
intersected 1888
calc dist 1888
sort dist 1888
slice 1
transform 1
intersected 1887
calc dist 1887
sort dist 1887
slice 1
transfo

intersected 1858
calc dist 1858
sort dist 1858
slice 1
transform 1
intersected 1736
calc dist 1736
sort dist 1736
slice 1
transform 1
intersected 1360
calc dist 1360
sort dist 1360
slice 1
transform 1
intersected 1857
calc dist 1857
sort dist 1857
slice 1
transform 1
intersected 1856
calc dist 1856
sort dist 1856
slice 1
transform 1
intersected 1856
calc dist 1856
sort dist 1856
slice 1
transform 1
intersected 1359
calc dist 1359
sort dist 1359
slice 1
transform 1
intersected 1358
calc dist 1358
sort dist 1358
slice 1
transform 1
intersected 1856
calc dist 1856
sort dist 1856
slice 1
transform 1
intersected 1084
calc dist 1084
sort dist 1084
slice 1
transform 1
intersected 1855
calc dist 1855
sort dist 1855
slice 1
transform 1
intersected 1357
calc dist 1357
sort dist 1357
slice 1
transform 1
intersected 1854
calc dist 1854
sort dist 1854
slice 1
transform 1
intersected 1853
calc dist 1853
sort dist 1853
slice 1
transform 1
intersected 1083
calc dist 1083
sort dist 1083
slice 1
transfo

calc dist 1352
sort dist 1352
slice 1
transform 1
intersected 1821
calc dist 1821
sort dist 1821
slice 1
transform 1
intersected 1821
calc dist 1821
sort dist 1821
slice 1
transform 1
intersected 1074
calc dist 1074
sort dist 1074
slice 1
transform 1
intersected 1351
calc dist 1351
sort dist 1351
slice 1
transform 1
intersected 1732
calc dist 1732
sort dist 1732
slice 1
transform 1
intersected 1821
calc dist 1821
sort dist 1821
slice 1
transform 1
intersected 1351
calc dist 1351
sort dist 1351
slice 1
transform 1
intersected 1821
calc dist 1821
sort dist 1821
slice 1
transform 1
intersected 1732
calc dist 1732
sort dist 1732
slice 1
transform 1
intersected 1821
calc dist 1821
sort dist 1821
slice 1
transform 1
intersected 1074
calc dist 1074
sort dist 1074
slice 1
transform 1
intersected 1351
calc dist 1351
sort dist 1351
slice 1
transform 1
intersected 1350
calc dist 1350
sort dist 1350
slice 1
transform 1
intersected 1821
calc dist 1821
sort dist 1821
slice 1
transform 1
intersected 

intersected 1725
calc dist 1725
sort dist 1725
slice 1
transform 1
intersected 1725
calc dist 1725
sort dist 1725
slice 1
transform 1
intersected 1790
calc dist 1790
sort dist 1790
slice 1
transform 1
intersected 1789
calc dist 1789
sort dist 1789
slice 1
transform 1
intersected 1344
calc dist 1344
sort dist 1344
slice 1
transform 1
intersected 1065
calc dist 1065
sort dist 1065
slice 1
transform 1
intersected 1064
calc dist 1064
sort dist 1064
slice 1
transform 1
intersected 1344
calc dist 1344
sort dist 1344
slice 1
transform 1
intersected 1344
calc dist 1344
sort dist 1344
slice 1
transform 1
intersected 1789
calc dist 1789
sort dist 1789
slice 1
transform 1
intersected 1788
calc dist 1788
sort dist 1788
slice 1
transform 1
intersected 1788
calc dist 1788
sort dist 1788
slice 1
transform 1
intersected 1788
calc dist 1788
sort dist 1788
slice 1
transform 1
intersected 1724
calc dist 1724
sort dist 1724
slice 1
transform 1
intersected 1723
calc dist 1723
sort dist 1723
slice 1
transfo

calc dist 1765
sort dist 1765
slice 1
transform 1
intersected 1333
calc dist 1333
sort dist 1333
slice 1
transform 1
intersected 1715
calc dist 1715
sort dist 1715
slice 1
transform 1
intersected 1765
calc dist 1765
sort dist 1765
slice 1
transform 1
intersected 1765
calc dist 1765
sort dist 1765
slice 1
transform 1
intersected 1765
calc dist 1765
sort dist 1765
slice 1
transform 1
intersected 1765
calc dist 1765
sort dist 1765
slice 1
transform 1
intersected 1765
calc dist 1765
sort dist 1765
slice 1
transform 1
intersected 1714
calc dist 1714
sort dist 1714
slice 1
transform 1
intersected 1764
calc dist 1764
sort dist 1764
slice 1
transform 1
intersected 1764
calc dist 1764
sort dist 1764
slice 1
transform 1
intersected 1714
calc dist 1714
sort dist 1714
slice 1
transform 1
intersected 1714
calc dist 1714
sort dist 1714
slice 1
transform 1
intersected 1764
calc dist 1764
sort dist 1764
slice 1
transform 1
intersected 1764
calc dist 1764
sort dist 1764
slice 1
transform 1
intersected 

intersected 1704
calc dist 1704
sort dist 1704
slice 1
transform 1
intersected 1736
calc dist 1736
sort dist 1736
slice 1
transform 1
intersected 1703
calc dist 1703
sort dist 1703
slice 1
transform 1
intersected 1323
calc dist 1323
sort dist 1323
slice 1
transform 1
intersected 1048
calc dist 1048
sort dist 1048
slice 1
transform 1
intersected 1047
calc dist 1047
sort dist 1047
slice 1
transform 1
intersected 1735
calc dist 1735
sort dist 1735
slice 1
transform 1
intersected 1734
calc dist 1734
sort dist 1734
slice 1
transform 1
intersected 1734
calc dist 1734
sort dist 1734
slice 1
transform 1
intersected 1047
calc dist 1047
sort dist 1047
slice 1
transform 1
intersected 1046
calc dist 1046
sort dist 1046
slice 1
transform 1
intersected 1323
calc dist 1323
sort dist 1323
slice 1
transform 1
intersected 1733
calc dist 1733
sort dist 1733
slice 1
transform 1
intersected 1702
calc dist 1702
sort dist 1702
slice 1
transform 1
intersected 1323
calc dist 1323
sort dist 1323
slice 1
transfo

intersected 1696
calc dist 1696
sort dist 1696
slice 1
transform 1
intersected 1040
calc dist 1040
sort dist 1040
slice 1
transform 1
intersected 1701
calc dist 1701
sort dist 1701
slice 1
transform 1
intersected 1700
calc dist 1700
sort dist 1700
slice 1
transform 1
intersected 1700
calc dist 1700
sort dist 1700
slice 1
transform 1
intersected 1310
calc dist 1310
sort dist 1310
slice 1
transform 1
intersected 1310
calc dist 1310
sort dist 1310
slice 1
transform 1
intersected 1699
calc dist 1699
sort dist 1699
slice 1
transform 1
intersected 1696
calc dist 1696
sort dist 1696
slice 1
transform 1
intersected 1309
calc dist 1309
sort dist 1309
slice 1
transform 1
intersected 1699
calc dist 1699
sort dist 1699
slice 1
transform 1
intersected 1699
calc dist 1699
sort dist 1699
slice 1
transform 1
intersected 1699
calc dist 1699
sort dist 1699
slice 1
transform 1
intersected 1699
calc dist 1699
sort dist 1699
slice 1
transform 1
intersected 1699
calc dist 1699
sort dist 1699
slice 1
transfo

calc dist 1668
sort dist 1668
slice 1
transform 1
intersected 1668
calc dist 1668
sort dist 1668
slice 1
transform 1
intersected 1668
calc dist 1668
sort dist 1668
slice 1
transform 1
intersected 1667
calc dist 1667
sort dist 1667
slice 1
transform 1
intersected 1666
calc dist 1666
sort dist 1666
slice 1
transform 1
intersected 1665
calc dist 1665
sort dist 1665
slice 1
transform 1
intersected 1690
calc dist 1690
sort dist 1690
slice 1
transform 1
intersected 1031
calc dist 1031
sort dist 1031
slice 1
transform 1
intersected 1665
calc dist 1665
sort dist 1665
slice 1
transform 1
intersected 1664
calc dist 1664
sort dist 1664
slice 1
transform 1
intersected 1030
calc dist 1030
sort dist 1030
slice 1
transform 1
intersected 1689
calc dist 1689
sort dist 1689
slice 1
transform 1
intersected 1689
calc dist 1689
sort dist 1689
slice 1
transform 1
intersected 1029
calc dist 1029
sort dist 1029
slice 1
transform 1
intersected 1028
calc dist 1028
sort dist 1028
slice 1
transform 1
intersected 

KeyboardInterrupt: 

# Model

In [None]:
# LIDAR_PC_SHAPE = [4, MAX_POINT_CLOUDS] # [x,y,z,intensity] x num_of_points
# MAP_OBJECT_SHAPE = [MAX_OBJECTS, MAX_POINTS, 2] # num_of_objects x num_of_points x [x, y]
# MAP_LAYER_SHAPE = [len(class_names), MAX_OBJECTS] # num_of_class x num_of_objects
class BaseModel(torch.nn.Module):
    def __init__(self, input_shape = [4, 3500], class_shape = [4, 1], object_shape = [1, 30, 2]):
        super().__init__()
        self.input_shape =input_shape
        self.object_shape = object_shape
        self.class_shape = class_shape
        
        self.conv1 = nn.Conv1d(input_shape[0], 64, kernel_size=1)
        self.batch1 = nn.BatchNorm1d(64)
        
        self.conv2 = nn.Conv1d(64, 64, kernel_size=1)
        self.batch2 = nn.BatchNorm1d(64)
        
        self.fc1 = nn.Linear(64 * input_shape[1], 1024)
        self.fc2 = nn.Linear(1024, 512)
        
        self.fc_obj = nn.Linear(512, np.prod(object_shape))
        self.fc_cls = nn.Linear(512, np.prod(class_shape))
        
    def forward(self, x):
        input_shape = self.input_shape
        object_shape = self.object_shape
        class_shape = self.class_shape
        
        x = self.conv1(x)
        x = F.relu(x)
        x = self.batch1(x)
        
        x = self.conv2(x)
        x = F.relu(x)
        x = self.batch2(x)
        
        x = x.reshape(-1, input_shape[1] * 64)
        
        x = self.fc1(x)
        x = F.relu(x)
        
        x = self.fc2(x)
        x = F.relu(x)
        
        obj_ = self.fc_obj(x)
        cls_ = self.fc_cls(x)
        
        obj_ = obj_.reshape(-1, object_shape[0], object_shape[1], object_shape[2])
        cls_ = cls_.reshape(-1, class_shape[0], class_shape[1])
        return cls_, obj_

# Loss

In [None]:
class BaseLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.pose_loss = nn.MSELoss()
        self.class_loss = nn.CrossEntropyLoss()
        pass
    def forward(self, c_hat, y_hat, c, y):
        
        loss = self.pose_loss(y, y_hat) + self.class_loss(c_hat, c.argmax(axis = 1))
        return loss

# Model Compile

In [None]:
model = BaseModel(LIDAR_PC_SHAPE, MAP_LAYER_SHAPE, MAP_OBJECT_SHAPE)
model.to(device)

loss_func= BaseLoss()
optimizer = optim.SGD(model.parameters(), lr = 0.001)

# Train Module

In [None]:
def train(epoch, progress_log):
    model.train()  # 신경망을 학습 모드로 전환

    # 데이터로더에서 미니배치를 하나씩 꺼내 학습을 수행
    mean_loss = 0
    data_num = 0
    
    for X, c, y in progress_log:
        
        X = X.to(device)
        c = c.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()  # 경사를 0으로 초기화
        c_hat, y_hat = model(X)  # 데이터를 입력하고 출력을 계산
        loss = loss_func(c_hat, y_hat, c, y)  # 출력과 훈련 데이터 정답 간의 오차를 계산
        
        loss.backward()  # 오차를 역전파 계산
        optimizer.step()  # 역전파 계산한 값으로 가중치를 수정
        
        mean_loss += loss
        data_num += X.shape[0]
        
    mean_loss /= data_num
    
    return mean_loss

# Valid Module

In [None]:
def valid(epoch, progress_log):
    model.eval()  # 신경망을 학습 모드로 전환

    # 데이터로더에서 미니배치를 하나씩 꺼내 학습을 수행
    mean_loss = 0
    data_num = 0
    
    with torch.no_grad():
        for X, c, y in progress_log:

            X = X.to(device)
            c = c.to(device)
            y = y.to(device)

            c_hat, y_hat = model(X)  # 데이터를 입력하고 출력을 계산
            loss = loss_func(c_hat, y_hat, c, y)  # 출력과 훈련 데이터 정답 간의 오차를 계산

            mean_loss += loss
            data_num += X.shape[0]
        
    mean_loss /= data_num
    
    return mean_loss

# Test Module

In [None]:
def test(epoch, progress_log):
    model.eval()  # 신경망을 학습 모드로 전환

    # 데이터로더에서 미니배치를 하나씩 꺼내 학습을 수행
    C_hat = []
    Y_hat = []
    
    with torch.no_grad():
        for X, _, _ in progress_log:

            X = X.to(device)

            c_hat, y_hat = model(X)  # 데이터를 입력하고 출력을 계산
            C_hat.append(c_hat)
            Y_hat.append(y_hat)
        
    C_hat = np.concatenate(C_hat)
    Y_hat = np.concatenate(Y_hat)
    
    return C_hat, Y_hat

# Fit

In [None]:
train_loss_list = []
valid_loss_list = []

patience_count = 0
min_valid_loss = np.inf
checkpoint_name = ""

if not os.path.isdir(f"./models/{notebookName}/model-{num_files}_checkpoint/"):
    os.mkdir(f"./models/{notebookName}/model-{num_files}_checkpoint/")
    
prog_epoch = tqdm(range(0, nepochs), position = 0, desc = 'EPOCH')
for epoch in prog_epoch:
    print( "-------------------------------------------------------")
    print(f"|EPOCH: {epoch+1}/{nepochs}")
    prog_train = tqdm(train_loader, desc = 'TRAIN', leave = False)
    prog_valid = tqdm(valid_loader, desc = 'VALID', leave = False)

    train_loss = train(epoch, prog_train)
    valid_loss = valid(prog_valid)
    
    if valid_loss < min_valid_loss:
        print(f"|{epoch+1}-th model is checked!, *model-{epoch}-{valid_loss}.pth*")
        min_valid_loss= valid_loss
        checkpoint_name = f"./models/{notebookName}/model-{num_files}_checkpoint/model-{epoch}-{valid_loss}.pth"
        torch.save(model.state_dict(), checkpoint_name)
    else:
        patience_count+=1
        if(patience_count > max_patience_count):
            break
    
    train_loss_list.append(train_loss)
    valid_loss_list.append(valid_loss)
    
    print(f"|TRAIN: loss={train_loss:.6f}|")
    print(f"|VALID: loss={valid_loss:.6f}|")


history = dict()
history['train_loss'] = train_loss_list
history['valid_loss'] = valid_loss_list

In [None]:
plt.figure(figsize = (16,6))
plt.subplot(2,1,1)
plt.plot(history['train_loss'], label = 'train')
plt.plot(history['valid_loss'], label = 'valid')
plt.ylabel('loss')

plt.subplot(2,1,2)
plt.plot(history['train_score'], label = 'train')
plt.plot(history['valid_score'], label = 'valid')
plt.ylabel('gpsloss')