In [1]:
import torch
from torchvision.models import resnet18
from torch import nn, optim
from torch import tensor
# import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
from torch.optim import lr_scheduler
from image_classification_simulation.data.office31_loader import Office31Loader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:

office_loader = Office31Loader(data_dir="../examples/data/amazon", hyper_params={"num_workers": 1, 'batch_size': 32})
office_loader.setup()
train_loader = office_loader.train_dataloader()
val_loader = office_loader.val_dataloader()
test_loader = office_loader.test_dataloader()

In [None]:
def dfs_freeze(model):
    for param in model.parameters():
      param.requires_grad = False

def dfs_unfreeze(model):
    for param in model.parameters():
      param.requires_grad = True

class Resnet(nn.Module):
    def __init__(self, feature_extractor: nn.Module):
        super(Resnet, self).__init__()
        self.feature_extractor = feature_extractor
        # dfs_freeze(self.feature_extractor)
        self.flatten = nn.Flatten()
        self.linear1 = torch.nn.Linear(1000, 512)
        self.linear2 = torch.nn.Linear(512, office_loader.num_unique_labels)
        self.activation = torch.nn.ReLU()

    def forward(
        self,
        batch_images: torch.Tensor
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_x = self.feature_extractor.forward(batch_images)
        z_x = self.flatten(z_x)
        z_x = self.linear1(z_x)
        z_x = self.activation(z_x)
        logits = self.linear2(z_x)

        return logits

model = None
convolutional_network = resnet18(pretrained=False)
model = Resnet(convolutional_network)
model = model.to(device)
print(model)

In [5]:
from torch.nn.functional import softmax
def evaluate(
    test_loader):
    """
    Returns the number of correct predictions of query labels, and the total number of predictions.
    """
    preds, true  = [], []
    for batch_images, batch_labels in test_loader:
      logits = model(batch_images.to(device)).detach().data
      probs = softmax(logits,0)
      preds.extend( torch.argmax(probs,1).tolist() )
      true.extend(batch_labels.tolist())

    preds, true = np.array(preds), np.array(true)
    return 100*( preds == true ).sum().item() / len(true)



In [None]:
from tqdm import tqdm
from torch.optim import lr_scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001)
# optimizer = optim.SGD( model.parameters(), lr=0.0001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=0.9, gamma=0.1)


# Train the model yourself with this cell
log_update_frequency = 1

all_loss = []
model.train()
epochs = 100
for epoch in range(1,epochs):
  preds, true  = [], []
  for batch_images,batch_labels in train_loader:
    optimizer.zero_grad()
    logits = model(
        batch_images.to(device)
    )

    loss = criterion(logits, batch_labels.to(device))
    loss.backward()
    optimizer.step()


    loss_value = loss.item()
    all_loss.append(loss_value)

    probs = torch.nn.functional.softmax(logits,0)
    preds.extend( torch.argmax(probs,1).tolist() )
    true.extend(batch_labels.tolist())

  preds, true = np.array(preds), np.array(true)
  train_accuracy = 100*( preds == true ).sum().item() / len(true)
  print('end of epoch {} total loss is {} train accuracy is {}.'.format(epoch,np.array(all_loss).sum()/epochs, train_accuracy ) )

  # if epoch == 3:
  #   dfs_unfreeze(model)
  #   print('weights are unfrozen!')

  if epoch % log_update_frequency == 0:
    print('Loss {} and validation accuracy {}: '.format(loss_value, evaluate(val_loader) ) )
    scheduler.step()
    print('learning rate updated to : ',scheduler.get_last_lr())

In [7]:
evaluate(test_loader)

31.66489172878949