In [1]:
!pip install -q -U "tensorflow-text==2.8.*"
!pip install -q tf-models-official==2.7.0

import os
import shutil
import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from official.nlp import optimization  # to create AdamW optimizer

tf.get_logger().setLevel('ERROR')

[K     |████████████████████████████████| 4.9 MB 14.2 MB/s 
[K     |████████████████████████████████| 1.8 MB 15.0 MB/s 
[K     |████████████████████████████████| 238 kB 91.3 MB/s 
[K     |████████████████████████████████| 596 kB 89.6 MB/s 
[K     |████████████████████████████████| 1.1 MB 73.7 MB/s 
[K     |████████████████████████████████| 116 kB 92.0 MB/s 
[K     |████████████████████████████████| 43 kB 2.1 MB/s 
[K     |████████████████████████████████| 352 kB 81.6 MB/s 
[K     |████████████████████████████████| 1.3 MB 73.2 MB/s 
[K     |████████████████████████████████| 99 kB 10.8 MB/s 
[?25h  Building wheel for py-cpuinfo (setup.py) ... [?25l[?25hdone
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone


In [2]:
bert_model_name = 'bert_en_uncased_L-12_H-768_A-12' 

map_name_to_handle = {
    'bert_en_uncased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3',
    'bert_en_cased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/3',
    'bert_multi_cased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/3',
    'small_bert/bert_en_uncased_L-2_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-2_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-2_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-2_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-4_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-4_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-4_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-4_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-6_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-6_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-6_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-6_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-8_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-8_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-8_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-8_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-10_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-10_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-10_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-10_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-12_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-12_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-12_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-768_A-12/1',
    'albert_en_base':
        'https://tfhub.dev/tensorflow/albert_en_base/2',
    'electra_small':
        'https://tfhub.dev/google/electra_small/2',
    'electra_base':
        'https://tfhub.dev/google/electra_base/2',
    'experts_pubmed':
        'https://tfhub.dev/google/experts/bert/pubmed/2',
    'experts_wiki_books':
        'https://tfhub.dev/google/experts/bert/wiki_books/2',
    'talking-heads_base':
        'https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_base/1',
}

map_model_to_preprocess = {
    'bert_en_uncased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'bert_en_cased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_cased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'bert_multi_cased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3',
    'albert_en_base':
        'https://tfhub.dev/tensorflow/albert_en_preprocess/3',
    'electra_small':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'electra_base':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'experts_pubmed':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'experts_wiki_books':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'talking-heads_base':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
}

tfhub_handle_encoder = map_name_to_handle[bert_model_name]
tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]

print(f'BERT model selected           : {tfhub_handle_encoder}')
print(f'Preprocess model auto-selected: {tfhub_handle_preprocess}')

BERT model selected           : https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3
Preprocess model auto-selected: https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3


In [3]:
groups = {'general': ['MentalHealthSupport','mentalhealth','mental','personalitydisorders',
                      'mentalillness','MentalHealthPH'],
          
          'control': ['askscience','askscience2','LifeProTips','LifeProTips2','AskReddit','AskReddit2',
                      'answers','answers2', 'AskScienceFiction','AskScienceFiction2','TrueAskReddit',
                      'TrueAskReddit2'],

          'adhd': ['ADHD','ADHD2'],

          'autism': ['aspergaers','autism2','AutisticQueers','AutismInWomen','Aspergers_Elders',
                     'asperger','AutisticPride','autism','AutismTranslated','aspergers_dating',
                     'aspergirls','AutisticAdults'],
          
          'anxiety': ['anxiety'],

          'ocd': ['OCD'],

          'ptsd': ['ptsd','CPTSD'],

          'phobia': ['Phobia','emetophobia','Agoraphobia'],

          'socialanxiety':['socialanxiety','socialanxiety2'],
          
          'depression': ['depression1','depression2','depression3'],

          'sadness': ['sad11','sad22','sad33'],
          
          'bipolar': ['bipolar','BipolarReddit'],
          
          'schizophrenia': ['schizophrenia','paranoidschizophrenia','schizoaffective','Psychosis'],

          'cluster_a': ['Schizoid','Schizotypal','ParanoidPersonality',
                        'Paranoid','ParanoiaCheck','Paranoia'],
          'cluster_b': ['BorderlinePDisorder','BPD','Borderline','hpd','NPD','narcissism',
                        'sociopath', 'psychopath','Psychopathy','aspd'],
          'cluster_c': ['OCPD','AvPD','Avoidant', 'DPD'],

          'selfharm': ['selfharm','StopSelfHarm','AdultSelfHarm',
                       'SuicideWatch11','SuicideWatch22','SuicideWatch33'],
          
          'addiction': ['addiction','alcoholism'],

          'eating': ['ARFID', 'bulimia','eating_disorders','EDAnonymous','EatingDisorders'],

          'dpdr': ['dpdr'],
          'dysmorphic': ['DysmorphicDisorder', 'BodyAcceptance'],
          'tourettes': ['Tourettes'],
          'anger': ['Anger'],
          }


In [4]:
# csv files:
train_path = '/content/drive/MyDrive/dataset/train/'
test_path = '/content/drive/MyDrive/dataset/test/'

# txt files
txt_500 = '/content/drive/MyDrive/txt_500/'


In [5]:
AUTOTUNE = tf.data.AUTOTUNE
batch_size = 32
seed = 42

raw_train_ds = tf.keras.utils.text_dataset_from_directory(
    txt_500,
    label_mode='categorical',
    batch_size=batch_size,
    validation_split=0.15,
    subset='training',
    seed=seed)

class_names = raw_train_ds.class_names
train_ds = raw_train_ds.cache('cache').prefetch(buffer_size=AUTOTUNE)

val_ds = tf.keras.utils.text_dataset_from_directory(
    txt_500,
    label_mode='categorical',
    batch_size=batch_size,
    validation_split=0.15,
    subset='validation',
    seed=seed)

val_ds = val_ds.cache('cache2').prefetch(buffer_size=AUTOTUNE)

# test_ds = tf.keras.utils.text_dataset_from_directory(
#     txt_test,
#     batch_size=batch_size)

# test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

Found 89192 files belonging to 23 classes.
Using 75814 files for training.
Found 89192 files belonging to 23 classes.
Using 13378 files for validation.


In [None]:
# bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)
# bert_model = hub.KerasLayer(tfhub_handle_encoder)

In [6]:
def build_classifier_model():
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
    preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
    encoder_inputs = preprocessing_layer(text_input)
    encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
    outputs = encoder(encoder_inputs)
    net = outputs['pooled_output']
    # net = tf.keras.layers.Dropout(0.1)(net)
    # net = tf.keras.layers.Dense(200)(net)
    # net = tf.keras.layers.Dropout(0.1)(net)
    net = tf.keras.layers.Dense(23, activation='softmax', name='classifier')(net)
    return tf.keras.Model(text_input, net)

In [7]:
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False,
                                               label_smoothing=0.0,
                                               name='categorical_crossentropy')
metrics = tf.metrics.CategoricalAccuracy()

In [8]:
epochs = 2
steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()
num_train_steps = steps_per_epoch * epochs
num_warmup_steps = int(0.1*num_train_steps)

init_lr = 3e-5
optimizer = optimization.create_optimizer(init_lr=init_lr,
                                          num_train_steps=num_train_steps,
                                          num_warmup_steps=num_warmup_steps,
                                          optimizer_type='adamw')

In [9]:
classifier_model = build_classifier_model()

classifier_model.compile(optimizer=optimizer,
                         loss=loss,
                         metrics=metrics)

In [10]:
checkpoint_path='/content/drive/MyDrive/model/running_bert6'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=False,
                                                 verbose=1)

In [11]:
classifier_model.load_weights('/content/drive/MyDrive/model/running_bert5')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f9bbd289fd0>

In [None]:
print(f'Training model with {tfhub_handle_encoder}')
with tf.device('/device:GPU:0'):
    history = classifier_model.fit(x=train_ds,
                                   validation_data=val_ds,
                                   epochs=epochs,
                                   callbacks=[cp_callback]
                                   )

Training model with https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3
Epoch 1/2

In [None]:
classifier_model.predict(['Examle text from Andy Salk who went shopping with his dog on a rainy day.'])

array([[0.03456249, 0.11523219, 0.0136016 , 0.0953878 , 0.12140025,
        0.04896197, 0.01363547, 0.10414491, 0.00437579, 0.00562976,
        0.10758086, 0.01429516, 0.01259161, 0.01964488, 0.03791166,
        0.04362164, 0.01944822, 0.05352129, 0.01667299, 0.03547608,
        0.03159145, 0.04399469, 0.00671724]], dtype=float32)

In [None]:
classifier_model.save('/content/drive/MyDrive/model/bert_0.2644')
classifier_model.save_weights('/content/drive/MyDrive/model/bert_0.2644_weights_3')



In [None]:
class_names

['addiction',
 'adhd',
 'anger',
 'anxiety',
 'autism',
 'bipolar',
 'cluster_a',
 'cluster_b',
 'cluster_c',
 'control',
 'depression',
 'dpdr',
 'dysmorphic',
 'eating',
 'general',
 'ocd',
 'phobia',
 'ptsd',
 'sadness',
 'schizophrenia',
 'selfharm',
 'socialanxiety',
 'tourettes']