In [1]:
import os
import torch
import pickle
import numpy as np
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import random_split, Dataset, DataLoader
torch.set_float32_matmul_precision('high')

In [2]:
class HandwritingDataset(Dataset):
    def __init__(self, x, y, z):
        self.x = x.clone().detach()
        self.y = y.clone().detach()
        self.z = z.clone().detach()
        
    def __getitem__(self, index):
        return self.x[index], self.y[index], self.z[index]
    
    def __len__(self):
        return len(self.x)

class HandwritingDataModule(pl.LightningDataModule):
    def __init__(self, dataDict, batch_size: int = 32):
        super(HandwritingDataModule, self).__init__()
        self.dataDict = dataDict
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.x = torch.tensor(self.dataDict['inputs'], dtype=torch.float)
        self.y = torch.tensor(self.dataDict['charLabels'], dtype=torch.float)
        self.z = torch.tensor(self.dataDict['charStarts'], dtype=torch.float)
        
        data_full = HandwritingDataset(self.x, self.y, self.z)
        
        if stage == "fit":
            self.train_dataset, self.val_dataset = random_split(data_full, [0.9, 0.1])
        
        if stage == "test":
            self.test_dataset = data_full

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=24)

#     def val_dataloader(self):
#         return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=24)

#     def test_dataloader(self):
#         return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=24)

In [3]:
with open('handwriting.dat', 'rb') as file:
    dataDict = pickle.load(file)
    
print(dataDict.keys())

dict_keys(['inputs', 'charLabels', 'charStarts'])


In [8]:
class HandwritingGRU(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_chars, reg_strength, lr):
        super(HandwritingGRU, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, num_layers=2, batch_first=True)
        self.fc_y = nn.Linear(hidden_size, num_chars)
        self.fc_z = nn.Linear(hidden_size, 1)
        self.reg_strength = reg_strength
        self.learning_rate = lr
        
    def forward(self, x, h=None):
        out, h = self.gru(x, h)
        y_logits = self.fc_y(out)
        z_logits = self.fc_z(out)
        yhat = torch.softmax(y_logits, dim=0)
        zhat = torch.tanh(z_logits)
        return yhat, zhat, h
    
    def training_step(self, batch, batch_idx):
        x, y, z = batch
        yhat, zhat, h = self(x)
        reg_loss = self.reg_strength * torch.norm(self.gru.weight_hh_l0) ** 2
        ce_loss = F.cross_entropy(yhat, y)
        mse_loss = F.mse_loss(yhat, y)
        loss = reg_loss + ce_loss + mse_loss
        self.log('reg_loss', reg_loss, on_step=True)
        self.log('ce_loss', ce_loss, on_step=True)
        self.log('mse_loss', mse_loss, on_step=True)
        self.log('validation_loss', loss, on_step=True)
        return loss
    
#     def validation_step(self, batch, batch_idx):
#         x, y, z = batch
#         yhat, zhat, h = self(x, h)
#         reg_loss = self.reg_strength * torch.norm(self.gru.weight_hh_l0) ** 2
#         ce_loss = F.cross_entropy(yhat, y)
#         mse_loss = F.mse_loss(yhat, y)
#         loss = reg_loss + ce_loss + mse_loss
#         self.log('reg_loss', reg_loss, on_step=True)
#         self.log('ce_loss', ce_loss, on_step=True)
#         self.log('mse_loss', mse_loss, on_step=True)
#         self.log('val_loss', loss, on_step=True)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

In [9]:
model = HandwritingGRU(input_size=192, hidden_size=512, num_chars=31, reg_strength=0.001, lr=0.01)
dm = HandwritingDataModule(dataDict)
dm.setup(stage="fit")
print(model)

HandwritingGRU(
  (gru): GRU(192, 512, num_layers=2, batch_first=True)
  (fc_y): Linear(in_features=512, out_features=31, bias=True)
  (fc_z): Linear(in_features=512, out_features=1, bias=True)
)


In [10]:
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=5)
trainer.fit(model, datamodule=dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | gru  | GRU    | 2.7 M 
1 | fc_y | Linear | 15.9 K
2 | fc_z | Linear | 513   
--------------------------------
2.7 M     Trainable params
0         Non-trainable params
2.7 M     Total params
10.707    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
