In [None]:
import os
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, accuracy_score

from modules.run import load_config, Trainer, Metrics
from modules.data import DataManager, processing, util
from modules.models import pretrained_cnn, pretrained_cnn_multichannel

In [None]:
config=load_config("test")

In [None]:
data_manager = DataManager(config)

In [None]:
train_generator, val_generator, dataframe = data_manager.generate_kenya()

In [None]:
convnet = pretrained_cnn(config, image_size=config["image_size"], n_channels=config["n_channels"])

In [None]:
trainer = Trainer(config)

In [None]:
convnet.compile(loss=trainer.loss, optimizer=trainer.optimizer, weighted_metrics=config["weighted_metrics"])

# Load weights
# convnet.load_weights("")

# convnet.fit_generator(
#     train_generator, 
#     epochs=config["n_epochs"],
#     callbacks=trainer.callbacks, 
#     validation_data=val_generator, 
#     validation_steps=len(val_generator),
#     class_weight=data_manager.class_weight("kenya")
# )

train_steps = config["sample"]["size"] * (1 - config["validation_split"]) // config["batch_size"] + 1
val_steps = config["sample"]["size"] * (config["validation_split"]) // config["batch_size"] + 1
convnet.fit_generator(
    train_generator, 
    config["sample"]["size"] * (1 - config["validation_split"]) // config["batch_size"] + 1,
    epochs=config["n_epochs"],
    callbacks=trainer.callbacks,
    validation_data=val_generator, 
    validation_steps=val_steps,
    class_weight=data_manager.class_weight("kenya"),
    use_multiprocessing=True
)