In [None]:
import logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.WARNING)
logger = logging.getLogger(__name__)

from transformers import (get_linear_schedule_with_warmup,
                          BertConfig, BertForMaskedLM, BertTokenizer,
                          GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
                          OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
                          RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer,
                          RobertaConfig, RobertaModel, RobertaTokenizer,
                          DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
import json
import sys
import argparse
import os
import random
import numpy as np
import torch
import wandb
from datasets import EventsDataset, MyConcatDataset, TextDataset
from sklearn.model_selection import KFold
from tqdm import tqdm
from models import get_model
from torch.utils.data import DataLoader, SubsetRandomSampler
import copy
import importlib

In [None]:
args = argparse.Namespace()


In [None]:
from msd import define_activation, get_multi_dataset, get_tokenizer, set_seed


device = torch.device("cuda" if torch.cuda.is_available()
                        and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()

args.device = device
if args.n_gpu == 0:
    args.n_gpu = 1

args.eval_batch_size = args.batch_size
args.train_batch_size = args.batch_size

args.per_gpu_train_batch_size = args.batch_size//args.n_gpu
args.per_gpu_eval_batch_size = args.batch_size//args.n_gpu
# Setup logging
logger.warning(f"Device: {device}, n_gpu: {args.n_gpu}")

# Set seed
set_seed(args.seed)

args.start_epoch = 0
args.start_step = 0

if args.cache_dir:
    args.model_cache_dir = os.path.join(args.cache_dir, "models")

logger.warning("Training/evaluation parameters %s", args)

args.code_activation = define_activation(args.code_activation)
args.message_activation = define_activation(args.message_activation)
args.event_activation = define_activation(args.event_activation)


with open(os.path.join(args.cache_dir, "orc", "orchestrator.json"), "r") as f:
    mall = json.load(f)

code_tokenizer, message_tokenizer = None, None
    
if args.source_model == "Code":
    code_tokenizer = get_tokenizer(
        args, args.code_model_type, args.code_tokenizer_name)
    dataset = TextDataset(code_tokenizer, args, mall,
                            mall.keys(), args.code_embedding_type, balance=False)
    code_tokenizer = dataset.tokenizer
    args.return_class = True
    dataset = MyConcatDataset(args, code_dataset=dataset)


elif args.source_model == "Message":
    message_tokenizer = get_tokenizer(
        args, args.message_model_type, args.message_tokenizer_name)
    dataset = TextDataset(message_tokenizer, args, mall,
                            mall.keys(), args.message_embedding_type, balance=False)
    args.return_class = True
    dataset = MyConcatDataset(args, message_dataset=dataset)


elif args.source_model == "Events":
    dataset = EventsDataset(args, mall, mall.keys(),  balance=True)
    args.xshape1 = dataset[0][0].shape[0]
    args.xshape2 = dataset[0][0].shape[1]
    args.return_class = True
    dataset = MyConcatDataset(args, events_dataset=dataset)

elif args.source_model == "Multi":
    dataset, code_tokenizer, message_tokenizer = get_multi_dataset(args, mall)
    args.return_class = True
else:
    raise NotImplementedError



In [None]:

best_acc = 0
splits = KFold(n_splits=args.folds, shuffle=True, random_state=args.seed)

best_accs = []
mall_keys_list = np.array(list(mall.keys()))
for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(len(mall_keys_list)))):
    if args.run_fold != -1 and args.run_fold != fold:
        continue

    logger.warning('Running Fold {}'.format(fold))
    dataset.set_hashes(mall_keys_list[train_idx], is_train=True)
    dataset.set_hashes(mall_keys_list[val_idx], is_train=False)

    with wandb.init(project="MSD4", tags=[args.source_model],  config=args) as run:
        model = get_model(args, message_tokenizer=message_tokenizer, code_tokenizer=code_tokenizer)
        run.define_metric("epoch")

        best_acc = train(args, dataset, model, fold,
                            train_idx, run, eval_idx=val_idx)
        best_accs.append(best_acc)
        test(args, model, dataset, val_idx, fold=fold)
        run.summary[f"best_acc"]  = max(best_accs)


        model_dir = os.path.join(args.output_dir, '{}'.format('checkpoint-best-acc'))
        output_dir = os.path.join(model_dir, f'{args.source_model}_model_{fold}.bin')
        artifact = wandb.Artifact(f'{args.source_model}_model_{fold}.bin', type='model')
        artifact.add_file(output_dir)
        run.log_artifact(artifact)

