# U-Net Depth (Treinamento)

In [None]:
# Dados do repositório
repository_name = 'unet_depth'
repository_url = 'https://github.com/duraes-antonio/unet_depth'
branch = 'main'

In [None]:
import random as native_random

from numpy import random as np_random
from tensorflow import random as tf_random

# Definir semente usada em operações como embaralhamento do dataset, pelo keras e TF
seed = 42

native_random.seed(seed)
np_random.seed(seed)
tf_random.set_seed(seed)

## Clonar e navegar para o repositório

In [None]:
import os

# Se o diretório pai for o repositório é porque está executando local
running_remote = not os.path.isdir('../unet_depth')

if running_remote:
    !git clone $repository_url
    !cd $repository_name
    os.chdir(repository_name)

In [None]:
!git checkout $branch
!git pull

## Instalar dependências

In [None]:
!pip install pymongo[srv] dnspython keras-unet-collection python-dotenv imutils
!pip install py-cpuinfo gpuinfo typing_extensions

if not running_remote:
    !pip install google-api-python-client google-auth-httplib2 google-auth-oauthlib

## Preparar variáveis de acordo com ambiente de execução

In [None]:
from infra.consts.environment import Environment

env = Environment.KAGGLE

In [None]:
if env is Environment.COLAB:
    from google.colab import drive

    drive.mount('/content/drive')

In [None]:
from infra.util.environment_vars import get_env_vars
import json

env_vars = get_env_vars(env)
os.environ["DATABASE_URL"] = env_vars['database_url']

with open('google_credentials.json', 'w', encoding='utf-8') as f:
    json.dump(env_vars['google_credentials'], f, ensure_ascii=False, indent=4)

with open('token.json', 'w', encoding='utf-8') as f:
    json.dump(env_vars['google_token'], f, ensure_ascii=False, indent=4)

## Instanciar serviços para persistência de resultados e blob

In [None]:
from infra.util.mongodb import build_db_name
from domain.models.network import NetworkConfig

config: NetworkConfig = {
    'size': 256,
    'filter_min': 64,
    'filter_max': 512,
    'pool': True,
    'unpool': True,
}
db_name = build_db_name(config)

In [None]:
from infra.services.results_service_mongodb import ResultServiceMongoDB
from domain.services.results_service import ResultService
from infra.services.blob_storage.blob_storage_service_google_drive import GoogleDriveBlobStorageService
from infra.services.model_storage_service_google_drive import ModelStorageServiceGoogleDrive
from domain.services.model_storage_service import ModelStorageService
from infra.services.test_case_execution_service_mongodb import TestCaseExecutionServiceMongoDB
from infra.services.test_case_service_mongodb import TestCaseServiceMongoDB
from domain.services.test_case_execution_service import TestCaseExecutionService
from domain.services.blob_storage_service import BlobStorageService
from domain.services.test_case_service import TestCaseService

test_case_serv: TestCaseService = TestCaseServiceMongoDB(db_name)
execution_serv: TestCaseExecutionService = TestCaseExecutionServiceMongoDB(db_name)

blob_service: BlobStorageService = GoogleDriveBlobStorageService(db_name)
model_storage: ModelStorageService = ModelStorageServiceGoogleDrive(blob_service)

result_service: ResultService = ResultServiceMongoDB(db_name)

## Baixar dataset

In [None]:
if not os.path.exists('data'):
    !git clone "https://gitlab.com/siddinc/new_depth.git" "./data"

## Executar aplicação

In [None]:
from infra.application_manager import ApplicationManager

train_path = "./data/nyu2_train.csv"
test_path = "./data/nyu2_test.csv"
batch_size = 4
epochs = 70

try:
    application = ApplicationManager(
        blob_service, model_storage, execution_serv,
        test_case_serv, result_service, config, epochs
    )
    application.prepare_train_data(train_path, batch_size, True, seed, 1)
    application.prepare_test_data(test_path, batch_size, True, seed)
    application.run()

except Exception as E:
    print(E)