In [5]:
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# Initialize the model
model = SimpleModel()

# Define a hook function
def forward_hook(module, input, output):
    print(f"Inside {module.__class__.__name__}")
    print(f"Input: {input}")
    print(f"Output: {output}")

# Register the forward hook
hook_handle = model.fc.register_forward_hook(forward_hook)

# Perform a forward pass
x = torch.randn(1, 10)
output = model(x)

# Remove the hook
hook_handle.remove()


In [7]:
import dnnlib
from torch_utils.download_util import check_file_by_key
import pickle
import torch
import os

In [3]:
model_path, classifier_path = check_file_by_key('cifar10')
with dnnlib.util.open_url(model_path) as f:
    net = pickle.load(f)['ema'].to(torch.device('cuda'))
net.sigma_min = 0.002
net.sigma_max = 80.0

In [4]:
net

In [11]:
predictor_path = '00001'
if not predictor_path.endswith('pkl'):      # load by experiment number
    # find the directory with trained AMED predictor
    predictor_path_str = '0' * (5 - len(predictor_path)) + predictor_path
    for file_name in os.listdir("exps"):
        if file_name.split('-')[0] == predictor_path_str:
            file_list = [f for f in os.listdir(os.path.join('exps', file_name)) if f.endswith("pkl")]
            max_index = -1
            max_file = None
            for ckpt_name in file_list:
                file_index = int(ckpt_name.split("-")[-1].split(".")[0])
                if file_index > max_index:
                    max_index = file_index
                    max_file = ckpt_name
            predictor_path = os.path.join('exps', file_name, max_file)
            break
print(f'Loading AMED predictor from "{predictor_path}"...')
with dnnlib.util.open_url(predictor_path, verbose=True) as f:
    AMED_predictor = pickle.load(f)['model']

In [27]:
AMED_predictor.scale_dir


In [21]:
from torchinfo import summary
# print(AMED_predictor)
summary(AMED_predictor, input_size=[(8, 8, 8), (1,1,1,1), (1,1,1,1)])