In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from sklearn.metrics import accuracy_score
from torchvision.models import resnet18, ResNet18_Weights

from models import ModelWrapper
from utils import get_train_dataset, get_eval_datasets, get_targets

In [2]:
batch_size = 128
lr = 0.0001
epochs_num = 10

In [3]:
weights = ResNet18_Weights.DEFAULT
preprocess = weights.transforms()
resnet = resnet18(weights=weights)

In [4]:
tmp_train_dataset = get_train_dataset(transform=preprocess)
tmp_val_dataset, test_dataset = get_eval_datasets(transform=preprocess)

In [5]:
resnet.fc = nn.Identity()

In [6]:
for param in resnet.parameters():
    param.requires_grad = False

In [7]:
# feed forward data through frozen layers only once to speed up training
def get_features_dataset(model, loader, device=None):
    if device is None:
        device = next(model.parameters()).device

    features = []
    labels = []

    model.eval()
    for x_batch, y_batch in loader:
        x_batch = x_batch.to(device)

        result = model(x_batch)
        features.append(result.detach().cpu())
        labels.append(y_batch.cpu())

    features_tensor = torch.cat(features)
    labels_tensor = torch.cat(labels)

    return torch.utils.data.TensorDataset(features_tensor, labels_tensor)

In [8]:
tmp_train_loader = torch.utils.data.DataLoader(tmp_train_dataset, batch_size=128)
tmp_val_loader = torch.utils.data.DataLoader(tmp_val_dataset, batch_size=128)

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [10]:
resnet.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [11]:
train_dataset = get_features_dataset(resnet, tmp_train_loader)
torch.save(train_dataset.tensors, 'transfer_learning_preprocessed_train.pth')

val_dataset = get_features_dataset(resnet, tmp_val_loader)
torch.save(val_dataset.tensors, 'transfer_learning_preprocessed_val.pth')

In [12]:
generator = torch.Generator().manual_seed(42)

targets = get_targets(train_dataset)
_, counts = targets.unique(return_counts=True)

weights = 1 / counts.float()
sample_weights = weights[targets.long()]

sampler = torch.utils.data.WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True,
    generator=generator
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, sampler=sampler)
val_loader = torch.utils.data.DataLoader(
    dataset=val_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=batch_size)

In [13]:
classificator = nn.Sequential(
    nn.Dropout(p=0.6),
    nn.Linear(512, 43))

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classificator.parameters(), lr=lr)

mw = ModelWrapper(classificator, criterion, optimizer)
mw.set_dataloaders(train_loader, val_loader)
# mw.set_writer(writer_filename, writer_subdirectory)

In [15]:
mw.train(100)

epoch: 0    loss: 3.5935698182959306   val_loss: 2.972222857475281
epoch: 1    loss: 2.876021574558824   val_loss: 2.4912251615524292
epoch: 2    loss: 2.431995736354846   val_loss: 2.18025194644928
epoch: 3    loss: 2.134095173133047   val_loss: 1.959123113155365
epoch: 4    loss: 1.9156980332004967   val_loss: 1.8026808404922485
epoch: 5    loss: 1.7522837414125507   val_loss: 1.6791580390930176
epoch: 6    loss: 1.6188222956999638   val_loss: 1.5871337676048278
epoch: 7    loss: 1.5187039557826576   val_loss: 1.5060655641555787
epoch: 8    loss: 1.4215254709481053   val_loss: 1.4446668457984924
epoch: 9    loss: 1.3508826956224214   val_loss: 1.3850774884223938
epoch: 10    loss: 1.3001562265688151   val_loss: 1.3477281069755553
epoch: 11    loss: 1.2430722439688358   val_loss: 1.3046664905548095
epoch: 12    loss: 1.2099242643876509   val_loss: 1.2745911502838134
epoch: 13    loss: 1.1546711257199922   val_loss: 1.2430265474319457
epoch: 14    loss: 1.1354784406543348   val_loss: 1

In [16]:
resnet.fc = classificator

In [17]:
resnet.eval()

y_true = torch.empty(0)
y_hat = torch.empty(0)

for x, y in test_loader:
    x = x.to(device)
    y_true = torch.cat([y_true, y])

    preds = torch.argmax(resnet(x), dim=1)
    y_hat = torch.cat([y_hat, preds.detach().cpu()])

print(f"Accuracy: {accuracy_score(y_true, y_hat)}")

Accuracy: 0.7388756927949327
