# Buzz Off: Mosquito recognition and classification

## Prototype


In [None]:
from src.helpers import setup_env

setup_env()

---
# Train, validation and test

Let's train our transfer learning model! Let's start defining the hyperparameters:

In [2]:
batch_size = 64  # size of the minibatch for stochastic gradient descent (or Adam)
valid_size = 0.2  # fraction of the training data to reserve for validation
num_epochs = 25  # number of epochs for training
num_classes = 2  # number of classes
learning_rate = 0.001  # Learning rate
opt = 'adam'      # optimizer. 'sgd' or 'adam'
weight_decay = 0.0 # regularization

In [None]:
from src.data import get_data_loaders
from src.optimization import get_optimizer, get_loss
from src.train import optimize
from src.transfer import get_model_transfer_learning

model_transfer = get_model_transfer_learning("densenet161")

# train the model
data_loaders = get_data_loaders(batch_size=batch_size)
optimizer = get_optimizer(
    model_transfer,
    learning_rate=learning_rate,
    optimizer=opt,
    weight_decay=weight_decay,
)
loss = get_loss()

optimize(
    data_loaders,
    model_transfer,
    optimizer,
    loss,
    n_epochs=num_epochs,
    save_path="checkpoints/model_transfer.pt",
    interactive_tracking=True
)

In [None]:
!pip install livelossplot tqdm

---
## Test the Model



In [None]:
import torch
from src.train import one_epoch_test
from src.data import get_data_loaders
from src.optimization import get_optimizer, get_loss
from src.transfer import get_model_transfer_learning
data_loaders = get_data_loaders(batch_size=batch_size)
loss = get_loss()
model_transfer = get_model_transfer_learning("densenet161", n_classes=num_classes)
# Load saved weights
model_transfer.load_state_dict(torch.load('checkpoints/model_transfer.pt'))

one_epoch_test(data_loaders['test'], model_transfer, loss)

Reusing cached mean and std
Dataset mean: tensor([0.4638, 0.4725, 0.4687]), std: tensor([0.2699, 0.2706, 0.3018])


Testing: 100%|██████████████████████████████████| 20/20 [00:19<00:00,  1.03it/s]

Test Loss: 1.413974


Test Accuracy: 62% (786/1250)





1.4139741092920302