In [1]:
import torch
import random
import numpy as np

from model import LSTM
from trainer import BenchmarkTrainer
from dataloader import GraphEnv, DataLoader

In [2]:
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_random_seed(70)

In [3]:
n_nodes = 9
n_obs = 9
trajectory_length = 16  # numer of node visits in a trajectory
num_desired_trajectories= 30

env = GraphEnv(
    n_items=n_nodes,                     # number of possible observations
    env='grid', 
    batch_size=trajectory_length, 
    num_desired_trajectories=num_desired_trajectories, 
    device=None, 
    unique=True,                         # each state is assigned a unique observation if true
    args = {"rows": 3, "cols": 3}
)

train_dataset = env.gen_dataset()
test_dataset = env.gen_dataset()

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [4]:
hidden_dim = 10
epochs = 10

model = LSTM(env.n_items, env.n_actions, env.size, hidden_dim)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

trainer = BenchmarkTrainer(
    model,
    train_dataloader,
    torch.optim.Adam(model.parameters()),
    torch.nn.CrossEntropyLoss()
)

lstm = trainer.train(epochs)

1499


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]