In [1]:
from ani2x import np, torch, OrderedDict, load_pretrained, get_consts_ani2x, init_params, torchani

In [2]:
def get_layer_sizes():
    layer_sizes = list()
    layer_size = 256
    for i in range(7, -1, -1):
        layer_sizes.append(int(layer_size))
        layer_size *= i / (i+1)
    return layer_sizes

def load_pretrained_frozen(id_=0,):
    model_orig = load_pretrained(id_=id_)
    model_new = OrderedDict()
    consts_ani2x = get_consts_ani2x()
    layer_sizes = get_layer_sizes()
    n_layers_add = 2
    for i in consts_ani2x.species:
        # Get original neural network
        nn_orig = model_orig[i]
        # Freeze original
        for parameter in nn_orig[:-1].parameters():
            parameter.requires_grad_(False)
        # Get the size of the last layer of the original
        last_size = nn_orig[-1].in_features
        # Get sizes of the layers to add
        li = layer_sizes.index(last_size)
        s = layer_sizes[li:li+n_layers_add+1]
        # Make the layers to add (this is tied to 2 layers, n_layers_add = 2)
        nn_add = torch.nn.Sequential(torch.nn.Linear(s[0], s[1]), torch.nn.CELU(alpha=0.1), 
                    torch.nn.Linear(s[1], s[2]), torch.nn.CELU(alpha=0.1), 
                    torch.nn.Linear(s[2], 1))
        # Initialize parameters according to Meli's initialization (optional)
        nn_add.apply(init_params)
        # Get the new neural network
        nn_new = torch.nn.Sequential(*nn_orig[:-1], *nn_add)
        # Assert frozen layers
        assert not np.array([parameter.requires_grad for parameter in nn_new[:6].parameters()]).any()
        # Keep the last singleton layer information from the original
        with torch.no_grad():
            nn_new[6].weight[0] = nn_orig[-1].weight[0]
            nn_new[6].bias[0] = nn_orig[-1].bias[0]
        model_new[i] = nn_new
    model_new = torchani.ANIModel(model_new)
    return model_new

In [3]:
model = load_pretrained_frozen()

In [4]:
model

ANIModel(
  (H): Sequential(
    (0): Linear(in_features=1008, out_features=256, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=256, out_features=192, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=192, out_features=160, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=160, out_features=128, bias=True)
    (7): CELU(alpha=0.1)
    (8): Linear(in_features=128, out_features=96, bias=True)
    (9): CELU(alpha=0.1)
    (10): Linear(in_features=96, out_features=1, bias=True)
  )
  (C): Sequential(
    (0): Linear(in_features=1008, out_features=224, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=224, out_features=192, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=192, out_features=160, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=160, out_features=128, bias=True)
    (7): CELU(alpha=0.1)
    (8): Linear(in_features=128, out_features=96, bias=True)
    (9): CELU(alpha=0.1)
    (10): Linear(i

In [None]:
# [f for f in model_pre['H'][:5].named_parameters()]
# model_pre['H'][:5].state_dict()
# layer = torch.nn.Linear(in_features=160, out_features=100)
# torch.nn.Sequential(model_pre['H'][:5], layer)
# [f for f in model_pre.named_parameters()]

In [None]:
# model_pre['H'][-1].weight.requires_grad = False
# model_pre['H'][-1].bias.requires_grad
# [f for f in model_pre.named_parameters()]
# model_pre.state_dict().keys()

# torch.load('./results_pre/best.pth', map_location=torch.device('cpu'))
