In [None]:
!pip install -Uqq fastbook
!pip list -v | grep fastai

[K     |████████████████████████████████| 727kB 9.3MB/s 
[K     |████████████████████████████████| 1.1MB 16.7MB/s 
[K     |████████████████████████████████| 194kB 34.7MB/s 
[K     |████████████████████████████████| 51kB 6.2MB/s 
[K     |████████████████████████████████| 61kB 5.9MB/s 
[?25h

In [None]:
import pandas as pd
import re

from datetime import datetime
from collections import Counter    

from fastbook import *

## Mount Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Functions

In [None]:
def prediction(pkl, test_path_str):
  """
  Returns tuple with:
    - list of tuples (probability classes for each obs)
    - list of predicted class for each obs
  """
  model = load_learner(str(pkl))
  test_dl = model.dls.test_dl(get_image_files(test_path_str))
  preds,y = model.get_preds(dl = test_dl)
  y = torch.argmax(preds, dim = 1)
  
  etiquetas_url = 'https://raw.githubusercontent.com/DesafiosAgTech/DesafioAgTech2020/master/dataset/Etiquetas.csv'
  etiquetas =  pd.read_csv(etiquetas_url, error_bad_lines=False)
  # Predicted to corresponding CultivoId + add global id for submit

  glob_list = [re.findall(r'(\d+).png', r) for r in [str(p) for p in test_dl.items]]
  globalids = [int(id) for sub in glob_list for id in sub] 
  prediction = [model.dls.vocab[p] for p in y.tolist()]
  
  return preds, prediction

In [None]:
def generate_ensemble(
    test_path_str
    ,export_path_str='/content/drive/MyDrive/MACHINLEARNING/DesafioAgTech2020/submit/'):
  """
  Submits csv ready to upload in the export path. Date at the end
  Input: 
    - path of test directory (.png) as string. Must end in GlobalId.png
    - path to export directory as string
  """
  # proba de cada clase y clase maxima
  # modelo 1
  a,b = prediction(
      "/content/drive/MyDrive/MACHINLEARNING/DesafioAgTech2020/xresnet50_heatmap.pkl"
      ,'/content/drive/MyDrive/MACHINLEARNING/DesafioAgTech2020/data/test_heatmap_ts')
  # modelo 2
  c,d = prediction(
      "/content/drive/MyDrive/MACHINLEARNING/DesafioAgTech2020/xresnet50_timeseries_fullbandas_conX_weight_dio069.pkl"
      ,'/content/drive/MyDrive/MACHINLEARNING/DesafioAgTech2020/data/test_ts_fullbandas_conX')
  # modelo 3
  e,f = prediction(
      "/content/drive/MyDrive/MACHINLEARNING/DesafioAgTech2020/xresnet50_timeseries_fullbandas_conX_weightCross.pkl"
      , '/content/drive/MyDrive/MACHINLEARNING/DesafioAgTech2020/data/test_ts_fullbandas_conX')
  # para cada obs, lista con la clase predicha por cada modelo
  full_preds = []
  for i in range(len(b)):
    print(i)
    temp = []
    for m in [b,d,f]:
     temp.append(m[i])
    full_preds.append(temp)
  # para cada obs, lista con la probabilidad de cada modelo para la clase predicha
  full_confidence = []
  for i in range(len(a)):
    print(i)
    temp = []
    for m in [a,c,e]:
     temp.append(torch.max(m[i]))
    full_confidence.append(temp)
  # counter de cada obs (objeto que cuenta la cantidad de veces de cada valor en una lista)
  # Se usa para ver si alguna clase es mayoritariamente pronosticada
  ensemble = []
  for i in full_preds:
    c = Counter(i)
    ensemble.append(c)  
  # claes mayoritaria y en caso de empate, clase con mayor prob en su modelo.
  ensemble2 = []
  for i,c in enumerate(ensemble):
    # >1 porque son 3 modelos. Generalizar
    if c.most_common()[0][1] > 1:
      ensemble2.append(c.most_common()[0][0])
    else:
      ensemble2.append(full_preds[i][full_confidence[i].index(max(full_confidence[i]))])
  prediction_df = pd.DataFrame(zip(globalids,ensemble2), columns = ['globalid','clase'])
  submit = prediction_df.merge(
      etiquetas, how="left", left_on='clase', right_on='Cultivo')[['globalid','CultivoId']]
  # exports submit
  now = datetime.now()
  file_results = \
    str(export_path_str) + now.strftime('%Y%m%d%H%M') + str(test_path_str).split("/")[-1] + '.csv'
  submit.to_csv(file_results, header=False, index=False)

## Run

In [None]:
# /ensemble es parte del nombre del csv. Quedo de la funcion de utils
generate_ensemble(test_path_str="/ensemble")