In [2]:
import os
import numpy as np
import pandas as pd
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
from utils import get_data, preprocess, split_dataset
from model import PRnet

In [3]:
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
batch_size = 256
input_size = 1 
hidden_size = 32
num_heads = 2
lr = 0.001
epochs = 40
eval_every=1

## Data Loading and Preprocessing

In [6]:
data_list = get_data()
total_data = preprocess(data_list)
total_data =split_dataset(total_data)
datasets, loaders={},{}
for datatype in ['train', 'eval', 'test']:
    datasets[datatype] = data.TensorDataset(torch.from_numpy(total_data[datatype][:,:-1]),torch.from_numpy(total_data[datatype][:, -1]))
    loaders[datatype] = data.DataLoader(datasets[datatype], batch_size=batch_size, shuffle=True)
    num = len(datasets[datatype])
    print(f'{datatype} samples:  {num}')
print('Data Loaded')

train samples:  17213
eval samples:  2459
test samples:  4918
Data Loaded


## Model Loading Training

In [7]:
model = PRnet(input_size, hidden_size, num_heads).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_function = nn.CrossEntropyLoss()

In [12]:
model_save_path = 'checkpoints/prnet.pkl'
val_loss_best = np.inf
for ep in range(epochs):
    with tqdm(total=len(loaders['train'].dataset), desc=f"[Epoch {ep+1:3d}/{epochs}]") as pbar:
        running_loss=0
        model.train()
        for idx_batch, (x, y) in enumerate(loaders['train']):
            optimizer.zero_grad()
            x, y = x.unsqueeze(2).float().to(device), y.long().to(device)    
            pred = model(x)
            loss = loss_function(pred, y)
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix({'loss': running_loss/(idx_batch+1)})
            pbar.update(x.shape[0])
        train_loss = running_loss/len(loaders['train'])
        
        if ep % eval_every==0:
            running_loss = 0
            correct = 0
            model.eval()
            for idx_batch, (x, y) in enumerate(loaders['eval']):
                x, y = x.unsqueeze(2).float().to(device), y.long().to(device)     
                pred = model(x)
                loss = loss_function(pred, y)
                running_loss += loss.item()
                pred =  pred.argmax(dim=-1)
                correct += (pred==y).sum()

            val_loss = running_loss/len(loaders['eval'])
            val_acc = correct/len(datasets['eval'])
        
        pbar.set_postfix({'loss': train_loss, 'val_loss': val_loss, 'val acc': val_acc})
        
        if val_loss < val_loss_best:
            val_loss_best = val_loss
            torch.save(model.state_dict(), model_save_path) 
    

[Epoch   1/40]: 100%|██████████| 17213/17213 [00:02<00:00, 7841.66it/s, loss=0.988, val_loss=0.894, val acc=tensor(0.6145, device='cuda:0')]
[Epoch   2/40]: 100%|██████████| 17213/17213 [00:02<00:00, 7672.06it/s, loss=0.73, val_loss=0.667, val acc=tensor(0.8691, device='cuda:0')]
[Epoch   3/40]: 100%|██████████| 17213/17213 [00:02<00:00, 7950.95it/s, loss=0.606, val_loss=0.559, val acc=tensor(0.9947, device='cuda:0')]
[Epoch   4/40]: 100%|██████████| 17213/17213 [00:02<00:00, 8224.09it/s, loss=0.558, val_loss=0.555, val acc=tensor(0.9972, device='cuda:0')]
[Epoch   5/40]: 100%|██████████| 17213/17213 [00:02<00:00, 8073.94it/s, loss=0.556, val_loss=0.554, val acc=tensor(0.9980, device='cuda:0')]
[Epoch   6/40]: 100%|██████████| 17213/17213 [00:02<00:00, 8007.61it/s, loss=0.555, val_loss=0.554, val acc=tensor(0.9980, device='cuda:0')]
[Epoch   7/40]: 100%|██████████| 17213/17213 [00:02<00:00, 8207.56it/s, loss=0.555, val_loss=0.554, val acc=tensor(0.9980, device='cuda:0')]
[Epoch   8/40]

## Inference

In [10]:
model.load_state_dict(torch.load('checkpoints/prnet.pkl'))
model.eval()
correct = 0
for idx_batch, (x, y) in enumerate(loaders['test']):
    x, y = x.unsqueeze(2).float().to(device), y.long().to(device)     
    pred = model(x)
    pred =  pred.argmax(dim=-1)
    correct += (pred==y).sum()
test_acc = correct/len(datasets['test'])
print(f'test acc: {test_acc}')

test acc: 0.9973565936088562


In [186]:
print(chr(int('0xb8', 16)))

¸
