In [1]:
import torch
from hydra import compose, initialize
import omegaconf

from model.builders import build_model
from utils.run_utils import configure_reproducibility

In [2]:
initialize(config_path="../conf/local", job_name="debug_dl")
cfg = compose(config_name="test_model_mnist")

In [3]:
cfg.run.wandb_mode='disabled'

In [4]:
print(omegaconf.OmegaConf.to_yaml(cfg))

run:
  device: cpu
  loglevel: INFO
  project_path: /media/shift97/MyPassport/Flavio/repos/seq-mnist
  codename: default_local
  seed: 2147483647
  wandb_mode: disabled
data:
  num_train: 3000
  num_test: 10000
  permute: true
train:
  batch_size: 4
  perc_valid: 0.1
model:
  name: dntm
  n_locations: 1000
  content_size: 8
  address_size: 8
  controller_input_size: 1
  controller_output_size: 10
  controller_hidden_state_size: 100
  ckpt: ${run.project_path}/models/checkpoints/dntm_trained_pmnist_02-06-22.pth



In [5]:
rng = configure_reproducibility(cfg.run.seed)

In [6]:
device = torch.device(cfg.run.device, 0)

In [7]:
model = build_model(cfg, device)

In [8]:
for name, param in model.named_parameters():
    print(name)
    print('-'*len(name))
    print(param.max().item())
    print(param.mean().item(), '+/-', param.std().item())
    print(param.min().item())
    print()

W_output
--------
1.713209867477417
-0.001367107848636806 +/- 0.37482231855392456
-1.5997358560562134

b_output
--------
0.144056037068367
0.10458298027515411 +/- 0.02648264914751053
0.05537264049053192

memory.memory_addresses
-----------------------
0.25264522433280945
-0.0022635748609900475 +/- 0.08650892227888107
-0.3241351842880249

memory.W_hat_hidden
-------------------
0.9974824786186218
0.0021850443445146084 +/- 0.16166022419929504
-0.9490581154823303

memory.W_query
--------------
0.5995498299598694
0.0016025464283302426 +/- 0.13604740798473358
-0.5335453748703003

memory.b_query
--------------
0.37084341049194336
-0.12415798008441925 +/- 0.2623516023159027
-0.6728257536888123

memory.u_sharpen
----------------
0.5539348721504211
-0.021196791902184486 +/- 0.19362500309944153
-0.44481608271598816

memory.b_sharpen
----------------
-0.05614190176129341
-0.05614190176129341 +/- nan
-0.05614190176129341

memory.b_lru
------------
0.1679871380329132
0.1679871380329132 +/- nan
0.16

In [9]:
from data.perm_seq_mnist import get_dataloaders
from torchmetrics.classification import Accuracy

In [10]:
train_dl, valid_data_loader = get_dataloaders(cfg, rng)

In [13]:
loss_fn = torch.nn.NLLLoss()

In [14]:
valid_accuracy = Accuracy().to(device)
valid_epoch_loss = 0
model.eval()
for batch_i, (mnist_images, targets) in enumerate(valid_data_loader):
    print("batch", batch_i)
    model.prepare_for_batch(mnist_images, device)

    mnist_images, targets = mnist_images.to(device), targets.to(device)

    _, outputs = model(mnist_images)
    output = outputs[-1, :, :]
    print(output.T.argmax(dim=1))
    loss_value = loss_fn(output.T, targets)
    valid_epoch_loss += loss_value.item() * mnist_images.size(0)
    print("loss:", loss_value)

    batch_accuracy = valid_accuracy(output.T, targets)
valid_accuracy_at_epoch = valid_accuracy.compute()
valid_epoch_loss /= len(valid_data_loader.sampler)
print(valid_accuracy_at_epoch)
print(valid_epoch_loss)

batch 0
tensor([1, 1, 1, 1])
loss: tensor(5.7175, grad_fn=<NllLossBackward0>)
batch 1
tensor([1, 1, 2, 1])
loss: tensor(6.6564, grad_fn=<NllLossBackward0>)
batch 2


KeyboardInterrupt: 