In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/Capstone/proofpoint/notebooks

In [None]:
!pip install mlflow --quiet
!pip install pyngrok --quiet

In [1]:
from src_imdb import load_imdb
from src_imdb import HyperParams
from src_imdb import train_val
from src_imdb import distillation
from src_imdb import predict_batch

import sklearn
import mlflow
from collections import Counter
import numpy as np

import torch


import nltk
nltk.download('stopwords')

import os
os.makedirs("resources", exist_ok=True)
CHECKPOINT_FOLDER = "./saved_model"

In [None]:
org_hyperparams = HyperParams.HyperParams()
org_hyperparams.OPTIM = 'rmsprop'
org_hyperparams.LR = 0.001
org_hyperparams.BIDIRECTIONAL = True
_ = train_val.train_and_test_model_with_hparams(org_hyperparams, "lstm_teacher")

In [26]:
org_hyperparams = HyperParams.HyperParams()
org_hyperparams.OPTIM = 'rmsprop'
org_hyperparams.LR = 0.001
org_hyperparams.BIDIRECTIONAL = True
_ = train_val.train_and_test_model_with_hparams(org_hyperparams,"lstm_old", new_data= 35_000)

In [45]:
with mlflow.start_run(run_name="MLflow on Colab"):

  t_model = torch.load(os.path.join(CHECKPOINT_FOLDER,'bi_lstm_teacher.pth'))
  t_model.eval()
  mlflow.pytorch.log_model(t_model, "teacher_model")

  t_test_pred = predict_batch.t_prediction_batch(t_model)
  t_test_label = np.argmax(t_test_pred[0],axis = 1)

  old_model = torch.load(os.path.join(CHECKPOINT_FOLDER,'lstm_old.pth'))
  old_model.eval()
  mlflow.pytorch.log_model(old_model, "old_model")
  
  old_test_pred = predict_batch.s_prediction_batch(old_model)
  old_test_label = np.argmax(old_test_pred[0],axis = 1)

  a = [0.2,0.4,0.6,0.8] #
  

  for alpha in a:
    churn_ratio_lst = {}
    churn_lst = {}
    win_loss_ratio_lst = {}
    # mlflow.s_model.autolog()

    distill = distillation.distillation(t_model,alpha)  

    org_hyperparams = HyperParams.HyperParams()
    org_hyperparams.OPTIM = 'rmsprop'
    org_hyperparams.BIDIRECTIONAL = True
    org_hyperparams.LR = 0.001

    train_val.train_and_test_model_with_hparams(org_hyperparams, f"lstm_student_alpha_{alpha}",
                                                    new_data= 35_000,distil=distill[0])

    s_model = torch.load(os.path.join(CHECKPOINT_FOLDER,f"lstm_student_alpha_{alpha}.pth"))

    mlflow.pytorch.log_model(s_model, f"student_model_alpha_{alpha}")
    s_model.eval()
    s_test_pred = predict_batch.s_prediction_batch(s_model)
    s_test_label = np.argmax(s_test_pred[0],axis = 1)

    # churn
    churn =round(1 - sklearn.metrics.accuracy_score(t_test_label, s_test_label),3)
    metric = f"churn_alpha_{alpha}"

    churn_lst[metric] = churn

    metrics = {**churn_lst}
    mlflow.log_metrics(metrics)
    
    # churn ratio
    churn_any =round(1 - sklearn.metrics.accuracy_score(t_test_label, old_test_label),3)

    churn_ratio = churn/churn_any
    metric_2 = f"churn_ratio_alpha_{alpha}"

    churn_ratio_lst[metric_2] = churn_ratio

    metrics_2 = {**churn_ratio_lst}
    mlflow.log_metrics(metrics_2)

    # win/loss

    s_vs_true_label = s_test_label == s_test_pred[1]
    t_vs_true_label = t_test_label == t_test_pred[1]

    win_loss = Counter([i[0] for i in zip(s_vs_true_label,t_vs_true_label) if i[0]!=i[1]])

    win_loss_ratio = win_loss[True]/win_loss[False]

    metric_3 = f"win_loss_ratio_{alpha}"

    win_loss_ratio_lst[metric_3] = win_loss_ratio

    metrics_3 = {**win_loss_ratio_lst}
    mlflow.log_metrics(metrics_3)

# run tracking UI in the background
get_ipython().system_raw("mlflow ui --port 5000 &") # run tracking UI in the background


# create remote tunnel using ngrok.com to allow local port access
# borrowed from https://colab.research.google.com/github/alfozan/MLflow-GBRT-demo/blob/master/MLflow-GBRT-demo.ipynb#scrollTo=4h3bKHMYUIG6

from pyngrok import ngrok

# Terminate open tunnels if exist
ngrok.kill()

# Setting the authtoken (optional)
# Get your authtoken from https://dashboard.ngrok.com/auth
NGROK_AUTH_TOKEN = "2G7FjlBCOLo6rD2nAmUP8PQNGeF_49jJ1SVFRxQzkwA6qo1sF"
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# Open an HTTPs tunnel on port 5000 for http://localhost:5000
ngrok_tunnel = ngrok.connect(addr="5000", proto="http", bind_tls=True)
print("MLflow Tracking UI:", ngrok_tunnel.public_url)