In [11]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch.utils.data import DataLoader

from etudelib.data.synthetic.synthetic import SyntheticDataset

from omegaconf import OmegaConf, DictConfig, ListConfig
from importlib import import_module
import os
import torch

In [83]:
qty_interactions = 1000
n_items = 500
max_seq_length = 15
qty_sessions = qty_interactions
batch_size = 32

args_model = 'CORE'
basedir = "../.."
config_path = os.path.join(basedir, f"etudelib/models/{args_model}/config.yaml".lower())
config = OmegaConf.load(config_path)
config['dataset'] = {}
config['dataset']['n_items'] = n_items
config['dataset']['max_seq_length'] = max_seq_length
config

{'model': {'name': 'CORE', 'embedding_size': 64, 'dnn_type': 'trm', 'sess_dropout': 0.2, 'item_dropout': 0.2, 'temperature': 0.07, 'n_layers': 2, 'n_heads': 2, 'inner_size': 256, 'hidden_dropout_prob': 0.5, 'attn_dropout_prob': 0.5, 'hidden_act': 'gelu', 'layer_norm_eps': 1e-12, 'initializer_range': 0.02}, 'optimizer': {'lr': 0.02}, 'trainer': {'accelerator': 'auto'}, 'dataset': {'n_items': 500, 'max_seq_length': 15}}

In [84]:
train_ds = SyntheticDataset(qty_interactions=qty_interactions,
                                qty_sessions=qty_sessions,
                                n_items=n_items,
                                max_seq_length=max_seq_length)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)


In [85]:
from etudelib.models.core.lightning_model import CORELightning

model = CORELightning(config)

In [86]:
print(model)

CORELightning(
  (model): COREModel(
    (sess_dropout): Dropout(p=0.2, inplace=False)
    (item_dropout): Dropout(p=0.2, inplace=False)
    (item_embedding): Embedding(500, 64, padding_idx=0)
    (net): TransNet(
      (position_embedding): Embedding(15, 64)
      (trm_encoder): TransformerEncoder(
        (layer): ModuleList(
          (0): TransformerLayer(
            (multi_head_attention): MultiHeadAttention(
              (query): Linear(in_features=64, out_features=64, bias=True)
              (key): Linear(in_features=64, out_features=64, bias=True)
              (value): Linear(in_features=64, out_features=64, bias=True)
              (softmax): Softmax(dim=-1)
              (attn_dropout): Dropout(p=0.5, inplace=False)
              (dense): Linear(in_features=64, out_features=64, bias=True)
              (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
              (out_dropout): Dropout(p=0.5, inplace=False)
            )
            (feed_forward): FeedF

In [87]:
trainer = Trainer(
        accelerator="auto",
        devices=None,
        max_epochs=1,
        callbacks=[TQDMProgressBar()],
    )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [88]:
trainer.fit(model, train_loader)


  | Name  | Type      | Params
------------------------------------
0 | model | COREModel | 133 K 
------------------------------------
133 K     Trainable params
0         Non-trainable params
133 K     Total params
0.532     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

In [89]:
eager_model = model.get_backbone()
eager_model.eval()

COREModel(
  (sess_dropout): Dropout(p=0.2, inplace=False)
  (item_dropout): Dropout(p=0.2, inplace=False)
  (item_embedding): Embedding(500, 64, padding_idx=0)
  (net): TransNet(
    (position_embedding): Embedding(15, 64)
    (trm_encoder): TransformerEncoder(
      (layer): ModuleList(
        (0): TransformerLayer(
          (multi_head_attention): MultiHeadAttention(
            (query): Linear(in_features=64, out_features=64, bias=True)
            (key): Linear(in_features=64, out_features=64, bias=True)
            (value): Linear(in_features=64, out_features=64, bias=True)
            (softmax): Softmax(dim=-1)
            (attn_dropout): Dropout(p=0.5, inplace=False)
            (dense): Linear(in_features=64, out_features=64, bias=True)
            (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
            (out_dropout): Dropout(p=0.5, inplace=False)
          )
          (feed_forward): FeedForward(
            (dense_1): Linear(in_features=64, out_featur

In [90]:
train_iter = iter(train_loader)


In [91]:
(x1, x2, y) = next(train_iter)
x1, x2

eager_model.eval()
with torch.no_grad():
    r = eager_model.forward(x1, x2)
    print(r)

torch.onnx.export(
    eager_model,
    (x1,x2),
    'core.onnx',
    input_names=['a', 'b'],
    output_names=['output']
)

tensor([[ -8.8714,   8.1044,   7.0799,  ...,   7.6284,   6.2464,  11.3567],
        [-10.9483,   8.8637,   7.2447,  ...,   8.4734,   7.3837,  10.1889],
        [ 13.6563,  -7.1146,  -2.4503,  ...,  -3.2070,  -4.9812,  -1.3354],
        ...,
        [-10.3650,   9.4327,   4.9590,  ...,   6.4376,   5.7977,   9.5832],
        [  2.1728,   3.7296,   7.2709,  ...,   3.2604,   4.3988,   9.8371],
        [  2.8921,   1.7458,   5.4790,  ...,   3.0410,   2.2261,   9.8335]])


In [92]:
import onnxruntime as ort
onnx_execution_providers = ['CPUExecutionProvider']
ort_sess = ort.InferenceSession('core.onnx', providers=onnx_execution_providers)




NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Where(9) node with name '/net/Where_2'

In [93]:
import onnx

In [94]:
onnx.load('core.onnx')

ir_version: 7
producer_name: "pytorch"
producer_version: "1.13.1"
graph {
  node {
    input: "item_embedding.weight"
    input: "a"
    output: "/item_embedding/Gather_output_0"
    name: "/item_embedding/Gather"
    op_type: "Gather"
  }
  node {
    output: "/net/Constant_output_0"
    name: "/net/Constant"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 7
        raw_data: "\000\000\000\000\000\000\000\000"
      }
      type: TENSOR
    }
  }
  node {
    input: "a"
    input: "/net/Constant_output_0"
    output: "/net/Greater_output_0"
    name: "/net/Greater"
    op_type: "Greater"
  }
  node {
    output: "/net/Constant_1_output_0"
    name: "/net/Constant_1"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        dims: 15
        data_type: 7
        raw_data: "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000\004\000\000\000\000\000\000