# Imports

In [None]:



import pandas as pd
import numpy as np
import jsonlines
import seaborn as sns
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch_optimizer as optim


from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from importlib import reload
pd.set_option('display.max_rows', 500)
pd.set_option('display.float_format', '{:0.3f}'.format)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
pd.options.display.width = 0
import warnings
import torchvision
warnings.filterwarnings('ignore')

from facebook_hateful_memes_detector.utils.globals import set_global, get_global
set_global("cache_dir", "/home/ahemf/cache/cache")
set_global("dataloader_workers", 4)
set_global("use_autocast", True)
set_global("models_dir", "/home/ahemf/cache/")

from facebook_hateful_memes_detector.utils import read_json_lines_into_df, in_notebook, set_device, print_code, my_collate, load_stored_params
get_global("cache_dir")
from facebook_hateful_memes_detector.models import Fasttext1DCNNModel, MultiImageMultiTextAttentionEarlyFusionModel, VilBertVisualBertModel
from facebook_hateful_memes_detector.preprocessing import TextImageDataset, get_datasets, get_image2torchvision_transforms, TextAugment
from facebook_hateful_memes_detector.preprocessing import DefinedRotation, QuadrantCut, ImageAugment, DefinedAffine, DefinedColorJitter, DefinedRandomPerspective
from facebook_hateful_memes_detector.preprocessing import DefinedAffine, HalfSwap, get_image_transforms, get_transforms_for_bbox_methods
from facebook_hateful_memes_detector.training import *
import facebook_hateful_memes_detector
reload(facebook_hateful_memes_detector)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_device(device)


In [None]:


def get_preprocess_text():
    char_level = {
        "keyboard": 0.1,
        "char_substitute": 0.4,
        "char_insert": 0.2,
        "char_swap": 0.2,
        "ocr": 0.0,
        "char_delete": 0.1
    }
    char_level = TextAugment([0.1, 0.4, 0.5], char_level)
    word_level = {
        "fasttext": 0.0,
        "glove_twitter": 0.0,
        "glove_wiki": 0.0,
        "word2vec": 0.0,
        "split": 0.2,
        "stopword_insert": 0.0,
        "word_join": 0.2,
        "word_cutout": 0.8,
        "gibberish_insert": 0.0
    }
    word_level = TextAugment([0.1, 0.4, 0.5], word_level)
    sentence_level = {
        "text_rotate": 0.0,
        "sentence_shuffle": 0.0,
        "one_third_cut": 0.3,
        "half_cut": 0.0,
        "part_select": 0.75
    }
    sentence_level = TextAugment([0.75, 0.25], sentence_level)
    gibberish = {
        "gibberish_insert": 0.25,
        "punctuation_insert": 0.75,
    }
    gibberish = TextAugment([0.75, 0.25], gibberish)

    def process(text):
        text = sentence_level(text)
        text = word_level(text)
        text = char_level(text)
        text = gibberish(text)
        return text

    return process


preprocess_text = get_preprocess_text()
transforms_for_bbox_methods = get_transforms_for_bbox_methods()


data = get_datasets(data_dir="../data/", train_text_transform=preprocess_text, train_image_transform=transforms_for_bbox_methods, 
                    test_text_transform=None, test_image_transform=None, 
                    cache_images = True, use_images = True, dev=False, test_dev=True,
                    keep_original_text=False, keep_original_image=False, 
                    keep_processed_image=True, keep_torchvision_image=False,)

# ImageAugment([0.2, 0.5, 0.3])


In [None]:
sgd = torch.optim.SGD
sgd_params = dict(lr=2e-2,
                  momentum=0.9,
                  dampening=0,
                  weight_decay=0,
                  nesterov=False)

rangerQH = optim.RangerQH
rangerQHparams = dict(
    lr=1e-3,
    betas=(0.9, 0.999),
    nus=(.7, 1.0),
    weight_decay=0.0,
    k=6,
    alpha=.5,
    decouple_weight_decay=True,
    eps=1e-8,
)

adam = torch.optim.Adam
adam_params = params = dict(lr=1e-3, weight_decay=1e-7)

adamw = torch.optim.AdamW
adamw_params = dict(lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-2)

novograd = optim.NovoGrad
novograd_params = dict(
    lr=1e-3,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0,
    grad_averaging=False,
    amsgrad=False,
)

qhadam = optim.QHAdam
qhadam_params = dict(
    lr=1e-3,
    betas=(0.9, 0.999),
    nus=(1.0, 1.0),
    weight_decay=0,
    decouple_weight_decay=False,
    eps=1e-8,
)

radam = optim.RAdam
radam_params = dict(
    lr=1e-3,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0,
)

yogi = optim.Yogi
yogi_params = dict(lr=1e-2,
                   betas=(0.9, 0.999),
                   eps=1e-3,
                   initial_accumulator=1e-6,
                   weight_decay=0)

In [None]:
batch_size=96
epochs = 10
adamw = torch.optim.AdamW
adamw_params = dict(lr=1e-4, betas=(0.9, 0.98), eps=1e-08, weight_decay=1e-2)
optimizer_class = adamw
optimizer_params = adamw_params


scheduler_init_fn = get_cosine_schedule_with_warmup()
reg_sched = get_regularizer_scheduler()




# Finetune Configs

In [None]:
lxmert_strategy = {
    "lxmert": {
        "model": {
            "bert": {
                "encoder": {
                    "x_layers": {
                        "lr": optimizer_params["lr"],
                        "finetune": True
                    },
                    "lr": optimizer_params["lr"],
                    "finetune": False
                },
                "pooler": {
                    "lr": optimizer_params["lr"],
                    "finetune": True
                },
            },
            "finetune": False
        }
    }
}

lxmert_strategy = {
    "finetune": True
}




## LXMERT

In [None]:
model_params = dict(
    model_name={"lxmert": dict(dropout=0.2, gaussian_noise=0.01)},
    num_classes=2,
    gaussian_noise=0.01,
    dropout=0.2,
    word_masking_proba=0.2,
    featurizer="pass",
    final_layer_builder=fb_1d_loss_builder,
    internal_dims=768,
    classifier_dims=768,
    n_tokens_in=96,
    n_tokens_out=96,
    n_layers=0,
    attention_drop_proba=0.0,
    loss="focal",
    dice_loss_coef=0.0,
    auc_loss_coef=0.0,
    bbox_swaps=5,
    bbox_copies=5,
    bbox_gaussian_noise=0.05,
    finetune=False)

from facebook_hateful_memes_detector.models.MultiModal.VilBertVisualBert import VilBertVisualBertModel




model_fn = model_builder(VilBertVisualBertModel,
                         model_params,
                         per_param_opts_fn=lxmert_strategy,
                         optimiser_class=optimizer_class,
                         optimiser_params=optimizer_params)



In [None]:
model, optimizer = model_fn()
# load_stored_params(model, "lxmert-smclr.pth")
model = model.to(get_device())


In [None]:
lr_strategy = {
    "finetune": False,
    "model_heads": {
        "finetune": True,
    },
    "final_layer": {
        "finetune": True,
    },
}

_ = group_wise_finetune(model, lr_strategy)
params_conf, _ = group_wise_lr(model, lr_strategy)
optimizer = optimizer_class(params_conf, **optimizer_params)

batch_size=128
epochs = 6

kfold = False
results, prfs = train_validate_ntimes(
    (model, optimizer),
    data,
    batch_size,
    epochs,
    kfold=kfold,
    scheduler_init_fn=scheduler_init_fn,
    model_call_back=reg_sched,
    validation_epochs=[2, 5, 7, 9, 11, 14, 17, 19, 24, 28],
    show_model_stats=False,
    sampling_policy="without_replacement",
    accumulation_steps=2,
)
r1, p1 = results, prfs
results
prfs


In [None]:
lr_strategy = {
    "finetune": True
}
adamw = torch.optim.AdamW
adamw_params = dict(lr=5e-5, betas=(0.9, 0.98), eps=1e-08, weight_decay=1e-2)
optimizer_class = adamw
optimizer_params = adamw_params

_ = group_wise_finetune(model, lr_strategy)
params_conf, _ = group_wise_lr(model, lr_strategy)
optimizer = optimizer_class(params_conf, **optimizer_params)

batch_size=64
epochs = 16

kfold = False
results, prfs = train_validate_ntimes(
    (model, optimizer),
    data,
    batch_size,
    epochs,
    kfold=kfold,
    scheduler_init_fn=scheduler_init_fn,
    model_call_back=reg_sched,
    validation_epochs=[5, 7, 9, 11, 15, 17, 24, 28, 31],
    show_model_stats=False,
    sampling_policy="without_replacement",
    accumulation_steps=4,
)
r1, p1 = results, prfs
results
prfs


In [None]:
preds, probas = [], []
dataset = convert_dataframe_to_dataset(data["dev"], data["metadata"], True)
for i in range(5):
    proba_list, all_probas_list, predictions_list, labels_list = generate_predictions(model, 128, dataset)
    probas.append(all_probas_list)
    preds.append(predictions_list)
    
from collections import Counter
preds_voted = [Counter(p).most_common()[0][0] for p in zip(*preds)]
probas_mean = torch.tensor(probas).mean(0)
pred_probas = probas_mean.max(dim=1).indices
from sklearn.metrics import roc_auc_score, average_precision_score, classification_report
from sklearn.metrics import precision_recall_fscore_support, accuracy_score


print(accuracy_score(labels_list, preds_voted))
print(accuracy_score(labels_list, pred_probas))
print(roc_auc_score(labels_list, probas_mean[:, 1].tolist(), multi_class="ovo", average="macro"))


# Predict

In [None]:
adam = torch.optim.Adam
adam_params = params = dict(lr=1e-4, weight_decay=1e-2)
optimizer = adam
optimizer_params = adam_params

model_fn = model_builder(VilBertVisualBertModel,
                         dict(model_name={
                             "lxmert":
                             dict(finetune=True,
                                  dropout=0.1,
                                  gaussian_noise=0.2),
                            },
                              num_classes=2,
                              gaussian_noise=0.2,
                              dropout=0.25,
                              featurizer="pass",
                              final_layer_builder=fb_1d_loss_builder,
                              internal_dims=768,
                              classifier_dims=768,
                              n_tokens_in=96,
                              n_tokens_out=96,
                              n_layers=2,
                              loss="focal",
                              dice_loss_coef=0.0,
                              auc_loss_coef=0.0,
                              word_masking_proba=0.2),
                         per_param_opts_fn=combo_strategy,
                         optimiser_class=optimizer,
                         optimiser_params=optimizer_params)

# model, opt = model_fn()
# model

##
## MMBT Region, Per module regularization, word_masking_proba, reg_scheduling

## Next accumulation_steps

In [None]:

torch.backends.cudnn.enabled = False
batch_size = 4
epochs = 7

submission, text_model = train_and_predict(
    model_fn,
    data,
    batch_size,
    epochs,
    scheduler_init_fn=scheduler_init_fn,
    accumulation_steps=16,
    model_call_back=reg_sched,
    sampling_policy="without_replacement") # "without_replacement"

submission.to_csv("submission.csv", index=False)
submission.sample(3)


In [None]:
data["test"] = data["dev"]
sf, _ = predict(text_model, data, batch_size)

print(sf.head())

from sklearn.metrics import roc_auc_score, average_precision_score, classification_report
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

labels_list = data["test"].label
proba_list = sf.proba
predictions_list = sf.label

auc = roc_auc_score(labels_list, proba_list)
# p_micro, r_micro, f1_micro, _ = precision_recall_fscore_support(labels_list, predictions_list, average="micro")
prfs = precision_recall_fscore_support(labels_list, predictions_list, average=None, labels=[0, 1])
map = average_precision_score(labels_list, proba_list)
acc = accuracy_score(labels_list, predictions_list)
validation_scores = [map, acc, auc]
print("scores = ", dict(zip(["map", "acc", "auc"], ["%.4f" % v for v in validation_scores])))


In [None]:
submission.sample(10)
submission.label.value_counts()

In [None]:


batch_size = 32
epochs = 7

submission, text_model = train_and_predict(
    model_fn,
    data,
    batch_size,
    epochs,
    scheduler_init_fn=scheduler_init_fn,
    accumulation_steps=1,
    model_call_back=reg_sched,
    sampling_policy=None) # "without_replacement"

submission.to_csv("submission2.csv", index=False)
submission.sample(3)
