In [55]:
import torch
from torchvision.models import resnet18
from torch import nn, optim
from torch import tensor
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
from torch.optim import lr_scheduler
from image_classification_simulation.data.office31_loader import Office31Loader
from torch.nn.functional import softmax

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [56]:
office_loader = Office31Loader(data_dir="../examples/data/amazon/images/", hyper_params={"num_workers": 2, 'batch_size': 32})
office_loader.setup('fit')
train_loader = office_loader.train_dataloader()
val_loader = office_loader.val_dataloader()

image size set to: 200


In [64]:
def dfs_freeze(model):
    for param in model.parameters():
      param.requires_grad = False

def dfs_unfreeze(model):
    for param in model.parameters():
      param.requires_grad = True

class Resnet(nn.Module):
    def __init__(self, feature_extractor: nn.Module):
        super(Resnet, self).__init__()
        self.feature_extractor = feature_extractor
        layers = list(self.feature_extractor.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)
        # dfs_freeze(self.feature_extractor)
        self.flatten = nn.Flatten()
        self.linear1 = torch.nn.Linear(512, 31)

    def forward(
        self,
        batch_images: torch.Tensor
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        # self.feature_extractor.eval()
        # with torch.no_grad():
        z_x = self.feature_extractor.forward(batch_images)

        z_x = self.flatten(z_x)
        logits = self.linear1(z_x)

        return logits

model = None
convolutional_network = resnet18(pretrained=True)
model = Resnet(convolutional_network)
model = model.to(device)
print(model)

Resnet(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [61]:
def evaluate(
  model,
    test_loader):
    """
    Returns the number of correct predictions of query labels, and the total number of predictions.
    """
    preds, true  = [], []
    correct, size = 0 , 0
    model.eval()
    for batch_images, batch_labels in test_loader:
      logits = model(batch_images.to(device)).detach().data
      probs = softmax(logits,1)
      preds = torch.argmax(probs,1)
      correct+= torch.sum( preds == batch_labels.to(device) ).item() 
      size+=batch_images.size(0)

    return 100*correct / size


In [66]:
# from tqdm import tqdm
from torch.optim import lr_scheduler
criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)
# optimizer = optim.Adam(  model.parameters(), lr=2e-5)
optimizer = optim.SGD( model.parameters(), lr=0.0001, momentum=0.9)
# optimizer = optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001, momentum=0.9)

scheduler = lr_scheduler.StepLR(optimizer, step_size=0.9, gamma=0.1)


# Train the model yourself with this cell
log_update_frequency = 1

all_loss = []
model.train()
epochs = 100
for epoch in range(1,epochs):
  preds, true  = [], []
  correct = 0
  t=0
  for batch_images,batch_labels in train_loader:
    model.train()
    optimizer.zero_grad()
    logits = model(
        batch_images.to(device)
    )

    loss = criterion(logits, batch_labels.to(device))
    loss.backward()
    optimizer.step()


    loss_value = loss.item()
    all_loss.append(loss_value)

    probs = softmax(logits,1)
    preds = torch.argmax(probs,1)#.tolist()
    correct+= torch.sum( preds == batch_labels.to(device) ).item() 
    t+=batch_labels.size(0)

  train_accuracy = 100 * correct/t
  print('end of epoch {} total loss is {} train accuracy is {}.'.format(epoch,np.array(all_loss).mean(), train_accuracy ) )
  all_loss = []
  correct=0
  t=0
  
  # if epoch == 3:
  #   dfs_unfreeze(model)
  #   print('weights are unfrozen!')

  if epoch % log_update_frequency == 0:
    print('Loss {} and validation accuracy {}: '.format(loss_value, evaluate(model,val_loader) ) )
    # scheduler.step()
    print('learning rate updated to : ',scheduler.get_last_lr())

end of epoch 1 total loss is 3.1899641543626784 train accuracy is 14.313880126182966.
Loss 3.1175014972686768 and validation accuracy 27.75800711743772: 
learning rate updated to :  [0.0001]
end of epoch 2 total loss is 2.9586816906929014 train accuracy is 33.12302839116719.
Loss 2.8326778411865234 and validation accuracy 41.637010676156585: 
learning rate updated to :  [0.0001]
end of epoch 3 total loss is 2.7381107926368715 train accuracy is 46.17507886435331.
Loss 2.59792160987854 and validation accuracy 53.736654804270465: 
learning rate updated to :  [0.0001]
end of epoch 4 total loss is 2.5267316937446593 train accuracy is 56.46687697160883.
Loss 2.389148473739624 and validation accuracy 59.430604982206404: 
learning rate updated to :  [0.0001]
end of epoch 5 total loss is 2.328317728638649 train accuracy is 62.973186119873816.
Loss 2.209064483642578 and validation accuracy 63.345195729537366: 
learning rate updated to :  [0.0001]
end of epoch 6 total loss is 2.1472518503665925 t

KeyboardInterrupt: 

: 

In [62]:
office_loader.setup('test')
test_loader = office_loader.test_dataloader()
evaluate(model,val_loader)

84.34163701067615