In [41]:
import torch
import os
import numpy as np
from pytorch_utils.pytorch_utils import count_parameters

In [5]:
home_path = os.path.expanduser('~')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

model_dir = f'{home_path}/projects/oads_fixation_crop_models'

Using device cpu


In [3]:
os.listdir(model_dir)

['random_30000_monotonic_1.0.pth',
 'random_5000_monotonic_1.0.pth',
 'fixation_30000_monotonic_0.1.pth',
 'random_30000_monotonic_0.1.pth',
 'fixation_5000_monotonic_0.1.pth',
 'fixation_15000_monotonic_1.0.pth',
 'fixation_30000_monotonic_1.0']

In [8]:
from pytorch_utils.resnet10 import ResNet10
import sys
sys.path.append('/home/nmuller/projects/oads-van/_code/')

In [16]:
from nn.architecture import ConvVarAutoencoder, VarResNet
from torch import nn

In [24]:
def convert_state_dict(state_dict):
    from collections import OrderedDict

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:]  # remove `module.`
        else:
            name = k
        new_state_dict[name] = v

    return new_state_dict

In [54]:
model = VarResNet()
model.decoder = nn.Sequential(
    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.ConvTranspose2d(128, 64 , kernel_size=4, stride=2, padding=1, bias=False),
    nn.ConvTranspose2d(64 , 32 , kernel_size=4, stride=2, padding=1, bias=False),
    nn.ConvTranspose2d(32 , 3  , kernel_size=4, stride=2, padding=1, bias=False),
)
model = model.to(device)

In [55]:
weight_path = os.path.join(model_dir, f'random_30000_monotonic_1.0.pth')
state_dict = convert_state_dict(torch.load(weight_path, map_location=device)['model_state_dict'])
model.load_state_dict(state_dict)

<All keys matched successfully>

In [80]:
def new_forward(self, x):
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    x_recon = self.decode(z)
    return x_recon

In [56]:
# self.avgpool = nn.AdaptiveAvgPool2d(output_size=((1, 1)))
# self.fc = nn.Linear(in_features=512, out_features=n_output_channels, bias=True)
model.decoder = nn.Sequential(
    nn.AdaptiveAvgPool2d(output_size=((1, 1))),
    nn.Flatten(start_dim=1),
    nn.Linear(in_features=512, out_features=21, bias=True)
)

In [81]:
bound_method = new_forward.__get__(model, model.__class__)
setattr(model, 'forward', bound_method)

In [57]:
for child in model.children():
    for param in child.parameters():
        param.requires_grad = False

for child in model.decoder.children():
    for param in child.parameters():
        param.requires_grad = True

In [58]:
count_parameters(model)

10773

In [90]:
input = torch.tensor(np.random.rand(1, 3, 224, 224)).type(torch.DoubleTensor)
model = model.type(torch.DoubleTensor)

In [91]:
out = model(input)

In [92]:
len(out), type(out)

(1, torch.Tensor)

In [93]:
out.shape#, out[1].shape, out[2].shape

torch.Size([1, 21])