In [1]:
import pickle
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
import datetime
import os
from sklearn.model_selection import train_test_split
with open('datasets.pickle', 'rb') as f:
    datasets = pickle.load(f)

In [2]:
IN_FEATURES = datasets[0].shape[0]
TEST_RATE = 0.2
LEARNING_RATE = 1e-4
NUM_EPOCHS = 500
BATCH_SIZE = 1024
EARLY_BIRD = 0
TEST_SET_NUM = int(len(datasets)*TEST_RATE)
BEST_MODEL_PATH = 'best_model.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
d_time = datetime.datetime.now()
folder_name = "runs/"+d_time.strftime("%Y%m%d%H%M%S")
os.mkdir(folder_name)
writer = SummaryWriter(log_dir=folder_name)
pre_test_loss = 10000.

train_dataset, test_dataset = train_test_split(datasets, test_size=TEST_RATE, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [3]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.in_features = in_features
        self.out_classes = 3
        self.down = nn.Sequential(            
            self._block_down(self.in_features, self.in_features*2),           
            self._block_down(self.in_features*2, self.in_features*2),           
            self._block_down(self.in_features*2, self.in_features),           
            self._block_down(self.in_features, self.in_features//2),           
            self._block_down(self.in_features//2, self.in_features//4),           
            self._block_down(self.in_features//4, self.out_classes),                       
            # nn.Tanh()
        )
        self.up = nn.Sequential(
            self._block_down(self.out_classes,self.in_features//4),
            self._block_down(self.in_features//4,self.in_features//2),
            self._block_down(self.in_features//2,self.in_features),
            self._block_down(self.in_features,self.in_features*2),
            self._block_down(self.in_features*2,self.in_features*2),
            self._block_down(self.in_features*2,self.in_features*2),
            self._block_down(self.in_features*2,self.in_features),
            nn.Tanh() # 여기만 다시 활성화            
        )
        self.essens=None
    
    def _block_down(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Linear(in_channels,out_channels),
            nn.ReLU(True),
            nn.Dropout(0.5)
        )
    def _block_up(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Linear(in_channels,out_channels),
            nn.ReLU(True),
            nn.Dropout(0.5)
        )
    def forward(self, x):
        self.essens = self.down(x)
        x2 = self.up(self.essens)
        return x2

In [4]:
critic = Discriminator(IN_FEATURES).to(device)

opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
print(len(train_loader))
critic.to(device)
for epoch in range(NUM_EPOCHS):
    train_loss = 0.
    for data in train_loader:
        data = data.to(device)
        output = critic(data.float())
        label = data.float()
        loss = torch.sum(torch.square(label-output))
        critic.zero_grad()
        train_loss+=loss
        loss.backward()        
        opt_critic.step()
    test_loss = 0.
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            output = critic(data.float())
            label = data.float()
            loss = torch.sum(torch.square(label-output))    
            test_loss+=loss       
    writer.add_scalar("TRAIN_LOSS",train_loss.item()/(len(datasets)-TEST_SET_NUM),epoch)
    writer.add_scalar("TEST_LOSS",test_loss.item()/TEST_SET_NUM,epoch)
    print("EPOCHS: ", epoch,"TRAIN_LOSS",train_loss.item()/(len(datasets)-TEST_SET_NUM),"TEST_LOSS",test_loss.item()/TEST_SET_NUM ) 
    if pre_test_loss>test_loss and epoch>50:
        pre_test_loss= test_loss
        torch.save(critic.state_dict(), BEST_MODEL_PATH)
    if pre_test_loss<test_loss and epoch>50:
        EARLY_BIRD +=1
        if EARLY_BIRD >100:
            break


EPOCHS:  0 TRAIN_LOSS 465.0870560460185 TEST_LOSS 454.86417378161127
EPOCHS:  1 TRAIN_LOSS 455.0456781140407 TEST_LOSS 452.20233715781126
EPOCHS:  2 TRAIN_LOSS 452.6162907260344 TEST_LOSS 449.93937879819373
EPOCHS:  3 TRAIN_LOSS 451.2155993242125 TEST_LOSS 449.19753417984
EPOCHS:  4 TRAIN_LOSS 450.6993026647388 TEST_LOSS 448.7433935631351
EPOCHS:  5 TRAIN_LOSS 450.28393693066306 TEST_LOSS 448.32238037836726
EPOCHS:  6 TRAIN_LOSS 449.90879366933257 TEST_LOSS 448.1000318532999
EPOCHS:  7 TRAIN_LOSS 449.632133008972 TEST_LOSS 447.7931409227464
EPOCHS:  8 TRAIN_LOSS 449.3853518997936 TEST_LOSS 447.68871206490576
EPOCHS:  9 TRAIN_LOSS 449.2898172812981 TEST_LOSS 447.49814189083685
EPOCHS:  10 TRAIN_LOSS 449.1672511624857 TEST_LOSS 447.40350636129136
EPOCHS:  11 TRAIN_LOSS 449.09025385751625 TEST_LOSS 447.33010636503883
EPOCHS:  12 TRAIN_LOSS 448.9642400981828 TEST_LOSS 447.253208751538
EPOCHS:  13 TRAIN_LOSS 448.90433109840455 TEST_LOSS 447.1414348974761
EPOCHS:  14 TRAIN_LOSS 448.866257155

KeyboardInterrupt: 