# Dados

In [3]:
import os
from pathlib import Path
import pandas as pd
import numpy as np
import cv2
import shutil
from torchvision import models
from module import data_read, setup
from torch.utils.data import ConcatDataset
from module.LightningModule import ClassificationModule
from module.setup import *
from module.data_read import *

torch.manual_seed(random_state)

DDR_DIR = Path.cwd() / 'DDR-dataset/DR_grading'
IDRID_DIR = Path.cwd() / 'IDRID-Classificacao'
#FGADR_DIR = Path.cwd() / 'FGADR Dataset/Seg-set'

## Binary

In [19]:
# Leitura IDRID
train_idrid_dataset = IDRIDDataset(IDRID_DIR/'1. Original Images/a. Training Set',
                                   IDRID_DIR/'2. Groundtruths/a. IDRiD_Disease Grading_Training Labels.csv',
                                   data_transforms['test'],
                                   convert_to_binary=True)
test_idrid_dataset = IDRIDDataset(IDRID_DIR/'1. Original Images/b. Testing Set',
                                  IDRID_DIR/'2. Groundtruths/b. IDRiD_Disease Grading_Testing Labels.csv',                   
                                  data_transforms['test'],
                                  convert_to_binary=True)
idrid_dataset = ConcatDataset([train_idrid_dataset, test_idrid_dataset])

#### EfficientNet-b7 Instaciation

In [9]:
effcb7_best_weights = models.EfficientNet_B7_Weights.DEFAULT
effcb7_model = models.efficientnet_b7(weights=effcb7_best_weights)
effcb7_preprocess = effcb7_best_weights.transforms()
num_ftrs = effcb7_model.classifier[-1].in_features
effcb7_model.classifier[-1] = nn.Linear(num_ftrs, num_classes)

In [12]:
effcb7_model._get_name()

'EfficientNet'

#### Execution

In [None]:
from lightning.pytorch.loggers import TensorBoardLogger
from sklearn.model_selection import KFold


anno_train_ddr = pd.read_csv(DDR_DIR/'train.txt', header=None, sep=' ')
anno_train_ddr[0] = 'train/'+anno_train_ddr[0]
anno_valid_ddr = pd.read_csv(DDR_DIR/'valid.txt', header=None, sep=' ')
anno_valid_ddr[0] = 'valid/'+anno_valid_ddr[0]
anno_test_ddr = pd.read_csv(DDR_DIR/'test.txt', header=None, sep=' ')
anno_test_ddr[0] = 'test/'+anno_test_ddr[0]

# Todas as anotações do DDR
anno_ddr = pd.concat([anno_train_ddr, anno_valid_ddr, anno_test_ddr], ignore_index=True)
kfold = KFold(n_splits=k_folds, shuffle=True, random_state=random_state)

for fold, (train_ids, test_ids) in enumerate(kfold.split(anno_ddr)):
    print(f'FOLD {fold}')
    print('--------------------------------')
    
    # Divide os dados com base nos atuais indices do K-fold
    train_dataset = DDRDatasetKFold(DDR_DIR, anno_ddr.iloc[train_ids], data_transforms['train'], convert_to_binary=True)
    test_dataset = DDRDatasetKFold(DDR_DIR, anno_ddr.iloc[test_ids], data_transforms['test'], convert_to_binary=True)


    # Definição dos dataloaders
    trainloader = torch.utils.data.DataLoader(
                      train_dataset,
                      batch_size=batch_size,
                      shuffle=True,
                      num_workers=5)
    testloader = torch.utils.data.DataLoader(
                      test_dataset,
                      batch_size=batch_size,
                      shuffle=False,
                      num_workers=5)
    idridloader = torch.utils.data.DataLoader(
                      idrid_dataset,
                      batch_size=batch_size,
                      num_workers=5)



    # DEBUG:
    # trainer = L.Trainer(fast_dev_run=4) # a execução do trainer se limitará a 2 batches
    # trainer = L.Trainer(limit_train_batches=30, max_epochs=1) # usar apenas 30% dos dados de cada lote de treino
    # trainer = L.Trainer(default_root_dir='checkpoints kfold/', accelerator='gpu', max_epochs=num_epochs)


    
    # treino
    # vgg16_model = ClassificationModule(vgg16, loss_function, optim.Adam)
    # trainer.fit(model=vgg16_model, train_dataloaders=trainloader)
    # teste 
    
    # print('\n\n TESTE IN THE TESt FOLD \n\n')    
    # trainer.test(model=vgg16_model, dataloaders=testloader) # test na fold de test 
    # print('\n\n TESTE IDRID (cros validation)\n\n')    
    # trainer.test(model=vgg16_model, dataloaders=idridloader) # test no idrid