In [34]:
%load_ext autoreload
%autoreload 2

from fastai.text import *
from fastai.text.data import DataBunch
from fastai.datasets import *
from pathlib import Path
import pandas as pd
from fastai.metrics import *
from fastai.train import *
from fastai.vision import *
from fastai.imports import nn, torch
from sklearn import metrics
from fastai.callbacks import *
from fastai.basic_train import get_preds

import sacred

import sklearn.metrics
import datetime
import news_utils
from pathlib import Path

import fastai
fastai.__version__

torch.cuda.set_device(2)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
EX_PA = Path('/mnt/data/group07/johannes/ynacc_proc/replicate/20k')
model_id = '2018_11_25_23_51_55_431399'# best without overfitting
db_name = '20k_class_no_over'

In [4]:
data_lm = TextLMDataBunch.load(EX_PA)

In [3]:
learn_lm = language_model_learner(data_lm).load(EX_PA/"models"/model_id, device="cpu")
learn_lm.save_encoder('encoder_' + model_id)

In [5]:
def setup_data(clas):
    split_path = Path('~/data/ynacc_proc/replicate/split')

    data_clas_train = pd.read_csv(split_path/'train_proc_with_ner.csv')
    data_clas_val = pd.read_csv(split_path/'val_proc_with_ner.csv')

    data_clas_train = data_clas_train[[clas, 'text_proc']]
    data_clas_val = data_clas_val[[clas, 'text_proc']]

    data_clas_train = data_clas_train.dropna()
    data_clas_val = data_clas_val.dropna()

    data_clas_train[clas] = data_clas_train[clas].astype(int)
    data_clas_val[clas] = data_clas_val[clas].astype(int)

    data_clas = TextClasDataBunch.from_df(EX_PA, data_clas_train, data_clas_val,
                                          vocab=data_lm.train_ds.vocab, bs=64, text_cols=['text_proc'], label_cols=[clas],)
    return data_clas

In [28]:
def run_for_class(clas, it=5):
    data_clas = setup_data(clas)
    
    drop_mult = 1
    text_classifier_learner(data_clas, drop_mult=drop_mult)
    optim_lr = news_utils.fastai.get_optimal_lr(learn)

    ex = Experiment(db_name + '_' + clas)
    ex.observers.append(MongoObserver.create(db_name=db_name))

    @ex.config
    def my_config():
        exp_id = datetime.datetime.now().strftime("%Y_%_m_%d_%H_%M_%S_%f")
        factor = 3
        wd = 1e-7
        moms = (0.8, 0.7)
        drop_mult = drop_mult
        full_epochs = 10
        lr = optim_lr

    @ex.main
    def run_exp(exp_id, drop_mult, lr, moms, wd, factor):
        encoder_name = 'encoder_' + model_id

        lrs = [lr / (factor ** (4 - x)) for x in range(4)] + [lr]
        
        learn = text_classifier_learner(data_clas, drop_mult=drop_mult)
        learn.load_encoder(encoder_name)

        learn.metrics += [news_utils.fastai.F1Macro(), news_utils.fastai.F1Weighted()]
        learn.callbacks += [
            SaveModelCallback(learn, name=exp_id),
        ]

        for i in range(1, 4):
            epochs = 1
            if i in [1, 2]:
                learn.freeze_to(-i)
            else:
                learn.unfreeze()
        #         learn.callbacks += [EarlyStoppingCallback(learn, patience=5)]
                epochs = full_epochs
        #             learn.fit_one_cycle(epochs, np.array(lrs) * 1 / (i ** 4), wd=wd, moms=moms)
            learn.fit_one_cycle(epochs, np.array(lrs), wd=wd, moms=moms)
    
    for _ in range(it):
        exp.run()

Total time: 04:37
epoch  train_loss  valid_loss  accuracy  F1_macro  F1_weighted
1      0.439330    0.400339    0.845626  0.516530  0.907203     (00:29)
2      0.406791    0.405685    0.845626  0.516530  0.907203     (00:27)
3      0.445035    0.409734    0.849057  0.535492  0.906663     (00:29)
4      0.429370    0.402804    0.845626  0.524935  0.904542     (00:29)
5      0.427848    0.403102    0.845626  0.524935  0.904542     (00:27)
6      0.408138    0.395928    0.842196  0.522642  0.899809     (00:23)
7      0.391929    0.402716    0.843911  0.515449  0.904813     (00:29)
8      0.413590    0.415070    0.847341  0.542091  0.901850     (00:26)
9      0.424379    0.406475    0.849057  0.543396  0.904165     (00:29)
10     0.394992    0.403925    0.849057  0.543396  0.904165     (00:25)



In [None]:
run_for_class('clcontroversial')