In [2]:
import torch

# Check for CUDA!
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Currently training on: {device}")

Currently training on: cuda


In [3]:
################################################################################
###########                      PARAMETERS                   ##################
################################################################################

BATCH_SIZE = 512
WINDOW_SIZE = 255

# Trainning parameters
DROPOUT = 0.00
CLIP = -1
N_EPOCHS = 10
LEARNING_RATE = 3e-4

# Network parameters
INPUT_SIZE = 1
OUTPUT_SIZE = 1
CHANNEL_SIZES = [32] * 4
KERNEL_SIZE = 16

In [5]:
################################################################################
###########                   Import Dataset(s)               ##################
################################################################################
from os import walk, path
from data.file_handler import filter_filenames, read_file
from data.processing import parse_str
from config import BASE, DATASET_DIR, DATA_REG, LABEL_REG
from torch.utils.data import DataLoader
from data.dataset import CustomDataset

# Get paths
filenames = next(walk(path.join(path.abspath('../'), "dataset")), (None, None, []))[2]  # extract all files from dataset folder
# file_paths = ["./dataset/data64QAM.txt", "./dataset/OSC_sync_291.txt", "./dataset/OSC_sync_292.txt", "./dataset/OSC_sync_293.txt"]
input_filenames = filter_filenames(filenames, DATA_REG) # filter the files, we only want output signals
input_filenames.sort()
label_filenames = filter_filenames(filenames, LABEL_REG)
label_filenames.sort()

array_data = []
array_labels = []

# We are specifally training on "OSC_sync_471.txt"
if input_filenames.count("OSC_sync_471.txt"):
  array_data = parse_str(
    read_file(path.join(path.abspath('../'), DATASET_DIR, "OSC_sync_471.txt"))
  )
else:
  print("OSC_sync_471.txt not found!")

if label_filenames.count("data64QAM.txt"):
  array_labels = parse_str(
    read_file(path.join(path.abspath('../'), DATASET_DIR, "data64QAM.txt")
    )
  )
else:
  print("data64QAM.txt not found!")

dataset = CustomDataset("../dataset", "OSC_sync_471.txt","data64QAM.txt", 255)

# Split into train and validation sets
train_size = int(0.6 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# dataset = CustomDataset_1D(array_data, array_labels, 255, train=True, train_ratio=0.6)
# dataloader = DataLoader(dataset, batch_size=512, shuffle=True)

# validation_dataset = CustomDataset_1D(
#   array_data, array_labels, 255, train=False, train_ratio=0.6
# )
# validation_dataloader = DataLoader(
#   validation_dataset, batch_size=512
# )

# train_count = 0
# flag = False
# print(f"Training dataloader iterations: {len(dataloader)}")
# print(f"Validation dataloader iterations: {len(validation_dataloader)}")
# for idx, data in enumerate(dataloader):
#     datas = data[0]
#     labels = data[1]
#     if not flag:
#       print("Data shape:", datas.shape)
#       print("Label shape:", labels.shape)
#       flag = True
#     train_count += datas.shape[0]
#     # break
# print(f"Training dataset count: {train_count}")
# test_count = 0
# flag = False
# for idx, data in enumerate(validation_dataloader):
#     datas = data[0]
#     labels = data[1]
#     if not flag:
#       print("Data shape:", datas.shape)
#       print("Label shape:", labels.shape)
#       flag = True
#     test_count += datas.shape[0]
#     # break
# print(f"Validation dataset count: {test_count}")
# print(f"Total count: {train_count + test_count}")


# Training

In [6]:
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

from nn.TCNN1 import BiTCN

model = BiTCN(
  INPUT_SIZE,
  OUTPUT_SIZE,
  CHANNEL_SIZES,
  KERNEL_SIZE,
  seq_len=WINDOW_SIZE,
  dropout=DROPOUT
)
model = model.to(device)

print(model)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(
  model.parameters(),
  lr = LEARNING_RATE
)
scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

def train(model, device,train_loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for X_batch, y_batch in tqdm(train_loader, desc="Training", leave=False):
        optimizer.zero_grad()
        output = model(X_batch.to(device))
        loss = criterion(output, y_batch.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        average_loss = total_loss / len(train_loader)
    return average_loss

def validate(model, device, val_loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for X_batch, y_batch in tqdm(val_loader, desc="Validation", leave=False):
            output = model(X_batch.to(device))
            loss = criterion(output, y_batch.to(device))
            total_loss += loss.item()
        average_loss = total_loss / len(val_loader)
    return average_loss

# Train the model
train_losses = []
val_losses = []

best_val_loss = float('inf')

for epoch in range(N_EPOCHS):
    print(f"Starting epoch {epoch + 1}/{N_EPOCHS}")

    train_loss = train(model, device, train_loader, optimizer, criterion)
    val_loss = validate(model, device, val_loader, criterion)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    scheduler.step()

    print(f"Epoch [{epoch + 1}/{N_EPOCHS}], Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        # torch.save(model.state_dict(), 'BiTCN_best_model.pth')
        print(f"Saved model with validation loss: {best_val_loss:.4f}")



  WeightNorm.apply(module, name, dim)


BiTCN(
  (tcn): TemporalConvNet(
    (network): Sequential(
      (0): TemporalBlock(
        (conv1): Conv1d(1, 32, kernel_size=(16,), stride=(1,), padding=(15,))
        (chomp1): Chomp1D()
        (relu1): PReLU(num_parameters=1)
        (dropout1): Dropout(p=0.0, inplace=False)
        (conv2): Conv1d(32, 32, kernel_size=(16,), stride=(1,), padding=(15,))
        (chomp2): Chomp1D()
        (relu2): PReLU(num_parameters=1)
        (dropout2): Dropout(p=0.0, inplace=False)
        (net): Sequential(
          (0): Conv1d(1, 32, kernel_size=(16,), stride=(1,), padding=(15,))
          (1): Chomp1D()
          (2): PReLU(num_parameters=1)
          (3): Dropout(p=0.0, inplace=False)
          (4): Conv1d(32, 32, kernel_size=(16,), stride=(1,), padding=(15,))
          (5): Chomp1D()
          (6): PReLU(num_parameters=1)
          (7): Dropout(p=0.0, inplace=False)
        )
        (downsample): Conv1d(1, 32, kernel_size=(1,), stride=(1,))
        (relu): PReLU(num_parameters=1)
    

                                                             

Epoch [1/10], Train Loss: 0.4683, Validation Loss: 0.0366
Saved model with validation loss: 0.0366
Starting epoch 2/10


                                                             

Epoch [2/10], Train Loss: 0.0248, Validation Loss: 0.0211
Saved model with validation loss: 0.0211
Starting epoch 3/10


                                                             

Epoch [3/10], Train Loss: 0.0173, Validation Loss: 0.0154
Saved model with validation loss: 0.0154
Starting epoch 4/10


                                                             

Epoch [4/10], Train Loss: 0.0141, Validation Loss: 0.0142
Saved model with validation loss: 0.0142
Starting epoch 5/10


                                                             

Epoch [5/10], Train Loss: 0.0137, Validation Loss: 0.0139
Saved model with validation loss: 0.0139
Starting epoch 6/10


                                                             

Epoch [6/10], Train Loss: 0.0134, Validation Loss: 0.0136
Saved model with validation loss: 0.0136
Starting epoch 7/10


                                                             

Epoch [7/10], Train Loss: 0.0132, Validation Loss: 0.0135
Saved model with validation loss: 0.0135
Starting epoch 8/10


                                                           

KeyboardInterrupt: 