In [18]:
# main.py

import torch
# import data_handler
# import networks
# import trainer
# import utils

import os, time
import scipy.io as sio
from torch.optim import lr_scheduler


# Arguments  ====================================== 반영

dataset = 'LER_data_20191125.xlsx' #2020_LER_20200529_V004.xlsx
datatype = 'p'
mean_model_type = 'mlp'
gan_model_type = 'gan1'
seed = 0
mean_lr = 5e-5
g_lr = 0.0001
d_lr = 0.0005
noise_d = 100
mean_hidden_dim = 100
gan_hidden_dim = 100
batch_size = 32
mean_nepochs = 1000
gan_nepochs = 200
workers = 0
num_of_input = 4
num_of_output = 6

# =====================================================

log_name = 'date_20200624_datatype_{}_model_{}_{}_seed_{}_mean_lr_{}_g_lr_{}_d_lr_{}_mean_hidden_dim_{}_gan_type_{}_gan_hidden_dim_{}_batch_{}_epoch_{}_noise_{}'.format(
    datatype,
    mean_model_type,
    gan_model_type,
    seed,
    mean_lr,
    g_lr,
    d_lr,
    mean_hidden_dim,
    gan_model_type, 
    gan_hidden_dim,
    batch_size, 
    mean_nepochs,
    noise_d
)

print(log_name)

# Dataset

dataset = get_dataset(dataset, datatype)

# loss result

result_dict = {}

kwargs = {'num_workers': workers}

print(torch.cuda.device_count())
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    
print("Inits...")
torch.set_default_tensor_type('torch.cuda.FloatTensor')


# ==================================================================================================
#                                          1. Predict mean 
# ==================================================================================================

mean_train_dataset_loader = meanLoader(dataset.train_X_per_cycle, dataset.train_Y_per_cycle, normalize)

mean_val_dataset_loader = meanLoader(dataset.val_X_per_cycle, dataset.val_Y_per_cycle, normalize)

mean_test_dataset_loader = meanLoader(dataset.test_X_per_cycle, dataset.test_Y_per_cycle, normalize)

# Dataloader

mean_train_iterator = torch.utils.data.DataLoader(mean_train_dataset_loader, batch_size=batch_size, shuffle=False, **kwargs)

mean_val_iterator = torch.utils.data.DataLoader(mean_val_dataset_loader, batch_size=batch_size, shuffle=False, **kwargs)

mean_test_iterator = torch.utils.data.DataLoader(mean_test_dataset_loader, batch_size=batch_size, shuffle=False, **kwargs)

# model

mean_model = get_mean_model(mean_model_type, mean_hidden_dim, num_of_input, num_of_output)
mean_model.apply(init_normal)
mean_model.cuda()

print(mean_model)

optimizer = torch.optim.Adam(mean_model.parameters(), lr=mean_lr)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) # scheduler

mean_mytrainer = get_mean_trainer(mean_train_iterator, mean_val_iterator, mean_model, mean_model_type, optimizer, exp_lr_scheduler)

for epoch in range(mean_nepochs):
    
    train_loss = mean_mytrainer.train()
    
    val_loss, val_r2 = mean_mytrainer.evaluate()
    
    current_lr = mean_mytrainer.current_lr
    
    if((epoch+1)% 10 == 0):
        print("epoch:{:2d}, lr:{:.6f}, || train_loss:{:.6f}, val_loss:{:.6f}, r2_score:{:.6f}".format(epoch, current_lr, train_loss, val_loss, val_r2))

result_dict['train_loss'] = mean_mytrainer.loss['train_loss']
result_dict['val_loss'] = mean_mytrainer.loss['val_loss']
        
mean_best_model = mean_mytrainer.best_model

# ==================================================================================================
#                                          2. Predict noise
# ==================================================================================================

# predict Y_mean

noise_train_dataset_loader = meanLoader(dataset.train_X, dataset.train_Y_noise, normalize)

noise_val_dataset_loader = meanLoader(dataset.val_X, dataset.val_Y_noise, normalize)

noise_test_dataset_loader = meanLoader(dataset.test_X, dataset.test_Y_noise, normalize)

# Dataloader

noise_train_iterator = torch.utils.data.DataLoader(noise_train_dataset_loader, batch_size=batch_size, shuffle=False, **kwargs)

noise_val_iterator = torch.utils.data.DataLoader(noise_val_dataset_loader, batch_size=batch_size, shuffle=False, **kwargs)

noise_test_iterator = torch.utils.data.DataLoader(noise_test_dataset_loader, batch_size=batch_size, shuffle=False, **kwargs)

# model

generator, discriminator = get_gan_model(gan_model_type, gan_hidden_dim, noise_d, num_of_input, num_of_output)
generator.cuda()
discriminator.cuda()

print(generator, discriminator)

## QQQQ
init_params(generator)
init_params(discriminator)

# scheduler

optimizer_g = torch.optim.Adam(generator.parameters(), lr = g_lr)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr = d_lr)

exp_gan_lr_scheduler = lr_scheduler.StepLR(optimizer_d, step_size=50, gamma=0.5)

# trainer

gan_mytrainer = get_gan_trainer(noise_train_iterator, noise_val_iterator, gan_model_type, generator, discriminator, optimizer_g, optimizer_d, exp_gan_lr_scheduler, noise_d)

for epoch in range(gan_nepochs):
    
    gan_mytrainer.train()
    
    p_real, p_fake = gan_mytrainer.evaluate()
        
    current_d_lr = gan_mytrainer.current_d_lr
    
    if((epoch+1)% 10 == 0):
        print("epoch:{:2d}, lr_d:{:.6f}, || p_real:{:.6f}, p_fake:{:.6f}".format(epoch, current_d_lr, p_real, p_fake))

result_dict['p_real_train'] = gan_mytrainer.prob['p_real_train']
result_dict['p_fake_train'] = gan_mytrainer.prob['p_fake_train']
result_dict['p_real_val'] = gan_mytrainer.prob['p_real_val']
result_dict['p_fake_val'] = gan_mytrainer.prob['p_fake_val']

if not os.path.exists('./result_data'):
    os.makedirs('./result_data')
sio.savemat('./result_data/'+log_name+'.mat', result_dict)
# net.state_dict()
if not os.path.exists('./mean_models'):
    os.makedirs('./mean_models')
torch.save(mean_best_model.state_dict(), './mean_models/'+log_name)
if not os.path.exists('./gen_models'):
    os.makedirs('./gen_models')
torch.save(generator.state_dict(), './gen_models/'+log_name)
if not os.path.exists('./dis_models'):
    os.makedirs('./dis_models')
torch.save(discriminator.state_dict(), './dis_models/'+log_name)

"""
train_dataset_loader = SemiLoader(dataset.train_X, dataset.train_Y, normalize)

val_dataset_loader = SemiLoader(dataset.val_X, dataset.val_Y, normalize)

test_dataset_loader = SemiLoader(dataset.test_X, dataset.test_Y, normalize)

# Dataloader 
kwargs = {'num_workers': workers}

train_iterator = torch.utils.data.DataLoader(train_dataset_loader, batch_size=batch_size, shuffle=True, **kwargs)

val_iterator = torch.utils.data.DataLoader(val_dataset_loader, batch_size=batch_size, shuffle=False, **kwargs)

test_iterator = torch.utils.data.DataLoader(test_dataset_loader, batch_size=batch_size, shuffle=False, **kwargs)

# model
print(torch.cuda.device_count())
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    
print("Inits...")
torch.set_default_tensor_type('torch.cuda.FloatTensor')

# model
myModel = get_model(trainer, unitN, ratio)
myModel = myModel.cuda()

print(myModel)

# optimizer
optimizer = torch.optim.Adam(myModel.parameters(), lr=lr)

# trainer
myTrainer = get_trainer(train_iterator, val_iterator, myModel, trainer, sample_num, lr, lr_patience, lr_factor, lr_min, optimizer)

if trainer == 'bbb':
    evalType = 'bbb'
    
t_classifier = get_evaluator(evalType, batch_size, sample_num)

total_epochs = nepochs

loss_dict = {}

for epoch in range(total_epochs):
    
    myTrainer.train(epoch)
    
    myTrainer.update_lr(epoch, t_classifier)

# complexity cost, log_prior, log_variational_posterior, loss, negative_log_likelihood    
loss_dict = myTrainer.loss    

if not os.path.exists('./loss_data'):
    os.makedirs('./loss_data')
sio.savemat('./loss_data/'+log_name+'.mat', loss_dict)
    
    
train_loss, train_r2, train_sample = t_classifier.evaluate(myModel, train_iterator, mode='test')
val_loss, val_r2, val_sample = t_classifier.evaluate(myModel, val_iterator, mode='test')
test_loss, test_r2, test_sample = t_classifier.evaluate(myModel, test_iterator, mode='test')

result_dict = {}

result_dict['train_loss'] = train_loss
result_dict['train_r2'] = train_r2
result_dict['train_sample'] = train_sample

result_dict['val_loss'] = val_loss
result_dict['val_r2'] = val_r2
result_dict['val_sample'] = val_sample

result_dict['test_loss'] = test_loss
result_dict['test_r2'] = test_r2
result_dict['test_sample'] = test_sample

# result dict
if not os.path.exists('./result_data'):
    os.makedirs('./result_data')
sio.savemat('./result_data/'+log_name+'.mat', result_dict)
# net.state_dict()
if not os.path.exists('./models/trained_model'):
    os.makedirs('./models/trained_model')
torch.save(myModel.state_dict(), './models/trained_model/'+log_name+'1')
"""

date_20200624_datatype_p_model_mlp_gan1_seed_0_mean_lr_5e-05_g_lr_0.0001_d_lr_0.0005_mean_hidden_dim_100_gan_type_gan1_gan_hidden_dim_100_batch_32_epoch_1000_noise_100
[ 0.502  23.3    53.1     1.0902]
[ 0.502  23.3    53.1     1.0902]
[ 0.502  23.3    53.1     1.0902]
[ 0.502  23.3    53.1     1.0902]
[ 0.502  23.3    53.1     1.0902]
[ 0.502  23.3    53.1     1.0902]
[ 0.502  23.3    53.1     1.0902]
[ 0.502  23.3    53.1     1.0902]
[ 0.502  23.3    53.1     1.0902]
[ 0.502  23.3    53.1     1.0902]
[ 0.358  19.5    48.4     1.0898]
[ 0.358  19.5    48.4     1.0898]
[ 0.358  19.5    48.4     1.0898]
[ 0.358  19.5    48.4     1.0898]
[ 0.358  19.5    48.4     1.0898]
[ 0.358  19.5    48.4     1.0898]
[ 0.358  19.5    48.4     1.0898]
[ 0.358  19.5    48.4     1.0898]
[ 0.358  19.5    48.4     1.0898]
[ 0.358  19.5    48.4     1.0898]
[ 0.792  24.7    45.4     1.0923]
[ 0.792  24.7    45.4     1.0923]
[ 0.792  24.7    45.4     1.0923]
[ 0.792  24.7    45.4     1.0923]
[ 0.792  24.7   

[ 0.755  21.1    50.6     0.9359]
[ 0.5   17.7   45.4    0.968]
[ 0.5   17.7   45.4    0.968]
[ 0.5   17.7   45.4    0.968]
[ 0.5   17.7   45.4    0.968]
[ 0.5   17.7   45.4    0.968]
[ 0.5   17.7   45.4    0.968]
[ 0.5   17.7   45.4    0.968]
[ 0.5   17.7   45.4    0.968]
[ 0.5   17.7   45.4    0.968]
[ 0.5   17.7   45.4    0.968]
[ 0.59   19.     49.1     1.0642]
[ 0.59   19.     49.1     1.0642]
[ 0.59   19.     49.1     1.0642]
[ 0.59   19.     49.1     1.0642]
[ 0.59   19.     49.1     1.0642]
[ 0.59   19.     49.1     1.0642]
[ 0.59   19.     49.1     1.0642]
[ 0.59   19.     49.1     1.0642]
[ 0.59   19.     49.1     1.0642]
[ 0.59   19.     49.1     1.0642]
[ 0.605  22.7    47.8     1.0145]
[ 0.605  22.7    47.8     1.0145]
[ 0.605  22.7    47.8     1.0145]
[ 0.605  22.7    47.8     1.0145]
[ 0.605  22.7    47.8     1.0145]
[ 0.605  22.7    47.8     1.0145]
[ 0.605  22.7    47.8     1.0145]
[ 0.605  22.7    47.8     1.0145]
[ 0.605  22.7    47.8     1.0145]
[ 0.605  22.7    47.

[ 0.5142 18.3686 48.5234  0.9812]
[ 0.5142 18.3686 48.5234  0.9812]
[ 0.5142 18.3686 48.5234  0.9812]
[ 0.5142 18.3686 48.5234  0.9812]
[ 0.5142 18.3686 48.5234  0.9812]
[ 0.5142 18.3686 48.5234  0.9812]
[ 0.5142 18.3686 48.5234  0.9812]
[ 0.4813 15.2163 47.2264  1.0308]
[ 0.4813 15.2163 47.2264  1.0308]
[ 0.4813 15.2163 47.2264  1.0308]
[ 0.4813 15.2163 47.2264  1.0308]
[ 0.4813 15.2163 47.2264  1.0308]
[ 0.4813 15.2163 47.2264  1.0308]
[ 0.4813 15.2163 47.2264  1.0308]
[ 0.4813 15.2163 47.2264  1.0308]
[ 0.4813 15.2163 47.2264  1.0308]
[ 0.4813 15.2163 47.2264  1.0308]
[ 0.2647 15.7212 52.8697  1.0279]
[ 0.2647 15.7212 52.8697  1.0279]
[ 0.2647 15.7212 52.8697  1.0279]
[ 0.2647 15.7212 52.8697  1.0279]
[ 0.2647 15.7212 52.8697  1.0279]
[ 0.2647 15.7212 52.8697  1.0279]
[ 0.2647 15.7212 52.8697  1.0279]
[ 0.2647 15.7212 52.8697  1.0279]
[ 0.2647 15.7212 52.8697  1.0279]
[ 0.2647 15.7212 52.8697  1.0279]
[ 0.2274 21.876  47.3243  1.0959]
[ 0.2274 21.876  47.3243  1.0959]
[ 0.2274 21.87

[ 0.1522 16.1579 48.0403  1.0859]
[ 0.5971 22.7102 46.8317  0.9826]
[ 0.5971 22.7102 46.8317  0.9826]
[ 0.5971 22.7102 46.8317  0.9826]
[ 0.5971 22.7102 46.8317  0.9826]
[ 0.5971 22.7102 46.8317  0.9826]
[ 0.5971 22.7102 46.8317  0.9826]
[ 0.5971 22.7102 46.8317  0.9826]
[ 0.5971 22.7102 46.8317  0.9826]
[ 0.5971 22.7102 46.8317  0.9826]
[ 0.5971 22.7102 46.8317  0.9826]
[ 0.6105 24.8909 51.0765  0.9587]
[ 0.6105 24.8909 51.0765  0.9587]
[ 0.6105 24.8909 51.0765  0.9587]
[ 0.6105 24.8909 51.0765  0.9587]
[ 0.6105 24.8909 51.0765  0.9587]
[ 0.6105 24.8909 51.0765  0.9587]
[ 0.6105 24.8909 51.0765  0.9587]
[ 0.6105 24.8909 51.0765  0.9587]
[ 0.6105 24.8909 51.0765  0.9587]
[ 0.6105 24.8909 51.0765  0.9587]
[ 0.3547 20.9343 47.4989  0.9348]
[ 0.3547 20.9343 47.4989  0.9348]
[ 0.3547 20.9343 47.4989  0.9348]
[ 0.3547 20.9343 47.4989  0.9348]
[ 0.3547 20.9343 47.4989  0.9348]
[ 0.3547 20.9343 47.4989  0.9348]
[ 0.3547 20.9343 47.4989  0.9348]
[ 0.3547 20.9343 47.4989  0.9348]
[ 0.3547 20.93

[ 0.3773 20.1585 48.4225  1.0301]
[ 0.3773 20.1585 48.4225  1.0301]
[ 0.3773 20.1585 48.4225  1.0301]
[ 0.3773 20.1585 48.4225  1.0301]
[ 0.3773 20.1585 48.4225  1.0301]
[ 0.3773 20.1585 48.4225  1.0301]
[ 0.3773 20.1585 48.4225  1.0301]
[ 0.3773 20.1585 48.4225  1.0301]
[ 0.383  18.9678 45.9728  1.0243]
[ 0.383  18.9678 45.9728  1.0243]
[ 0.383  18.9678 45.9728  1.0243]
[ 0.383  18.9678 45.9728  1.0243]
[ 0.383  18.9678 45.9728  1.0243]
[ 0.383  18.9678 45.9728  1.0243]
[ 0.383  18.9678 45.9728  1.0243]
[ 0.383  18.9678 45.9728  1.0243]
[ 0.383  18.9678 45.9728  1.0243]
[ 0.383  18.9678 45.9728  1.0243]
[ 0.5458 18.4923 45.1955  1.0241]
[ 0.5458 18.4923 45.1955  1.0241]
[ 0.5458 18.4923 45.1955  1.0241]
[ 0.5458 18.4923 45.1955  1.0241]
[ 0.5458 18.4923 45.1955  1.0241]
[ 0.5458 18.4923 45.1955  1.0241]
[ 0.5458 18.4923 45.1955  1.0241]
[ 0.5458 18.4923 45.1955  1.0241]
[ 0.5458 18.4923 45.1955  1.0241]
[ 0.5458 18.4923 45.1955  1.0241]
[ 0.4619 22.5428 47.7809  0.9887]
[ 0.4619 22.54

epoch:319, lr:0.000006, || train_loss:0.025760, val_loss:0.060359, r2_score:-0.207173
epoch:329, lr:0.000006, || train_loss:0.025698, val_loss:0.060437, r2_score:-0.208730
epoch:339, lr:0.000006, || train_loss:0.025635, val_loss:0.060510, r2_score:-0.210200
epoch:349, lr:0.000006, || train_loss:0.025572, val_loss:0.060585, r2_score:-0.211700
epoch:359, lr:0.000006, || train_loss:0.025510, val_loss:0.060659, r2_score:-0.213181
epoch:369, lr:0.000006, || train_loss:0.025447, val_loss:0.060732, r2_score:-0.214641
epoch:379, lr:0.000006, || train_loss:0.025385, val_loss:0.060808, r2_score:-0.216169
epoch:389, lr:0.000006, || train_loss:0.025323, val_loss:0.060883, r2_score:-0.217663
epoch:399, lr:0.000006, || train_loss:0.025261, val_loss:0.060957, r2_score:-0.219145
epoch:409, lr:0.000003, || train_loss:0.025227, val_loss:0.060994, r2_score:-0.219888
epoch:419, lr:0.000003, || train_loss:0.025196, val_loss:0.061032, r2_score:-0.220633
epoch:429, lr:0.000003, || train_loss:0.025165, val_lo

KeyboardInterrupt: 

In [17]:
import torch
from torch.utils.data import DataLoader, TensorDataset

from torchvision import transforms


import pandas as pd
from pandas import ExcelWriter
from pandas import ExcelFile
import numpy as np

import matplotlib.pyplot as plt
import math
import sklearn
from sklearn.metrics import r2_score
import os

def get_dataset(name, datatype):
    if name == 'LER_data_20191125.xlsx':
        return SEMI_data(name, datatype, num_input=4, num_output=6, num_in_cycle=10, num_of_cycle=270, num_train=230, num_val=20, num_test=20, x_cols="D:G", y_cols="K:P", header=2)
    elif name == 'LER_data_20191107.xlsx':
        return SEMI_data(name, datatype, num_input=4, num_output=6, num_in_cycle=10, num_of_cycle=185, num_train=150, num_val=15 , num_test=20, x_cols="D:G", y_cols="K:P", header=2)
    elif name == '2020_LER_20200529_V004.xlsx':
        return SEMI_data(name, datatype, num_input=4, num_output=9, num_in_cycle=50, num_of_cycle=72, num_train=50, num_val=10, num_test=12, x_cols="D:G", y_cols="H:P", header=0)
        
def load_data(file_path, num_input, num_output, num_in_cycle, num_of_cycle, x_cols, y_cols, header):  
    """
    
     1) 20191107 기준 : num_input, num_output, num_of_cycle = 185, num_in_cycle=10, header=2, usecols="D:G" 확인 필수
     2) num_input, num_output, num_in_cycle, num_of_cycle 새로 추가함
    
    """
    num_total = num_of_cycle*num_in_cycle
    
    data_x = pd.read_excel('./'+file_path, sheet_name='uniformly sampling', usecols=x_cols, nrows=num_total+1, header=header)
    data_y = pd.read_excel('./'+file_path, sheet_name='uniformly sampling', usecols=y_cols, nrows=num_total+1, header=header)
    
    # loader 는 일반 Dataloader 사용
    
    X_all = np.zeros((num_total, num_input))
    X_per_cycle = np.zeros((num_of_cycle, num_input))    
    X_nominal = np.zeros(num_input)
    
    Y_all = np.zeros((num_total, num_output))
    Y_per_cycle = np.zeros((num_of_cycle, num_output))
    Y_nominal = np.zeros(num_output)
    
    # STEP 1: EXCEL FILE to NUMPY
    
    # X_per_cycle
    for i in range(num_of_cycle):
        X_per_cycle[i] = data_x[i*num_in_cycle+1:i*num_in_cycle+2].values

    # X_all
    X_all = np.repeat(X_per_cycle,num_in_cycle,axis=0)
    for i in range(X_all.shape[0]):
        print(X_all[i])


    # Y_all
    for i in range(num_total):
        Y_all[i] = data_y[i+1:i+2].values
    
    # Y_per_cycle
    for i in range(num_of_cycle):
        Y_per_cycle[i] = np.mean(Y_all[i*num_in_cycle:(i+1)*num_in_cycle],axis=0)

        
    print("============ Data load =============")
    print("X data shape: ", X_all.shape)
    print("Y data shape: ", Y_all.shape)  
    print("any nan in X?: ", np.argwhere(np.isnan(X_all)))
    print("any nan in Y?: ", np.argwhere(np.isnan(Y_all)))
      
    return X_all, Y_all, X_per_cycle, Y_per_cycle

def split_data(x, y, num_train, num_val, num_test):
    
    if len(x) == len(y):
        print("Same number of x data and y data")
        len_total = len(x)
    else:
        print("Different number of x data and y data")
    
    x_train, y_train = x[:num_train], y[:num_train]
    x_val, y_val = x[num_train:num_train+num_val], y[num_train:num_train+num_val]
    x_test, y_test = x[num_train+num_val:], y[num_train+num_val:]
    
    print("============= Data split ==============")
    print("train X: {} train Y: {}".format(x_train.shape, y_train.shape))
    print("val X: {} val Y: {}".format(x_val.shape, y_val.shape))
    print("test X: {} test Y: {}".format(x_test.shape, y_test.shape))
    
    return x_train, y_train, x_val, y_val, x_test, y_test
    

class Dataset():   
    def __init__(self, name):
        
        self.train_X = None
        self.train_Y = None
        
        self.val_X = None
        self.val_Y = None
        
        self.test_X = None
        self.test_Y = None
        
        self.train_X_per_cycle = None
        self.val_X_per_cycle = None
        self.test_X_per_cycle = None
              
        self.train_Y_per_cycle = None
        self.val_Y_per_cycle = None
        self.test_Y_per_cycle = None
        
        self.train_Y_mean_predict = None
        self.val_Y_mean_predict = None
        self.test_Y_mean_predict = None
        
        self.train_Y_noise = None
        self.val_Y_noise = None
        self.test_Y_noise = None

class SEMI_data(Dataset):
    def __init__(self, name, datatype, num_input, num_output, num_in_cycle, num_of_cycle, num_train, num_val, num_test, x_cols, y_cols, header):
        super().__init__(name)
        
        # STEP 1: Load data
        
        X_all, Y_all, X_per_cycle, Y_per_cycle = load_data(name, num_input, num_output, num_in_cycle, num_of_cycle, x_cols, y_cols, header)
        
        # N=1(odd), P=0(even)
        if name == '2020_LER_20200529_V004.xlsx':
            if datatype == 'n':
                X_all, Y_all, X_per_cycle, Y_per_cycle = X_all[::2], Y_all[::2], X_per_cycle[::2], Y_per_cycle[::2]    # odd
            
                self.train_X, self.train_Y, self.val_X, self.val_Y, self.test_X, self.test_Y = split_data(X_all, Y_all, num_train*num_in_cycle//2, num_val*num_in_cycle//2, num_test*num_in_cycle//2)
                self.train_X_per_cycle, self.train_Y_per_cycle, self.val_X_per_cycle, self.val_Y_per_cycle, self.test_X_per_cycle, self.test_Y_per_cycle = split_data(X_per_cycle, Y_per_cycle, num_train//2, num_val//2, num_test//2)        
            
            elif datatype == 'p':
                X_all, Y_all, X_per_cycle, Y_per_cycle = X_all[1::2], Y_all[1::2], X_per_cycle[1::2], Y_per_cycle[1::2]   # even   
            
                self.train_X, self.train_Y, self.val_X, self.val_Y, self.test_X, self.test_Y = split_data(X_all, Y_all, num_train*num_in_cycle//2, num_val*num_in_cycle//2, num_test*num_in_cycle//2)
                self.train_X_per_cycle, self.train_Y_per_cycle, self.val_X_per_cycle, self.val_Y_per_cycle, self.test_X_per_cycle, self.test_Y_per_cycle = split_data(X_per_cycle, Y_per_cycle, num_train//2, num_val//2, num_test//2) 
            
            else:    # pn둘다 있을경우
                self.train_X, self.train_Y, self.val_X, self.val_Y, self.test_X, self.test_Y = split_data(X_all, Y_all, num_train*num_in_cycle, num_val*num_in_cycle, num_test*num_in_cycle)
                self.train_X_per_cycle, self.train_Y_per_cycle, self.val_X_per_cycle, self.val_Y_per_cycle, self.test_X_per_cycle, self.test_Y_per_cycle = split_data(X_per_cycle, Y_per_cycle, num_train, num_val, num_test) 
        else:
            self.train_X, self.train_Y, self.val_X, self.val_Y, self.test_X, self.test_Y = split_data(X_all, Y_all, num_train*num_in_cycle, num_val*num_in_cycle, num_test*num_in_cycle)
            self.train_X_per_cycle, self.train_Y_per_cycle, self.val_X_per_cycle, self.val_Y_per_cycle, self.test_X_per_cycle, self.test_Y_per_cycle = split_data(X_per_cycle, Y_per_cycle, num_train, num_val, num_test) 
            
            
        
        # STEP 2: Split data
        
        # OPTIONAL: Split data for Y_mean, Y_noise
        
        
        self.train_Y_mean = np.repeat(self.train_Y_per_cycle, num_in_cycle, axis=0)
        self.val_Y_mean = np.repeat(self.val_Y_per_cycle, num_in_cycle, axis=0)
        self.test_Y_mean = np.repeat(self.test_Y_per_cycle, num_in_cycle, axis=0)
        
        print("train_Y_mean shape", self.train_Y_mean.shape)
        print("val_Y_mean shape", self.val_Y_mean.shape)
        print("test_Y_mean shape", self.test_Y_mean.shape)
        
        self.train_Y_noise = self.train_Y - self.train_Y_mean
        self.val_Y_noise = self.val_Y - self.val_Y_mean
        self.test_Y_noise = self.test_Y - self.test_Y_mean
        
        print("train_Y_noise shape", self.train_Y_noise.shape)
        print("val_Y_noise shape", self.val_Y_noise.shape)
        print("test_Y_noise shape", self.test_Y_noise.shape) 

In [72]:
import numpy as np

###  NUMPY의 홀수, 짝수 프린트하는법

num_in_cycle = 50
num_of_cycle = 200

a = np.arange(60000)
a = a.reshape(10000,6)

print(a)

x = np.empty(6)
y = np.empty(6)

for i in range(num_of_cycle):
    # print(num_in_cycle*(2*i),num_in_cycle*(2*i)+num_in_cycle)
    x = np.vstack((x, a[num_in_cycle*(2*i):num_in_cycle*(2*i)+num_in_cycle]))
    # print(num_in_cycle*(2*i+1),num_in_cycle*(2*i+1)+num_in_cycle)
    y = np.vstack((y, a[num_in_cycle*(2*i+1):num_in_cycle*(2*i+1)+num_in_cycle]))

x = x[1:]
y = y[1:]
    
for i in range(len(x)):
    print(x[i])
for i in range(len(x)):    
    print(y[i])



 
    

[[    0     1     2     3     4     5]
 [    6     7     8     9    10    11]
 [   12    13    14    15    16    17]
 ...
 [59982 59983 59984 59985 59986 59987]
 [59988 59989 59990 59991 59992 59993]
 [59994 59995 59996 59997 59998 59999]]
[0. 1. 2. 3. 4. 5.]
[ 6.  7.  8.  9. 10. 11.]
[12. 13. 14. 15. 16. 17.]
[18. 19. 20. 21. 22. 23.]
[24. 25. 26. 27. 28. 29.]
[30. 31. 32. 33. 34. 35.]
[36. 37. 38. 39. 40. 41.]
[42. 43. 44. 45. 46. 47.]
[48. 49. 50. 51. 52. 53.]
[54. 55. 56. 57. 58. 59.]
[60. 61. 62. 63. 64. 65.]
[66. 67. 68. 69. 70. 71.]
[72. 73. 74. 75. 76. 77.]
[78. 79. 80. 81. 82. 83.]
[84. 85. 86. 87. 88. 89.]
[90. 91. 92. 93. 94. 95.]
[ 96.  97.  98.  99. 100. 101.]
[102. 103. 104. 105. 106. 107.]
[108. 109. 110. 111. 112. 113.]
[114. 115. 116. 117. 118. 119.]
[120. 121. 122. 123. 124. 125.]
[126. 127. 128. 129. 130. 131.]
[132. 133. 134. 135. 136. 137.]
[138. 139. 140. 141. 142. 143.]
[144. 145. 146. 147. 148. 149.]
[150. 151. 152. 153. 154. 155.]
[156. 157. 158. 159. 160. 161.

[8628. 8629. 8630. 8631. 8632. 8633.]
[8634. 8635. 8636. 8637. 8638. 8639.]
[8640. 8641. 8642. 8643. 8644. 8645.]
[8646. 8647. 8648. 8649. 8650. 8651.]
[8652. 8653. 8654. 8655. 8656. 8657.]
[8658. 8659. 8660. 8661. 8662. 8663.]
[8664. 8665. 8666. 8667. 8668. 8669.]
[8670. 8671. 8672. 8673. 8674. 8675.]
[8676. 8677. 8678. 8679. 8680. 8681.]
[8682. 8683. 8684. 8685. 8686. 8687.]
[8688. 8689. 8690. 8691. 8692. 8693.]
[8694. 8695. 8696. 8697. 8698. 8699.]
[9000. 9001. 9002. 9003. 9004. 9005.]
[9006. 9007. 9008. 9009. 9010. 9011.]
[9012. 9013. 9014. 9015. 9016. 9017.]
[9018. 9019. 9020. 9021. 9022. 9023.]
[9024. 9025. 9026. 9027. 9028. 9029.]
[9030. 9031. 9032. 9033. 9034. 9035.]
[9036. 9037. 9038. 9039. 9040. 9041.]
[9042. 9043. 9044. 9045. 9046. 9047.]
[9048. 9049. 9050. 9051. 9052. 9053.]
[9054. 9055. 9056. 9057. 9058. 9059.]
[9060. 9061. 9062. 9063. 9064. 9065.]
[9066. 9067. 9068. 9069. 9070. 9071.]
[9072. 9073. 9074. 9075. 9076. 9077.]
[9078. 9079. 9080. 9081. 9082. 9083.]
[9084. 9085.

[19368. 19369. 19370. 19371. 19372. 19373.]
[19374. 19375. 19376. 19377. 19378. 19379.]
[19380. 19381. 19382. 19383. 19384. 19385.]
[19386. 19387. 19388. 19389. 19390. 19391.]
[19392. 19393. 19394. 19395. 19396. 19397.]
[19398. 19399. 19400. 19401. 19402. 19403.]
[19404. 19405. 19406. 19407. 19408. 19409.]
[19410. 19411. 19412. 19413. 19414. 19415.]
[19416. 19417. 19418. 19419. 19420. 19421.]
[19422. 19423. 19424. 19425. 19426. 19427.]
[19428. 19429. 19430. 19431. 19432. 19433.]
[19434. 19435. 19436. 19437. 19438. 19439.]
[19440. 19441. 19442. 19443. 19444. 19445.]
[19446. 19447. 19448. 19449. 19450. 19451.]
[19452. 19453. 19454. 19455. 19456. 19457.]
[19458. 19459. 19460. 19461. 19462. 19463.]
[19464. 19465. 19466. 19467. 19468. 19469.]
[19470. 19471. 19472. 19473. 19474. 19475.]
[19476. 19477. 19478. 19479. 19480. 19481.]
[19482. 19483. 19484. 19485. 19486. 19487.]
[19488. 19489. 19490. 19491. 19492. 19493.]
[19494. 19495. 19496. 19497. 19498. 19499.]
[19800. 19801. 19802. 19803. 198

[32550. 32551. 32552. 32553. 32554. 32555.]
[32556. 32557. 32558. 32559. 32560. 32561.]
[32562. 32563. 32564. 32565. 32566. 32567.]
[32568. 32569. 32570. 32571. 32572. 32573.]
[32574. 32575. 32576. 32577. 32578. 32579.]
[32580. 32581. 32582. 32583. 32584. 32585.]
[32586. 32587. 32588. 32589. 32590. 32591.]
[32592. 32593. 32594. 32595. 32596. 32597.]
[32598. 32599. 32600. 32601. 32602. 32603.]
[32604. 32605. 32606. 32607. 32608. 32609.]
[32610. 32611. 32612. 32613. 32614. 32615.]
[32616. 32617. 32618. 32619. 32620. 32621.]
[32622. 32623. 32624. 32625. 32626. 32627.]
[32628. 32629. 32630. 32631. 32632. 32633.]
[32634. 32635. 32636. 32637. 32638. 32639.]
[32640. 32641. 32642. 32643. 32644. 32645.]
[32646. 32647. 32648. 32649. 32650. 32651.]
[32652. 32653. 32654. 32655. 32656. 32657.]
[32658. 32659. 32660. 32661. 32662. 32663.]
[32664. 32665. 32666. 32667. 32668. 32669.]
[32670. 32671. 32672. 32673. 32674. 32675.]
[32676. 32677. 32678. 32679. 32680. 32681.]
[32682. 32683. 32684. 32685. 326

[44574. 44575. 44576. 44577. 44578. 44579.]
[44580. 44581. 44582. 44583. 44584. 44585.]
[44586. 44587. 44588. 44589. 44590. 44591.]
[44592. 44593. 44594. 44595. 44596. 44597.]
[44598. 44599. 44600. 44601. 44602. 44603.]
[44604. 44605. 44606. 44607. 44608. 44609.]
[44610. 44611. 44612. 44613. 44614. 44615.]
[44616. 44617. 44618. 44619. 44620. 44621.]
[44622. 44623. 44624. 44625. 44626. 44627.]
[44628. 44629. 44630. 44631. 44632. 44633.]
[44634. 44635. 44636. 44637. 44638. 44639.]
[44640. 44641. 44642. 44643. 44644. 44645.]
[44646. 44647. 44648. 44649. 44650. 44651.]
[44652. 44653. 44654. 44655. 44656. 44657.]
[44658. 44659. 44660. 44661. 44662. 44663.]
[44664. 44665. 44666. 44667. 44668. 44669.]
[44670. 44671. 44672. 44673. 44674. 44675.]
[44676. 44677. 44678. 44679. 44680. 44681.]
[44682. 44683. 44684. 44685. 44686. 44687.]
[44688. 44689. 44690. 44691. 44692. 44693.]
[44694. 44695. 44696. 44697. 44698. 44699.]
[45000. 45001. 45002. 45003. 45004. 45005.]
[45006. 45007. 45008. 45009. 450

[55902. 55903. 55904. 55905. 55906. 55907.]
[55908. 55909. 55910. 55911. 55912. 55913.]
[55914. 55915. 55916. 55917. 55918. 55919.]
[55920. 55921. 55922. 55923. 55924. 55925.]
[55926. 55927. 55928. 55929. 55930. 55931.]
[55932. 55933. 55934. 55935. 55936. 55937.]
[55938. 55939. 55940. 55941. 55942. 55943.]
[55944. 55945. 55946. 55947. 55948. 55949.]
[55950. 55951. 55952. 55953. 55954. 55955.]
[55956. 55957. 55958. 55959. 55960. 55961.]
[55962. 55963. 55964. 55965. 55966. 55967.]
[55968. 55969. 55970. 55971. 55972. 55973.]
[55974. 55975. 55976. 55977. 55978. 55979.]
[55980. 55981. 55982. 55983. 55984. 55985.]
[55986. 55987. 55988. 55989. 55990. 55991.]
[55992. 55993. 55994. 55995. 55996. 55997.]
[55998. 55999. 56000. 56001. 56002. 56003.]
[56004. 56005. 56006. 56007. 56008. 56009.]
[56010. 56011. 56012. 56013. 56014. 56015.]
[56016. 56017. 56018. 56019. 56020. 56021.]
[56022. 56023. 56024. 56025. 56026. 56027.]
[56028. 56029. 56030. 56031. 56032. 56033.]
[56034. 56035. 56036. 56037. 560

[13146. 13147. 13148. 13149. 13150. 13151.]
[13152. 13153. 13154. 13155. 13156. 13157.]
[13158. 13159. 13160. 13161. 13162. 13163.]
[13164. 13165. 13166. 13167. 13168. 13169.]
[13170. 13171. 13172. 13173. 13174. 13175.]
[13176. 13177. 13178. 13179. 13180. 13181.]
[13182. 13183. 13184. 13185. 13186. 13187.]
[13188. 13189. 13190. 13191. 13192. 13193.]
[13194. 13195. 13196. 13197. 13198. 13199.]
[13500. 13501. 13502. 13503. 13504. 13505.]
[13506. 13507. 13508. 13509. 13510. 13511.]
[13512. 13513. 13514. 13515. 13516. 13517.]
[13518. 13519. 13520. 13521. 13522. 13523.]
[13524. 13525. 13526. 13527. 13528. 13529.]
[13530. 13531. 13532. 13533. 13534. 13535.]
[13536. 13537. 13538. 13539. 13540. 13541.]
[13542. 13543. 13544. 13545. 13546. 13547.]
[13548. 13549. 13550. 13551. 13552. 13553.]
[13554. 13555. 13556. 13557. 13558. 13559.]
[13560. 13561. 13562. 13563. 13564. 13565.]
[13566. 13567. 13568. 13569. 13570. 13571.]
[13572. 13573. 13574. 13575. 13576. 13577.]
[13578. 13579. 13580. 13581. 135

[24498. 24499. 24500. 24501. 24502. 24503.]
[24504. 24505. 24506. 24507. 24508. 24509.]
[24510. 24511. 24512. 24513. 24514. 24515.]
[24516. 24517. 24518. 24519. 24520. 24521.]
[24522. 24523. 24524. 24525. 24526. 24527.]
[24528. 24529. 24530. 24531. 24532. 24533.]
[24534. 24535. 24536. 24537. 24538. 24539.]
[24540. 24541. 24542. 24543. 24544. 24545.]
[24546. 24547. 24548. 24549. 24550. 24551.]
[24552. 24553. 24554. 24555. 24556. 24557.]
[24558. 24559. 24560. 24561. 24562. 24563.]
[24564. 24565. 24566. 24567. 24568. 24569.]
[24570. 24571. 24572. 24573. 24574. 24575.]
[24576. 24577. 24578. 24579. 24580. 24581.]
[24582. 24583. 24584. 24585. 24586. 24587.]
[24588. 24589. 24590. 24591. 24592. 24593.]
[24594. 24595. 24596. 24597. 24598. 24599.]
[24900. 24901. 24902. 24903. 24904. 24905.]
[24906. 24907. 24908. 24909. 24910. 24911.]
[24912. 24913. 24914. 24915. 24916. 24917.]
[24918. 24919. 24920. 24921. 24922. 24923.]
[24924. 24925. 24926. 24927. 24928. 24929.]
[24930. 24931. 24932. 24933. 249

[35892. 35893. 35894. 35895. 35896. 35897.]
[35898. 35899. 35900. 35901. 35902. 35903.]
[35904. 35905. 35906. 35907. 35908. 35909.]
[35910. 35911. 35912. 35913. 35914. 35915.]
[35916. 35917. 35918. 35919. 35920. 35921.]
[35922. 35923. 35924. 35925. 35926. 35927.]
[35928. 35929. 35930. 35931. 35932. 35933.]
[35934. 35935. 35936. 35937. 35938. 35939.]
[35940. 35941. 35942. 35943. 35944. 35945.]
[35946. 35947. 35948. 35949. 35950. 35951.]
[35952. 35953. 35954. 35955. 35956. 35957.]
[35958. 35959. 35960. 35961. 35962. 35963.]
[35964. 35965. 35966. 35967. 35968. 35969.]
[35970. 35971. 35972. 35973. 35974. 35975.]
[35976. 35977. 35978. 35979. 35980. 35981.]
[35982. 35983. 35984. 35985. 35986. 35987.]
[35988. 35989. 35990. 35991. 35992. 35993.]
[35994. 35995. 35996. 35997. 35998. 35999.]
[36300. 36301. 36302. 36303. 36304. 36305.]
[36306. 36307. 36308. 36309. 36310. 36311.]
[36312. 36313. 36314. 36315. 36316. 36317.]
[36318. 36319. 36320. 36321. 36322. 36323.]
[36324. 36325. 36326. 36327. 363

[47154. 47155. 47156. 47157. 47158. 47159.]
[47160. 47161. 47162. 47163. 47164. 47165.]
[47166. 47167. 47168. 47169. 47170. 47171.]
[47172. 47173. 47174. 47175. 47176. 47177.]
[47178. 47179. 47180. 47181. 47182. 47183.]
[47184. 47185. 47186. 47187. 47188. 47189.]
[47190. 47191. 47192. 47193. 47194. 47195.]
[47196. 47197. 47198. 47199. 47200. 47201.]
[47202. 47203. 47204. 47205. 47206. 47207.]
[47208. 47209. 47210. 47211. 47212. 47213.]
[47214. 47215. 47216. 47217. 47218. 47219.]
[47220. 47221. 47222. 47223. 47224. 47225.]
[47226. 47227. 47228. 47229. 47230. 47231.]
[47232. 47233. 47234. 47235. 47236. 47237.]
[47238. 47239. 47240. 47241. 47242. 47243.]
[47244. 47245. 47246. 47247. 47248. 47249.]
[47250. 47251. 47252. 47253. 47254. 47255.]
[47256. 47257. 47258. 47259. 47260. 47261.]
[47262. 47263. 47264. 47265. 47266. 47267.]
[47268. 47269. 47270. 47271. 47272. 47273.]
[47274. 47275. 47276. 47277. 47278. 47279.]
[47280. 47281. 47282. 47283. 47284. 47285.]
[47286. 47287. 47288. 47289. 472

[58194. 58195. 58196. 58197. 58198. 58199.]
[58500. 58501. 58502. 58503. 58504. 58505.]
[58506. 58507. 58508. 58509. 58510. 58511.]
[58512. 58513. 58514. 58515. 58516. 58517.]
[58518. 58519. 58520. 58521. 58522. 58523.]
[58524. 58525. 58526. 58527. 58528. 58529.]
[58530. 58531. 58532. 58533. 58534. 58535.]
[58536. 58537. 58538. 58539. 58540. 58541.]
[58542. 58543. 58544. 58545. 58546. 58547.]
[58548. 58549. 58550. 58551. 58552. 58553.]
[58554. 58555. 58556. 58557. 58558. 58559.]
[58560. 58561. 58562. 58563. 58564. 58565.]
[58566. 58567. 58568. 58569. 58570. 58571.]
[58572. 58573. 58574. 58575. 58576. 58577.]
[58578. 58579. 58580. 58581. 58582. 58583.]
[58584. 58585. 58586. 58587. 58588. 58589.]
[58590. 58591. 58592. 58593. 58594. 58595.]
[58596. 58597. 58598. 58599. 58600. 58601.]
[58602. 58603. 58604. 58605. 58606. 58607.]
[58608. 58609. 58610. 58611. 58612. 58613.]
[58614. 58615. 58616. 58617. 58618. 58619.]
[58620. 58621. 58622. 58623. 58624. 58625.]
[58626. 58627. 58628. 58629. 586

In [4]:
# UTILS

def normalize(x, y):
        
    x_mean = np.mean(x, axis=0, dtype=np.float32)
    x_std = np.std(x, axis=0, dtype=np.float32)
        
    y_mean = np.mean(y, axis=0, dtype=np.float32)
    y_std = np.std(y, axis=0, dtype=np.float32)
        
    norm_x = ( x - x_mean ) / (x_std+1e-10)
    norm_y = ( y - y_mean ) / (y_std+1e-10)
        
    return norm_x, norm_y

def init_params(model):
    for p in model.parameters():
        if (p.dim() > 1):
            nn.init.xavier_normal_(p)
        else:
            nn.init.uniform_(p, 0.1, 0.2)

In [5]:
import torch
import torch.utils.data as td

class meanLoader(td.Dataset):
    def __init__(self, data_x, data_y, normalize = None):
        
        
        self.data_x = data_x
        self.data_y = data_y
        self.normalize = normalize
        
            
    def __len__(self):
        return len(self.data_x)
           
    def __getitem__(self, index):
        
        if self.normalize is not None:
            self.data_x, self.data_y = self.normalize(self.data_x, self.data_y)
        
        x = self.data_x[index]
        y = self.data_y[index]
            
        x = torch.from_numpy(x).float().cuda()
        y = torch.from_numpy(y).float().cuda()
                
        return x, y
    

In [6]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn.modules.utils import _single, _pair, _triple

def get_mean_model(mean_model_type, mean_hidden_dim, num_of_input, num_of_output):
    if mean_model_type == 'mlp':
        return mean_mlp_net(mean_hidden_dim=mean_hidden_dim, num_of_input=num_of_input, num_of_output=num_of_output)
    
class mean_mlp_net(nn.Module):
    def __init__(self, mean_hidden_dim, num_of_input, num_of_output):
        super(mean_mlp_net, self).__init__()
        self.fc1 = nn.Linear(num_of_input, mean_hidden_dim)
        self.fc2 = nn.Linear(mean_hidden_dim, mean_hidden_dim)
        self.fc3 = nn.Linear(mean_hidden_dim, mean_hidden_dim)
        self.fc4 = nn.Linear(mean_hidden_dim, num_of_output)
        
    def forward(self, x):
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        #x=F.dropout(x, training=self.training)
        x=self.fc4(x)
        return x    
    
def get_gan_model(gan_type, gan_hidden_dim, noise_d, num_of_input, num_of_output):
    if gan_type == 'gan1':
        return gen1(noise_d+num_of_input, gan_hidden_dim, num_of_output), dis1(num_of_output+num_of_input, gan_hidden_dim)
    
class gen1(nn.Module):
    def __init__(self, d_noise_num_of_input, gan_hidden_dim, num_of_output):
        super(gen1, self).__init__()
        self.fc1 = nn.Linear(d_noise_num_of_input, gan_hidden_dim)
        self.fc2 = nn.Linear(gan_hidden_dim, gan_hidden_dim)
        self.fc3 = nn.Linear(gan_hidden_dim, num_of_output)
        
    def forward(self, noise, x):
        gen_input = torch.cat((noise, x), axis=1)
        r = F.relu(self.fc1(gen_input))
        r = F.relu(self.fc2(r))
        r = self.fc3(r)
        
        return r
        
class dis1(nn.Module):
    def __init__(self, num_of_output, gan_hidden_dim):
        super(dis1, self).__init__()
        self.fc1 = nn.Linear(num_of_output, gan_hidden_dim)
        self.fc2 = nn.Linear(gan_hidden_dim, gan_hidden_dim)
        self.fc3 = nn.Linear(gan_hidden_dim, 1)
        
    def forward(self, y, x):
        dis_input = torch.cat((y, x), axis=1)
        r = F.relu(self.fc1(dis_input))
        r = F.relu(self.fc2(r))
        r = torch.sigmoid(self.fc3(r))
        
        return r

In [7]:
# utils

def init_normal(m):
        if type(m) == nn.Linear:
            nn.init.kaiming_normal_(m.weight)

In [8]:
def get_mean_trainer(train_iterator, val_iterator, mean_model, mean_model_type, optimizer, exp_lr_scheduler):
    if mean_model_type == 'mlp':
        return mean_trainer(train_iterator, val_iterator, mean_model, optimizer, exp_lr_scheduler)
    
def get_gan_trainer(noise_trainer_iterator, noise_val_iterator, gan_model_type, generator, discriminator, optimizer_g, optimizer_d, exp_gan_lr_scheduler, noise_d):
    if gan_model_type == 'gan1':
        return gan1_trainer(noise_trainer_iterator, noise_val_iterator, generator, discriminator, optimizer_g, optimizer_d, exp_gan_lr_scheduler, noise_d)

class mean_GenericTrainer:
    """
    Base class for mean trainer
    """
    def __init__(self, train_iterator, val_iterator, mean_model, optimizer, exp_lr_scheduler):
        self.train_iterator = train_iterator
        self.val_iterator = val_iterator
        self.model = mean_model
                
        self.optimizer = optimizer
        self.current_lr = None
        
        self.exp_lr_scheduler = exp_lr_scheduler
        
        self.loss = {'train_loss':[], 'val_loss':[]}
    
class mean_trainer(mean_GenericTrainer):
    def __init__(self, train_iterator, val_iterator, mean_model, optimizer, exp_lr_scheduler):
        super().__init__(train_iterator, val_iterator, mean_model, optimizer, exp_lr_scheduler)
            
        self.best_loss = np.inf
        self.best_model = None
            
    def train(self):
        
        train_loss_list = []
        train_loss = 0
        train_num = 0
        self.model.train()
        
        for i, data in enumerate(self.train_iterator):
            data_x, data_y = data
            data_x, data_y = data_x.cuda(), data_y.cuda()
            
            mini_batch_size = len(data_x)
            output = self.model(data_x)
            loss = F.mse_loss(output, data_y, reduction='mean')
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            train_loss += loss
            train_num += mini_batch_size
            
        train_loss /= train_num
        self.loss['train_loss'].append(train_loss)
            
        for param_group in self.optimizer.param_groups:
            self.current_lr = param_group['lr']
        self.exp_lr_scheduler.step()
        
        return train_loss
            
    def evaluate(self):

        val_loss_list = []
        val_loss = 0.0
        val_num = 0
        self.model.eval()
        
        true_arr = []
        pred_arr = []
        
        with torch.no_grad():
            for i, data in enumerate(self.val_iterator):
                data_x, data_y = data
                data_x, data_y = data_x.cuda(), data_y.cuda()
                
                mini_batch_size = len(data_x)
                
                output = self.model(data_x)
                loss = F.mse_loss(output, data_y, reduction='mean')
                
                val_loss += loss
                val_num += len(data_y)
                
                true_arr += (data_y.data.cpu().numpy()).tolist()
                pred_arr += (output.data.cpu().numpy()).tolist()
                
            val_loss /= val_num
            self.loss['val_loss'].append(val_loss)
            
        val_r2 = r2_score(true_arr, pred_arr)
        
        if val_loss < self.best_loss:
            self.best_model = self.model
            self.best_loss = val_loss
            self.best_mean = pred_arr
            
                
        return val_loss, val_r2      

class gan_GenericTrainer:
    def __init__(self, noise_trainer_iterator, noise_val_iterator, generator, discriminator, optimizer_g, optimizer_d, exp_gan_lr_scheduler, noise_d):
        self.train_iterator = noise_trainer_iterator
        self.val_iterator = noise_val_iterator
        
        self.G = generator
        self.D = discriminator
        
        self.optimizer_G = optimizer_g
        self.optimizer_D = optimizer_d
        
        self.exp_gan_lr_scheduler = exp_gan_lr_scheduler
        self.current_d_lr = None
        
        self.noise_d = noise_d
        
        self.prob = {'p_real_train':[], 'p_fake_train':[], 'p_real_val':[], 'p_fake_val':[]}
        
class gan1_trainer(gan_GenericTrainer):
    def __init__(self, noise_trainer_iterator, noise_val_iterator, generator, discriminator, optimizer_g, optimizer_d, exp_gan_lr_scheduler, noise_d):
        super().__init__(noise_trainer_iterator, noise_val_iterator, generator, discriminator, optimizer_g, optimizer_d, exp_gan_lr_scheduler, noise_d)
        
    def train(self):
        
        p_real_list = []
        p_fake_list = []
        
        self.G.train()
        self.D.train()
        
        for i, data in enumerate(self.train_iterator):
            
            data_x, data_y = data
            data_x, data_y = data_x.cuda(), data_y.cuda()
            mini_batch_size = len(data_x)
            
            # GENERATOR
            
            z = sample_z(mini_batch_size, self.noise_d)
                        
            gen_y = self.G(z, data_x)
            p_fake = self.D(gen_y, data_x)
            
            g_loss = -1*torch.log(p_fake).mean()
            
            self.optimizer_G.zero_grad()
            g_loss.backward(retain_graph=True)
            self.optimizer_G.step()
            
            # DISCRIMINATOR
            
            
            # Loss for real data
            p_real = self.D(data_y, data_x)
            d_real_loss = -1*torch.log(p_real).mean()
            
            # Loss for fake data
            p_fake = self.D(gen_y, data_x)
            d_fake_loss = -1*torch.log(1.-p_fake).mean()            
    
            d_loss = (d_real_loss + d_fake_loss)/2
            
            self.optimizer_D.zero_grad()
            d_loss.backward(retain_graph=True)
            self.optimizer_D.step()
            
        # print(p_real)
        self.prob['p_real_train'].append(p_real)
        # print(self.prob['p_real_train'])
        self.prob['p_fake_train'].append(p_fake)
            
        for param_group in self.optimizer_D.param_groups:
            self.current_d_lr = param_group['lr']
        self.exp_gan_lr_scheduler.step()
                    
    def evaluate(self):
        
        p_real_list = []
        p_fake_list = []
        
        p_real, p_fake = 0., 0.
        batch_num = 0
        
        self.G.eval()
        self.D.eval()
        
        for i, data in enumerate(self.val_iterator):
            
            data_x, data_y = data
            data_x, data_y = data_x.cuda(), data_y.cuda()
            
            mini_batch_size = len(data_x)
            
            z = sample_z(mini_batch_size, self.noise_d)
            
            with torch.autograd.no_grad():
                p_real += torch.sum(self.D(data_y, data_x)/mini_batch_size)
                
                gen_y = self.G(z, data_x)
                
                p_fake += torch.sum(self.D(gen_y, data_x)/mini_batch_size)
                
            batch_num += 1
            
        p_real /= batch_num
        p_fake /= batch_num
        
        self.prob['p_real_val'].append(p_real)
        self.prob['p_fake_val'].append(p_fake)
        
        return p_real, p_fake

In [9]:
import torch
import torch.utils.data as td

class SemiLoader(td.Dataset):
    def __init__(self, data_x, data_y, normalize = None):
        
# N,P를 고려 한 모델(1,0으로 넣은 것)
                   
        self.data_x = data_x
        self.data_y = data_y
        self.normalize = normalize
        
    def __len__(self):
        return len(self.data_x)
           
    def __getitem__(self, index):
            
        if self.normalize is not None:
            self.data_x, self.data_y = self.normalize(self.data_x, self.data_y)
            
        x = self.data_x[index]
        y = self.data_y[index]
            
        x = torch.from_numpy(x).float().cuda()
        y = torch.from_numpy(y).float().cuda()
                
        return x, y