In [68]:
%reload_ext autoreload 
%autoreload 2
import os
import re
import json
import torch
import random
import pickle
import numpy as np
import pandas as pd

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import data_parser

In [62]:
class CustomDataset(Dataset):
    def __init__(self, case, port, line_model, input_cols, output_col, indices, device):
        # case: simulation case name in ../Data/
        # port: the number of ports in the line model
        # line_model: name of line model in Tml Simulation
        # input_cols: ['W', 'Trap', 'Length'], for example
        # outpul_col: ['A(1,2)'] or ['P(2,4)'], for example
        
        self.case = case
        self.port = port
        self.home_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
        self.result_dir = os.path.join(self.home_dir, 'Data', '%s' % self.case)
        self.config_dir = os.path.join(self.home_dir, 'Data', 'Config', '%s.json' % line_model)
        
        # define df headers
        self.snp_headers = []
        for i in range(port):
            for j in range(port):
                for s in 'SR', 'SI':
                    self.snp_headers.append('%s(%d,%d)' % (s, i+1, j+1))
        
        # define input cols and parameters input
        self.input_cols = input_cols
        with open(self.config_dir) as f:
            self.config = json.load(f)
        self.parameters_df = data_parser.read_input_feature_xlsx(case, port).loc[indices][input_cols]
        for input_col in input_cols:
            self.parameters_df[input_col] = (self.parameters_df[input_col] - self.config[input_col]['min']) / (self.config[input_col]['max'] - self.config[input_col]['min'])
        
        # define output col
        # if output col is P(i,j), output is [sinP, cosP]
        self.output_col = output_col
        self.index = [eval(i) for i in re.findall(r"(\d)", self.output_col)] # get output indices
        if 'A' in self.output_col:
            self.index = ['A'] + self.index
        elif 'P' in self.output_col:
            self.index = ['P'] + self.index
        elif 'SR' in self.output_col:
            self.index = ['SR'] + self.index
        elif 'SI' in self.output_col:
            self.index = ['SI'] + self.index
        else:
            raise ValueError('Output col must be A(i,j) or P(i,j)!')
            
        # define device
        self.device = device

    def __len__(self):
        return len(self.parameters_df.index)

    def __getitem__(self, idx):
        
        idx = self.parameters_df.index[idx]
        
        df = pd.read_csv(os.path.join(self.result_dir, idx, 'RLGC', 'TransmissionLine.s%dp' % self.port), skiprows=port+3, delim_whitespace=True, header=None).loc[:, 1:]
        df.columns = self.snp_headers
        
        # process input
        parameters = torch.Tensor(self.parameters_df.loc[idx])
        
        # process output
        if self.index[0] == 'A':
            output = np.sqrt(df['SR(%d,%d)'% (self.index[1], self.index[2])] ** 2 + df['SI(%d,%d)'% (self.index[1], self.index[2])] ** 2)
            output = torch.Tensor(output)
        elif self.index[0] == 'P':
            phase = np.arctan2(df['SI(%d,%d)'% (self.index[1], self.index[2])] , df['SR(%d,%d)'% (self.index[1], self.index[2])])
            sinP = np.sin(phase)
            cosP = np.cos(phase)
            output = torch.Tensor([sinP, cosP])
        elif self.index[0] == 'SR':
            output = torch.Tensor(df['SR(%d,%d)'% (self.index[1], self.index[2])])
        else: # self.index[0] == 'SI'
            output = torch.Tensor(df['SI(%d,%d)'% (self.index[1], self.index[2])])
        return parameters.to(self.device), output.to(self.device)

def generate_indices(case, port,):
    
    current_dir = os.getcwd()
    home_dir = os.path.abspath(os.path.join(current_dir, '..'))
    data_dir = os.path.join(home_dir, 'Data')
    result_dir = os.path.join(data_dir, '%s' % case)
    keys = []
    for dirs in os.listdir(result_dir):
        if os.path.isdir(os.path.join(result_dir, dirs, 'RLGC')):
            snp = 'TransmissionLine.s%dp' % port            
            if os.path.exists(os.path.join(result_dir, dirs, 'RLGC', snp)):
                keys.append(dirs)
    return keys

In [63]:
case = 'Differential_Stripline_test'
port = 4
line_model = 'Stripline_Diff-Pair'
input_cols = ['W', 'Trap', 'Length']
dataname = '%s' %(case)

In [64]:
keys = generate_indices(case, port)

In [69]:
# Randomly shuffle the indices
read_idx = False
index_file = '../Data/Indices/index_%s.pkl' % dataname
if read_idx == True:
    with open(index_file, 'rb') as f:
        indices = pickle.load(f)
else:
    random.seed(42)
    np.random.shuffle(keys)
    
    # Split the indices into 80% training set, 10% testing set and 10% validation set
    indices = {}
    indices['train_idx'] = keys[:int(len(keys) * 0.8)]
    indices['val_idx'] = keys[int(len(keys) * 0.8):int(len(keys) * 0.9)]
    indices['test_idx'] = keys[int(len(keys) * 0.9):]

    with open(index_file, 'wb') as f:
        pickle.dump(indices, f)

In [74]:
train_dataset = CustomDataset(case, port, line_model, input_cols, 'A(1,1)', indices['train_idx'])
parameters, output = train_dataset.__getitem__(0)

In [75]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [77]:
for i_batch, sample in enumerate(train_dataloader):
    X, y = sample
    print(y.shape)

torch.Size([64, 501])
torch.Size([64, 501])
torch.Size([64, 501])
torch.Size([64, 501])
torch.Size([64, 501])
torch.Size([64, 501])
torch.Size([64, 501])
torch.Size([64, 501])
torch.Size([64, 501])
torch.Size([64, 501])
torch.Size([64, 501])
torch.Size([64, 501])
torch.Size([32, 501])
