In [None]:
! pip install sklearn nltk rouge

# Imports

In [None]:
from fastai.text import *
from statistics import mean, median, stdev

import sentencepiece as spm

In [None]:
import sys

sys.path.append("../../")
from eval.exp.nb_evaluation import *

sys.path.append("../../../")
from src.proc.exp.nb_proc import *
from src.prep.exp.nb_prep import *


In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
# setup paths and model type
model_path = Path("/tf/data/models")
data_path  = Path("/tf/data/datasets")

task_type = "merged"

In [None]:
sp = spm.SentencePieceProcessor()
sp.Load(str(data_path/"merged/model.model"))

# Load Data

In [None]:
df_trn, df_val, df_tst = read_data(data_path/task_type)

In [None]:
bs = 8

Percentage of data to be used: sample 

In [None]:
data = gen_lm_data(df_trn, df_val, task_type, data_path, bs = bs)
data.save(task_type + '/data_lm_100pct.pkl')

In [None]:
data = load_data(data_path/task_type, 'data_lm_10pct.pkl', bs=bs)

In [None]:
len(data.train_ds), len(data.valid_ds)

# Model Setup

In [None]:
pretrained = False

In [None]:
# amit experiments
learn = language_model_learner(data, AWD_LSTM,
                               drop_mult = 0.3, pretrained = pretrained,
                               metrics=[accuracy])

In [None]:
learn.lr_find()
learn.recorder.plot()

# Model Training

In [None]:
# Set hyperparameters
max_lr = 1e-2
moms = (0.5, .75)
pct_strt = 0.02
a_epochs = 10

In [None]:
callback_fns = [
    callbacks.SaveModelCallback(
        learn, every='improvement',
        monitor='valid_loss', name='transformer_save_model'
    ),
    callbacks.EarlyStoppingCallback(
        learn, monitor='valid_loss', min_delta = 0.01,
        patience = 3
    )
]

In [None]:
#amit experiments
learn.fit_one_cycle(a_epochs, max_lr, callbacks = callback_fns)

In [None]:
learn.load('awd_lstm_save_model')

In [None]:
figure_plot = learn.recorder.plot_losses(return_fig=True)

In [None]:
figure_plot.savefig(fname="awd_lstm_plot_losses.png", format='png')

In [None]:
from PIL import Image
Image.open('/tf/main/nbs/mdling/awd_lstm/awd_lstm_plot_losses.png')

# Model Evaluation

### Vulnerability Classification

In [None]:
task_type = "buggy"
vuln_trn, vuln_val, vuln_tst = read_data(data_path/task_type)

In [None]:
vuln_val = tag_task(vuln_val, task_type)

In [None]:
acc, prec, recal = eval_vuln(learn, vuln_val[:100], sp = sp)

In [None]:
acc, prec, recal

### Comment Generation

In [None]:
task_type = "mthds_cmts"
cmt_trn, cmt_val, cmt_tst = read_data(data_path/task_type)

In [None]:
cmt_val = tag_task(cmt_val, task_type)

In [None]:
b1, b2, b3, b4, meteor, preds = eval_txt(learn, cmt_val[:10], sp = sp)

In [None]:
mean(b1), mean(b2), mean(b3), mean(b4)

In [None]:
mean(meteor)

In [None]:
preds[9]

In [None]:
cmt_val['query'][9], cmt_val['res'][9]

In [None]:
mean(rouge_l)

### StackOverflow QA

In [None]:
task_type = "so_posts"
so_trn, so_val, so_tst = read_data(data_path/task_type)

In [None]:
so_val = tag_task(so_val, task_type)

In [None]:
b1, b2, b3, b4, meteor, preds = eval_txt(learn, so_val[:10], sp = sp)

In [None]:
mean(b1), mean(b2), mean(b3), mean(b4)

In [None]:
mean(meteor)

In [None]:
preds[9]

In [None]:
so_val['query'][9], so_val['res'][9]

In [None]:
mean(rouge_l)