In [1]:
import torch
from torchvision.models import resnet18
from torch import nn, optim
from torch import tensor
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
from torch.optim import lr_scheduler
from image_classification_simulation.data.office31_loader import Office31Loader
from torch.nn.functional import softmax

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
office_loader = Office31Loader(
    data_dir="../examples/data/domain_adaptation_images/amazon/images",
    eval_dir="../examples/data/domain_adaptation_images/dslr/images/",
    hyper_params={"num_workers": 2, 'batch_size': 32}
    )
office_loader.setup('fit')
train_loader = office_loader.train_dataloader()
val_loader = office_loader.val_dataloader()

image size set to: 224


In [3]:
from image_classification_simulation.models.resnet_baseline import Resnet

hparams = {
        "size": 964,
        "loss": "CrossEntropyLoss",
        "pretrained": True,
        "num_classes": 964,
    }
model = Resnet(hparams).to(device)
print(model)

Resnet(
  (loss_fn): CrossEntropyLoss()
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.

In [4]:
def evaluate(
  model,
    test_loader):
    """
    Returns the number of correct predictions of query labels, and the total number of predictions.
    """
    preds, true  = [], []
    correct, size = 0 , 0
    model.eval()
    for batch_images, batch_labels in test_loader:
      logits = model(batch_images.to(device)).detach().data
      probs = softmax(logits,1)
      preds = torch.argmax(probs,1)
      correct+= torch.sum( preds == batch_labels.to(device) ).item() 
      size+=batch_images.size(0)

    return 100*correct / size


In [6]:
# from tqdm import tqdm
from torch.optim import lr_scheduler
criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)
# optimizer = optim.Adam(  model.parameters(), lr=2e-5)
optimizer = optim.SGD( model.parameters(), lr=0.0001, momentum=0.9)
# optimizer = optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001, momentum=0.9)

scheduler = lr_scheduler.StepLR(optimizer, step_size=0.9, gamma=0.1)


# Train the model yourself with this cell
log_update_frequency = 1

all_loss = []
model.train()
epochs = 100
for epoch in range(1,epochs):
  preds, true  = [], []
  correct = 0
  t=0
  for batch_images,batch_labels in train_loader:
    model.train()
    optimizer.zero_grad()
    logits = model(
        batch_images.to(device)
    )

    loss = criterion(logits, batch_labels.to(device))
    loss.backward()
    optimizer.step()


    loss_value = loss.item()
    all_loss.append(loss_value)

    probs = softmax(logits,1)
    preds = torch.argmax(probs,1)#.tolist()
    correct+= torch.sum( preds == batch_labels.to(device) ).item() 
    t+=batch_labels.size(0)

  train_accuracy = 100 * correct/t
  print('end of epoch {} total loss is {} train accuracy is {}.'.format(epoch,np.array(all_loss).mean(), train_accuracy ) )
  all_loss = []
  correct=0
  t=0
  
  # if epoch == 3:
  #   dfs_unfreeze(model)
  #   print('weights are unfrozen!')

  if epoch % log_update_frequency == 0:
    print('Loss {} and validation accuracy {}: '.format(loss_value, evaluate(model,val_loader) ) )
    # scheduler.step()
    print('learning rate updated to : ',scheduler.get_last_lr())



IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/mila/s/sina.sarparast/.conda/envs/test/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/mila/s/sina.sarparast/.conda/envs/test/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/mila/s/sina.sarparast/.conda/envs/test/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/mila/s/sina.sarparast/.conda/envs/test/lib/python3.8/site-packages/torch/utils/data/dataset.py", line 471, in __getitem__
    return self.dataset[self.indices[idx]]
IndexError: list index out of range


In [12]:
office_loader.setup('test')
test_loader = office_loader.test_dataloader()
evaluate(model,test_loader)

86.01895734597156

## Evaluation on DSLR subset

In [13]:
office_loader.setup('eval')
eval_loader = office_loader.eval_dataloader()
evaluate(model,eval_loader)

58.63453815261044

looks like that the model is not very good at adapting to the images with background in the DSLR dataset.