In [3]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from custom_dataset import CustomDataset
from rnn_model import RNNModel

# 수정된 PaddingCollate 클래스
class PaddingCollate:
    @staticmethod
    def __call__(batch):
        time = [item['data'][0] for item in batch]
        pin1 = [item['data'][1] for item in batch]
        pin2 = [item['data'][2] for item in batch]

        max_length_pin1 = max(len(seq) for seq in pin1)
        max_length_pin2 = max(len(seq) for seq in pin2)
        max_length = max(max_length_pin1, max_length_pin2)

        padded_pin1 = []
        padded_pin2 = []

        for seq1, seq2 in zip(pin1, pin2):
            # 리스트를 텐서로 변환하여 패딩 수행
            padded_pin1.append(PaddingCollate.pad_sequence(torch.tensor(seq1), max_length))
            padded_pin2.append(PaddingCollate.pad_sequence(torch.tensor(seq2), max_length))

        time_tensor = torch.tensor(time, dtype=torch.float)
        padded_pin1_tensor = torch.stack(padded_pin1)
        padded_pin2_tensor = torch.stack(padded_pin2)

        return time_tensor, padded_pin1_tensor, padded_pin2_tensor

    @staticmethod
    def pad_sequence(sequence, max_length):
        padding_length = max_length - sequence.size(0)
        padded_sequence = torch.nn.functional.pad(sequence, (0, padding_length))
        return padded_sequence


# 저장된 데이터셋 파일 이름
file_name = 'concatenated_dataset'+'.pth'

# 데이터셋 로드
dataset = torch.load(file_name)
print(dataset)

# DataLoader에 로드된 데이터셋 사용
batch_size = 64  # 배치 크기 설정
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=PaddingCollate())

# 학습 반복 횟수 설정
num_epochs = 100

# 모델 생성
input_size = 2  # 입력 차원 설정 (pin1과 pin2)
hidden_size = 64  # 은닉 상태의 크기 설정
output_size = 1  # 출력 차원 설정
num_layers = 2
model = RNNModel(input_size, hidden_size, output_size, num_layers)

# 손실 함수와 옵티마이저 정의
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 모델 학습
for epoch in range(num_epochs):
    total_loss = 0
    for i, (time, pin1, pin2) in enumerate(dataloader):
        # 모델 입력 준비
        inputs = torch.stack([pin1, pin2], dim=2)  # pin1과 pin2를 합쳐서 입력으로 사용
        inputs = inputs.float()  # 데이터 타입을 float으로 변환
        targets = time.float()  # 타겟은 측정 시간으로 설정

        # Forward pass
        outputs = model(inputs)

        # Loss 계산 및 Backpropagation
        loss = criterion(outputs.squeeze(), targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Epoch마다 손실 출력
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(dataloader)}')

torch.save(model.state_dict(), 'rnn_model.pth')
print("모델 저장이 완료되었습니다.")

<torch.utils.data.dataset.ConcatDataset object at 0x00000203AB49C590>
Epoch [1/100], Loss: 2.032498776912689
Epoch [2/100], Loss: 0.9291759580373764
Epoch [3/100], Loss: 0.12778213678393513
Epoch [4/100], Loss: 0.22079689428210258


  padded_pin1.append(PaddingCollate.pad_sequence(torch.tensor(seq1), max_length))
  padded_pin2.append(PaddingCollate.pad_sequence(torch.tensor(seq2), max_length))


Epoch [5/100], Loss: 0.11263717338442802
Epoch [6/100], Loss: 0.028834990691393614
Epoch [7/100], Loss: 0.06898840796202421
Epoch [8/100], Loss: 0.04088587174192071
Epoch [9/100], Loss: 0.01132285944186151
Epoch [10/100], Loss: 0.026359519455581903
Epoch [11/100], Loss: 0.019150303909555078
Epoch [12/100], Loss: 0.012974777957424521
Epoch [13/100], Loss: 0.024822188075631857
Epoch [14/100], Loss: 0.011853378149680793
Epoch [15/100], Loss: 0.009989956975914538
Epoch [16/100], Loss: 0.010450383299030364
Epoch [17/100], Loss: 0.010011848760768771
Epoch [18/100], Loss: 0.01082534552551806
Epoch [19/100], Loss: 0.01150921720545739
Epoch [20/100], Loss: 0.009916511364281178
Epoch [21/100], Loss: 0.009912413661368191
Epoch [22/100], Loss: 0.009175224753562361
Epoch [23/100], Loss: 0.008520250616129488
Epoch [24/100], Loss: 0.010456001269631088
Epoch [25/100], Loss: 0.008758321404457092
Epoch [26/100], Loss: 0.008108531008474529
Epoch [27/100], Loss: 0.008152462192811072
Epoch [28/100], Loss: 