## Config

In [None]:
import numpy as np

from sklearn.metrics import f1_score

import wandb

wandb.init(project="NeuroWood2022-name", entity="nuclear_foxes_team")

In [2]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD, lr_scheduler
from torch.utils.data import DataLoader


torch.manual_seed(7575)
np.random.seed(7575)
torch.cuda.empty_cache()

from config import Config
from source import train, WoodDataset, save_model
import models

## Custom functions

In [3]:
def transform(x):
    if (x == 0):
        return x
    elif (x == 1):
        return x
    else:
        return 3


def calculate_accuracy(y_pred, y_true):
    correct = (y_pred.argmax(dim=1) == y_true).float().sum()
    return correct


def calculate_f1_score(y_pred, y_true):
    y_pred = np.argmax(y_pred.detach().cpu().numpy(), axis=1)
    y_true = y_true.detach().cpu().numpy()
    return f1_score(y_true, y_pred, average='macro')

## Data loading

In [4]:
train_dataset_full = WoodDataset()

train_set_size = int(len(train_dataset_full) * 0.8)
valid_set_size = len(train_dataset_full) - train_set_size

train_dataset, val_dataset = torch.utils.data.random_split(train_dataset_full, [train_set_size, valid_set_size])

In [5]:
train_dataloader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, pin_memory=True, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, pin_memory=True, shuffle=True)

In [6]:
model = models.Effnetb3()

lr = 8e-4
gamma = 0.6
parameters = {
    'optimizer' : Adam(model.parameters(), lr=lr),
    'criterion' : CrossEntropyLoss(),
    'val_criterion' : CrossEntropyLoss(),
    'val_metric' : calculate_accuracy,
    'n_epochs' : 15,
    'device' : Config.DEVICE,
}

parameters['scheduler'] = lr_scheduler.ExponentialLR(parameters['optimizer'], gamma=gamma)

wandb_dict = {
    "Learning_rate": lr,
    "Gamma": gamma,
    "Epochs": parameters['n_epochs'],
    "Batch_size": Config.BATCH_SIZE,
    "Network": "Effnetb3 + Adam + ExponentialScheduler",
    "Full train dataset" : False
}

wandb.config.update(wandb_dict)

In [None]:
train(model, train_dataloader, val_dataloader=val_dataloader, **parameters)