In [5]:
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# Initialize the model
model = SimpleModel()

# Define a hook function
def forward_hook(module, input, output):
    print(f"Inside {module.__class__.__name__}")
    print(f"Input: {input}")
    print(f"Output: {output}")

# Register the forward hook
hook_handle = model.fc.register_forward_hook(forward_hook)

# Perform a forward pass
x = torch.randn(1, 10)
output = model(x)

# Remove the hook
hook_handle.remove()


Inside Linear
Input: (tensor([[ 0.4914, -0.4579,  0.1160,  0.8799,  1.3486,  0.5487,  0.1196,  0.1384,
         -1.3128,  0.4260]]),)
Output: tensor([[-0.0956]], grad_fn=<AddmmBackward0>)


In [7]:
import dnnlib
from torch_utils.download_util import check_file_by_key
import pickle
import torch
import os

In [3]:
model_path, classifier_path = check_file_by_key('cifar10')
with dnnlib.util.open_url(model_path) as f:
    net = pickle.load(f)['ema'].to(torch.device('cuda'))
net.sigma_min = 0.002
net.sigma_max = 80.0

Model already exists: ../amed-solver-main/src/cifar10/edm-cifar10-32x32-uncond-vp.pkl


In [4]:
net

EDMPrecond(
  (model): SongUNet(
    (map_noise): PositionalEmbedding()
    (map_augment): Linear()
    (map_layer0): Linear()
    (map_layer1): Linear()
    (enc): ModuleDict(
      (32x32_conv): Conv2d()
      (32x32_block0): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
        (skip): Conv2d()
      )
      (32x32_block1): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
      )
      (32x32_block2): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
      )
      (32x32_block3): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
      )
      (16x16_down): UNetBlock(
        (norm0): GroupNorm

In [11]:
predictor_path = '00001'
if not predictor_path.endswith('pkl'):      # load by experiment number
    # find the directory with trained AMED predictor
    predictor_path_str = '0' * (5 - len(predictor_path)) + predictor_path
    for file_name in os.listdir("exps"):
        if file_name.split('-')[0] == predictor_path_str:
            file_list = [f for f in os.listdir(os.path.join('exps', file_name)) if f.endswith("pkl")]
            max_index = -1
            max_file = None
            for ckpt_name in file_list:
                file_index = int(ckpt_name.split("-")[-1].split(".")[0])
                if file_index > max_index:
                    max_index = file_index
                    max_file = ckpt_name
            predictor_path = os.path.join('exps', file_name, max_file)
            break
print(f'Loading AMED predictor from "{predictor_path}"...')
with dnnlib.util.open_url(predictor_path, verbose=True) as f:
    AMED_predictor = pickle.load(f)['model']

Loading AMED predictor from "./exp/00001-cifar10-4-5-amed-heun-1-uni1.0-afs/network-snapshot-000010.pkl"...


In [27]:
AMED_predictor.scale_dir


0.01

In [21]:
from torchinfo import summary
# print(AMED_predictor)
summary(AMED_predictor, input_size=[(8, 8, 8), (1,1,1,1), (1,1,1,1)])

Layer (type:depth-idx)                   Output Shape              Param #
AMED_predictor                           [8, 1]                    --
├─PositionalEmbedding: 1-1               [1, 8]                    --
├─Linear: 1-2                            [1, 8]                    (72)
├─PositionalEmbedding: 1-3               [1, 8]                    --
├─Linear: 1-4                            [1, 8]                    (recursive)
├─Linear: 1-5                            [8, 128]                  (8,320)
├─Linear: 1-6                            [8, 4]                    (516)
├─Linear: 1-7                            [8, 1]                    (21)
├─Sigmoid: 1-8                           [8, 1]                    --
├─Linear: 1-9                            [8, 1]                    (21)
├─Sigmoid: 1-10                          [8, 1]                    --
Total params: 8,950
Trainable params: 0
Non-trainable params: 8,950
Total mult-adds (M): 0.07
Input size (MB): 0.00
Forward/backward