In [1]:
# Import necessary modules
import sys
import os
import torch 
# Set root folder to project root
os.chdir(os.path.dirname(os.getcwd()))

# Add root folder to path
sys.path.append(os.getcwd())

from src.utils.config import Hyperparameters
from src.utils.model_loading import load_model
from src.utils.maze_loading import load_mazes
from src.utils.testing import is_correct
from src.utils.analysis import plot_mazes

In [None]:
# Load model
dt_net_original = load_model('dt_net', pretrained='models/dt_net/original.pth')
dt_net = load_model('dt_net', pretrained='models/dt_net/2025-03-12_03:43:44/best.pth')
it_net = load_model('it_net', pretrained='models/it_net/2025-03-26_23:59:57/best.pth')
model = dt_net

# Load mazes
hyperparams = Hyperparameters()
hyperparams.iters = 30
hyperparams.num_mazes = 100
hyperparams.percolation = 0.5
hyperparams.maze_size = 9
inputs, solutions = load_mazes(hyperparams)

# Predict
# predictions = model.predict(inputs, iters=hyperparams.iters)
with torch.no_grad():
    model.eval()
    latents = model.input_to_latent(inputs)
    latents = model.latent_forward(latents, inputs, iters=hyperparams.iters)
    #latents = model.latent_forward_layer(torch.cat([latents, inputs], dim=1))
    outputs = model.latent_to_output(latents)
    predictions = model.output_to_prediction(outputs, inputs)

# Evaluate predictions
corrects = is_correct(inputs, predictions, solutions)
print(f'{corrects.sum()} out of {len(corrects)} predictions are correct.')
print(f'outputs/visuals/mazes/predictions_{model.name}_size-{hyperparams.maze_size}_iters-{hyperparams.iters}')
# Plot results
plot_mazes(
    inputs,
    solutions,
    predictions,
    file_name=f'outputs/visuals/mazes/predictions_{model.name}_size-{hyperparams.maze_size}_iters-{hyperparams.iters}'
)

2025-03-29 15:11:17,000 - src.utils.model_loading - INFO - Loaded dt_net to cuda:0
2025-03-29 15:11:17,011 - src.utils.model_loading - INFO - Loaded dt_net to cuda:0
2025-03-29 15:11:17,033 - src.utils.model_loading - INFO - Loaded it_net to cuda:0
2025-03-29 15:11:17,035 - src.utils.maze_loading - INFO - Attempting 100 mazes to generate 100 mazes with size: 9, percolation: 0.5, and deadend_start: True
2025-03-29 15:11:17,399 - src.utils.maze_loading - INFO - Attempting 200 mazes to generate 100 mazes with size: 9, percolation: 0.5, and deadend_start: True
2025-03-29 15:11:18,085 - src.utils.maze_loading - INFO - Loaded 100 mazes with size: 9, percolation: 0.5, and deadend_start: True


0.0998
0.2405
-0.8036
0.0833
100 out of 100 predictions are correct.
outputs/visuals/mazes/predictions_dt_net_size-9_iters-30


In [3]:
# Print summary of model modules
print("Model Modules:")
for name, module in model.named_modules():
    print(f"{name}: {module}")

# Print summary of model weights, including number of parameters and norm of weight vector
print("\nModel Weights Summary:")
total_params = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        num_params = param.numel()
        total_params += num_params
        weight_norm = torch.norm(param).item()
        print(f"{name}: {num_params} parameters, weight norm: {weight_norm:.4f}")
print(f"\nTotal number of trainable parameters in the model: {total_params}")
print(f"Average number of parameters per layer: {total_params / len(list(model.named_parameters())):.2f}")

Model Modules:
: DTNet(
  (projection): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): ReLU()
  )
  (recur_block): Sequential(
    (0): Conv2d(131, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (gn1): Sequential()
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (gn2): Sequential()
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (gn1): Sequential()
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (gn2): Sequential()
        (shortcut): Sequential()
      )
    )
  )
  (head): Sequential(
    (0): Conv2d(128, 32, kernel_size=(3, 3