# Most of the experiments are cited from 
- https://medium.com/daangn/pytorch-multi-gpu-%ED%95%99%EC%8A%B5-%EC%A0%9C%EB%8C%80%EB%A1%9C-%ED%95%98%EA%B8%B0-27270617936b

In [None]:
import os, sys, time, pickle
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from collections import OrderedDict
import utils
import tqdm

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision import transforms

import dataset

In [None]:
single_datadir = '/data/jehyuk/TEP/single_states/'
trans_datadir = '/data/jehyuk/TEP/transient_processes/'
attack_datadir = '/data/jehyuk/TEP/attacks/'

In [None]:
single_datalist = sorted(os.listdir(single_datadir))
trans_datalist = sorted(os.listdir(trans_datadir))
attack_datalist = sorted(os.listdir(attack_datadir))

In [None]:
attack_datalist

In [None]:
total_columns = ['Time', 'A Feed', 'D Feed', 'E Feed', 'A + C Feed', 
                 'Recycle flow', 'Reactor feed', 'Reactor pressure', 'Reactor level','Reactor temperature',
                 'Purge rate', 'Seperator temperature', 'Seperator level', 'Seperator pressure', 'Seperator underflow',
                 'Stripper level', 'Stripper pressure', 'Stripper underflow',  'Stripper temperature', 'Stripper Steam flow_meas',
                 'Compressor work', 'Reactor cooling water temperature', 'Condensor cooling water temperature', 'Feed %A', 'Feed %B', 
                 'Feed %C', 'Feed %D', 'Feed %E', 'Feed %F', 'Purge %A', 
                 'Purge %B', 'Purge %C', 'Purge %D', 'Purge %E', 'Purge %F', 
                 'Purge %G', 'Purge %H', 'Product %D', 'Product %E', 'Product %F', 
                 'Product %G', 'Product %H', 'D feed flow', 'E feed flow', 'A feed flow', 
                 'C feed flow', 'Compressor recycle valve', 'Purge flow', 'Separator liquid flow', 'Stripper liquid product flow',
                 'Stripper Steam flow_mv', 'Reactor cooling water flow', 'Condenser cooling water flow', 'Reactor Agitator speed', 'is_mv_attack',
                 'is_meas_attack', 'is_sp_attack', 'state', 'product_rate', 'hourly_cost']

measured_columns = ['A Feed', 'D Feed', 'E Feed', 'A + C Feed', 
                    'Recycle flow', 'Reactor feed', 'Reactor pressure', 'Reactor level','Reactor temperature',
                    'Purge rate', 'Seperator temperature', 'Seperator level', 'Seperator pressure', 'Seperator underflow',
                    'Stripper level', 'Stripper pressure', 'Stripper underflow',  'Stripper temperature', 'Stripper Steam flow_meas',
                    'Compressor work', 'Reactor cooling water temperature', 'Condensor cooling water temperature', 
                    'Feed %A', 'Feed %B', 'Feed %C', 'Feed %D', 'Feed %E', 'Feed %F',
                    'Purge %A', 'Purge %B', 'Purge %C', 'Purge %D', 'Purge %E', 'Purge %F', 'Purge %G', 'Purge %H',
                    'Product %D', 'Product %E', 'Product %F', 'Product %G', 'Product %H']

# manipulated vars == control vars
manipulated_columns = ['D feed flow', 'E feed flow', 'A feed flow', 'C feed flow', 
                       'Compressor recycle valve', 'Purge flow', 'Separator liquid flow', 'Stripper liquid product flow',
                       'Stripper Steam flow_mv', 'Reactor cooling water flow', 'Condenser cooling water flow', 'Reactor Agitator speed']

attack_columns = ['is_mv_attack', 'is_meas_attack', 'is_sp_attack']
general_columns = ['Time', 'state', 'product_rate', 'hourly_cost']

# mode0,1,2,3,4,5,6: Different data distribution

In [None]:
trn_mode0_datalist = sorted([x for x in single_datalist if 'mode_0' in x])
tst_mode0_type21_datalist = sorted([x for x in attack_datalist if 'mode_0_type_21' in x])
tst_mode0_type22_datalist = sorted([x for x in attack_datalist if 'mode_0_type_22' in x])
tst_mode0_type23_datalist = sorted([x for x in attack_datalist if 'mode_0_type_23' in x])
tst_mode0_type24_datalist = sorted([x for x in attack_datalist if 'mode_0_type_24' in x])

In [None]:
# trn_ex1 = pd.read_csv(os.path.join(single_datadir, trn_mode0_datalist[0]), names = total_columns)
# trn_ex2 = pd.read_csv(os.path.join(single_datadir, trn_mode0_datalist[2]), names = total_columns)
# tst_ex1 = pd.read_csv(os.path.join(attack_datadir, tst_mode0_type21_datalist[0]), names = total_columns)
# tst_ex2 = pd.read_csv(os.path.join(attack_datadir, tst_mode0_type21_datalist[2]), names = total_columns)

In [None]:
# plt.plot(trn_ex1['Reactor temperature'][71250:71500])
# plt.plot(trn_ex2['Reactor temperature'][71250:71500])
# plt.plot(tst_ex1['Reactor temperature'][71250:71500])
# plt.plot(tst_ex2['Reactor temperature'][71250:71500])
# # plt.plot(110+tst_ex['is_meas_attack'][71250:71500])

In [None]:
def get_statdict(df, used_cols=measured_columns+manipulated_columns):
    stat_dict = dict()
    for col in used_cols:
        stat_dict[col] = dict()
        stat_dict[col]['mean'] = df[col].mean()
        stat_dict[col]['std'] = df[col].std()
        stat_dict[col]['min'] = df[col].min()
        stat_dict[col]['max'] = df[col].max()
    return stat_dict

In [None]:
def change_bool_to_float(df, used_cols):
    df = df[used_cols]
#     df *= 1
    df = df.astype(float)
    return df.values

In [None]:
def normalize(df, stat_dict, used_cols, normalize_method = 'none'):
    eps = 1e-7
    df_new = copy.copy(df[used_cols])
    for col in used_cols:
        stats = stat_dict[col]
        if normalize_method == 'none':
            continue
        elif normalize_method == 'z':
            df_new[col] = (df[col] - stats['mean']) / (stats['std'] + eps)
        elif normalize_method == 'minmax':
            df_new[col] = (df[col] - stats['min']) / (stats['max'] - stats['min'] + eps)
    return df_new.values

In [None]:
def make_twlist(tw=500, datadir=single_datadir, datalist=trn_mode0_datalist):
    total_twlist = []
    total_datalist = []
    for i, fname in enumerate(datalist):
        start = time.time()
        data = pd.read_csv(os.path.join(datadir, fname), names=total_columns)
   
        twlist = [data[j-tw: j] for j in range(tw, data.shape[0])]
        total_datalist.append(data)
        total_twlist.extend(twlist)
        print('{}/{} fname: {}, len(twlist): {}, elapsed_time: {:.2f}s'.format(i+1, len(datalist), fname[:-4], len(twlist), time.time()-start))
    df_total = pd.concat([x for x in total_datalist], axis=0)

    return total_twlist, df_total

In [None]:
trn_twlist, df_total = make_twlist(datadir=single_datadir,datalist=trn_mode0_datalist)
stat_dict = get_statdict(df_total)


In [None]:
del df_total

In [None]:
tst_twlist, _ = make_twlist(datadir=attack_datadir, datalist=tst_mode0_type21_datalist)

In [None]:
transform_op = transforms.Compose([dataset.Preprocessing(used_cols = measured_columns + manipulated_columns, 
                                                        stat_dict = stat_dict, normalize_method='minmax'),
                                   dataset.ToTensor()])

In [None]:
trn_dset = dataset.TWDataset(trn_twlist, transform=transform_op)
tst_dset = dataset.TWDataset(tst_twlist, transform=transform_op)

In [None]:
# np_list = []
# for i, df in tqdm.tqdm_notebook(enumerate(tst_twlist), total=len(tst_twlist)):
#     np_df = df.values
#     np_list.append(np_df)
#     if (i+1) % 10000 == 0:
#         time.sleep(0.1)

In [None]:
class EncoderConv1d(nn.Module):
    def __init__(self, tw, n_vars, kernels, strides, paddings, n_ch,
                 actfn_name='relu', use_fc=False, fc_size=20):
        super(EncoderConv1d, self).__init__()
        self.tw = tw
        self.n_vars = n_vars
        self.kernels = kernels
        self.strides = strides
        self.paddings = paddings
        self.use_fc = use_fc
        self.n_ch =n_ch

        k, s, p = kernels, strides, paddings
        length = tw
        act_fn = utils.get_actfn(actfn_name)

        layers = OrderedDict()
        in_ch, out_ch = n_vars, self.n_ch
        for i in range(len(k)-1):
            layers[f'conv{i+1}'] = nn.Conv1d(in_ch, out_ch, k[i], s[i], p[i], bias=False)
            layers[f'bn{i+1}'] = nn.BatchNorm1d(num_features=out_ch)
            layers[f'act{i+1}'] = act_fn
            length = utils.conv1d_output_size(length, k[i], s[i], p[i])
            in_ch, out_ch = out_ch, out_ch*2
        i += 1
        layers[f'conv{i+1}'] = nn.Conv1d(in_ch, out_ch, k[i], s[i], p[i], bias=True) # No batchnorm -> Bias!
        length = utils.conv1d_output_size(length, k[i], s[i], p[i])
        self.layers = nn.Sequential(layers)
        if self.use_fc:
            self.fc = nn.Linear(length * out_ch, fc_size)

    def forward(self, x):
        out = self.layers(x)
        if self.use_fc:
            out = torch.flatten(out, start_dim=1)
            out = self.fc(out)
        return out

In [None]:
class DecoderConv1d(nn.Module):
    def __init__(self, tw, n_vars, kernels, strides, paddings, n_ch,
                 embed_length, actfn_name='relu', outactfn_name='sigmoid', use_fc=False, fc_size=20):
        super(DecoderConv1d, self).__init__()
        self.tw = tw
        self.n_vars =n_vars
        self.kernels = [x for x in reversed(kernels)]
        self.strides = [x for x in reversed(strides)]
        self.paddings = [x for x in reversed(paddings)]
        self.n_ch = n_ch
        self.embed_length = embed_length
        self.use_fc = use_fc
        self.fc_size = fc_size

        k, s, p = self.kernels, self.strides, self.paddings

        in_ch = self.n_ch * (2**(len(k)-1))
        out_ch = int(in_ch / 2)
        if self.use_fc:
            self.fc = nn.Linear(self.fc_size, embed_length * in_ch)
        length = embed_length
        act_fn = utils.get_actfn(actfn_name)
        outact_fn = utils.get_actfn(outactfn_name)
        layers = OrderedDict()
        for i in range(len(k)-1):
            layers[f'convtr{i+1}'] = nn.ConvTranspose1d(in_ch, out_ch, k[i], s[i], p[i], bias=False)
            layers[f'bn{i+1}'] = nn.BatchNorm1d(num_features=out_ch)
            layers[f'act{i+1}'] = act_fn
            length = utils.convtr1d_output_size(length, k[i], s[i], p[i])
            in_ch, out_ch = out_ch, int(out_ch/2)
        i += 1
        layers[f'convtr{i+1}'] = nn.ConvTranspose1d(in_ch, n_vars, k[i], s[i], p[i], bias=True)
        # length = utils.convtr1d_output_size(length, k[i], s[i], p[i])
        self.layers = nn.Sequential(layers)
        self.outact_fn = outact_fn
        
    def forward(self, z):
        if self.use_fc:
            z = self.fc(z)
            out_ch = self.n_ch * (2**(len(self.kernels)-1))
            z = z.view(-1, out_ch, self.embed_length)
        out = self.layers(z)
        if self.outact_fn is not None:
            out = self.outact_fn(out)
        return out



In [None]:
# enc = EncoderConv1d(tw=500, n_vars = len(measured_columns+manipulated_columns), 
#                     kernels=[5,5,5,5], strides=[1,1,1,1], paddings=[0,0,0,0], n_ch=64)
# dec = DecoderConv1d(tw=500, n_vars=len(measured_columns+manipulated_columns),
#                     kernels=[5,5,5,5], strides=[1,1,1,1], paddings=[0,0,0,0], n_ch=64,
#                     embed_length=5, actfn_name='relu', outactfn_name='sigmoid', use_fc=False, fc_size=20)

In [None]:
# for i, batch in enumerate(loader):
#     batch['x'] = torch.transpose(batch['x'], 1, 2)
#     batch['general'] = torch.transpose(batch['general'], 1, 2)
#     batch['attack'] = torch.transpose(batch['attack'], 1, 2)
#     if i ==0:
#         break

In [None]:
# z = enc(batch['x'])
# x_hat = dec(z)
# print(z.size(), batch['x'].size(), x_hat.size())

In [None]:
class arguments():
    def __init__(self):
        self.trn_datadir = '/data/jehyuk/TEP/single_states/'
        self.tst_datadir = '/data/jehyuk/TEP/attacks/'
        self.mode = 0
        self.tw = 500
        self.k = [5,5,5,5]
        self.s = [1,1,1,1]
        self.p = [0,0,0,0]
        self.n_ch = 128
        self.actfn_name='relu'
        self.outactfn_name='sigmoid'
        self.use_fc=False
        self.fc_size = 20
        self.normalize = 'minmax'
        self.lr = 0.0002
        self.n_workers = 5
        self.device_num = 0
        self.whole_devices = [0,1,2,3]
        self.weight_decay = 0.01
        self.n_epoch = 1
        self.trn_batch_size = 1024
        self.tst_batch_size = 128

In [None]:
args = arguments()

In [None]:
class ConvAE(nn.Module):
    def __init__(self, tw, used_cols, k, s, p, n_ch, use_fc=False, fc_size=20, 
                 actfn_name='relu', outactfn_name='sigmoid'):
        """
        :param tw: time window size
        :param used_cols: column list which is used in modeling
        :param k: kernel size list
        :param s: stride size list
        :param p: padding size list
        :param n_ch: channel bunch
        :param actfn_name: activation function in hidden layer. default='relu'
        :param outactfn_name: activation function in output layer. default='sigmoid'
        """
        super(ConvAE, self).__init__()
        self.tw = tw
        self.used_cols = used_cols
        self.k = k
        self.s = s
        self.p = p
        self.n_ch = n_ch
        embed_len = tw
        for i in range(len(k)):
            embed_len = utils.conv1d_output_size(embed_len, k[i], s[i], p[i])
        self.enc = EncoderConv1d(tw, len(used_cols), k, s, p, n_ch, actfn_name, use_fc, fc_size)
        self.dec = DecoderConv1d(tw, len(used_cols), k, s, p, n_ch, embed_len, actfn_name, outactfn_name, use_fc, fc_size)
    
    def forward(self, x):
        z = self.enc(x)
        x_hat = self.dec(z)
        return x_hat

In [None]:
device = torch.device(f'cuda:{args.device_num}')

In [None]:
model = ConvAE(args.tw, measured_columns+manipulated_columns, args.k, args.s, args.p, args.n_ch)

# 1. Use DataParallel

- 모델을 각 GPU에 복사해서 할당 (replicate)
- 매 iteration마다 batch를 GPU수만큼 분배 (scatter)
- 각 GPU에서 forward를 진행 (parallel_apply)
- 각 GPU에서 모델이 출력을 내보내면 이 출력들을 하나의 GPU로 collect (gather)

In [None]:
# trn_loader = DataLoader(trn_dset, batch_size = args.trn_batch_size, num_workers=5, shuffle=True)
# tst_loader = DataLoader(tst_dset, batch_size = args.tst_batch_size, num_workers=0, shuffle=False)

In [None]:
# class AETrainer1:
#     def __init__(self, model, lr, weight_decay, device, whole_devices, check_every=10):
#         """
#         :param model: AE model to train
#         :param trn_loader: train data loader
#         :param tst_loader: test data loader
#         :param lr: learning rate
#         :param weight_decay: weight decay
#         :param device: torch.device('cuda:{}') where the main model is positioned
#         :param whole_device: list of cuda numbers where the parallel operation is done
#         :param check_every: logging frequency
#         """
#         self.model = model.to(device)
#         self.lr = lr
#         self.weight_decay = weight_decay
#         self.device = device
#         self.whole_devices = whole_devices
#         self.check_every = check_every
#         if device.type == 'cuda' and torch.cuda.device_count()>1 and len(whole_devices) > 1:
#             print("Using {} gpus for training AE".format(len(whole_devices)))
#             self.model = nn.DataParallel(self.model, device_ids=whole_devices)
#         self.optim = optim.Adam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
#         self.criterion = nn.MSELoss(reduction='sum')
#         self.check_every = check_every
        
#     def partial_fit(self, data_loader, epoch, train=True):
#         """
#         :param data_loader: torch.utils.data.DataLoader for iteration
#         :param train: boolean value whether it is train or test
#         """
#         str_code = 'train' if train else 'test'
#         data_iter = tqdm.tqdm_notebook(enumerate(data_loader), 
#                                        desc='epoch_{}:{}'.format(str_code, epoch), 
#                                        total=len(data_loader))
#         avg_loss = 0
#         for iter_num, batch in data_iter:
#             if train:
#                 self.model.train()
#             else:
#                 self.model.eval()
#             self.optim.zero_grad()
#             x=torch.transpose(batch['x'], 1, 2).to(self.device)
#             general = batch['general'].to(self.device)
#             attack = batch['attack'].to(self.device)
#             x_hat = self.model(x)
#             loss = self.criterion(x_hat, x)
#             avg_loss += loss.item()
#             loss.backward()
#             self.optim.step()
#             post_fix = {
#                 "epoch": epoch, 
#                 "iter": iter_num+1, 
#                 "avg_loss": avg_loss / (data_loader.batch_size*(iter_num+1)),
#                 "batch_loss": loss.item() / data_loader.batch_size
#             }
#             if (iter_num+1) % self.check_every == 0:
#                 data_iter.write(str(post_fix))
    
#     def train(self, data_loader, n_epoch, train=True):
#         for epoch in tqdm.tqdm_notebook(range(1, n_epoch+1)):
#             self.partial_fit(data_loader, epoch, train)
            
#     def test(self, data_loader, epoch=0, train=False):
#         self.partial_fit(data_loader, epoch, train)
    
#     def save_model(self, save_dir):
#         if device.type == 'cuda' and torch.cuda.device_count()>1 and len(self.whole_devices)>1:
#             torch.save(self.model.module.encoder.state_dict(), os.path.join(save_dir, 'encoder.pkl'))
#             torch.save(self.model.module.decoder.state_dict(), os.path.join(save_dir, 'decoder.pkl'))
#         else:
#             torch.save(self.model.encoder.state_dict(), os.path.join(save_dir, 'encoder.pkl'))
#             torch.save(self.model.decoder.state_dict(), os.path.join(save_dir, 'decoder.pkl'))
       
#     def load_model(self, save_dir):
#         if device.type == 'cuda' and torch.cuda.device_count()>1 and len(self.whole_devices)>1:
#             self.model.module.encoder.load_state_dict(torch.load(os.path.join(save_dir, 'encoder.pkl')))
#             self.model.module.decoder.load_state_dict(torch.load(os.path.join(save_dir, 'decoder.pkl')))
#         else:
#             self.model.encoder.load_state_dict(torch.load(os.path.join(save_dir, 'encoder.pkl')))
#             self.model.decoder.load_state_dict(torch.load(os.path.join(save_dir, 'decoder.pkl')))

In [None]:
# trainer = AETrainer1(model, args.lr, args.weight_decay, device, args.whole_devices, check_every=200)

In [None]:
# trainer.train(trn_loader, args.n_epoch)

## Result

# 2. Custom DataParallel 사용하기

- Pytorch의 nn.DataParallel을 사용하면 메모리 불균형이 생긴다.
    - 이는 위에서 언급한 gather 프로세스 때문에 이런 현상이 발생한다.
    - 왜 그래야 할까?
        - 모델은 DataParallel을 통해 병렬 연산이 가능하게 했지만, loss function이 그대로
        - 이러한 이유로 하나의 GPU로 출력값을 모아서 loss를 계산함


- Loss function도 병렬로 연산하도록 만들면 메모리 불균형 문제를 해결하는 것이 가능
- How?
    - Pytorch에서는 loss function도 하나의 모듈 --> loss function을 각 GPU에 replicate(*)
    - 정답 tensor를 각 GPU로 scatter
    - loss를 계산하기 위한 모델에서의 출력값, loss function, 정답 tensor 모두 각 GPU에 존재 --> 각 GPU에서 loss 계산 가능 --> loss backward
    
    
- How to replicate the loss function module?
    - target을 각 GPU로 scatter하고, 각 GPU에 replicate된 모듈에서 계산을 한다.
    - 계산된 output과 Reduce.apply를 통해 각 GPU에서 backward연산을 하도록 한다.
    
    
- Custom DataParallel class인 DataParallelModel을 사용한다.
    - nn.DataParallel은 기본적으로 하나의 GPU로 출력을 모은다.
    - Pytorch-Encoding package에서 parallel.py파일을 가져와서 학습코드에서 import하도록 하면 된다.
        - Source code: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py
        - Issues: https://github.com/zhanghang1989/PyTorch-Encoding/issues/54

In [None]:
import parallel

In [None]:
class AETrainer2:
    def __init__(self, model, lr, weight_decay, device, whole_devices, check_every=10):
        """
        :param model: AE model to train
        :param trn_loader: train data loader
        :param tst_loader: test data loader
        :param lr: learning rate
        :param weight_decay: weight decay
        :param device: torch.device('cuda:{}') where the main model is positioned
        :param whole_device: list of cuda numbers where the parallel operation is done
        :param check_every: logging frequency
        """
        torch.cuda.empty_cache()
        self.model = model.to(device)
        self.lr = lr
        self.weight_decay = weight_decay
        self.device = device
        self.whole_devices = whole_devices
        self.check_every = check_every
        self.criterion = nn.MSELoss(reduction='sum')
        self.optim = optim.Adam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
        if device.type == 'cuda' and torch.cuda.device_count()>1 and len(whole_devices) > 1:
            print("Using {} gpus for training AE".format(len(whole_devices)))
            self.model = parallel.DataParallelModel(self.model).to(device)
#             nn.DataParallel(self.model, device_ids=whole_devices)
            self.criterion = parallel.DataParallelCriterion(self.criterion).to(device)
        self.check_every = check_every
        
    def partial_fit(self, data_loader, epoch, train=True):
        """
        :param data_loader: torch.utils.data.DataLoader for iteration
        :param train: boolean value whether it is train or test
        """
        str_code = 'train' if train else 'test'
        data_iter = tqdm.tqdm_notebook(enumerate(data_loader), 
                                       desc='epoch_{}:{}'.format(str_code, epoch), 
                                       total=min(len(data_loader), 100))
        avg_loss = 0
        for iter_num, batch in data_iter:
            if train:
                self.model.train()
            else:
                self.model.eval()
            self.optim.zero_grad()
            x=torch.transpose(batch['x'], 1, 2).to(self.device)
            general = batch['general'].to(self.device)
            attack = batch['attack'].to(self.device)
            x_hat = self.model(x)
            loss = self.criterion(x_hat, x)
            avg_loss += loss.item()
            loss.backward()
            self.optim.step()
            post_fix = {
                "epoch": epoch, 
                "iter": iter_num+1, 
                "avg_loss": avg_loss / (data_loader.batch_size*(iter_num+1)),
                "batch_loss": loss.item() / data_loader.batch_size
            }
            if (iter_num+1) % self.check_every == 0:
                data_iter.write(str(post_fix))
                
            if (iter_num+1) == 50:
                break

    
    def train(self, data_loader, n_epoch, train=True):
        for epoch in tqdm.tqdm_notebook(range(1, n_epoch+1)):
            self.partial_fit(data_loader, epoch, train)
            
    def test(self, data_loader, epoch=0, train=False):
        self.partial_fit(data_loader, epoch, train)
    
    def save_model(self, save_dir):
        if device.type == 'cuda' and torch.cuda.device_count()>1 and len(self.whole_devices)>1:
            torch.save(self.model.module.encoder.state_dict(), os.path.join(save_dir, 'encoder.pkl'))
            torch.save(self.model.module.decoder.state_dict(), os.path.join(save_dir, 'decoder.pkl'))
        else:
            torch.save(self.model.encoder.state_dict(), os.path.join(save_dir, 'encoder.pkl'))
            torch.save(self.model.decoder.state_dict(), os.path.join(save_dir, 'decoder.pkl'))
       
    def load_model(self, save_dir):
        if device.type == 'cuda' and torch.cuda.device_count()>1 and len(self.whole_devices)>1:
            self.model.module.encoder.load_state_dict(torch.load(os.path.join(save_dir, 'encoder.pkl')))
            self.model.module.decoder.load_state_dict(torch.load(os.path.join(save_dir, 'decoder.pkl')))
        else:
            self.model.encoder.load_state_dict(torch.load(os.path.join(save_dir, 'encoder.pkl')))
            self.model.decoder.load_state_dict(torch.load(os.path.join(save_dir, 'decoder.pkl')))

In [None]:
# trainer = AETrainer2(model, args.lr, args.weight_decay, device, args.whole_devices, check_every=10)

In [None]:
# trn_loader = DataLoader(trn_dset, batch_size = args.trn_batch_size, num_workers=5, shuffle=True)
# tst_loader = DataLoader(tst_dset, batch_size = args.tst_batch_size, num_workers=0, shuffle=False)

In [None]:
# trainer.train(trn_loader, args.n_epoch, train=True)

# 3. Pytorch에서 Distributed Package 사용하기

- 예제 코드: https://github.com/pytorch/examples/blob/master/imagenet/main.py
- 분산 학습을 사용해서 Multi-GPU 학습 하기(예제 코드를 기반으로)
    - main.py를 실행하면, line 109의 mp.spawn(main_worker, ...)가 실행됨
        - main_worker에서는 4개의 GPU를 한개의 node로 간주하고, world_size를 결정(line 106)
        - mp.spawn은 4개의 GPU에서 따로 따로 main_worker을 multi processing으로 실행(line 109)
        
    - dist.init_process_group (line 129)
        - 각 GPU마다 분산학습을 위한 초기화
        - Pytorch의 document에서는 multi-GPU학습시, backend로 nccl을 사용하라고 함
        - init_method에서 FREEPORT에 사용 가능한 port를 적으면 됨
        - 초기화를 실행하고 나면, 분산학습이 가능해짐
        
    - torch.nn.parallel.DistributedDataParallel(DDP) (line 151, 156)
        - DDP는 module level에서 data parallelism을 실행함
        - torch.distributed 패키지 내의 communication collectives를 사용하여, 
          gradient, params, buffers를 synchronize한다.
        - Within process, across process가 둘 다 가능하다.
            - Within: nn.DataParallel과 유사
            - Across: DDP는 필요한 params들에 대해, forward에서 synchronization하고, backward에서 grad syncrhonization
            
    - DistributedSampler (line 210)
        - dset을 DistributedSampler로 감싸주고, DataLoader에서 sampler에 인자로 넣어줌
        - DDP와 함께 사용해야 함
        - 작동 원리
            - 각 sampler는 전체 데이터를 GPU 갯수로 나눈 부분 데이터에서만 데이터를 샘플링
            - 부분데이터를 만들기 위해 전체 dset의 idx list를 무작위로 섞고, 
              해당 idx list를 쪼개서 GPU sampler에 할당
            - Epoch마다 idx list가 계속 달라지므로, train_sampler.set_epoch(epoch)를 매 epoch마다 학습전에 실행

## Does not work in this server!!

In [None]:
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.utils.data.distributed

In [None]:
class arguments():
    def __init__(self):
        self.trn_datadir = '/data/jehyuk/TEP/single_states/'
        self.tst_datadir = '/data/jehyuk/TEP/attacks/'
        self.mode = 0
        self.tw = 500
        self.k = [5,5,5,5]
        self.s = [1,1,1,1]
        self.p = [0,0,0,0]
        self.n_ch = 128
        self.actfn_name='relu'
        self.outactfn_name='sigmoid'
        self.use_fc=False
        self.fc_size = 20
        self.normalize = 'minmax'
        self.lr = 0.0002
        self.n_workers = 5
        self.device_num = 0
        self.whole_devices = [0,1,2,3]
        self.weight_decay = 0.01
        self.n_epoch = 1
        self.trn_batch_size = 1024
        self.tst_batch_size = 128
        self.dist_url = 'tcp://147.46.178.58:54321'
        self.dist_backend = 'nccl'
        self.world_size = 1
        self.distributed = True
        self.rank = 0
        self.num_workers = 8

In [None]:
args = arguments()

In [None]:
model = ConvAE(args.tw, measured_columns+manipulated_columns, args.k, args.s, args.p, args.n_ch)

In [None]:
device = torch.device(f'cuda:{args.device_num}')

In [None]:
transform_op = transforms.Compose([dataset.Preprocessing(used_cols = measured_columns + manipulated_columns, 
                                                         stat_dict = stat_dict, normalize_method='minmax'),
                                   dataset.ToTensor()])
trn_dset = dataset.TWDataset(trn_twlist, transform=transform_op)
tst_dset = dataset.TWDataset(tst_twlist, transform=transform_op)

In [None]:
class AETrainer3:
    def __init__(self, model, args):
#         torch.cuda.empty_cache()
        self.args = args
        self.model = model
        self.ngpus_per_node = len(args.whole_devices)
        self.world_size = self.ngpus_per_node * self.args.world_size
        self.criterion = nn.MSELoss()
        self.opt = optim.Adam(self.model.parameters(), lr = args.lr, weight_decay = args.weight_decay)
        
    def train_per_worker(device_num, trn_dset, ngpus_per_node, args):
        args.device_num = device_num
        print(1)
        device = torch.device(f'cuda:{args.device_num}')
        print(2)
        ngpus_per_node = len(args.whole_deivces)
        print(3)
        args.rank = args.rank * ngpus_per_node + args.device_num
        print(4)
        dist.init_process_group(backend=args.dist_bacekend, init_method=args.dist_url, 
                                world_size=args.world_size, rank=args.rank)
        print(5)
        self.model.to(device)
        print(6)
        args.trn_batch_size = int(args.trn_batch_size / ngpus_per_node)
        args.tst_batch_size = int(args.tst_batch_size / ngpus_per_node)
        print(7)
        model = torch.nn.parallel.DistributedDataParallel(model).to(device)
        print(8)
        n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(9)
        print('# of params of model is : {}'.format(n_params))
        
        print('>>Preparing Data...')
        
        trn_sampler = torch.utils.data.distributed.DistributedSampler(trn_dset)
        trn_loader = DataLoader(trn_dset, batch_size = args.batch_size, 
                                shuffle=(trn_sampler is None), 
                                num_workers = args.num_workers, 
                                sampler = trn_sampler)
        
        
        for epoch in range(1, args.n_epoch+1):
            trn_sampler.set_epoch(epoch)
            self.train_epoch(trn_loader, train=True, device=device)
        
    def train_epoch(self, data_loader, train, device):
        str_code = 'train' if train else 'test'
        data_iter = tqdm.tqdm_notebook(enumerate(data_loader), 
                                       desc='epoch_{}:{}'.format(str_code, epoch), 
                                       total = min(len(data_loader), 50))
        avg_loss = 0
        for iter_num, batch in data_iter:
            if train:
                self.model.train()
            else:
                self.model.eval()
            self.optim.zero_grad()
            x=torch.transpose(batch['x'], 1, 2).to(self.device)
            general = batch['general'].to(self.device)
            attack = batch['attack'].to(self.device)
            x_hat = self.model(x)
            loss = self.criterion(x_hat, x)
            avg_loss += loss.item()
            loss.backward()
            self.optim.step()
            post_fix = {
                "epoch": epoch, 
                "iter": iter_num+1, 
                "avg_loss": avg_loss / (data_loader.batch_size*(iter_num+1)),
                "batch_loss": loss.item() / data_loader.batch_size
            }
            if (iter_num+1) % self.check_every == 0:
                data_iter.write(str(post_fix))
                
            if (iter_num+1) == 50:
                break
                
    def train(self, trn_dset, train=True):
        args = self.args
        args.world_size = self.ngpus_per_node * self.args.world_size
        ngpus_per_node = self.ngpus_per_node
        print(0)
        mp.spawn(self.train_per_worker, nprocs = ngpus_per_node, args=(trn_dset, ngpus_per_node, args))
        
            
#     def test(self, data_loader, epoch=0, train=False):
#         self.partial_fit(data_loader, epoch, train)
    
#     def save_model(self, save_dir):
#         if device.type == 'cuda' and torch.cuda.device_count()>1 and len(self.whole_devices)>1:
#             torch.save(self.model.module.encoder.state_dict(), os.path.join(save_dir, 'encoder.pkl'))
#             torch.save(self.model.module.decoder.state_dict(), os.path.join(save_dir, 'decoder.pkl'))
#         else:
#             torch.save(self.model.encoder.state_dict(), os.path.join(save_dir, 'encoder.pkl'))
#             torch.save(self.model.decoder.state_dict(), os.path.join(save_dir, 'decoder.pkl'))
       
#     def load_model(self, save_dir):
#         if device.type == 'cuda' and torch.cuda.device_count()>1 and len(self.whole_devices)>1:
#             self.model.module.encoder.load_state_dict(torch.load(os.path.join(save_dir, 'encoder.pkl')))
#             self.model.module.decoder.load_state_dict(torch.load(os.path.join(save_dir, 'decoder.pkl')))
#         else:
#             self.model.encoder.load_state_dict(torch.load(os.path.join(save_dir, 'encoder.pkl')))
#             self.model.decoder.load_state_dict(torch.load(os.path.join(save_dir, 'decoder.pkl')))

In [None]:
trainer = AETrainer3(model, args)

In [None]:
trainer.train(trn_dset)