### Loading Feature extractor model

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES']="0,1"
#torch.cuda.set_device(0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(' # of GPUs available : ', torch.cuda.device_count())

log_interval = 500

 # of GPUs available :  2


In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # Layer 1
        self.conv1 = nn.Conv2d(1, 64, 3) # 28x28 -> 26x26
        self.b1    = nn.BatchNorm2d(64)
        self.pool  = nn.MaxPool2d(2, 2)  # 26x26 -> 13x13
        
        # Layer 2
        self.conv2 = nn.Conv2d(64, 128, 3)  # 13x13 -> 11x11
        self.b2    = nn.BatchNorm2d(128)
        #self.pool                          # 11x11 -> 5x5
        
        # Layer 3
        self.conv3 = nn.Conv2d(128, 128, 3) # 5x5 -> 3x3
        self.b3    = nn.BatchNorm2d(128)
        #self.pool                          # 3x3 -> 1x1 

        # FC Layers (Newly designed below!!)
        self.fcn1 = nn.Linear(128 * 1 * 1, 52)
        self.bfn1 = nn.BatchNorm1d(52)
        
        self.fcn2 = nn.Linear(52, 10)
        

    def forward(self, x):
        # Layer 1
        x = self.pool(F.relu(self.b1(self.conv1(x))))
        #print(self.num_flat_features(x))
        
        # Layer 2
        x = self.pool(F.relu(self.b2(self.conv2(x))))
        
        # Layer 3
        x = self.pool(F.relu(self.b3(self.conv3(x))))
        
        # Flatten tensors
        x = x.view(-1, self.num_flat_features(x))
        
        # FC Layer 1 (Newly designed below!!)
        x = F.relu(self.bfn1(self.fcn1(x)))
        x = F.dropout(x, training=self.training)
        
        # FC Layer 2
        x = self.fcn2(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

### Load dataset

In [None]:
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16,
                                         shuffle=False, num_workers=2)

### Load Model

In [None]:
# Case 1 : Load on GPU
device = torch.device("cuda")
model = Net()
model.load_state_dict(torch.load('./results/model.pth'), strict=False)
model.to(device)

In [None]:
print(model.state_dict())

### Test feature extraction

In [None]:
model.eval()

with torch.no_grad():
    for images_eval, labels_eval in testloader:
        images_eval, labels_eval = images_eval.to(device), labels_eval.to(device)
        outputs_eval = model(images_eval)
        #print(outputs_eval.shape) # shape[1] should be 128

### Resume training process

### Define a loss function and optimizer

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

### Define Trainer

In [None]:
def train(epoch):
    model.train()
    
    running_loss = 0.0
    total = 0
    correct = 0.0
    
    for i, data in enumerate(testloader, 0):
        # get the inputs
        inputs, labels = data
        
        # For Multi-gpu processing
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        model.train(True) # Set to Train mode
        outputs = model(inputs)
        
        # Calculate Loss
        loss = criterion(outputs, labels)
        # Backpropagation
        loss.backward()
        optimizer.step()
        
        # Calculate Training Accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()    

        # print statistics
        running_loss += loss.item()
        if i % log_interval == 0:    # print every 200 mini-batches
            print('Train Epoch : %2d [%6d, %6d] loss: %.3f TrnAcc: %.3f' %
                  (epoch, total, len(testloader.dataset), running_loss / 2000, correct / total))
            running_loss = 0.0

In [None]:
best_val_acc = 0.0
best_epoch = 0

for epoch in range(5):  # loop over the dataset multiple times
    train(epoch)
print('Finished Training')