# Fine-tune Network on Danish Fungi 2020 Dataset

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch

from src.core import models, metrics, training, data, loss_functions
from src.utils import nb_setup
from src.dev import experiments as exp

DATA_DIR = 'data/danish_fungi_dataset/'
TRAIN_SET_DIR = 'train_resized'

SEED = 42

nb_setup.init()
nb_setup.set_random_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
# create training 
config = exp.create_config(
    data='df2020',
    model='vit_base_224',
    loss='ce',
    opt='sgd',
    no_epochs=30,
    batch_size=64,
    total_batch_size=64,
    learning_rate=0.01,
    weight='none',
    dataset='mini',
    scheduler='reduce_lr_on_plateau',
    # note=''
)

# include configuration from model
_model_config = models.get_model(config.model, pretrained=False).pretrained_config
config.update(_model_config)

# save config file
config.save(DATA_DIR + config.specs_name)

# create loss, optimizer and scheduler functions
loss_fn = loss_functions.LOSSES[config.loss]
weight_fn = loss_functions.WEIGHTING[config.weight]
opt_fn = training.OPTIMIZERS[config.opt]
sched_fn = training.SCHEDULERS[config.scheduler]

DATASETS = {
    'full': ('DF20-train_metadata_PROD.csv', 'DF20-public_test_metadata_PROD.csv'),
    'mini': ('DF20M-train_metadata_PROD.csv', 'DF20M-public_test_metadata_PROD.csv')
}

print(config)

## Load the Data

In [None]:
# load metadata
train_df = pd.read_csv(DATA_DIR + DATASETS[config.dataset][0])
valid_df = pd.read_csv(DATA_DIR + DATASETS[config.dataset][1])

classes = np.unique(train_df['scientificName'])
no_classes = len(classes)
assert no_classes == len(np.unique(valid_df['scientificName']))
print(f'No classes: {no_classes}')
print(f'Train set length: {len(train_df):,d}')
print(f'Validation set length: {len(valid_df):,d}')

In [None]:
# create transforms
train_tfms, valid_tfms = data.get_transforms(
    size=config.input_size, mean=config.image_mean,
    std=config.image_std)

# create data loaders
trainloader = data.get_dataloader(
    train_df, img_path_col='image_path', label_col='scientificName',
    path=DATA_DIR + TRAIN_SET_DIR, transforms=train_tfms, labels=classes,
    batch_size=config.batch_size, shuffle=True, num_workers=4)
validloader = data.get_dataloader(
    valid_df, img_path_col='image_path', label_col='scientificName',
    path=DATA_DIR + TRAIN_SET_DIR, transforms=valid_tfms, labels=classes,
    batch_size=config.batch_size, shuffle=False, num_workers=4)
assert trainloader.dataset._label2id == validloader.dataset._label2id

trainloader.dataset.show_items()

## Train the Model

In [None]:
# create model
model = models.get_model(config.model, no_classes, pretrained=True)
assert np.all([param.requires_grad for param in model.parameters()])

# create loss
freq = train_df['scientificName'].value_counts()[trainloader.dataset.labels].values
weights = weight_fn(freq)
criterion = loss_fn(weight=torch.Tensor(weights).to(device) if weights is not None else None)

# create trainer
trainer = training.Trainer(
    model,
    trainloader,
    criterion,
    opt_fn,
    sched_fn,
    validloader=validloader,
    accumulation_steps=config.total_batch_size // config.batch_size,
    path=DATA_DIR,
    model_filename=config.model_name,
    history_filename=config.history_file,
    device=device)

In [None]:
# train model
trainer.train(no_epochs=config.no_epochs, lr=config.learning_rate)

In [None]:
# find learning rate
# lr_finder = trainer.lr_find()