In [55]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [1]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import pandas as pd
from bpemb import BPEmb

from fastai.callbacks import *
from fastai.datasets import *
from fastai.imports import nn, torch
from fastai.metrics import *
from fastai.text import *
from fastai.text.data import TextLMDataBunch
from fastai.text.transform import BaseTokenizer, Tokenizer, Vocab
from fastai.train import *
import fastai

import news_utils.fastai
import news_utils.clean.german
torch.cuda.set_device(0)

In [2]:
fastai.__version__

'1.0.42'

In [3]:
bpemb_de = BPEmb(lang="de", vs=25000, dim=300)

In [59]:
len(bpemb_de.words)

25000

In [60]:
itos = dict(enumerate(bpemb_de.words + ['xxpad']))
voc = Vocab(itos)

In [None]:
df_all = pd.read_pickle('/mnt/data/group07/johannes/ompc/unla.pkl')
df_all['text_cat'] = df_all.apply(lambda x: (x['Headline'] if not x['Headline'] is None else '') + ' ' + (x['Body'] if not x['Body'] is None else '') + ' xxp ' + ('xxa' if pd.isna(x['ID_Parent_Post']) else 'xxb') , axis=1)

df_all['text_ids'] = df_all['text_cat'].apply(lambda x: bpemb_de.encode_ids_with_bos_eos(news_utils.clean.german.clean_german(x)))

df_all_train = df_all[df_all['ID_Article'] < 11500]
df_all_val = df_all[df_all['ID_Article'] >= 11500]


In [None]:
data_lm_ft = TextLMDataBunch.from_ids(device='cuda:0', bs=128, path='/mnt/data/group07/johannes/ompc/lmexp',vocab=voc, train_ids=df_all_train['text_ids'], valid_ids=df_all_val['text_ids'])

In [None]:
data_lm_ft.save('whatever')

In [4]:
data_lm_ft = TextLMDataBunch.load(path='/mnt/data/group07/johannes/ompc/lmexp', cache_name='whatever')

In [10]:
fold = '1'
cat = 'ArgumentsUsed'
# cat = 'PersonalStories'
model_id = '2019_ 4_01_00_11_32_066215' 
exp_path = '/mnt/data/group07/johannes/ompc/exp/' + cat + '_' + fold

In [11]:
# data_lm = TextLMDataBunch.load(Path('/mnt/data/group07/johannes/germanlm/exp_1'))

In [12]:
# data_lm.vocab = voc

In [13]:
# shutil.copy('/mnt/data/group07/johannes/ompc/exp/unlab/models/enc1.pth', '/mnt/data/group07/johannes/ompc/exp/' + cat + '_' + fold +'/models/enc1.pth')

In [14]:
train_df = pd.read_pickle(Path('/mnt/data/group07/johannes/ompc/data_ann')/cat/fold/'train.pkl')
test_df = pd.read_pickle(Path('/mnt/data/group07/johannes/ompc/data_ann')/cat/fold/'test.pkl')



# train_df = train_df.dropna(subset=['Body'])
# test_df = test_df.dropna(subset=['Body'])

train_df['text_cat'] = train_df.apply(lambda x: (x['Headline'] if not x['Headline'] is None else '') + ' ' + (x['Body'] if not x['Body'] is None else '') + ' xxp ' + ('xxa' if pd.isna(x['ID_Parent_Post']) else 'xxb') , axis=1)

test_df['text_cat'] = test_df.apply(lambda x: (x['Headline'] if not x['Headline'] is None else '') + ' ' + (x['Body'] if not x['Body'] is None else '') + ' xxp ' + ('xxa' if pd.isna(x['ID_Parent_Post']) else 'xxb') , axis=1)

train_df['text_ids'] = train_df['text_cat'].apply(lambda x: bpemb_de.encode_ids_with_bos_eos(news_utils.clean.german.clean_german(x)))

test_df['text_ids'] = test_df['text_cat'].apply(lambda x: bpemb_de.encode_ids_with_bos_eos(news_utils.clean.german.clean_german(x)))

data = TextClasDataBunch.from_ids(pad_idx=25000, bs=64,path=exp_path, vocab=data_lm_ft.vocab, classes=[0, 1], train_lbls=train_df['Value'], valid_lbls=test_df['Value'], train_ids=train_df['text_ids'], valid_ids=test_df['text_ids'])

In [15]:
# train_df['Value'] = train_df['Value'].apply(lambda x: 1 if x == 0 else 0)
train_df['Value'].value_counts()

In [16]:
# data_lm = TextLMDataBunch.from_ids(bs=64,path='/mnt/data/group07/johannes/ompc/lmexp', vocab=voc, train_ids=train_df['text_ids'], valid_ids=test_df['text_ids'])

In [17]:
learn_lm = language_model_learner(data_lm_ft).load(Path('/mnt/data/group07/johannes/ompc/lmexp/models/' + model_id))

In [18]:
learn_lm.save_encoder('enc5')

In [19]:
# data.save()

In [51]:
shutil.copy('/mnt/data/group07/johannes/ompc/lmexp/models/enc5.pth', '/mnt/data/group07/johannes/ompc/exp/' + cat + '_' + fold +'/models/enc5.pth')

learn = text_classifier_learner(data, drop_mult=0.8)
learn.load_encoder('enc5')

hidden:  0.24
input:  0.32000000000000006
embed:  0.04000000000000001
weight:  0.4


In [45]:
import sklearn

In [48]:
sklearn.utils.class_weight.compute_class_weight('balanced', [0, 1], train_df['Value'])

array([0.698146, 1.761697])

In [52]:
learn.loss_func = torch.nn.CrossEntropyLoss(torch.FloatTensor(sklearn.utils.class_weight.compute_class_weight('balanced', [0, 1], train_df['Value'])).cuda())

In [None]:
lr = 0.001
lrs = [lr / (2.6 ** (4 - x)) for x in range(4)] + [lr]
# learn.metrics += [KappaScore()]

learn.metrics += [KappaScore(), news_utils.fastai.F1Bin(), news_utils.fastai.PrecBin(), news_utils.fastai.RecaBin()]

#         learn.callbacks += [
#             SaveModelCallback(learn, name=exp_id, monitor='kappa_score'),
#             news_utils.fastai.SacredLogger(learn, ex),
#         ]

for i in range(1, 4):
    epochs = 1
    if i in [1, 2]:
        learn.freeze_to(-i)
    else:
        learn.unfreeze()
        epochs = 100
    learn.fit(epochs, np.array(lrs))

epoch,train_loss,valid_loss,accuracy,kappa_score,F1_bin,prec_bin,reca_bin
1,0.690971,0.664428,0.523546,0.183945,0.514124,0.883495,0.362550


epoch,train_loss,valid_loss,accuracy,kappa_score,F1_bin,prec_bin,reca_bin
1,0.655169,0.690048,0.650970,0.011561,0.212500,0.165049,0.298246


epoch,train_loss,valid_loss,accuracy,kappa_score,F1_bin,prec_bin,reca_bin
1,0.662588,0.656431,0.656510,0.071055,0.287356,0.242718,0.352113
2,0.647084,0.657620,0.700831,0.123235,0.280000,0.203883,0.446809
3,0.620392,0.668299,0.695291,0.137077,0.312500,0.242718,0.438596
4,0.644261,0.615450,0.731302,0.198567,0.331034,0.233010,0.571429
5,0.649279,0.592113,0.756233,0.248948,0.352941,0.233010,0.727273
6,0.629625,0.582902,0.728532,0.163601,0.279412,0.184466,0.575758
7,0.635249,0.565803,0.731302,0.187031,0.312057,0.213592,0.578947
8,0.607298,0.568865,0.756233,0.327775,0.476190,0.388350,0.615385
9,0.623233,0.563429,0.728532,0.187423,0.319444,0.223301,0.560976
10,0.604618,0.539045,0.764543,0.352815,0.497041,0.407767,0.636364
11,0.622137,0.537871,0.764543,0.312350,0.437086,0.320388,0.687500
12,0.618599,0.514627,0.772853,0.389624,0.534091,0.456311,0.643836
13,0.600529,0.508278,0.770083,0.421038,0.578680,0.553398,0.606383
14,0.618451,0.500399,0.781163,0.413829,0.553672,0.475728,0.662162
15,0.614008,0.550348,0.736842,0.209483,0.335664,0.233010,0.600000
16,0.608270,0.508290,0.775623,0.406551,0.552486,0.485437,0.641026
17,0.582233,0.522470,0.764543,0.417554,0.581281,0.572816,0.590000
18,0.596711,0.523303,0.753463,0.400724,0.574163,0.582524,0.566038
19,0.594967,0.493597,0.786704,0.442868,0.583784,0.524272,0.658537
20,0.574749,0.466805,0.767313,0.454879,0.621622,0.669903,0.579832
21,0.575699,0.501927,0.742382,0.426818,0.614108,0.718447,0.536232
22,0.569675,0.482338,0.781163,0.464964,0.618357,0.621359,0.615385
23,0.583063,0.485491,0.781163,0.502573,0.660944,0.747573,0.592308
24,0.608845,0.467225,0.828255,0.550002,0.663043,0.592233,0.753086
25,0.579418,0.456940,0.750693,0.482297,0.664179,0.864078,0.539394
26,0.577523,0.472440,0.742382,0.441595,0.629482,0.766990,0.533784
27,0.558341,0.491406,0.728532,0.413086,0.611111,0.747573,0.516779
28,0.572791,0.481881,0.778393,0.489374,0.649123,0.718447,0.592000
29,0.577479,0.497146,0.728532,0.430654,0.628788,0.805825,0.515528
30,0.557037,0.452988,0.803324,0.504744,0.639594,0.611650,0.670213
31,0.517323,0.430668,0.806094,0.548250,0.687500,0.747573,0.636364
32,0.555999,0.495965,0.772853,0.452605,0.613208,0.631068,0.596330
33,0.544878,0.497064,0.761773,0.457029,0.629310,0.708738,0.565891
34,0.544971,0.477364,0.764543,0.467672,0.638298,0.728155,0.568182
35,0.551953,0.463389,0.795014,0.506010,0.650943,0.669903,0.633028
36,0.527961,0.452753,0.803324,0.524674,0.663507,0.679612,0.648148
37,0.538495,0.432466,0.811634,0.551208,0.685185,0.718447,0.654867
38,0.552509,0.470597,0.756233,0.464717,0.642276,0.766990,0.552448
39,0.538910,0.440467,0.764543,0.478882,0.650206,0.766990,0.564286
40,0.511986,0.427434,0.783934,0.502139,0.657895,0.728155,0.600000
41,0.540589,0.430239,0.783934,0.502139,0.657895,0.728155,0.600000
42,0.532264,0.434980,0.811634,0.532668,0.663366,0.650485,0.676768
43,0.513412,0.472791,0.728532,0.413086,0.611111,0.747573,0.516779
44,0.534699,0.468942,0.734072,0.425064,0.619048,0.757282,0.523490
45,0.519940,0.447469,0.742382,0.435776,0.623482,0.747573,0.534722
46,0.505029,0.446711,0.756233,0.459066,0.636364,0.747573,0.553957
47,0.512850,0.394536,0.803324,0.548058,0.689956,0.766990,0.626984
48,0.512651,0.443019,0.770083,0.471673,0.637555,0.708738,0.579365
49,0.532379,0.425611,0.797784,0.516840,0.660465,0.689320,0.633929
50,0.493028,0.429474,0.759003,0.472178,0.647773,0.776699,0.555556
51,0.505340,0.527157,0.722992,0.398012,0.600000,0.728155,0.510204
52,0.505123,0.431972,0.742382,0.452879,0.640927,0.805825,0.532051
53,0.514498,0.401999,0.772853,0.495947,0.661157,0.776699,0.575540
54,0.496089,0.441932,0.786704,0.507172,0.660793,0.728155,0.604839
55,0.514368,0.415878,0.800554,0.535343,0.678571,0.737864,0.628099
56,0.508925,0.409980,0.795014,0.527671,0.675439,0.747573,0.616000
57,0.488529,0.434399,0.770083,0.482978,0.649789,0.747573,0.574627
58,0.486918,0.431922,0.742382,0.447294,0.635294,0.786408,0.532895
59,0.489433,0.431426,0.753463,0.460044,0.639676,0.766990,0.548611
60,0.484264,0.396293,0.767313,0.507071,0.676923,0.854369,0.560510
61,0.484052,0.409174,0.750693,0.466485,0.648438,0.805825,0.542484
62,0.478360,0.463246,0.698061,0.409467,0.627986,0.893204,0.484211
63,0.448899,0.447139,0.698061,0.412201,0.630508,0.902913,0.484375
64,0.471448,0.427165,0.731302,0.451514,0.647273,0.864078,0.517442
65,0.492391,0.455893,0.775623,0.475808,0.636771,0.689320,0.591667
66,0.465611,0.481506,0.695291,0.394184,0.615385,0.854369,0.480874
67,0.473776,0.464043,0.747922,0.436160,0.619247,0.718447,0.544118
68,0.439667,0.429718,0.734072,0.430917,0.625000,0.776699,0.522876
69,0.458757,0.437421,0.772853,0.464876,0.627273,0.669903,0.589744
70,0.459302,0.442028,0.772853,0.470807,0.633929,0.689320,0.586777
71,0.431857,0.438612,0.759003,0.477621,0.653386,0.796117,0.554054
72,0.465551,0.442911,0.783934,0.490979,0.645455,0.689320,0.606838
73,0.424764,0.414101,0.778393,0.502994,0.663866,0.766990,0.585185
74,0.421585,0.409925,0.761773,0.495335,0.669231,0.844660,0.554140
75,0.404495,0.440517,0.767313,0.483653,0.652893,0.766990,0.568345
76,0.469607,0.460856,0.759003,0.463797,0.639004,0.747573,0.557971
77,0.438546,0.472322,0.736842,0.432512,0.624506,0.766990,0.526667
78,0.414933,0.483600,0.767313,0.475345,0.644068,0.737864,0.571429
79,0.389596,0.500581,0.781163,0.474139,0.629108,0.650485,0.609091
80,0.394932,0.490510,0.770083,0.482978,0.649789,0.747573,0.574627
81,0.376517,0.478496,0.770083,0.480197,0.646809,0.737864,0.575758
82,0.384107,0.451310,0.764543,0.458942,0.628821,0.699029,0.571429
83,0.448735,0.497451,0.731302,0.451514,0.647273,0.864078,0.517442
84,0.395205,0.442480,0.753463,0.462843,0.642570,0.776699,0.547945
85,0.409372,0.446293,0.772853,0.493273,0.658333,0.766990,0.576642
86,0.389597,0.463456,0.731302,0.443406,0.639405,0.834951,0.518072
87,0.372136,0.496325,0.742382,0.450101,0.638132,0.796117,0.532468
88,0.408274,0.477731,0.728532,0.439050,0.637037,0.834951,0.514970
89,0.377214,0.489750,0.767313,0.478144,0.647059,0.747573,0.570370
