In [1]:
from intromlproject.utils.read_dataset import read_dataset, transform_dataset, get_data_loader

train_data, val_data, test_data = read_dataset("intromlproject/cub", transform_dataset())
train_loader, val_loader, test_loader = get_data_loader(train_data, val_data, test_data, batch_size=32)


  from .autonotebook import tqdm as notebook_tqdm


Read dataset successfully
Data loader successfully


In [2]:
dataset_sizes = {
    'train': len(train_loader.dataset),
    'val': len(val_loader.dataset),
    'test': len(test_loader.dataset)
}

dataloaders = {
    'train': train_loader,
    'val': val_loader,
    'test': test_loader
}

In [3]:
class_names = train_data.classes
print(class_names)

['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani', '005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet', '009.Brewer_Blackbird', '010.Red_winged_Blackbird', '011.Rusty_Blackbird', '012.Yellow_headed_Blackbird', '013.Bobolink', '014.Indigo_Bunting', '015.Lazuli_Bunting', '016.Painted_Bunting', '017.Cardinal', '018.Spotted_Catbird', '019.Gray_Catbird', '020.Yellow_breasted_Chat', '021.Eastern_Towhee', '022.Chuck_will_Widow', '023.Brandt_Cormorant', '024.Red_faced_Cormorant', '025.Pelagic_Cormorant', '026.Bronzed_Cowbird', '027.Shiny_Cowbird', '028.Brown_Creeper', '029.American_Crow', '030.Fish_Crow', '031.Black_billed_Cuckoo', '032.Mangrove_Cuckoo', '033.Yellow_billed_Cuckoo', '034.Gray_crowned_Rosy_Finch', '035.Purple_Finch', '036.Northern_Flicker', '037.Acadian_Flycatcher', '038.Great_Crested_Flycatcher', '039.Least_Flycatcher', '040.Olive_sided_Flycatcher', '041.Scissor_tailed_Flycatcher', '042.Ver

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import timm
import copy
import os


In [5]:
model_name = 'vit_base_patch16_224'
model = timm.create_model(model_name, pretrained=True)
model.head = nn.Linear(model.head.in_features, len(class_names))

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(device)

cuda:0


In [8]:
num_epochs = 5

In [9]:
def train_model(model, criterion, optimizer, num_epochs=10):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward
                # Track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    print(f'Best val Acc: {best_acc:4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [10]:
model_ft = train_model(model, criterion, optimizer, num_epochs=num_epochs)

Epoch 0/4
----------


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


train Loss: 2.2990 Acc: 0.4627
val Loss: 1.1035 Acc: 0.6927

Epoch 1/4
----------
train Loss: 0.6658 Acc: 0.8019
val Loss: 1.0830 Acc: 0.7244

Epoch 2/4
----------
train Loss: 0.3513 Acc: 0.8906
val Loss: 0.9705 Acc: 0.7527

Epoch 3/4
----------
train Loss: 0.2150 Acc: 0.9336
val Loss: 1.0239 Acc: 0.7434

Epoch 4/4
----------
train Loss: 0.1620 Acc: 0.9513
val Loss: 1.0758 Acc: 0.7527

Best val Acc: 0.752747


In [11]:
torch.save(model_ft.state_dict(), 'fine_tuned_vit.pth')