# RNN

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [None]:
import pickle
from typing import List, Tuple, Union

import torch
from torch import nn, Tensor
from torch.nn.utils.rnn import PackedSequence, pack_sequence
from torch.utils.data import DataLoader, random_split

from datautils import SeqDataset
from model_house import LSTM, GRU
from trainutils import device, train, prediction

## Hyer-parameters

In [None]:
batch_size = 16
learning_rate = 1e-4
trial_name = "rnn-demo"


## Define Dataset

In [None]:
with open("./scaler.skl", "rb") as fp:
    scaler = pickle.load(fp)

train_dataset = SeqDataset(
    feat_dir="./data/libriphone/feat/train",
    split_filepath="./data/libriphone/train_split.txt",
    labels_filepath="./data/libriphone/train_labels.txt",
    scaler=scaler,
)
test_dataset = SeqDataset(
    feat_dir="./data/libriphone/feat/test", split_filepath="./data/libriphone/test_split.txt", scaler=scaler
)

train_len = int(len(train_dataset) * 0.8)
valid_len = len(train_dataset) - train_len
train_dataset, valid_dataset = random_split(train_dataset, [train_len, valid_len])


In [None]:
def collate(
    batch: Union[List[Tuple[Tensor, Tensor]], List[Tensor]]
) -> Union[List[Tuple[PackedSequence, Tensor]], List[PackedSequence]]:
    if isinstance(batch[0], tuple):
        features, labels = zip(*batch)
        return pack_sequence(features, enforce_sorted=False), torch.cat(labels)
    else:
        return pack_sequence(batch, enforce_sorted=False)


train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate
)
valid_dataloader = DataLoader(
    valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate
)


## Define Model

In [None]:
model = GRU(embed_size=39, hidden_size=256, num_layers=3, num_classes=41).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [None]:
train(train_dataloader, valid_dataloader, model, criterion, optimizer, 5000, 100, trial_name)

## Prediction

In [None]:
test_dataloader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate, pin_memory=True
)


model.load_state_dict(torch.load(f"./models/{trial_name}.ckpt"))
prediction(test_dataloader, model)
