In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.models as models

import numpy as np
import random
import time

from useful_things.dataloader import *
from useful_things.networks import *
from useful_things.helper_functions import *

Next we load pretrained models, see https://pytorch.org/docs/stable/torchvision/models.html for a list of pretrained models to choose from. For different networks, print the model first to find the first number in the nn.Linear(X,3, bias=True)

In [34]:
model = models.resnet18(pretrained=True)
# print(model)

# freeze the existing layers
count = 0
for param in model.parameters():
    param.requires_grad = False
    count += 1
    #if count < 2:
        # print(param)

# add new fully connected layers at the end of the network (defaults unfrozed):
model.fc = nn.Linear(512,3, bias=True)

old_conv_weights = model.state_dict()['conv1.weight']
d, _, w, h = old_conv_weights.shape
old_var = old_conv_weights.var().item()

added_conv_weights = np.random.randn(d,1,h,w)
added_conv_weights *= np.sqrt(old_var)
added_conv_weights = torch.from_numpy(added_conv_weights).float()


new_conv_weights = torch.cat((old_conv_weights,added_conv_weights),1)

model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.state_dict()['conv1.weight'] = new_conv_weights

print(model)


torch.Size([64, 4, 7, 7])
