# LSTMによる非線形歪補償
時系列データを考慮できるLSTMによる補償

In [1]:
#import
import sys
import time
import datetime

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

sys.path.append('../')
from pyopt.util import save_pickle, load_pickle

# 1. Preprocessing

## 1.1 データの整形

In [2]:
def data_shaping(input_signal, signal, max_tap, tap):
    x = np.zeros((len(input_signal) - (max_tap - 1), tap, 2), dtype=float)
    y = np.zeros((len(input_signal) - (max_tap - 1), 2), dtype=float)
    for i, j in enumerate(np.arange(int((max_tap - 1) / 2), len(input_signal) - int((max_tap - 1) / 2))):
        x[i, :, 0] = signal[j - int((tap - 1) / 2): j + int((tap - 1) / 2) + 1].real
        x[i, :, 1] = signal[j - int((tap - 1) / 2): j + int((tap - 1) / 2) + 1].imag
        y[i, 0] = input_signal[j].real
        y[i, 1] = input_signal[j].imag
    return x, y

In [51]:
def data_shaping2(input_signal, signal, max_tap, tap):
    x, y = data_shaping(input_signal, signal, max_tap, tap)
    x_flap = x[:, int((tap - 1) / 2):, :][:, ::-1, :]
    x = np.concatenate([x[:, :int((tap + 1) / 2), :], x_flap], axis=2)
    return x, y

In [52]:
#動作確認
tap = 9
max_tap = 11

df_dir = '../data/input/prbs.csv'
df = pd.read_csv(df_dir, index_col=0)  # dataframe読み込み
condition = (df['N']==13) & (df['itr']==1) & (df['form']=='RZ16QAM') & (df['n']==32) & (df['equalize']==False) & (df['baudrate']==28) & (df['PdBm']==1)
sgnl = load_pickle(df[condition].iloc[0]['data_path'])  # dataframeから条件と合う行を取得し,pickleの保存先(data_path)にアクセス
lc = sgnl.linear_compensation(500, sgnl.signal['x_500'])
x, y = data_shaping2(sgnl.signal['x_0'][16::32], lc[16::32], max_tap, tap)  # ANNに入力できるようにデータを整形

print('x size: ', x.shape)
print('y size: ', y.shape)

x size:  (2038, 5, 4)
y size:  (2038, 2)


## 1.2 平均,標準偏差の計算

In [53]:
mean = np.mean(x)
std = np.std(x)

print('mean: ', mean)
print('std: ', std)

mean:  1122.4889875685737
std:  52164.373260725646


# 2. Dataset定義

In [54]:
class Dataset(data.Dataset):
    def __init__(self, x, y, mean, std):
        self.x, self.y, self.mean, self.std = x, y, mean, std
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, index):
        x = self.x[index]
        y = self.y[index]
        
        x = (x - self.mean) / self.std
        y = (y - self.mean) / self.std
        return torch.Tensor(x), torch.Tensor(y)

In [55]:
#動作確認
train_dataset = Dataset(x=x, y=y, mean=mean, std=std)

index = 0
x_normalized, y_normalized = train_dataset.__getitem__(index)
x_array = x_normalized.detach().numpy()

print('mean: ', np.mean(x_array))
print('std: ', np.std(x_array))
print(x_normalized)
print(y_normalized)

mean:  0.41994858
std:  1.0200968
tensor([[ 0.0977, -0.5941,  1.8054,  0.5386],
        [ 0.1692, -0.6423, -0.6453,  1.8167],
        [ 0.5423, -1.8595,  1.7882,  0.5346],
        [ 1.4128, -0.2311, -0.6323, -0.2189],
        [ 1.7284,  0.5301,  1.7284,  0.5301]])
tensor([1.3295, 1.3295])


In [56]:
batch_size = 100

train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

dataloaders_dict = {'train': train_dataloader}

# 3. Model定義

In [57]:
class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(LSTM, self).__init__()
        self.rnn = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x, hidden0=None):
        x, (hidden, cell) = self.rnn(x, hidden0)
        x = self.fc(x[:, -1, :])  # int((x.shape[1] - 1) / 2)
        return x

In [58]:
#動作確認
hidden_dim = 100
model = LSTM(input_dim=4, hidden_dim=hidden_dim, output_dim=2)
for x, y in train_dataloader:
    output = model(x)
    print(output[:6])
    break

tensor([[ 0.0847, -0.0189],
        [ 0.0995, -0.0231],
        [ 0.0972, -0.0327],
        [ 0.0919, -0.0341],
        [ 0.1034, -0.0234],
        [ 0.0946, -0.0539]], grad_fn=<SliceBackward>)


# 4. Train定義

In [59]:
def evm_score(y_true, y_pred):
    if y_true.ndim == 2:
        y_true = y_true[:, 0] + 1j * y_true[:, 1]
        y_pred = y_pred[:, 0] + 1j * y_pred[:, 1]
    tmp = 0
    for i in range(len(y_pred)):
        tmp += abs(y_pred[i] - y_true[i]) ** 2 / abs(y_true[i]) ** 2
    evm = torch.sqrt(tmp / len(y_pred)) * 100
    return evm

In [60]:
def train_model(model, dataloaders_dict, criterion, optimizer, epochs):
    for epoch in range(epochs):
        start_time = time.time()
        
        for phase in dataloaders_dict.keys():
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            epoch_loss = 0.0
            epoch_evms = 0.0
            
            for x, y in dataloaders_dict[phase]:
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(x)
                    loss = criterion(outputs, y)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    epoch_loss += loss.item() * x.size(0)
                    epoch_evms = (evm_score(y, outputs) / 100) ** 2 * x.size(0)
            
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_evm = torch.sqrt(epoch_evms / len(dataloaders_dict[phase].dataset)) * 100
            
            duration = str(datetime.timedelta(seconds=time.time() - start_time))[:7]
            print('{} | Epoch: {}/{} | {} Loss: {:.4} | EVM: {:.4}'.format(duration, epoch + 1, epochs, phase, epoch_loss, epoch_evm))
    return model

In [61]:
#動作確認
epochs = 5
lr = 0.001

criterion = nn.MSELoss()
optimizer = optim.Adam(params=model.parameters(), lr=lr)

train_model(model=model, dataloaders_dict=dataloaders_dict, criterion=criterion, optimizer=optimizer, epochs=epochs);

0:00:00 | Epoch: 1/5 | train Loss: 0.908 | EVM: 12.29
0:00:00 | Epoch: 2/5 | train Loss: 0.6559 | EVM: 9.45
0:00:00 | Epoch: 3/5 | train Loss: 0.226 | EVM: 3.254
0:00:00 | Epoch: 4/5 | train Loss: 0.01303 | EVM: 1.621
0:00:00 | Epoch: 5/5 | train Loss: 0.005479 | EVM: 1.164


# 5. 実行

In [63]:
tap = 201
max_tap = 501
batch_size = 100
hidden_dim = 100
epochs = 500
lr = 0.001

In [64]:
df_dir = '../data/input/'
df0 = pd.read_csv(df_dir+'prbs.csv', index_col=0)

condition0 = (df0['N']==13) & (df0['itr']==1) & (df0['form']=='RZ16QAM') & (df0['n']==32) & (df0['equalize']==False) & (df0['baudrate']==28) & (df0['PdBm']==1)
sgnl0 = load_pickle(df0[condition0].iloc[0]['data_path'])
lc0 = sgnl0.linear_compensation(2500, sgnl0.signal['x_2500'])
x0, y0 = data_shaping2(sgnl0.signal['x_0'][16::32], lc0[16::32], max_tap, tap)

condition1 = (df0['N']==17) & (df0['itr']==1) & (df0['form']=='RZ16QAM') & (df0['n']==32) & (df0['equalize']==False) & (df0['baudrate']==28) & (df0['PdBm']==1)
sgnl1 = load_pickle(df0[condition1].iloc[0]['data_path'])
lc1 = sgnl1.linear_compensation(2500, sgnl1.signal['x_2500'])
x1, y1 = data_shaping2(sgnl1.signal['x_0'][16::32], lc1[16::32], max_tap, tap)

mean = np.mean(x0)
std = np.std(x0)

train_dataset = Dataset(x=x0, y=y0, mean=mean, std=std)
val_dataset = Dataset(x=x1, y=y1, mean=mean, std=std)

train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
dataloaders_dict = {'train': train_dataloader, 'val': val_dataloader}

model = LSTM(input_dim=4, hidden_dim=hidden_dim, output_dim=2)
criterion = nn.MSELoss()
optimizer = optim.Adam(params=model.parameters(), lr=lr)

model = train_model(model=model, dataloaders_dict=dataloaders_dict, criterion=criterion, optimizer=optimizer, epochs=epochs)

0:00:02 | Epoch: 1/500 | train Loss: 0.9473 | EVM: 16.31
0:00:25 | Epoch: 1/500 | val Loss: 0.8707 | EVM: 4.214
0:00:02 | Epoch: 2/500 | train Loss: 0.7505 | EVM: 14.03
0:00:25 | Epoch: 2/500 | val Loss: 0.6066 | EVM: 3.448
0:00:02 | Epoch: 3/500 | train Loss: 0.4529 | EVM: 9.721
0:00:25 | Epoch: 3/500 | val Loss: 0.2724 | EVM: 2.258
0:00:02 | Epoch: 4/500 | train Loss: 0.1552 | EVM: 4.546
0:00:26 | Epoch: 4/500 | val Loss: 0.05851 | EVM: 1.044
0:00:02 | Epoch: 5/500 | train Loss: 0.02692 | EVM: 2.932
0:00:26 | Epoch: 5/500 | val Loss: 0.0197 | EVM: 0.8369
0:00:02 | Epoch: 6/500 | train Loss: 0.01524 | EVM: 3.299
0:00:27 | Epoch: 6/500 | val Loss: 0.01706 | EVM: 0.7958
0:00:02 | Epoch: 7/500 | train Loss: 0.01396 | EVM: 1.759
0:00:27 | Epoch: 7/500 | val Loss: 0.01701 | EVM: 0.7768
0:00:02 | Epoch: 8/500 | train Loss: 0.01338 | EVM: 3.089
0:00:27 | Epoch: 8/500 | val Loss: 0.01639 | EVM: 0.7735
0:00:02 | Epoch: 9/500 | train Loss: 0.01349 | EVM: 2.931
0:00:27 | Epoch: 9/500 | val Loss:

0:00:30 | Epoch: 71/500 | val Loss: 0.01691 | EVM: 0.7891
0:00:02 | Epoch: 72/500 | train Loss: 0.01226 | EVM: 3.312
0:00:31 | Epoch: 72/500 | val Loss: 0.01636 | EVM: 0.7971
0:00:02 | Epoch: 73/500 | train Loss: 0.01232 | EVM: 3.076
0:00:30 | Epoch: 73/500 | val Loss: 0.0167 | EVM: 0.8028
0:00:02 | Epoch: 74/500 | train Loss: 0.01218 | EVM: 2.579
0:00:31 | Epoch: 74/500 | val Loss: 0.01626 | EVM: 0.8035
0:00:02 | Epoch: 75/500 | train Loss: 0.01258 | EVM: 1.978
0:00:30 | Epoch: 75/500 | val Loss: 0.01747 | EVM: 0.795
0:00:02 | Epoch: 76/500 | train Loss: 0.0122 | EVM: 2.609
0:00:30 | Epoch: 76/500 | val Loss: 0.01638 | EVM: 0.7737
0:00:02 | Epoch: 77/500 | train Loss: 0.01218 | EVM: 2.65
0:00:30 | Epoch: 77/500 | val Loss: 0.01726 | EVM: 0.8157
0:00:02 | Epoch: 78/500 | train Loss: 0.0122 | EVM: 2.581
0:00:30 | Epoch: 78/500 | val Loss: 0.01655 | EVM: 0.8037
0:00:02 | Epoch: 79/500 | train Loss: 0.01265 | EVM: 2.243
0:00:30 | Epoch: 79/500 | val Loss: 0.0166 | EVM: 0.7923
0:00:02 | Ep

0:00:27 | Epoch: 141/500 | val Loss: 0.01848 | EVM: 0.8152
0:00:02 | Epoch: 142/500 | train Loss: 0.009951 | EVM: 2.713
0:00:27 | Epoch: 142/500 | val Loss: 0.01988 | EVM: 0.7919
0:00:02 | Epoch: 143/500 | train Loss: 0.009765 | EVM: 2.621
0:00:27 | Epoch: 143/500 | val Loss: 0.01851 | EVM: 0.7966
0:00:02 | Epoch: 144/500 | train Loss: 0.00961 | EVM: 2.704
0:00:27 | Epoch: 144/500 | val Loss: 0.01906 | EVM: 0.8302
0:00:02 | Epoch: 145/500 | train Loss: 0.009782 | EVM: 2.616
0:00:26 | Epoch: 145/500 | val Loss: 0.01826 | EVM: 0.7815
0:00:02 | Epoch: 146/500 | train Loss: 0.009278 | EVM: 2.276
0:00:29 | Epoch: 146/500 | val Loss: 0.01894 | EVM: 0.7782
0:00:02 | Epoch: 147/500 | train Loss: 0.009189 | EVM: 2.94
0:00:27 | Epoch: 147/500 | val Loss: 0.01919 | EVM: 0.8108
0:00:02 | Epoch: 148/500 | train Loss: 0.009318 | EVM: 2.439
0:00:29 | Epoch: 148/500 | val Loss: 0.01885 | EVM: 0.7718
0:00:02 | Epoch: 149/500 | train Loss: 0.009461 | EVM: 1.769
0:00:29 | Epoch: 149/500 | val Loss: 0.019

0:00:37 | Epoch: 210/500 | val Loss: 0.0263 | EVM: 0.9183
0:00:03 | Epoch: 211/500 | train Loss: 0.002504 | EVM: 1.05
0:00:37 | Epoch: 211/500 | val Loss: 0.02637 | EVM: 0.9343
0:00:03 | Epoch: 212/500 | train Loss: 0.002557 | EVM: 1.297
0:00:37 | Epoch: 212/500 | val Loss: 0.02644 | EVM: 0.9359
0:00:03 | Epoch: 213/500 | train Loss: 0.002407 | EVM: 1.291
0:00:37 | Epoch: 213/500 | val Loss: 0.02622 | EVM: 0.9404
0:00:03 | Epoch: 214/500 | train Loss: 0.002392 | EVM: 1.452
0:00:38 | Epoch: 214/500 | val Loss: 0.02694 | EVM: 0.9475
0:00:03 | Epoch: 215/500 | train Loss: 0.00241 | EVM: 1.341
0:00:37 | Epoch: 215/500 | val Loss: 0.0267 | EVM: 0.9103
0:00:03 | Epoch: 216/500 | train Loss: 0.002263 | EVM: 1.395
0:00:37 | Epoch: 216/500 | val Loss: 0.02691 | EVM: 0.9353
0:00:03 | Epoch: 217/500 | train Loss: 0.002029 | EVM: 1.006
0:00:30 | Epoch: 217/500 | val Loss: 0.02744 | EVM: 0.9606
0:00:02 | Epoch: 218/500 | train Loss: 0.002175 | EVM: 1.432
0:00:27 | Epoch: 218/500 | val Loss: 0.02684

0:00:27 | Epoch: 278/500 | val Loss: 0.03072 | EVM: 1.008
0:00:02 | Epoch: 279/500 | train Loss: 0.0002063 | EVM: 0.3822
0:00:27 | Epoch: 279/500 | val Loss: 0.03053 | EVM: 0.9862
0:00:02 | Epoch: 280/500 | train Loss: 0.0001902 | EVM: 0.3222
0:00:27 | Epoch: 280/500 | val Loss: 0.03066 | EVM: 0.9916
0:00:02 | Epoch: 281/500 | train Loss: 0.0001785 | EVM: 0.3424
0:00:27 | Epoch: 281/500 | val Loss: 0.03087 | EVM: 0.9949
0:00:02 | Epoch: 282/500 | train Loss: 0.0001799 | EVM: 0.3518
0:00:27 | Epoch: 282/500 | val Loss: 0.03081 | EVM: 0.9943
0:00:02 | Epoch: 283/500 | train Loss: 0.0001693 | EVM: 0.4039
0:00:26 | Epoch: 283/500 | val Loss: 0.03083 | EVM: 0.99
0:00:02 | Epoch: 284/500 | train Loss: 0.0001893 | EVM: 0.4397
0:00:26 | Epoch: 284/500 | val Loss: 0.03072 | EVM: 0.9924
0:00:02 | Epoch: 285/500 | train Loss: 0.0002003 | EVM: 0.3096
0:00:27 | Epoch: 285/500 | val Loss: 0.03084 | EVM: 0.9837
0:00:02 | Epoch: 286/500 | train Loss: 0.0001795 | EVM: 0.359
0:00:27 | Epoch: 286/500 | v

0:00:30 | Epoch: 346/500 | val Loss: 0.03072 | EVM: 0.9848
0:00:02 | Epoch: 347/500 | train Loss: 0.0001069 | EVM: 0.2204
0:00:30 | Epoch: 347/500 | val Loss: 0.03057 | EVM: 0.9886
0:00:02 | Epoch: 348/500 | train Loss: 8.568e-05 | EVM: 0.1499
0:00:30 | Epoch: 348/500 | val Loss: 0.03065 | EVM: 0.9886
0:00:02 | Epoch: 349/500 | train Loss: 6.085e-05 | EVM: 0.1685
0:00:30 | Epoch: 349/500 | val Loss: 0.03071 | EVM: 0.9974
0:00:02 | Epoch: 350/500 | train Loss: 5.098e-05 | EVM: 0.1693
0:00:30 | Epoch: 350/500 | val Loss: 0.03067 | EVM: 0.9892
0:00:02 | Epoch: 351/500 | train Loss: 5.267e-05 | EVM: 0.163
0:00:30 | Epoch: 351/500 | val Loss: 0.03077 | EVM: 0.9956
0:00:02 | Epoch: 352/500 | train Loss: 5.08e-05 | EVM: 0.1669
0:00:31 | Epoch: 352/500 | val Loss: 0.0307 | EVM: 0.9871
0:00:02 | Epoch: 353/500 | train Loss: 4.941e-05 | EVM: 0.1733
0:00:30 | Epoch: 353/500 | val Loss: 0.03078 | EVM: 0.9945
0:00:02 | Epoch: 354/500 | train Loss: 4.815e-05 | EVM: 0.2034
0:00:30 | Epoch: 354/500 | 

0:00:02 | Epoch: 414/500 | train Loss: 8.386e-05 | EVM: 0.2252
0:00:26 | Epoch: 414/500 | val Loss: 0.0302 | EVM: 0.978
0:00:02 | Epoch: 415/500 | train Loss: 8.706e-05 | EVM: 0.2195
0:00:26 | Epoch: 415/500 | val Loss: 0.03034 | EVM: 0.9723
0:00:02 | Epoch: 416/500 | train Loss: 0.0001023 | EVM: 0.2337
0:00:26 | Epoch: 416/500 | val Loss: 0.03033 | EVM: 0.9933
0:00:02 | Epoch: 417/500 | train Loss: 9.48e-05 | EVM: 0.2099
0:00:26 | Epoch: 417/500 | val Loss: 0.03026 | EVM: 0.9873
0:00:02 | Epoch: 418/500 | train Loss: 8.876e-05 | EVM: 0.2116
0:00:27 | Epoch: 418/500 | val Loss: 0.03007 | EVM: 0.977
0:00:02 | Epoch: 419/500 | train Loss: 0.0001133 | EVM: 0.2474
0:00:26 | Epoch: 419/500 | val Loss: 0.03019 | EVM: 0.9695
0:00:02 | Epoch: 420/500 | train Loss: 0.0001107 | EVM: 0.2119
0:00:27 | Epoch: 420/500 | val Loss: 0.03047 | EVM: 0.9898
0:00:02 | Epoch: 421/500 | train Loss: 0.0001033 | EVM: 0.234
0:00:27 | Epoch: 421/500 | val Loss: 0.03036 | EVM: 0.9889
0:00:02 | Epoch: 422/500 | tr

0:00:27 | Epoch: 481/500 | val Loss: 0.02972 | EVM: 0.9674
0:00:02 | Epoch: 482/500 | train Loss: 0.000125 | EVM: 0.2019
0:00:27 | Epoch: 482/500 | val Loss: 0.02956 | EVM: 0.9751
0:00:02 | Epoch: 483/500 | train Loss: 0.0001343 | EVM: 0.3228
0:00:28 | Epoch: 483/500 | val Loss: 0.02969 | EVM: 0.9641
0:00:02 | Epoch: 484/500 | train Loss: 0.0001073 | EVM: 0.3159
0:00:27 | Epoch: 484/500 | val Loss: 0.02963 | EVM: 0.9672
0:00:02 | Epoch: 485/500 | train Loss: 0.0001083 | EVM: 0.2587
0:00:27 | Epoch: 485/500 | val Loss: 0.02964 | EVM: 0.9611
0:00:02 | Epoch: 486/500 | train Loss: 0.0001114 | EVM: 0.234
0:00:26 | Epoch: 486/500 | val Loss: 0.02962 | EVM: 0.9739
0:00:02 | Epoch: 487/500 | train Loss: 9.853e-05 | EVM: 0.1979
0:00:27 | Epoch: 487/500 | val Loss: 0.02964 | EVM: 0.9661
0:00:02 | Epoch: 488/500 | train Loss: 9.962e-05 | EVM: 0.2766
0:00:27 | Epoch: 488/500 | val Loss: 0.02953 | EVM: 0.9726
0:00:02 | Epoch: 489/500 | train Loss: 9.77e-05 | EVM: 0.1986
0:00:27 | Epoch: 489/500 | 