In [1]:
import os
import time
import copy


import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import pytorch_lightning as pl


import tableprint as tp
import torchmetrics

In [2]:
TRAIN_PATH = './data/train'
TEST_PATH = './data/validation'

img_transforms = transforms.Compose([transforms.Resize((224,224)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

train_data = datasets.ImageFolder(TRAIN_PATH, transform=img_transforms)
val_data = datasets.ImageFolder(TEST_PATH, transform=img_transforms)

In [3]:
num_workers = 4
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers,shuffle=False)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size,  num_workers=num_workers)

In [4]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # conv layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=18, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(in_channels=18, out_channels=3, kernel_size=3, stride=2)
        
        # dense layers
        self.fc1 = nn.Linear(2187 , 1024)
        self.fc2 = nn.Linear(1024 , 2)
            
    def forward(self, X):
        X = F.relu(self.conv1(X)) # here, RELU is being treated as a function rather than a layer/module
        X = F.relu(self.conv2(X))
        X = F.relu(self.conv3(X))
        X = X.view(-1, 2187)
        X = F.dropout(X, p=0.2)
        X = F.relu(self.fc1(X))
        X = F.dropout(X, p=0.2)
        X = self.fc2(X)
        return X

# Without PTL

In [9]:
root='./ckpt/'
ckpt = os.listdir(root)[0]
pre_trained_model=torch.load('./ckpt/' + ckpt)
base_model_new = CNN()
print(f'Initial State: {base_model_new.state_dict()["fc2.bias"]}')
my_model_kvpair=base_model_new.state_dict()
for key,value in pre_trained_model['state_dict'].items():
    my_key = key[6:]
    my_model_kvpair[my_key] = pre_trained_model['state_dict'][key]
base_model_new.load_state_dict(my_model_kvpair)
print(f'After Loading: {base_model_new.state_dict()["fc2.bias"]}')

Initial State: tensor([ 0.0271, -0.0105])
After Loading: tensor([ 0.0071, -0.0110])
